Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
b29fc6d4
Commit
b29fc6d4
authored
Nov 11, 2023
by
aria1th
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Implement Hypertile
Co-Authored-By:
Kieran Hunt
<
kph@hotmail.ca
>
parent
294f8a51
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
358 additions
and
40 deletions
+358
-40
modules/hypertile.py
modules/hypertile.py
+333
-0
modules/processing.py
modules/processing.py
+25
-40
No files found.
modules/hypertile.py
0 → 100644
View file @
b29fc6d4
"""
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
Warn : The patch works well only if the input image has a width and height that are multiples of 128
Author : @tfernd Github : https://github.com/tfernd/HyperTile
"""
from
__future__
import
annotations
from
typing
import
Callable
from
typing_extensions
import
Literal
import
logging
from
functools
import
wraps
,
cache
from
contextlib
import
contextmanager
import
math
import
torch.nn
as
nn
import
random
from
einops
import
rearrange
# TODO add SD-XL layers
DEPTH_LAYERS
=
{
0
:
[
# SD 1.5 U-Net (diffusers)
"down_blocks.0.attentions.0.transformer_blocks.0.attn1"
,
"down_blocks.0.attentions.1.transformer_blocks.0.attn1"
,
"up_blocks.3.attentions.0.transformer_blocks.0.attn1"
,
"up_blocks.3.attentions.1.transformer_blocks.0.attn1"
,
"up_blocks.3.attentions.2.transformer_blocks.0.attn1"
,
# SD 1.5 U-Net (ldm)
"input_blocks.1.1.transformer_blocks.0.attn1"
,
"input_blocks.2.1.transformer_blocks.0.attn1"
,
"output_blocks.9.1.transformer_blocks.0.attn1"
,
"output_blocks.10.1.transformer_blocks.0.attn1"
,
"output_blocks.11.1.transformer_blocks.0.attn1"
,
# SD 1.5 VAE
"decoder.mid_block.attentions.0"
,
],
1
:
[
# SD 1.5 U-Net (diffusers)
"down_blocks.1.attentions.0.transformer_blocks.0.attn1"
,
"down_blocks.1.attentions.1.transformer_blocks.0.attn1"
,
"up_blocks.2.attentions.0.transformer_blocks.0.attn1"
,
"up_blocks.2.attentions.1.transformer_blocks.0.attn1"
,
"up_blocks.2.attentions.2.transformer_blocks.0.attn1"
,
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.0.attn1"
,
"input_blocks.5.1.transformer_blocks.0.attn1"
,
"output_blocks.6.1.transformer_blocks.0.attn1"
,
"output_blocks.7.1.transformer_blocks.0.attn1"
,
"output_blocks.8.1.transformer_blocks.0.attn1"
,
],
2
:
[
# SD 1.5 U-Net (diffusers)
"down_blocks.2.attentions.0.transformer_blocks.0.attn1"
,
"down_blocks.2.attentions.1.transformer_blocks.0.attn1"
,
"up_blocks.1.attentions.0.transformer_blocks.0.attn1"
,
"up_blocks.1.attentions.1.transformer_blocks.0.attn1"
,
"up_blocks.1.attentions.2.transformer_blocks.0.attn1"
,
# SD 1.5 U-Net (ldm)
"input_blocks.7.1.transformer_blocks.0.attn1"
,
"input_blocks.8.1.transformer_blocks.0.attn1"
,
"output_blocks.3.1.transformer_blocks.0.attn1"
,
"output_blocks.4.1.transformer_blocks.0.attn1"
,
"output_blocks.5.1.transformer_blocks.0.attn1"
,
],
3
:
[
# SD 1.5 U-Net (diffusers)
"mid_block.attentions.0.transformer_blocks.0.attn1"
,
# SD 1.5 U-Net (ldm)
"middle_block.1.transformer_blocks.0.attn1"
,
],
}
# XL layers, thanks for GitHub@gel-crabs for the help
DEPTH_LAYERS_XL
=
{
0
:
[
# SD 1.5 U-Net (diffusers)
"down_blocks.0.attentions.0.transformer_blocks.0.attn1"
,
"down_blocks.0.attentions.1.transformer_blocks.0.attn1"
,
"up_blocks.3.attentions.0.transformer_blocks.0.attn1"
,
"up_blocks.3.attentions.1.transformer_blocks.0.attn1"
,
"up_blocks.3.attentions.2.transformer_blocks.0.attn1"
,
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.0.attn1"
,
"input_blocks.5.1.transformer_blocks.0.attn1"
,
"output_blocks.3.1.transformer_blocks.0.attn1"
,
"output_blocks.4.1.transformer_blocks.0.attn1"
,
"output_blocks.5.1.transformer_blocks.0.attn1"
,
# SD 1.5 VAE
"decoder.mid_block.attentions.0"
,
"decoder.mid.attn_1"
,
],
1
:
[
# SD 1.5 U-Net (diffusers)
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.1.attn1"
,
"input_blocks.5.1.transformer_blocks.1.attn1"
,
"output_blocks.3.1.transformer_blocks.1.attn1"
,
"output_blocks.4.1.transformer_blocks.1.attn1"
,
"output_blocks.5.1.transformer_blocks.1.attn1"
,
"input_blocks.7.1.transformer_blocks.0.attn1"
,
"input_blocks.8.1.transformer_blocks.0.attn1"
,
"output_blocks.0.1.transformer_blocks.0.attn1"
,
"output_blocks.1.1.transformer_blocks.0.attn1"
,
"output_blocks.2.1.transformer_blocks.0.attn1"
,
"input_blocks.7.1.transformer_blocks.1.attn1"
,
"input_blocks.8.1.transformer_blocks.1.attn1"
,
"output_blocks.0.1.transformer_blocks.1.attn1"
,
"output_blocks.1.1.transformer_blocks.1.attn1"
,
"output_blocks.2.1.transformer_blocks.1.attn1"
,
"input_blocks.7.1.transformer_blocks.2.attn1"
,
"input_blocks.8.1.transformer_blocks.2.attn1"
,
"output_blocks.0.1.transformer_blocks.2.attn1"
,
"output_blocks.1.1.transformer_blocks.2.attn1"
,
"output_blocks.2.1.transformer_blocks.2.attn1"
,
"input_blocks.7.1.transformer_blocks.3.attn1"
,
"input_blocks.8.1.transformer_blocks.3.attn1"
,
"output_blocks.0.1.transformer_blocks.3.attn1"
,
"output_blocks.1.1.transformer_blocks.3.attn1"
,
"output_blocks.2.1.transformer_blocks.3.attn1"
,
"input_blocks.7.1.transformer_blocks.4.attn1"
,
"input_blocks.8.1.transformer_blocks.4.attn1"
,
"output_blocks.0.1.transformer_blocks.4.attn1"
,
"output_blocks.1.1.transformer_blocks.4.attn1"
,
"output_blocks.2.1.transformer_blocks.4.attn1"
,
"input_blocks.7.1.transformer_blocks.5.attn1"
,
"input_blocks.8.1.transformer_blocks.5.attn1"
,
"output_blocks.0.1.transformer_blocks.5.attn1"
,
"output_blocks.1.1.transformer_blocks.5.attn1"
,
"output_blocks.2.1.transformer_blocks.5.attn1"
,
"input_blocks.7.1.transformer_blocks.6.attn1"
,
"input_blocks.8.1.transformer_blocks.6.attn1"
,
"output_blocks.0.1.transformer_blocks.6.attn1"
,
"output_blocks.1.1.transformer_blocks.6.attn1"
,
"output_blocks.2.1.transformer_blocks.6.attn1"
,
"input_blocks.7.1.transformer_blocks.7.attn1"
,
"input_blocks.8.1.transformer_blocks.7.attn1"
,
"output_blocks.0.1.transformer_blocks.7.attn1"
,
"output_blocks.1.1.transformer_blocks.7.attn1"
,
"output_blocks.2.1.transformer_blocks.7.attn1"
,
"input_blocks.7.1.transformer_blocks.8.attn1"
,
"input_blocks.8.1.transformer_blocks.8.attn1"
,
"output_blocks.0.1.transformer_blocks.8.attn1"
,
"output_blocks.1.1.transformer_blocks.8.attn1"
,
"output_blocks.2.1.transformer_blocks.8.attn1"
,
"input_blocks.7.1.transformer_blocks.9.attn1"
,
"input_blocks.8.1.transformer_blocks.9.attn1"
,
"output_blocks.0.1.transformer_blocks.9.attn1"
,
"output_blocks.1.1.transformer_blocks.9.attn1"
,
"output_blocks.2.1.transformer_blocks.9.attn1"
,
],
2
:
[
# SD 1.5 U-Net (diffusers)
"mid_block.attentions.0.transformer_blocks.0.attn1"
,
# SD 1.5 U-Net (ldm)
"middle_block.1.transformer_blocks.0.attn1"
,
"middle_block.1.transformer_blocks.1.attn1"
,
"middle_block.1.transformer_blocks.2.attn1"
,
"middle_block.1.transformer_blocks.3.attn1"
,
"middle_block.1.transformer_blocks.4.attn1"
,
"middle_block.1.transformer_blocks.5.attn1"
,
"middle_block.1.transformer_blocks.6.attn1"
,
"middle_block.1.transformer_blocks.7.attn1"
,
"middle_block.1.transformer_blocks.8.attn1"
,
"middle_block.1.transformer_blocks.9.attn1"
,
],
}
RNG_INSTANCE
=
random
.
Random
()
def
random_divisor
(
value
:
int
,
min_value
:
int
,
/
,
max_options
:
int
=
1
)
->
int
:
"""
Returns a random divisor of value that
x * min_value <= value
if max_options is 1, the behavior is deterministic
"""
min_value
=
min
(
min_value
,
value
)
# All big divisors of value (inclusive)
divisors
=
[
i
for
i
in
range
(
min_value
,
value
+
1
)
if
value
%
i
==
0
]
# divisors in small -> big order
ns
=
[
value
//
i
for
i
in
divisors
[:
max_options
]]
# has at least 1 element # big -> small order
idx
=
RNG_INSTANCE
.
randint
(
0
,
len
(
ns
)
-
1
)
return
ns
[
idx
]
def
set_hypertile_seed
(
seed
:
int
)
->
None
:
RNG_INSTANCE
.
seed
(
seed
)
def
largest_tile_size_available
(
width
:
int
,
height
:
int
)
->
int
:
"""
Calculates the largest tile size available for a given width and height
Tile size is always a power of 2
"""
gcd
=
math
.
gcd
(
width
,
height
)
largest_tile_size_available
=
1
while
gcd
%
(
largest_tile_size_available
*
2
)
==
0
:
largest_tile_size_available
*=
2
return
largest_tile_size_available
def
iterative_closest_divisors
(
hw
:
int
,
aspect_ratio
:
float
)
->
tuple
[
int
,
int
]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
We check all possible divisors of hw and return the closest to the aspect ratio
"""
divisors
=
[
i
for
i
in
range
(
2
,
hw
+
1
)
if
hw
%
i
==
0
]
# all divisors of hw
pairs
=
[(
i
,
hw
//
i
)
for
i
in
divisors
]
# all pairs of divisors of hw
ratios
=
[
w
/
h
for
h
,
w
in
pairs
]
# all ratios of pairs of divisors of hw
closest_ratio
=
min
(
ratios
,
key
=
lambda
x
:
abs
(
x
-
aspect_ratio
))
# closest ratio to aspect_ratio
closest_pair
=
pairs
[
ratios
.
index
(
closest_ratio
)]
# closest pair of divisors to aspect_ratio
return
closest_pair
@
cache
def
find_hw_candidates
(
hw
:
int
,
aspect_ratio
:
float
)
->
tuple
[
int
,
int
]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
"""
h
,
w
=
round
(
math
.
sqrt
(
hw
*
aspect_ratio
)),
round
(
math
.
sqrt
(
hw
/
aspect_ratio
))
# find h and w such that h*w = hw and h/w = aspect_ratio
if
h
*
w
!=
hw
:
w_candidate
=
hw
/
h
# check if w is an integer
if
not
w_candidate
.
is_integer
():
h_candidate
=
hw
/
w
# check if h is an integer
if
not
h_candidate
.
is_integer
():
return
iterative_closest_divisors
(
hw
,
aspect_ratio
)
else
:
h
=
int
(
h_candidate
)
else
:
w
=
int
(
w_candidate
)
return
h
,
w
@
contextmanager
def
split_attention
(
layer
:
nn
.
Module
,
/
,
aspect_ratio
:
float
,
# width/height
tile_size
:
int
=
128
,
# 128 for VAE
swap_size
:
int
=
1
,
# 1 for VAE
*
,
disable
:
bool
=
False
,
max_depth
:
Literal
[
0
,
1
,
2
,
3
]
=
0
,
# ! Try 0 or 1
scale_depth
:
bool
=
True
,
# scale the tile-size depending on the depth
is_sdxl
:
bool
=
False
,
# is the model SD-XL
):
# Hijacks AttnBlock from ldm and Attention from diffusers
if
disable
:
logging
.
info
(
f
"Attention for {layer.__class__.__qualname__} not splitted"
)
yield
return
latent_tile_size
=
max
(
128
,
tile_size
)
//
8
def
self_attn_forward
(
forward
:
Callable
,
depth
:
int
,
layer_name
:
str
,
module
:
nn
.
Module
)
->
Callable
:
@
wraps
(
forward
)
def
wrapper
(
*
args
,
**
kwargs
):
x
=
args
[
0
]
# VAE
if
x
.
ndim
==
4
:
b
,
c
,
h
,
w
=
x
.
shape
nh
=
random_divisor
(
h
,
latent_tile_size
,
swap_size
)
nw
=
random_divisor
(
w
,
latent_tile_size
,
swap_size
)
if
nh
*
nw
>
1
:
x
=
rearrange
(
x
,
"b c (nh h) (nw w) -> (b nh nw) c h w"
,
nh
=
nh
,
nw
=
nw
)
# split into nh * nw tiles
out
=
forward
(
x
,
*
args
[
1
:],
**
kwargs
)
if
nh
*
nw
>
1
:
out
=
rearrange
(
out
,
"(b nh nw) c h w -> b c (nh h) (nw w)"
,
nh
=
nh
,
nw
=
nw
)
# U-Net
else
:
hw
:
int
=
x
.
size
(
1
)
h
,
w
=
find_hw_candidates
(
hw
,
aspect_ratio
)
assert
h
*
w
==
hw
,
f
"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
factor
=
2
**
depth
if
scale_depth
else
1
nh
=
random_divisor
(
h
,
latent_tile_size
*
factor
,
swap_size
)
nw
=
random_divisor
(
w
,
latent_tile_size
*
factor
,
swap_size
)
module
.
_split_sizes_hypertile
.
append
((
nh
,
nw
))
# type: ignore
if
nh
*
nw
>
1
:
x
=
rearrange
(
x
,
"b (nh h nw w) c -> (b nh nw) (h w) c"
,
h
=
h
//
nh
,
w
=
w
//
nw
,
nh
=
nh
,
nw
=
nw
)
out
=
forward
(
x
,
*
args
[
1
:],
**
kwargs
)
if
nh
*
nw
>
1
:
out
=
rearrange
(
out
,
"(b nh nw) hw c -> b nh nw hw c"
,
nh
=
nh
,
nw
=
nw
)
out
=
rearrange
(
out
,
"b nh nw (h w) c -> b (nh h nw w) c"
,
h
=
h
//
nh
,
w
=
w
//
nw
)
return
out
return
wrapper
# Handle hijacking the forward method and recovering afterwards
try
:
if
is_sdxl
:
layers
=
DEPTH_LAYERS_XL
else
:
layers
=
DEPTH_LAYERS
for
depth
in
range
(
max_depth
+
1
):
for
layer_name
,
module
in
layer
.
named_modules
():
if
any
(
layer_name
.
endswith
(
try_name
)
for
try_name
in
layers
[
depth
]):
# print input shape for debugging
logging
.
debug
(
f
"HyperTile hijacking attention layer at depth {depth}: {layer_name}"
)
# hijack
module
.
_original_forward_hypertile
=
module
.
forward
module
.
forward
=
self_attn_forward
(
module
.
forward
,
depth
,
layer_name
,
module
)
module
.
_split_sizes_hypertile
=
[]
yield
finally
:
for
layer_name
,
module
in
layer
.
named_modules
():
# remove hijack
if
hasattr
(
module
,
"_original_forward_hypertile"
):
if
module
.
_split_sizes_hypertile
:
logging
.
debug
(
f
"layer {layer_name} splitted with ({module._split_sizes_hypertile})"
)
# recover
module
.
forward
=
module
.
_original_forward_hypertile
del
module
.
_original_forward_hypertile
del
module
.
_split_sizes_hypertile
modules/processing.py
View file @
b29fc6d4
...
...
@@ -24,6 +24,7 @@ from modules.shared import opts, cmd_opts, state
import
modules.shared
as
shared
import
modules.paths
as
paths
import
modules.face_restoration
from
modules.hypertile
import
split_attention
,
set_hypertile_seed
,
largest_tile_size_available
import
modules.images
as
images
import
modules.styles
import
modules.sd_models
as
sd_models
...
...
@@ -799,17 +800,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts
=
[]
output_images
=
[]
unet_object
=
p
.
sd_model
.
model
vae_model
=
p
.
sd_model
.
first_stage_model
try
:
from
hyper_tile
import
split_attention
,
flush
except
(
ImportError
,
ModuleNotFoundError
):
# pip install git+https://github.com/tfernd/HyperTile@2ef64b2800d007d305755c33550537410310d7df
split_attention
=
lambda
*
args
,
**
kwargs
:
lambda
x
:
x
# return a no-op context manager
flush
=
lambda
:
None
import
random
saved_rng_state
=
random
.
getstate
()
random
.
seed
(
p
.
seed
)
# hyper_tile uses random, so we need to seed it
with
torch
.
no_grad
(),
p
.
sd_model
.
ema_scope
():
with
devices
.
autocast
():
p
.
init
(
p
.
all_prompts
,
p
.
all_seeds
,
p
.
all_subseeds
)
...
...
@@ -871,20 +861,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p
.
comment
(
comment
)
p
.
extra_generation_params
.
update
(
model_hijack
.
extra_generation_params
)
set_hypertile_seed
(
p
.
seed
)
# add batch size + hypertile status to information to reproduce the run
if
p
.
n_iter
>
1
:
shared
.
state
.
job
=
f
"Batch {n+1} out of {p.n_iter}"
with
devices
.
without_autocast
()
if
devices
.
unet_needs_upcast
else
devices
.
autocast
():
# get largest tile size available, which is 2^x which is factor of gcd of p.width and p.height
gcd
=
math
.
gcd
(
p
.
width
,
p
.
height
)
largest_tile_size_available
=
1
while
gcd
%
(
largest_tile_size_available
*
2
)
==
0
:
largest_tile_size_available
*=
2
aspect_ratio
=
p
.
width
/
p
.
height
with
split_attention
(
vae_model
,
aspect_ratio
=
aspect_ratio
,
tile_size
=
min
(
largest_tile_size_available
,
128
),
disable
=
not
shared
.
opts
.
hypertile_split_vae_attn
):
with
split_attention
(
unet_object
,
aspect_ratio
=
aspect_ratio
,
tile_size
=
min
(
largest_tile_size_available
,
256
),
swap_size
=
2
,
disable
=
not
shared
.
opts
.
hypertile_split_unet_attn
):
flush
()
samples_ddim
=
p
.
sample
(
conditioning
=
p
.
c
,
unconditional_conditioning
=
p
.
uc
,
seeds
=
p
.
seeds
,
subseeds
=
p
.
subseeds
,
subseed_strength
=
p
.
subseed_strength
,
prompts
=
p
.
prompts
)
if
getattr
(
samples_ddim
,
'already_decoded'
,
False
):
...
...
@@ -892,8 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else
:
if
opts
.
sd_vae_decode_method
!=
'Full'
:
p
.
extra_generation_params
[
'VAE Decoder'
]
=
opts
.
sd_vae_decode_method
with
split_attention
(
vae_model
,
aspect_ratio
=
aspect_ratio
,
tile_size
=
min
(
largest_tile_size_available
,
128
),
disable
=
not
shared
.
opts
.
hypertile_split_vae_attn
):
flush
()
with
split_attention
(
p
.
sd_model
.
first_stage_model
,
aspect_ratio
=
p
.
width
/
p
.
height
,
tile_size
=
min
(
largest_tile_size_available
(
p
.
width
,
p
.
height
),
128
),
disable
=
not
shared
.
opts
.
hypertile_split_vae_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
x_samples_ddim
=
decode_latent_batch
(
p
.
sd_model
,
samples_ddim
,
target_device
=
devices
.
cpu
,
check_for_nans
=
True
)
x_samples_ddim
=
torch
.
stack
(
x_samples_ddim
)
.
float
()
...
...
@@ -1000,7 +981,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if
opts
.
grid_save
:
images
.
save_image
(
grid
,
p
.
outpath_grids
,
"grid"
,
p
.
all_seeds
[
0
],
p
.
all_prompts
[
0
],
opts
.
grid_format
,
info
=
infotext
(
use_main_prompt
=
True
),
short_filename
=
not
opts
.
grid_extended_filename
,
p
=
p
,
grid
=
True
)
random
.
setstate
(
saved_rng_state
)
if
not
p
.
disable_extra_networks
and
p
.
extra_network_data
:
extra_networks
.
deactivate
(
p
,
p
.
extra_network_data
)
...
...
@@ -1161,24 +1141,25 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def
sample
(
self
,
conditioning
,
unconditional_conditioning
,
seeds
,
subseeds
,
subseed_strength
,
prompts
):
self
.
sampler
=
sd_samplers
.
create_sampler
(
self
.
sampler_name
,
self
.
sd_model
)
aspect_ratio
=
self
.
width
/
self
.
height
x
=
self
.
rng
.
next
()
tile_size
=
largest_tile_size_available
(
self
.
width
,
self
.
height
)
with
split_attention
(
self
.
sd_model
.
first_stage_model
,
aspect_ratio
=
aspect_ratio
,
tile_size
=
min
(
tile_size
,
128
),
swap_size
=
1
,
disable
=
not
shared
.
opts
.
hypertile_split_vae_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
with
split_attention
(
self
.
sd_model
.
model
,
aspect_ratio
=
aspect_ratio
,
tile_size
=
min
(
tile_size
,
256
),
swap_size
=
2
,
disable
=
not
shared
.
opts
.
hypertile_split_unet_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
devices
.
torch_gc
()
samples
=
self
.
sampler
.
sample
(
self
,
x
,
conditioning
,
unconditional_conditioning
,
image_conditioning
=
self
.
txt2img_image_conditioning
(
x
))
del
x
if
not
self
.
enable_hr
:
return
samples
if
self
.
latent_scale_mode
is
None
:
with
split_attention
(
self
.
sd_model
.
first_stage_model
,
aspect_ratio
=
aspect_ratio
,
tile_size
=
min
(
tile_size
,
256
),
swap_size
=
1
,
disable
=
not
shared
.
opts
.
hypertile_split_vae_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
decoded_samples
=
torch
.
stack
(
decode_latent_batch
(
self
.
sd_model
,
samples
,
target_device
=
devices
.
cpu
,
check_for_nans
=
True
))
.
to
(
dtype
=
torch
.
float32
)
else
:
decoded_samples
=
None
with
sd_models
.
SkipWritingToConfig
():
sd_models
.
reload_model_weights
(
info
=
self
.
hr_checkpoint_info
)
devices
.
torch_gc
()
return
self
.
sample_hr_pass
(
samples
,
decoded_samples
,
seeds
,
subseeds
,
subseed_strength
,
prompts
)
def
sample_hr_pass
(
self
,
samples
,
decoded_samples
,
seeds
,
subseeds
,
subseed_strength
,
prompts
):
...
...
@@ -1186,7 +1167,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return
samples
self
.
is_hr_pass
=
True
target_width
=
self
.
hr_upscale_to_x
target_height
=
self
.
hr_upscale_to_y
...
...
@@ -1264,18 +1244,19 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if
self
.
scripts
is
not
None
:
self
.
scripts
.
before_hr
(
self
)
tile_size
=
largest_tile_size_available
(
target_width
,
target_height
)
with
split_attention
(
self
.
sd_model
.
first_stage_model
,
aspect_ratio
=
target_width
/
target_height
,
tile_size
=
min
(
tile_size
,
256
),
swap_size
=
1
,
disable
=
not
opts
.
hypertile_split_vae_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
with
split_attention
(
self
.
sd_model
.
model
,
aspect_ratio
=
target_width
/
target_height
,
tile_size
=
min
(
tile_size
,
256
),
swap_size
=
3
,
max_depth
=
1
,
scale_depth
=
True
,
disable
=
not
opts
.
hypertile_split_unet_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
samples
=
self
.
sampler
.
sample_img2img
(
self
,
samples
,
noise
,
self
.
hr_c
,
self
.
hr_uc
,
steps
=
self
.
hr_second_pass_steps
or
self
.
steps
,
image_conditioning
=
image_conditioning
)
sd_models
.
apply_token_merging
(
self
.
sd_model
,
self
.
get_token_merging_ratio
())
self
.
sampler
=
None
devices
.
torch_gc
()
with
split_attention
(
self
.
sd_model
.
first_stage_model
,
aspect_ratio
=
target_width
/
target_height
,
tile_size
=
min
(
tile_size
,
256
),
swap_size
=
1
,
disable
=
not
opts
.
hypertile_split_vae_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
decoded_samples
=
decode_latent_batch
(
self
.
sd_model
,
samples
,
target_device
=
devices
.
cpu
,
check_for_nans
=
True
)
self
.
is_hr_pass
=
False
return
decoded_samples
def
close
(
self
):
...
...
@@ -1550,7 +1531,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if
self
.
initial_noise_multiplier
!=
1.0
:
self
.
extra_generation_params
[
"Noise multiplier"
]
=
self
.
initial_noise_multiplier
x
*=
self
.
initial_noise_multiplier
aspect_ratio
=
self
.
width
/
self
.
height
tile_size
=
largest_tile_size_available
(
self
.
width
,
self
.
height
)
with
split_attention
(
self
.
sd_model
.
first_stage_model
,
aspect_ratio
=
aspect_ratio
,
tile_size
=
min
(
tile_size
,
128
),
swap_size
=
1
,
disable
=
not
shared
.
opts
.
hypertile_split_vae_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
with
split_attention
(
self
.
sd_model
.
model
,
aspect_ratio
=
aspect_ratio
,
tile_size
=
min
(
tile_size
,
256
),
swap_size
=
2
,
disable
=
not
shared
.
opts
.
hypertile_split_unet_attn
,
is_sdxl
=
shared
.
sd_model
.
is_sdxl
):
devices
.
torch_gc
()
samples
=
self
.
sampler
.
sample_img2img
(
self
,
self
.
init_latent
,
x
,
conditioning
,
unconditional_conditioning
,
image_conditioning
=
self
.
image_conditioning
)
if
self
.
mask
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment