Commit ad4eaba6 authored by kurumuz's avatar kurumuz

n, n+1... seeding of noise

parent 85d64b84
...@@ -59,6 +59,14 @@ def prompt_mixing(model, prompt_body, batch_size): ...@@ -59,6 +59,14 @@ def prompt_mixing(model, prompt_body, batch_size):
else: else:
return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size) return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)
def sample_start_noise(seed, C, H, W, f, device="cuda"):
if seed:
torch.manual_seed(seed)
np.random.seed(seed)
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
return noise
@torch.no_grad() @torch.no_grad()
#@torch.autocast("cuda", enabled=True, dtype=torch.float16) #@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def encode_image(image, model): def encode_image(image, model):
...@@ -217,7 +225,17 @@ class StableDiffusionModel(nn.Module): ...@@ -217,7 +225,17 @@ class StableDiffusionModel(nn.Module):
start_code = encode_image(request.image, self.model.first_stage_model).to(self.device) start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
start_code = self.model.get_first_stage_encoding(start_code) start_code = self.model.get_first_stage_encoding(start_code)
start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0) start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0)
start_code = start_code + (torch.randn_like(start_code) * request.noise)
main_noise = []
start_noise = []
for seed in range(request.seed, request.seed+request.n_samples):
main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
start_noise.append(sample_start_noise(None, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
main_noise = torch.cat(main_noise, dim=0)
start_noise = torch.cat(start_noise, dim=0)
start_code = start_code + (start_noise * request.noise)
t_enc = int(request.strength * request.steps) t_enc = int(request.strength * request.steps)
if request.sampler.startswith("k_"): if request.sampler.startswith("k_"):
...@@ -230,14 +248,12 @@ class StableDiffusionModel(nn.Module): ...@@ -230,14 +248,12 @@ class StableDiffusionModel(nn.Module):
sampler = "normal" sampler = "normal"
if request.image is None: if request.image is None:
start_code = None main_noise = []
if request.fixed_code or sampler == "k-diffusion": for seed in range(request.seed, request.seed+request.n_samples):
start_code = torch.randn([ main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
request.n_samples,
request.latent_channels, main_noise = torch.cat(main_noise, dim=0)
request.height // request.downsampling_factor, start_code = main_noise
request.width // request.downsampling_factor,
], device=self.device)
prompt = [request.prompt] * request.n_samples prompt = [request.prompt] * request.n_samples
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples) prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
...@@ -268,16 +284,11 @@ class StableDiffusionModel(nn.Module): ...@@ -268,16 +284,11 @@ class StableDiffusionModel(nn.Module):
x_T=start_code, x_T=start_code,
) )
elif sampler == 'img2img':
with self.model.ema_scope():
start_code = self.ddim.stochastic_encode(start_code, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=None)
samples = self.ddim.decode(start_code, prompt_condition, t_enc, unconditional_guidance_scale=request.scale, unconditional_conditioning=uc)
elif sampler == "k-diffusion": elif sampler == "k-diffusion":
with self.model.ema_scope(): with self.model.ema_scope():
sigmas = self.k_model.get_sigmas(request.steps) sigmas = self.k_model.get_sigmas(request.steps)
if request.image is not None: if request.image is not None:
noise = torch.randn_like(start_code) * sigmas[request.steps - t_enc - 1] noise = main_noise * sigmas[request.steps - t_enc - 1]
start_code = start_code + noise start_code = start_code + noise
sigmas = sigmas[request.steps - t_enc - 1:] sigmas = sigmas[request.steps - t_enc - 1:]
......
...@@ -3,7 +3,7 @@ from dotmap import DotMap ...@@ -3,7 +3,7 @@ from dotmap import DotMap
import math import math
from io import BytesIO from io import BytesIO
import base64 import base64
traceback import random
v1pp_defaults = { v1pp_defaults = {
'steps': 50, 'steps': 50,
...@@ -96,6 +96,11 @@ def sanitize_stable_diffusion(request): ...@@ -96,6 +96,11 @@ def sanitize_stable_diffusion(request):
if request.sampler not in samplers: if request.sampler not in samplers:
return False, "sampler should be one of {}".format(samplers) return False, "sampler should be one of {}".format(samplers)
if request.seed is None:
state = random.getstate
request.seed = random.randint(0, 2**32)
random.setstate(state)
if request.image is not None: if request.image is not None:
#decode from base64 #decode from base64
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment