Commit 2abc4178 authored by CodeHatchling's avatar CodeHatchling

Re-implemented soft inpainting via a script. Also fixed some mistakes with the...

Re-implemented soft inpainting via a script. Also fixed some mistakes with the previous hooks, removed unnecessary formatting changes, removed code that I had forgotten to.
parent ac457891
...@@ -879,14 +879,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -879,14 +879,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim) ps = scripts.PostSampleArgs(samples_ddim)
p.scripts.post_sample(p, ps) p.scripts.post_sample(p, ps)
samples_ddim = pp.samples samples_ddim = ps.samples
if getattr(samples_ddim, 'already_decoded', False): if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim x_samples_ddim = samples_ddim
else: else:
if opts.sd_vae_decode_method != 'Full': if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.stack(x_samples_ddim).float()
...@@ -944,7 +943,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -944,7 +943,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
p.scripts.postprocess_maskoverlay(p, ppmo) p.scripts.postprocess_maskoverlay(p, ppmo)
mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
if p.color_corrections is not None and i < len(p.color_corrections): if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction: if save_samples and opts.save_images_before_color_correction:
...@@ -959,7 +958,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -959,7 +958,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
original_denoised_image = image.copy() original_denoised_image = image.copy()
if p.paste_to is not None: if p.paste_to is not None:
original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to) original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
image = apply_overlay(image, p.paste_to, overlay_image) image = apply_overlay(image, p.paste_to, overlay_image)
...@@ -1512,9 +1511,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -1512,9 +1511,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.overlay_images is not None: if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size self.overlay_images = self.overlay_images * self.batch_size
if self.masks_for_overlay is not None:
self.masks_for_overlay = self.masks_for_overlay * self.batch_size
if self.color_corrections is not None and len(self.color_corrections) == 1: if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size self.color_corrections = self.color_corrections * self.batch_size
...@@ -1565,14 +1561,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -1565,14 +1561,15 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
blended_samples = samples * self.nmask + self.init_latent * self.mask if self.mask is not None:
blended_samples = samples * self.nmask + self.init_latent * self.mask
if self.scripts is not None: if self.scripts is not None:
mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True) mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
self.scripts.on_mask_blend(self, mba) self.scripts.on_mask_blend(self, mba)
blended_samples = mba.blended_latent blended_samples = mba.blended_latent
samples = blended_samples samples = blended_samples
del x del x
devices.torch_gc() devices.torch_gc()
......
...@@ -12,12 +12,12 @@ from modules import shared, paths, script_callbacks, extensions, script_loading, ...@@ -12,12 +12,12 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object() AlwaysVisible = object()
class MaskBlendArgs: class MaskBlendArgs:
def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None): def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
self.current_latent = current_latent self.current_latent = current_latent
self.nmask = nmask self.nmask = nmask
self.init_latent = init_latent self.init_latent = init_latent
self.mask = mask self.mask = mask
self.blended_samples = blended_samples self.blended_latent = blended_latent
self.denoiser = denoiser self.denoiser = denoiser
self.is_final_blend = denoiser is None self.is_final_blend = denoiser is None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment