Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
21880eb9
Commit
21880eb9
authored
Feb 10, 2023
by
space-nuko
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix logspam and live previews
parent
12531998
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
31 deletions
+41
-31
modules/models/diffusion/uni_pc/sampler.py
modules/models/diffusion/uni_pc/sampler.py
+15
-5
modules/models/diffusion/uni_pc/uni_pc.py
modules/models/diffusion/uni_pc/uni_pc.py
+15
-17
modules/sd_samplers_compvis.py
modules/sd_samplers_compvis.py
+11
-9
No files found.
modules/models/diffusion/uni_pc/sampler.py
View file @
21880eb9
...
...
@@ -19,9 +19,10 @@ class UniPCSampler(object):
attr
=
attr
.
to
(
torch
.
device
(
"cuda"
))
setattr
(
self
,
name
,
attr
)
def
set_hooks
(
self
,
before
,
after
):
self
.
before_sample
=
before
self
.
after_sample
=
after
def
set_hooks
(
self
,
before_sample
,
after_sample
,
after_update
):
self
.
before_sample
=
before_sample
self
.
after_sample
=
after_sample
self
.
after_update
=
after_update
@
torch
.
no_grad
()
def
sample
(
self
,
...
...
@@ -50,9 +51,17 @@ class UniPCSampler(object):
):
if
conditioning
is
not
None
:
if
isinstance
(
conditioning
,
dict
):
cbs
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]]
.
shape
[
0
]
ctmp
=
conditioning
[
list
(
conditioning
.
keys
())[
0
]]
while
isinstance
(
ctmp
,
list
):
ctmp
=
ctmp
[
0
]
cbs
=
ctmp
.
shape
[
0
]
if
cbs
!=
batch_size
:
print
(
f
"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
elif
isinstance
(
conditioning
,
list
):
for
ctmp
in
conditioning
:
if
ctmp
.
shape
[
0
]
!=
batch_size
:
print
(
f
"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else
:
if
conditioning
.
shape
[
0
]
!=
batch_size
:
print
(
f
"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
...
...
@@ -60,6 +69,7 @@ class UniPCSampler(object):
# sampling
C
,
H
,
W
=
shape
size
=
(
batch_size
,
C
,
H
,
W
)
print
(
f
'Data shape for UniPC sampling is {size}, eta {eta}'
)
device
=
self
.
model
.
betas
.
device
if
x_T
is
None
:
...
...
@@ -79,7 +89,7 @@ class UniPCSampler(object):
guidance_scale
=
unconditional_guidance_scale
,
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
,
condition
=
conditioning
,
unconditional_condition
=
unconditional_conditioning
,
before_sample
=
self
.
before_sample
,
after_sample
=
self
.
after_sample
)
uni_pc
=
UniPC
(
model_fn
,
ns
,
predict_x0
=
True
,
thresholding
=
False
,
condition
=
conditioning
,
unconditional_condition
=
unconditional_conditioning
,
before_sample
=
self
.
before_sample
,
after_sample
=
self
.
after_sample
,
after_update
=
self
.
after_update
)
x
=
uni_pc
.
sample
(
img
,
steps
=
S
,
skip_type
=
"time_uniform"
,
method
=
"multistep"
,
order
=
3
,
lower_order_final
=
True
)
return
x
.
to
(
device
),
None
modules/models/diffusion/uni_pc/uni_pc.py
View file @
21880eb9
...
...
@@ -378,7 +378,8 @@ class UniPC:
condition
=
None
,
unconditional_condition
=
None
,
before_sample
=
None
,
after_sample
=
None
after_sample
=
None
,
after_update
=
None
):
"""Construct a UniPC.
...
...
@@ -394,6 +395,7 @@ class UniPC:
self
.
unconditional_condition
=
unconditional_condition
self
.
before_sample
=
before_sample
self
.
after_sample
=
after_sample
self
.
after_update
=
after_update
def
dynamic_thresholding_fn
(
self
,
x0
,
t
=
None
):
"""
...
...
@@ -434,15 +436,6 @@ class UniPC:
noise
=
self
.
noise_prediction_fn
(
x
,
t
)
dims
=
x
.
dim
()
alpha_t
,
sigma_t
=
self
.
noise_schedule
.
marginal_alpha
(
t
),
self
.
noise_schedule
.
marginal_std
(
t
)
from
pprint
import
pp
print
(
"X:"
)
pp
(
x
)
print
(
"sigma_t:"
)
pp
(
sigma_t
)
print
(
"noise:"
)
pp
(
noise
)
print
(
"alpha_t:"
)
pp
(
alpha_t
)
x0
=
(
x
-
expand_dims
(
sigma_t
,
dims
)
*
noise
)
/
expand_dims
(
alpha_t
,
dims
)
if
self
.
thresholding
:
p
=
0.995
# A hyperparameter in the paper of "Imagen" [1].
...
...
@@ -524,7 +517,7 @@ class UniPC:
return
self
.
multistep_uni_pc_vary_update
(
x
,
model_prev_list
,
t_prev_list
,
t
,
order
,
**
kwargs
)
def
multistep_uni_pc_vary_update
(
self
,
x
,
model_prev_list
,
t_prev_list
,
t
,
order
,
use_corrector
=
True
):
print
(
f
'using unified predictor-corrector with order {order} (solver type: vary coeff)'
)
#
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
ns
=
self
.
noise_schedule
assert
order
<=
len
(
model_prev_list
)
...
...
@@ -568,7 +561,7 @@ class UniPC:
A_p
=
C_inv_p
if
use_corrector
:
print
(
'using corrector'
)
#
print('using corrector')
C_inv
=
torch
.
linalg
.
inv
(
C
)
A_c
=
C_inv
...
...
@@ -627,7 +620,7 @@ class UniPC:
return
x_t
,
model_t
def
multistep_uni_pc_bh_update
(
self
,
x
,
model_prev_list
,
t_prev_list
,
t
,
order
,
x_t
=
None
,
use_corrector
=
True
):
print
(
f
'using unified predictor-corrector with order {order} (solver type: B(h))'
)
#
print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
ns
=
self
.
noise_schedule
assert
order
<=
len
(
model_prev_list
)
dims
=
x
.
dim
()
...
...
@@ -695,7 +688,7 @@ class UniPC:
D1s
=
None
if
use_corrector
:
print
(
'using corrector'
)
#
print('using corrector')
# for order 1, we use a simplified version
if
order
==
1
:
rhos_c
=
torch
.
tensor
([
0.5
],
device
=
b
.
device
)
...
...
@@ -755,8 +748,9 @@ class UniPC:
t_T
=
self
.
noise_schedule
.
T
if
t_start
is
None
else
t_start
device
=
x
.
device
if
method
==
'multistep'
:
assert
steps
>=
order
assert
steps
>=
order
,
"UniPC order must be < sampling steps"
timesteps
=
self
.
get_time_steps
(
skip_type
=
skip_type
,
t_T
=
t_T
,
t_0
=
t_0
,
N
=
steps
,
device
=
device
)
print
(
f
"Running UniPC Sampling with {timesteps.shape[0]} timesteps"
)
assert
timesteps
.
shape
[
0
]
-
1
==
steps
with
torch
.
no_grad
():
vec_t
=
timesteps
[
0
]
.
expand
((
x
.
shape
[
0
]))
...
...
@@ -768,6 +762,8 @@ class UniPC:
x
,
model_x
=
self
.
multistep_uni_pc_update
(
x
,
model_prev_list
,
t_prev_list
,
vec_t
,
init_order
,
use_corrector
=
True
)
if
model_x
is
None
:
model_x
=
self
.
model_fn
(
x
,
vec_t
)
if
self
.
after_update
is
not
None
:
self
.
after_update
(
x
,
model_x
)
model_prev_list
.
append
(
model_x
)
t_prev_list
.
append
(
vec_t
)
for
step
in
range
(
order
,
steps
+
1
):
...
...
@@ -776,13 +772,15 @@ class UniPC:
step_order
=
min
(
order
,
steps
+
1
-
step
)
else
:
step_order
=
order
print
(
'this step order:'
,
step_order
)
#
print('this step order:', step_order)
if
step
==
steps
:
print
(
'do not run corrector at the last step'
)
#
print('do not run corrector at the last step')
use_corrector
=
False
else
:
use_corrector
=
True
x
,
model_x
=
self
.
multistep_uni_pc_update
(
x
,
model_prev_list
,
t_prev_list
,
vec_t
,
step_order
,
use_corrector
=
use_corrector
)
if
self
.
after_update
is
not
None
:
self
.
after_update
(
x
,
model_x
)
for
i
in
range
(
order
-
1
):
t_prev_list
[
i
]
=
t_prev_list
[
i
+
1
]
model_prev_list
[
i
]
=
model_prev_list
[
i
+
1
]
...
...
modules/sd_samplers_compvis.py
View file @
21880eb9
...
...
@@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler:
return
x
,
ts
,
cond
,
unconditional_conditioning
def
after_sample
(
self
,
x
,
ts
,
cond
,
uncond
,
res
):
if
self
.
is_unipc
:
# unipc model_fn returns (pred_x0)
# p_sample_ddim returns (x_prev, pred_x0)
res
=
(
None
,
res
[
0
])
def
update_step
(
self
,
last_latent
):
if
self
.
mask
is
not
None
:
self
.
last_latent
=
self
.
init_latent
*
self
.
mask
+
self
.
nmask
*
res
[
1
]
self
.
last_latent
=
self
.
init_latent
*
self
.
mask
+
self
.
nmask
*
last_latent
else
:
self
.
last_latent
=
res
[
1
]
self
.
last_latent
=
last_latent
sd_samplers_common
.
store_latent
(
self
.
last_latent
)
...
...
@@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler:
state
.
sampling_step
=
self
.
step
shared
.
total_tqdm
.
update
()
def
after_sample
(
self
,
x
,
ts
,
cond
,
uncond
,
res
):
if
not
self
.
is_unipc
:
self
.
update_step
(
res
[
1
])
return
x
,
ts
,
cond
,
uncond
,
res
def
unipc_after_update
(
self
,
x
,
model_x
):
self
.
update_step
(
x
)
def
initialize
(
self
,
p
):
self
.
eta
=
p
.
eta
if
p
.
eta
is
not
None
else
shared
.
opts
.
eta_ddim
if
self
.
eta
!=
0.0
:
...
...
@@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler:
if
hasattr
(
self
.
sampler
,
fieldname
):
setattr
(
self
.
sampler
,
fieldname
,
self
.
p_sample_ddim_hook
)
if
self
.
is_unipc
:
self
.
sampler
.
set_hooks
(
lambda
x
,
t
,
c
,
u
:
self
.
before_sample
(
x
,
t
,
c
,
u
),
lambda
x
,
t
,
c
,
u
,
r
:
self
.
after_sample
(
x
,
t
,
c
,
u
,
r
))
self
.
sampler
.
set_hooks
(
lambda
x
,
t
,
c
,
u
:
self
.
before_sample
(
x
,
t
,
c
,
u
),
lambda
x
,
t
,
c
,
u
,
r
:
self
.
after_sample
(
x
,
t
,
c
,
u
,
r
)
,
lambda
x
,
mx
:
self
.
unipc_after_update
(
x
,
mx
)
)
self
.
mask
=
p
.
mask
if
hasattr
(
p
,
'mask'
)
else
None
self
.
nmask
=
p
.
nmask
if
hasattr
(
p
,
'nmask'
)
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment