Commit 84cfe3b5 authored by novelailab's avatar novelailab

prompting should work

parent c6954d2d
...@@ -34,6 +34,7 @@ def fix_batch(tensor, bs): ...@@ -34,6 +34,7 @@ def fix_batch(tensor, bs):
return torch.stack([tensor.squeeze(0)]*bs, dim=0) return torch.stack([tensor.squeeze(0)]*bs, dim=0)
# mix conditioning vectors for prompts # mix conditioning vectors for prompts
# @aero
def prompt_mixing(model, prompt_body, batch_size): def prompt_mixing(model, prompt_body, batch_size):
if "|" in prompt_body: if "|" in prompt_body:
prompt_parts = prompt_body.split("|") prompt_parts = prompt_body.split("|")
...@@ -183,7 +184,7 @@ class StableDiffusionModel(nn.Module): ...@@ -183,7 +184,7 @@ class StableDiffusionModel(nn.Module):
], device=self.device) ], device=self.device)
prompt = [request.prompt] * request.n_samples prompt = [request.prompt] * request.n_samples
prompt_condition = self.model.get_learned_conditioning(prompt) prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
uc = None uc = None
if request.scale != 1.0: if request.scale != 1.0:
...@@ -227,13 +228,13 @@ class StableDiffusionModel(nn.Module): ...@@ -227,13 +228,13 @@ class StableDiffusionModel(nn.Module):
if request.scale != 1.0: if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""]) uc = self.model.get_learned_conditioning(request.n_samples * [""])
c = prompt_mixing(self.model, prompt[0], request.n_samples)#(model.get_learned_conditioning(prompts) + model.get_learned_conditioning(["taken at night"])) / 2 prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
# encode (scaled latent) # encode (scaled latent)
start_code_terped=None start_code_terped=None
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
# decode it # decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=request.scale, samples = sampler.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,) unconditional_conditioning=uc,)
x_samples_ddim = self.model.decode_first_stage(samples) x_samples_ddim = self.model.decode_first_stage(samples)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment