Commit 56787cb0 authored by kurumuz's avatar kurumuz

always use ddim for stage 2

parent c7f34c27
......@@ -217,7 +217,7 @@ class StableDiffusionModel(nn.Module):
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
sampler.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
t_enc = int(request.strength * request.steps)
print("init latent shape:")
......@@ -233,9 +233,9 @@ class StableDiffusionModel(nn.Module):
# encode (scaled latent)
start_code_terped=None
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
# decode it
samples = sampler.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
samples = self.ddim.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,)
x_samples_ddim = self.model.decode_first_stage(samples)
......
......@@ -111,10 +111,6 @@ def sanitize_input(config, request):
if config.model_name == 'stable-diffusion':
return sanitize_stable_diffusion(request)
elif config.model_name == 'dalle-mini':
return sanitize_dalle_mini(request)
\ No newline at end of file
return sanitize_dalle_mini(request)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment