Commit 9e1603d8 authored by kurumuz's avatar kurumuz

fix knn search

parent 40839379
...@@ -700,7 +700,7 @@ class EmbedderModel(nn.Module): ...@@ -700,7 +700,7 @@ class EmbedderModel(nn.Module):
found = [] found = []
for tag, count in self.tag_count_sorted: for tag, count in self.tag_count_sorted:
if len(tag) > len(text) and tag.startswith(text): if len(tag) > len(text) and tag.startswith(text):
found.append([tag, count]) found.append([tag, count, 0])
results = [] results = []
embedding = self([text]) embedding = self([text])
...@@ -712,7 +712,9 @@ class EmbedderModel(nn.Module): ...@@ -712,7 +712,9 @@ class EmbedderModel(nn.Module):
tag = self.index[id] tag = self.index[id]
count = self.tag_count[tag] count = self.tag_count[tag]
prob = D[i] prob = D[i]
results.append([tag, count]) results.append([tag, count, prob])
print(results)
#sort results by count and prob after #sort results by count and prob after
results = sorted(results, key=lambda x: x[1], reverse=True) results = sorted(results, key=lambda x: x[1], reverse=True)
...@@ -721,6 +723,7 @@ class EmbedderModel(nn.Module): ...@@ -721,6 +723,7 @@ class EmbedderModel(nn.Module):
if result[0] in results: if result[0] in results:
results.remove(result) results.remove(result)
results = results[:-len(found)] if len(found) > 0:
results = found + results results = results[:-len(found)]
results = found + results
return results return results
...@@ -76,6 +76,7 @@ class Masker(TypedDict): ...@@ -76,6 +76,7 @@ class Masker(TypedDict):
class Tags(TypedDict): class Tags(TypedDict):
tag: str tag: str
count: int count: int
confidence: float
class GenerationRequest(BaseModel): class GenerationRequest(BaseModel):
prompt: str prompt: str
...@@ -271,7 +272,7 @@ async def predict_tags(request: TextRequest, authorized: bool = Depends(verify_t ...@@ -271,7 +272,7 @@ async def predict_tags(request: TextRequest, authorized: bool = Depends(verify_t
process_time = time.perf_counter() - t process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds") logger.info(f"Request took {process_time:0.3f} seconds")
return TagOutput(tags=[Tags(tag=tag, count=count) for tag, count in tags]) return TagOutput(tags=[Tags(tag=tag, count=count, confidence=confidence) for tag, count, confidence in tags])
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
......
export MODEL="embedder"
export DEV="True"
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment