Commit eb112c6f authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #16035 from v0xie/cfgpp

Add new sampler DDIM CFG++
parents ace00a1f 663a4d80
...@@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module): ...@@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module):
self.model_wrap = None self.model_wrap = None
self.p = None self.p = None
self.last_noise_uncond = None
# NOTE: masking before denoising can cause the original latents to be oversmoothed # NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise # as the original latents do not have noise
self.mask_before_denoising = False self.mask_before_denoising = False
...@@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module): ...@@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module):
# so is_edit_model is set to False to support AND composition. # so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
is_cfg_pp = 'CFG++' in self.sampler.config.name
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
...@@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module): ...@@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module):
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params) cfg_denoised_callback(denoised_params)
if is_cfg_pp:
self.last_noise_uncond = x_out[-uncond.shape[0]:]
self.last_noise_uncond = torch.clone(self.last_noise_uncond)
if is_edit_model: if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond: elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
elif is_cfg_pp:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale/12.5) # CFG++ scale of (0, 1) maps to (1.0, 12.5)
else: else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
......
...@@ -10,6 +10,7 @@ import modules.shared as shared ...@@ -10,6 +10,7 @@ import modules.shared as shared
samplers_timesteps = [ samplers_timesteps = [
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}),
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
] ]
......
...@@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= ...@@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
return x return x
@torch.no_grad()
def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
last_noise_uncond = model.last_noise_uncond
a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise
if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})
return x
@torch.no_grad() @torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment