Commit 44328813 authored by kurumuz's avatar kurumuz

fix seeding maybe

parent cd1aeb01
...@@ -17,6 +17,7 @@ import time ...@@ -17,6 +17,7 @@ import time
from PIL import Image from PIL import Image
import k_diffusion as K import k_diffusion as K
import contextlib import contextlib
import random
def pil_upscale(image, scale=1): def pil_upscale(image, scale=1):
device = image.device device = image.device
...@@ -162,6 +163,11 @@ class StableInterface(nn.Module): ...@@ -162,6 +163,11 @@ class StableInterface(nn.Module):
return x_0 return x_0
def seed_everything(seed: int):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
class StableDiffusionModel(nn.Module): class StableDiffusionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
nn.Module.__init__(self) nn.Module.__init__(self)
...@@ -282,8 +288,7 @@ class StableDiffusionModel(nn.Module): ...@@ -282,8 +288,7 @@ class StableDiffusionModel(nn.Module):
CrossAttention.set_hypernetwork(module) CrossAttention.set_hypernetwork(module)
if request.seed is not None: if request.seed is not None:
torch.manual_seed(request.seed) seed_everything(request.seed)
np.random.seed(request.seed)
if request.image is not None: if request.image is not None:
request.steps = 50 request.steps = 50
...@@ -302,7 +307,7 @@ class StableDiffusionModel(nn.Module): ...@@ -302,7 +307,7 @@ class StableDiffusionModel(nn.Module):
start_noise = [] start_noise = []
for seed in range(request.seed, request.seed+request.n_samples): for seed in range(request.seed, request.seed+request.n_samples):
main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device)) main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
start_noise.append(sample_start_noise(None, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device)) start_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
main_noise = torch.cat(main_noise, dim=0) main_noise = torch.cat(main_noise, dim=0)
start_noise = torch.cat(start_noise, dim=0) start_noise = torch.cat(start_noise, dim=0)
...@@ -424,8 +429,7 @@ class StableDiffusionModel(nn.Module): ...@@ -424,8 +429,7 @@ class StableDiffusionModel(nn.Module):
def sample_two_stages(self, request): def sample_two_stages(self, request):
request = DotMap(request) request = DotMap(request)
if request.seed is not None: if request.seed is not None:
torch.manual_seed(request.seed) seed_everything(request.seed)
np.random.seed(request.seed)
if request.plms: if request.plms:
sampler = self.plms sampler = self.plms
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment