Commit 7ae5e142 authored by kurumuz's avatar kurumuz

multi-knn

parent fb3c04d0
...@@ -709,14 +709,18 @@ class EmbedderModel(nn.Module): ...@@ -709,14 +709,18 @@ class EmbedderModel(nn.Module):
import requests import requests
knn_folder = config.knn_folder knn_folder = config.knn_folder
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda() self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
self.index = pickle.load(requests.get(f"{knn_folder}/index.pkl", stream='True').raw) self.indexes = {}
self.tag_count = pickle.load(requests.get(f"{knn_folder}/all_tags.pkl", stream='True').raw) for folder in knn_folder.split(","):
r = requests.get(f"{knn_folder}/knn.index", stream='True') name, url = folder.split(":")
with open("knn.index", "wb") as f: index = pickle.load(requests.get(f"{url}/index.pkl", stream='True').raw)
f.write(r.content) tag_count = pickle.load(requests.get(f"{url}/all_tags.pkl", stream='True').raw)
tag_count_sorted = sorted(tag_count.items(), key=lambda x: x[1], reverse=True)
self.knn = faiss.read_index("knn.index") r = requests.get(f"{url}/knn.index", stream='True')
self.tag_count_sorted = sorted(self.tag_count.items(), key=lambda x: x[1], reverse=True) with open("knn.index", "wb") as f:
f.write(r.content)
knn = faiss.read_index("knn.index")
self.indexes[name] = [index, tag_count, tag_count_sorted, knn]
def __call__(self, sentences): def __call__(self, sentences):
with torch.no_grad(): with torch.no_grad():
...@@ -725,9 +729,11 @@ class EmbedderModel(nn.Module): ...@@ -725,9 +729,11 @@ class EmbedderModel(nn.Module):
def get_top_k(self, request): def get_top_k(self, request):
text = request.prompt text = request.prompt
model = request.model
index, tag_count, tag_count_sorted, knn = self.indexes[model]
#check if text is a substring in tag_count.keys() #check if text is a substring in tag_count.keys()
found = [] found = []
for tag, count in self.tag_count_sorted: for tag, count in tag_count_sorted:
if len(tag) >= len(text) and tag.startswith(text): if len(tag) >= len(text) and tag.startswith(text):
found.append([tag, count, 0]) found.append([tag, count, 0])
...@@ -735,11 +741,11 @@ class EmbedderModel(nn.Module): ...@@ -735,11 +741,11 @@ class EmbedderModel(nn.Module):
embedding = self([text]) embedding = self([text])
#print(embedding.dtype) #print(embedding.dtype)
k = 15 k = 15
D, I = self.knn.search(embedding, k) D, I = knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze() D, I = D.squeeze(), I.squeeze()
for i, id in enumerate(I): for i, id in enumerate(I):
tag = self.index[id] tag = index[id]
count = self.tag_count[tag] count = tag_count[tag]
prob = D[i] prob = D[i]
results.append([tag, count, prob]) results.append([tag, count, prob])
...@@ -753,8 +759,6 @@ class EmbedderModel(nn.Module): ...@@ -753,8 +759,6 @@ class EmbedderModel(nn.Module):
#filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4 #filter results for >0.5 confidence unless it has the search text in it and confidence is >0.4
results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])] results = [x for x in results if x[2] > 0.5 or (x[2] > 0.4 and text in x[0])]
found = sorted(found, key=lambda x: x[1], reverse=True) found = sorted(found, key=lambda x: x[1], reverse=True)
print(found)
print(results)
if len(found) > 0: if len(found) > 0:
#results = results[:-len(found)] #results = results[:-len(found)]
results = found + results results = found + results
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment