Commit c0bd8b53 authored by kurumuz's avatar kurumuz

infilling maybe works?

parent ce219434
...@@ -68,6 +68,14 @@ def sample_start_noise(seed, C, H, W, f, device="cuda"): ...@@ -68,6 +68,14 @@ def sample_start_noise(seed, C, H, W, f, device="cuda"):
noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0) noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
return noise return noise
def sample_start_noise_special(seed, request, device="cuda"):
if request.seed is not None:
torch.manual_seed(request.seed)
np.random.seed(request.seed)
noise = torch.randn([request.latent_channels, request.height // request.downsampling_factor, request.width // request.downsampling_factor], device=device).unsqueeze(0)
return noise
@torch.no_grad() @torch.no_grad()
#@torch.autocast("cuda", enabled=True, dtype=torch.float16) #@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def encode_image(image, model): def encode_image(image, model):
...@@ -259,8 +267,24 @@ class StableDiffusionModel(nn.Module): ...@@ -259,8 +267,24 @@ class StableDiffusionModel(nn.Module):
if request.image is None: if request.image is None:
main_noise = [] main_noise = []
for seed in range(request.seed, request.seed+request.n_samples): for seed_offset in range(request.n_samples):
main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device)) noise_x = sample_start_noise_special(request.seed+seed_offset, request, self.device)
if request.masks is not None:
for maskobj in request.masks:
mask_seed = maskobj["seed"]
mask = maskobj["mask"]
mask = np.asarray(mask)
mask = torch.from_numpy(mask).clone().to(self.device)
mask = mask.float() / 255.0
# convert RGB or grayscale image into 4-channel
mask = mask[0]
mask = torch.repeat_interleave(mask, request.latent_channels, dim=0).unsqueeze(0)
mask = (mask > 0.5).float()
# interpolate start noise
noise_x = (noise_x * (1-mask)) + (sample_start_noise_special(mask_seed+seed_offset, request, self.device) * mask)
main_noise.append(noise_x)
main_noise = torch.cat(main_noise, dim=0) main_noise = torch.cat(main_noise, dim=0)
start_code = main_noise start_code = main_noise
......
...@@ -136,7 +136,6 @@ def sanitize_stable_diffusion(request, config): ...@@ -136,7 +136,6 @@ def sanitize_stable_diffusion(request, config):
if request.masks is not None: if request.masks is not None:
masks = request.masks masks = request.masks
images = []
for x in range(len(masks)): for x in range(len(masks)):
image = masks[x]["mask"] image = masks[x]["mask"]
try: try:
...@@ -160,6 +159,7 @@ def sanitize_stable_diffusion(request, config): ...@@ -160,6 +159,7 @@ def sanitize_stable_diffusion(request, config):
image = Image.open(BytesIO(image)) image = Image.open(BytesIO(image))
#image = image.convert('RGB') #image = image.convert('RGB')
image = image.resize((request.width//request.downsampling_factor, request.height//request.downsampling_factor), resample=Image.Resampling.LANCZOS) image = image.resize((request.width//request.downsampling_factor, request.height//request.downsampling_factor), resample=Image.Resampling.LANCZOS)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
return False, "Error while opening and cleaning image" return False, "Error while opening and cleaning image"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment