Commit 9e1603d8 authored by kurumuz's avatar kurumuz

fix knn search

parent 40839379
......@@ -700,7 +700,7 @@ class EmbedderModel(nn.Module):
found = []
for tag, count in self.tag_count_sorted:
if len(tag) > len(text) and tag.startswith(text):
found.append([tag, count])
found.append([tag, count, 0])
results = []
embedding = self([text])
......@@ -712,7 +712,9 @@ class EmbedderModel(nn.Module):
tag = self.index[id]
count = self.tag_count[tag]
prob = D[i]
results.append([tag, count])
results.append([tag, count, prob])
print(results)
#sort results by count and prob after
results = sorted(results, key=lambda x: x[1], reverse=True)
......@@ -721,6 +723,7 @@ class EmbedderModel(nn.Module):
if result[0] in results:
results.remove(result)
results = results[:-len(found)]
results = found + results
if len(found) > 0:
results = results[:-len(found)]
results = found + results
return results
......@@ -76,6 +76,7 @@ class Masker(TypedDict):
class Tags(TypedDict):
tag: str
count: int
confidence: float
class GenerationRequest(BaseModel):
prompt: str
......@@ -271,7 +272,7 @@ async def predict_tags(request: TextRequest, authorized: bool = Depends(verify_t
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TagOutput(tags=[Tags(tag=tag, count=count) for tag, count in tags])
return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
except Exception as e:
traceback.print_exc()
......
export MODEL="embedder"
export DEV="True"
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment