Commit a6b7f0ac authored by kurumuz's avatar kurumuz

tagsearcher

parent 2db2cafd
...@@ -31,6 +31,7 @@ RUN pip3 install min-dalle ...@@ -31,6 +31,7 @@ RUN pip3 install min-dalle
RUN pip3 install https://www.dropbox.com/s/8ozhhbo1g7y5dsz/basedformer-f7c8c4fe12f8a0acf6588d8d09a8b9b0481895e3.zip?dl=1 RUN pip3 install https://www.dropbox.com/s/8ozhhbo1g7y5dsz/basedformer-f7c8c4fe12f8a0acf6588d8d09a8b9b0481895e3.zip?dl=1
#built DS #built DS
RUN pip3 install https://www.dropbox.com/s/euzpgpfrs9isf1z/deepspeed-0.7.3%2B55b7b9e0-cp38-cp38-linux_x86_64.whl?dl=1 RUN pip3 install https://www.dropbox.com/s/euzpgpfrs9isf1z/deepspeed-0.7.3%2B55b7b9e0-cp38-cp38-linux_x86_64.whl?dl=1
RUN pip3 install faiss-cpu sentence_transformers
#Open ports #Open ports
EXPOSE 8080 EXPOSE 8080
......
...@@ -234,6 +234,28 @@ class StableDiffusionModel(nn.Module): ...@@ -234,6 +234,28 @@ class StableDiffusionModel(nn.Module):
if config.prior_path: if config.prior_path:
self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device) self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device)
def fuse_model(self):
for param in self.model.model.parameters():
param.requires_grad = False
c = self.model.get_learned_conditioning(["what the hell is wrong with you!"]).float()
uc = self.model.get_learned_conditioning([""]).float()
sigmas = self.k_model.get_sigmas(30)
start_code = torch.randn([1, 4, 64, 64], device="cuda").float()
x_0 = start_code * sigmas[0]
test_sigma = sigmas[1] * x_0.new_ones([x_0.shape[0]])
with torch.autocast("cuda", torch.float16):
x_two = torch.cat([x_0] * 2)
cnd = torch.cat([uc, c])
sigma_two = torch.cat([test_sigma] * 2)
inputs = {'apply_model': (x_two, sigma_two, cnd)}
traced_model = torch.jit.trace_module(self.model, inputs)
#traced_model = traced_model.half()
self.k_model = K.external.CompVisDenoiser(traced_model)
self.k_model = K.external.StableInterface(self.k_model)
def from_folder(self, folder): def from_folder(self, folder):
folder = Path(folder) folder = Path(folder)
model_config = OmegaConf.load(folder / "config.yaml") model_config = OmegaConf.load(folder / "config.yaml")
...@@ -649,3 +671,56 @@ class BasedformerModel(nn.Module): ...@@ -649,3 +671,56 @@ class BasedformerModel(nn.Module):
prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1) prompt = torch.cat([prompt, torch.tensor([[49527]], dtype=torch.long).cuda()], dim=1)
is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True) is_safe, corrected = generate(self.model.module, prompt, self.tokenizer, tokens_to_generate=150, ds=True)
return is_safe, corrected return is_safe, corrected
class EmbedderModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
from sentence_transformers import SentenceTransformer
import faiss
import pickle
import requests
self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2').cuda()
self.index = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/index.pkl", stream='True').raw)
self.tag_count = pickle.load(requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/all_tags.pkl", stream='True').raw)
r = requests.get("https://f004.backblazeb2.com/file/naipublicbucketxyz/knn.index", stream='True')
with open("knn.index", "wb") as f:
f.write(r.content)
self.knn = faiss.read_index("knn.index")
self.tag_count_sorted = sorted(self.tag_count.items(), key=lambda x: x[1], reverse=True)
def __call__(self, sentences):
with torch.no_grad():
sentence_embeddings = self.model.encode(sentences)
return sentence_embeddings
def get_top_k(self, request):
text = request.prompt
#check if text is a substring in tag_count.keys()
found = []
for tag, count in self.tag_count_sorted:
if tag.startswith(text):
found.append([tag, count])
results = []
embedding = self([text])
#print(embedding.dtype)
k = 10
D, I = self.knn.search(embedding, k)
D, I = D.squeeze(), I.squeeze()
for i, id in enumerate(I):
tag = self.index[id]
count = self.tag_count[tag]
prob = D[i]
results.append([tag, count])
#sort results by count and prob after
results = sorted(results, key=lambda x: x[1], reverse=True)
found = found[:5]
for result in found:
if result in results:
results.remove(result)
results = results[:-len(found)]
results = found + results
return results
\ No newline at end of file
...@@ -179,6 +179,9 @@ def sanitize_dalle_mini(request): ...@@ -179,6 +179,9 @@ def sanitize_dalle_mini(request):
def sanitize_basedformer(request): def sanitize_basedformer(request):
return True, request return True, request
def sanitize_embedder(request):
return True, request
def sanitize_input(config, request): def sanitize_input(config, request):
""" """
Sanitize the input data and set defaults Sanitize the input data and set defaults
...@@ -200,3 +203,6 @@ def sanitize_input(config, request): ...@@ -200,3 +203,6 @@ def sanitize_input(config, request):
elif config.model_name == 'basedformer': elif config.model_name == 'basedformer':
return sanitize_basedformer(request) return sanitize_basedformer(request)
elif config.model_name == "embedder":
return sanitize_embedder(request)
\ No newline at end of file
...@@ -103,6 +103,9 @@ class GenerationRequest(BaseModel): ...@@ -103,6 +103,9 @@ class GenerationRequest(BaseModel):
class TextRequest(BaseModel): class TextRequest(BaseModel):
prompt: str prompt: str
class TagOutput(BaseModel):
tags: List[List[str, int]]
class TextOutput(BaseModel): class TextOutput(BaseModel):
is_safe: str is_safe: str
corrected_text: str corrected_text: str
...@@ -250,5 +253,33 @@ def generate_text(request: TextRequest, authorized: bool = Depends(verify_token) ...@@ -250,5 +253,33 @@ def generate_text(request: TextRequest, authorized: bool = Depends(verify_token)
os.kill(mainpid, signal.SIGTERM) os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e)) return ErrorOutput(error=str(e))
@app.post('/predict-tags', response_model=Union[TagOutput, ErrorOutput])
async def predict_tags(request: TextRequest, authorized: bool = Depends(verify_token)):
t = time.perf_counter()
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
tags = model.get_top_k(request)
process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds")
return TagOutput(tags=tags)
except Exception as e:
traceback.print_exc()
capture_exception(e)
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
logger.error("GPU error, committing seppuku.")
os.kill(mainpid, signal.SIGTERM)
return ErrorOutput(error=str(e))
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=80, log_level="info") uvicorn.run("main:app", host="0.0.0.0", port=80, log_level="info")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment