Commit 49bf6828 authored by kurumuz's avatar kurumuz

fix batch img2img

parent 16890c43
...@@ -211,11 +211,12 @@ class StableDiffusionModel(nn.Module): ...@@ -211,11 +211,12 @@ class StableDiffusionModel(nn.Module):
if request.sampler == "plms": if request.sampler == "plms":
request.sampler = "k_lms" request.sampler = "k_lms"
if request.sampler == "ddim": if request.sampler == "ddim":
request.sampler = "ddim_img2img" request.sampler = "k_lms"
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False) self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
start_code = encode_image(request.image, self.model.first_stage_model).to(self.device) start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
start_code = self.model.get_first_stage_encoding(start_code) start_code = self.model.get_first_stage_encoding(start_code)
start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0)
start_code = start_code + (torch.randn_like(start_code) * request.noise) start_code = start_code + (torch.randn_like(start_code) * request.noise)
t_enc = int(request.strength * request.steps) t_enc = int(request.strength * request.steps)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment