Commit 8c32594d authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #14208 from CodeHatchling/soft-inpainting

Soft Inpainting
parents f3cc5f83 f1ff932c
...@@ -791,3 +791,4 @@ def flatten(img, bgcolor): ...@@ -791,3 +791,4 @@ def flatten(img, bgcolor):
img = background img = background
return img.convert('RGB') return img.convert('RGB')
This diff is collapsed.
...@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading, ...@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object() AlwaysVisible = object()
class MaskBlendArgs:
def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
self.current_latent = current_latent
self.nmask = nmask
self.init_latent = init_latent
self.mask = mask
self.blended_latent = blended_latent
self.denoiser = denoiser
self.is_final_blend = denoiser is None
self.sigma = sigma
class PostSampleArgs:
def __init__(self, samples):
self.samples = samples
class PostprocessImageArgs: class PostprocessImageArgs:
def __init__(self, image): def __init__(self, image):
self.image = image self.image = image
class PostProcessMaskOverlayArgs:
def __init__(self, index, mask_for_overlay, overlay_image):
self.index = index
self.mask_for_overlay = mask_for_overlay
self.overlay_image = overlay_image
class PostprocessBatchListArgs: class PostprocessBatchListArgs:
def __init__(self, images): def __init__(self, images):
...@@ -206,6 +226,25 @@ class Script: ...@@ -206,6 +226,25 @@ class Script:
pass pass
def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
"""
Called in inpainting mode when the original content is blended with the inpainted content.
This is called at every step in the denoising process and once at the end.
If is_final_blend is true, this is called for the final blending stage.
Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
"""
pass
def post_sample(self, p, ps: PostSampleArgs, *args):
"""
Called after the samples have been generated,
but before they have been decoded by the VAE, if applicable.
Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
"""
pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args): def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
""" """
Called for every image after it has been generated. Called for every image after it has been generated.
...@@ -213,6 +252,13 @@ class Script: ...@@ -213,6 +252,13 @@ class Script:
pass pass
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
"""
Called for every image after it has been generated.
"""
pass
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
""" """
This function is called after processing ends for AlwaysVisible scripts. This function is called after processing ends for AlwaysVisible scripts.
...@@ -767,6 +813,22 @@ class ScriptRunner: ...@@ -767,6 +813,22 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def post_sample(self, p, ps: PostSampleArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.post_sample(p, ps, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def on_mask_blend(self, p, mba: MaskBlendArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.on_mask_blend(p, mba, *script_args)
except Exception:
errors.report(f"Error running post_sample: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs): def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
...@@ -775,6 +837,14 @@ class ScriptRunner: ...@@ -775,6 +837,14 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_maskoverlay(p, ppmo, *script_args)
except Exception:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs): def before_component(self, component, **kwargs):
for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
try: try:
......
...@@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module): ...@@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module):
self.sampler = sampler self.sampler = sampler
self.model_wrap = None self.model_wrap = None
self.p = None self.p = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
self.mask_before_denoising = False self.mask_before_denoising = False
@property @property
...@@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module): ...@@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module):
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
# If we use masks, blending between the denoised and original latent images occurs here.
def apply_blend(current_latent):
blended_latent = current_latent * self.nmask + self.init_latent * self.mask
if self.p.scripts is not None:
from modules import scripts
mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
self.p.scripts.on_mask_blend(self.p, mba)
blended_latent = mba.blended_latent
return blended_latent
# Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None: if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x x = apply_blend(x)
batch_size = len(conds_list) batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)] repeats = [len(conds_list[i]) for i in range(batch_size)]
...@@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module): ...@@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module):
else: else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
# Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None: if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised denoised = apply_blend(denoised)
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment