Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
eb112c6f
Commit
eb112c6f
authored
Jul 06, 2024
by
AUTOMATIC1111
Committed by
GitHub
Jul 06, 2024
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #16035 from v0xie/cfgpp
Add new sampler DDIM CFG++
parents
ace00a1f
663a4d80
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
0 deletions
+48
-0
modules/sd_samplers_cfg_denoiser.py
modules/sd_samplers_cfg_denoiser.py
+10
-0
modules/sd_samplers_timesteps.py
modules/sd_samplers_timesteps.py
+1
-0
modules/sd_samplers_timesteps_impl.py
modules/sd_samplers_timesteps_impl.py
+37
-0
No files found.
modules/sd_samplers_cfg_denoiser.py
View file @
eb112c6f
...
@@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module):
...
@@ -58,6 +58,8 @@ class CFGDenoiser(torch.nn.Module):
self
.
model_wrap
=
None
self
.
model_wrap
=
None
self
.
p
=
None
self
.
p
=
None
self
.
last_noise_uncond
=
None
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
# as the original latents do not have noise
self
.
mask_before_denoising
=
False
self
.
mask_before_denoising
=
False
...
@@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module):
...
@@ -160,6 +162,8 @@ class CFGDenoiser(torch.nn.Module):
# so is_edit_model is set to False to support AND composition.
# so is_edit_model is set to False to support AND composition.
is_edit_model
=
shared
.
sd_model
.
cond_stage_key
==
"edit"
and
self
.
image_cfg_scale
is
not
None
and
self
.
image_cfg_scale
!=
1.0
is_edit_model
=
shared
.
sd_model
.
cond_stage_key
==
"edit"
and
self
.
image_cfg_scale
is
not
None
and
self
.
image_cfg_scale
!=
1.0
is_cfg_pp
=
'CFG++'
in
self
.
sampler
.
config
.
name
conds_list
,
tensor
=
prompt_parser
.
reconstruct_multicond_batch
(
cond
,
self
.
step
)
conds_list
,
tensor
=
prompt_parser
.
reconstruct_multicond_batch
(
cond
,
self
.
step
)
uncond
=
prompt_parser
.
reconstruct_cond_batch
(
uncond
,
self
.
step
)
uncond
=
prompt_parser
.
reconstruct_cond_batch
(
uncond
,
self
.
step
)
...
@@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module):
...
@@ -273,10 +277,16 @@ class CFGDenoiser(torch.nn.Module):
denoised_params
=
CFGDenoisedParams
(
x_out
,
state
.
sampling_step
,
state
.
sampling_steps
,
self
.
inner_model
)
denoised_params
=
CFGDenoisedParams
(
x_out
,
state
.
sampling_step
,
state
.
sampling_steps
,
self
.
inner_model
)
cfg_denoised_callback
(
denoised_params
)
cfg_denoised_callback
(
denoised_params
)
if
is_cfg_pp
:
self
.
last_noise_uncond
=
x_out
[
-
uncond
.
shape
[
0
]:]
self
.
last_noise_uncond
=
torch
.
clone
(
self
.
last_noise_uncond
)
if
is_edit_model
:
if
is_edit_model
:
denoised
=
self
.
combine_denoised_for_edit_model
(
x_out
,
cond_scale
)
denoised
=
self
.
combine_denoised_for_edit_model
(
x_out
,
cond_scale
)
elif
skip_uncond
:
elif
skip_uncond
:
denoised
=
self
.
combine_denoised
(
x_out
,
conds_list
,
uncond
,
1.0
)
denoised
=
self
.
combine_denoised
(
x_out
,
conds_list
,
uncond
,
1.0
)
elif
is_cfg_pp
:
denoised
=
self
.
combine_denoised
(
x_out
,
conds_list
,
uncond
,
cond_scale
/
12.5
)
# CFG++ scale of (0, 1) maps to (1.0, 12.5)
else
:
else
:
denoised
=
self
.
combine_denoised
(
x_out
,
conds_list
,
uncond
,
cond_scale
)
denoised
=
self
.
combine_denoised
(
x_out
,
conds_list
,
uncond
,
cond_scale
)
...
...
modules/sd_samplers_timesteps.py
View file @
eb112c6f
...
@@ -10,6 +10,7 @@ import modules.shared as shared
...
@@ -10,6 +10,7 @@ import modules.shared as shared
samplers_timesteps
=
[
samplers_timesteps
=
[
(
'DDIM'
,
sd_samplers_timesteps_impl
.
ddim
,
[
'ddim'
],
{}),
(
'DDIM'
,
sd_samplers_timesteps_impl
.
ddim
,
[
'ddim'
],
{}),
(
'DDIM CFG++'
,
sd_samplers_timesteps_impl
.
ddim_cfgpp
,
[
'ddim_cfgpp'
],
{}),
(
'PLMS'
,
sd_samplers_timesteps_impl
.
plms
,
[
'plms'
],
{}),
(
'PLMS'
,
sd_samplers_timesteps_impl
.
plms
,
[
'plms'
],
{}),
(
'UniPC'
,
sd_samplers_timesteps_impl
.
unipc
,
[
'unipc'
],
{}),
(
'UniPC'
,
sd_samplers_timesteps_impl
.
unipc
,
[
'unipc'
],
{}),
]
]
...
...
modules/sd_samplers_timesteps_impl.py
View file @
eb112c6f
...
@@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
...
@@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
return
x
return
x
@
torch
.
no_grad
()
def
ddim_cfgpp
(
model
,
x
,
timesteps
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
,
eta
=
0.0
):
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
alphas_cumprod
=
model
.
inner_model
.
inner_model
.
alphas_cumprod
alphas
=
alphas_cumprod
[
timesteps
]
alphas_prev
=
alphas_cumprod
[
torch
.
nn
.
functional
.
pad
(
timesteps
[:
-
1
],
pad
=
(
1
,
0
))]
.
to
(
float64
(
x
))
sqrt_one_minus_alphas
=
torch
.
sqrt
(
1
-
alphas
)
sigmas
=
eta
*
np
.
sqrt
((
1
-
alphas_prev
.
cpu
()
.
numpy
())
/
(
1
-
alphas
.
cpu
())
*
(
1
-
alphas
.
cpu
()
/
alphas_prev
.
cpu
()
.
numpy
()))
extra_args
=
{}
if
extra_args
is
None
else
extra_args
s_in
=
x
.
new_ones
((
x
.
shape
[
0
]))
s_x
=
x
.
new_ones
((
x
.
shape
[
0
],
1
,
1
,
1
))
for
i
in
tqdm
.
trange
(
len
(
timesteps
)
-
1
,
disable
=
disable
):
index
=
len
(
timesteps
)
-
1
-
i
e_t
=
model
(
x
,
timesteps
[
index
]
.
item
()
*
s_in
,
**
extra_args
)
last_noise_uncond
=
model
.
last_noise_uncond
a_t
=
alphas
[
index
]
.
item
()
*
s_x
a_prev
=
alphas_prev
[
index
]
.
item
()
*
s_x
sigma_t
=
sigmas
[
index
]
.
item
()
*
s_x
sqrt_one_minus_at
=
sqrt_one_minus_alphas
[
index
]
.
item
()
*
s_x
pred_x0
=
(
x
-
sqrt_one_minus_at
*
e_t
)
/
a_t
.
sqrt
()
dir_xt
=
(
1.
-
a_prev
-
sigma_t
**
2
)
.
sqrt
()
*
last_noise_uncond
noise
=
sigma_t
*
k_diffusion
.
sampling
.
torch
.
randn_like
(
x
)
x
=
a_prev
.
sqrt
()
*
pred_x0
+
dir_xt
+
noise
if
callback
is
not
None
:
callback
({
'x'
:
x
,
'i'
:
i
,
'sigma'
:
0
,
'sigma_hat'
:
0
,
'denoised'
:
pred_x0
})
return
x
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
plms
(
model
,
x
,
timesteps
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
):
def
plms
(
model
,
x
,
timesteps
,
extra_args
=
None
,
callback
=
None
,
disable
=
None
):
alphas_cumprod
=
model
.
inner_model
.
inner_model
.
alphas_cumprod
alphas_cumprod
=
model
.
inner_model
.
inner_model
.
alphas_cumprod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment