Commit 2abc4178 authored by CodeHatchling's avatar CodeHatchling

Re-implemented soft inpainting via a script. Also fixed some mistakes with the...

Re-implemented soft inpainting via a script. Also fixed some mistakes with the previous hooks, removed unnecessary formatting changes, removed code that I had forgotten to.
parent ac457891
......@@ -879,14 +879,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
p.scripts.post_sample(p, ps)
samples_ddim = pp.samples
samples_ddim = ps.samples
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
......@@ -944,7 +943,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
p.scripts.postprocess_maskoverlay(p, ppmo)
mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image
mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image
if p.color_corrections is not None and i < len(p.color_corrections):
if save_samples and opts.save_images_before_color_correction:
......@@ -959,7 +958,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
original_denoised_image = image.copy()
if p.paste_to is not None:
original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to)
original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)
image = apply_overlay(image, p.paste_to, overlay_image)
......@@ -1512,9 +1511,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
if self.masks_for_overlay is not None:
self.masks_for_overlay = self.masks_for_overlay * self.batch_size
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size
......@@ -1565,10 +1561,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
blended_samples = samples * self.nmask + self.init_latent * self.mask
if self.scripts is not None:
mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True)
mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
self.scripts.on_mask_blend(self, mba)
blended_samples = mba.blended_latent
......
......@@ -12,12 +12,12 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object()
class MaskBlendArgs:
def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None):
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
self.current_latent = current_latent
self.nmask = nmask
self.init_latent = init_latent
self.mask = mask
self.blended_samples = blended_samples
self.blended_latent = blended_latent
self.denoiser = denoiser
self.is_final_blend = denoiser is None
......
import gradio as gr
from modules.ui_components import InputAccordion
import modules.scripts as scripts
class SoftInpaintingSettings:
def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation):
self.mask_blend_power = mask_blend_power
......@@ -46,8 +51,10 @@ def latent_blend(soft_inpainting, a, b, t):
current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
soft_inpainting.inpaint_detail_preservation) * one_minus_t3
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
soft_inpainting.inpaint_detail_preservation) * t3
desired_magnitude = a_magnitude
desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
del a_magnitude, b_magnitude, t3, one_minus_t3
......@@ -84,15 +91,13 @@ def get_modified_nmask(soft_inpainting, nmask, sigma):
NOTE: "mask" is not used
"""
import torch
# todo: Why is sigma 2D? Both values are the same.
return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
def apply_adaptive_masks(
latent_orig,
latent_processed,
overlay_images,
masks_for_overlay,
width, height,
paste_to):
import torch
......@@ -112,6 +117,8 @@ def apply_adaptive_masks(
kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
masks_for_overlay = []
for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
converted_mask = distance_map.float().cpu().numpy()
converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
......@@ -137,11 +144,11 @@ def apply_adaptive_masks(
# Expand the mask to fit the whole image if needed.
if paste_to is not None:
converted_mask = proc. uncrop(converted_mask,
converted_mask = proc.uncrop(converted_mask,
(overlay_image.width, overlay_image.height),
paste_to)
masks_for_overlay[i] = converted_mask
masks_for_overlay.append(converted_mask)
image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
......@@ -149,11 +156,13 @@ def apply_adaptive_masks(
overlay_images[i] = image_masked.convert('RGBA')
return masks_for_overlay
def apply_masks(
soft_inpainting,
nmask,
overlay_images,
masks_for_overlay,
width, height,
paste_to):
import torch
......@@ -179,6 +188,8 @@ def apply_masks(
(width, height),
paste_to)
masks_for_overlay = []
for i, overlay_image in enumerate(overlay_images):
masks_for_overlay[i] = converted_mask
......@@ -188,6 +199,8 @@ def apply_masks(
overlay_images[i] = image_masked.convert('RGBA')
return masks_for_overlay
# ------------------- Constants -------------------
......@@ -219,12 +232,21 @@ el_ids = SoftInpaintingSettings(
"inpaint_detail_preservation")
# ------------------- UI -------------------
class Script(scripts.Script):
def __init__(self):
self.masks_for_overlay = None
self.overlay_images = None
def title(self):
return "Soft Inpainting"
def gradio_ui():
import gradio as gr
from modules.ui_components import InputAccordion
def show(self, is_img2img):
return scripts.AlwaysVisible if is_img2img else False
def ui(self, is_img2img):
if not is_img2img:
return
with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
with gr.Group():
......@@ -292,17 +314,88 @@ def gradio_ui():
- **High values**: Stronger contrast, may over-saturate colors.
""")
return (
[
soft_inpainting_enabled,
result.mask_blend_power,
result.mask_blend_scale,
result.inpaint_detail_preservation
],
[
(soft_inpainting_enabled, enabled_gen_param_label),
self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
(result.mask_blend_power, gen_param_labels.mask_blend_power),
(result.mask_blend_scale, gen_param_labels.mask_blend_scale),
(result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)
]
)
(result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)]
self.paste_field_names = []
for _, field_name in self.infotext_fields:
self.paste_field_names.append(field_name)
return [soft_inpainting_enabled,
result.mask_blend_power,
result.mask_blend_scale,
result.inpaint_detail_preservation]
def process(self, p, enabled, power, scale, detail_preservation):
if not enabled:
return
# Shut off the rounding it normally does.
p.mask_round = False
settings = SoftInpaintingSettings(power, scale, detail_preservation)
# p.extra_generation_params["Mask rounding"] = False
settings.add_generation_params(p.extra_generation_params)
def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation):
if not enabled:
return
if mba.sigma is None:
mba.blended_latent = mba.current_latent
return
settings = SoftInpaintingSettings(power, scale, detail_preservation)
# todo: Why is sigma 2D? Both values are the same.
mba.blended_latent = latent_blend(settings,
mba.init_latent,
mba.current_latent,
get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation):
if not enabled:
return
settings = SoftInpaintingSettings(power, scale, detail_preservation)
from modules import images
from modules.shared import opts
# since the original code puts holes in the existing overlay images,
# we have to rebuild them.
self.overlay_images = []
for img in p.init_images:
image = images.flatten(img, opts.img2img_background_color)
if p.paste_to is None and p.resize_mode != 3:
image = images.resize_image(p.resize_mode, image, p.width, p.height)
self.overlay_images.append(image.convert('RGBA'))
if getattr(ps.samples, 'already_decoded', False):
self.masks_for_overlay = apply_masks(soft_inpainting=settings,
nmask=p.nmask,
overlay_images=self.overlay_images,
width=p.width,
height=p.height,
paste_to=p.paste_to)
else:
self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent,
latent_processed=ps.samples,
overlay_images=self.overlay_images,
width=p.width,
height=p.height,
paste_to=p.paste_to)
def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation):
if not enabled:
return
ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]
ppmo.overlay_image = self.overlay_images[ppmo.index]
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment