Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
5b2a60b8
Commit
5b2a60b8
authored
Jun 16, 2024
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
initial SD3 support
parent
a7116aa9
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
333 additions
and
44 deletions
+333
-44
README.md
README.md
+1
-1
configs/sd3-inference.yaml
configs/sd3-inference.yaml
+5
-0
extensions-builtin/Lora/networks.py
extensions-builtin/Lora/networks.py
+3
-1
modules/models/sd3/mmdit.py
modules/models/sd3/mmdit.py
+2
-1
modules/models/sd3/sd3_impls.py
modules/models/sd3/sd3_impls.py
+7
-7
modules/models/sd3/sd3_model.py
modules/models/sd3/sd3_model.py
+166
-0
modules/processing.py
modules/processing.py
+2
-1
modules/sd_models.py
modules/sd_models.py
+74
-13
modules/sd_models_config.py
modules/sd_models_config.py
+6
-1
modules/sd_models_types.py
modules/sd_models_types.py
+6
-0
modules/sd_samplers_common.py
modules/sd_samplers_common.py
+2
-2
modules/sd_samplers_kdiffusion.py
modules/sd_samplers_kdiffusion.py
+7
-2
modules/sd_vae_approx.py
modules/sd_vae_approx.py
+22
-5
modules/sd_vae_taesd.py
modules/sd_vae_taesd.py
+30
-10
No files found.
README.md
View file @
5b2a60b8
...
@@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
...
@@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits
## Credits
Licenses for borrowed code can be found in
`Settings -> Licenses`
screen, and also in
`html/licenses.html`
file.
Licenses for borrowed code can be found in
`Settings -> Licenses`
screen, and also in
`html/licenses.html`
file.
-
Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
-
Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers
, https://github.com/mcmonkey4eva/sd3-ref
-
k-diffusion - https://github.com/crowsonkb/k-diffusion.git
-
k-diffusion - https://github.com/crowsonkb/k-diffusion.git
-
Spandrel - https://github.com/chaiNNer-org/spandrel implementing
-
Spandrel - https://github.com/chaiNNer-org/spandrel implementing
-
GFPGAN - https://github.com/TencentARC/GFPGAN.git
-
GFPGAN - https://github.com/TencentARC/GFPGAN.git
...
...
configs/sd3-inference.yaml
0 → 100644
View file @
5b2a60b8
model
:
target
:
modules.models.sd3.sd3_model.SD3Inferencer
params
:
shift
:
3
state_dict
:
null
extensions-builtin/Lora/networks.py
View file @
5b2a60b8
...
@@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
...
@@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping
[
network_name
]
=
module
network_layer_mapping
[
network_name
]
=
module
module
.
network_layer_name
=
network_name
module
.
network_layer_name
=
network_name
else
:
else
:
for
name
,
module
in
shared
.
sd_model
.
cond_stage_model
.
wrapped
.
named_modules
():
cond_stage_model
=
getattr
(
shared
.
sd_model
.
cond_stage_model
,
'wrapped'
,
shared
.
sd_model
.
cond_stage_model
)
for
name
,
module
in
cond_stage_model
.
named_modules
():
network_name
=
name
.
replace
(
"."
,
"_"
)
network_name
=
name
.
replace
(
"."
,
"_"
)
network_layer_mapping
[
network_name
]
=
module
network_layer_mapping
[
network_name
]
=
module
module
.
network_layer_name
=
network_name
module
.
network_layer_name
=
network_name
...
...
modules/models/sd3/mmdit.py
View file @
5b2a60b8
...
@@ -6,7 +6,8 @@ import numpy as np
...
@@ -6,7 +6,8 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
other_impls
import
attention
,
Mlp
from
modules.models.sd3.other_impls
import
attention
,
Mlp
class
PatchEmbed
(
nn
.
Module
):
class
PatchEmbed
(
nn
.
Module
):
""" 2D Image to Patch Embedding"""
""" 2D Image to Patch Embedding"""
...
...
modules/models/sd3/sd3_impls.py
View file @
5b2a60b8
### Impls of the SD3 core diffusion model and VAE
### Impls of the SD3 core diffusion model and VAE
import
torch
,
math
,
einops
import
torch
,
math
,
einops
from
mmdit
import
MMDiT
from
m
odules.models.sd3.m
mdit
import
MMDiT
from
PIL
import
Image
from
PIL
import
Image
...
@@ -46,16 +46,16 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
...
@@ -46,16 +46,16 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
class
BaseModel
(
torch
.
nn
.
Module
):
class
BaseModel
(
torch
.
nn
.
Module
):
"""Wrapper around the core MM-DiT model"""
"""Wrapper around the core MM-DiT model"""
def
__init__
(
self
,
shift
=
1.0
,
device
=
None
,
dtype
=
torch
.
float32
,
file
=
None
,
prefix
=
""
):
def
__init__
(
self
,
shift
=
1.0
,
device
=
None
,
dtype
=
torch
.
float32
,
state_dict
=
None
,
prefix
=
""
):
super
()
.
__init__
()
super
()
.
__init__
()
# Important configuration values can be quickly determined by checking shapes in the source file
# Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
patch_size
=
file
.
get_tensor
(
f
"{prefix}x_embedder.proj.weight"
)
.
shape
[
2
]
patch_size
=
state_dict
[
f
"{prefix}x_embedder.proj.weight"
]
.
shape
[
2
]
depth
=
file
.
get_tensor
(
f
"{prefix}x_embedder.proj.weight"
)
.
shape
[
0
]
//
64
depth
=
state_dict
[
f
"{prefix}x_embedder.proj.weight"
]
.
shape
[
0
]
//
64
num_patches
=
file
.
get_tensor
(
f
"{prefix}pos_embed"
)
.
shape
[
1
]
num_patches
=
state_dict
[
f
"{prefix}pos_embed"
]
.
shape
[
1
]
pos_embed_max_size
=
round
(
math
.
sqrt
(
num_patches
))
pos_embed_max_size
=
round
(
math
.
sqrt
(
num_patches
))
adm_in_channels
=
file
.
get_tensor
(
f
"{prefix}y_embedder.mlp.0.weight"
)
.
shape
[
1
]
adm_in_channels
=
state_dict
[
f
"{prefix}y_embedder.mlp.0.weight"
]
.
shape
[
1
]
context_shape
=
file
.
get_tensor
(
f
"{prefix}context_embedder.weight"
)
.
shape
context_shape
=
state_dict
[
f
"{prefix}context_embedder.weight"
]
.
shape
context_embedder_config
=
{
context_embedder_config
=
{
"target"
:
"torch.nn.Linear"
,
"target"
:
"torch.nn.Linear"
,
"params"
:
{
"params"
:
{
...
...
modules/models/sd3/sd3_model.py
0 → 100644
View file @
5b2a60b8
import
contextlib
import
os
from
typing
import
Mapping
import
safetensors
import
torch
import
k_diffusion
from
modules.models.sd3.other_impls
import
SDClipModel
,
SDXLClipG
,
T5XXLModel
,
SD3Tokenizer
from
modules.models.sd3.sd3_impls
import
BaseModel
,
SDVAE
,
SD3LatentFormat
from
modules
import
shared
,
modelloader
,
devices
CLIPG_URL
=
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_g.safetensors"
CLIPG_CONFIG
=
{
"hidden_act"
:
"gelu"
,
"hidden_size"
:
1280
,
"intermediate_size"
:
5120
,
"num_attention_heads"
:
20
,
"num_hidden_layers"
:
32
,
}
CLIPL_URL
=
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/clip_l.safetensors"
CLIPL_CONFIG
=
{
"hidden_act"
:
"quick_gelu"
,
"hidden_size"
:
768
,
"intermediate_size"
:
3072
,
"num_attention_heads"
:
12
,
"num_hidden_layers"
:
12
,
}
T5_URL
=
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/resolve/main/text_encoders/t5xxl_fp16.safetensors"
T5_CONFIG
=
{
"d_ff"
:
10240
,
"d_model"
:
4096
,
"num_heads"
:
64
,
"num_layers"
:
24
,
"vocab_size"
:
32128
,
}
class
SafetensorsMapping
(
Mapping
):
def
__init__
(
self
,
file
):
self
.
file
=
file
def
__len__
(
self
):
return
len
(
self
.
file
.
keys
())
def
__iter__
(
self
):
for
key
in
self
.
file
.
keys
():
yield
key
def
__getitem__
(
self
,
key
):
return
self
.
file
.
get_tensor
(
key
)
class
SD3Cond
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
()
.
__init__
(
*
args
,
**
kwargs
)
self
.
tokenizer
=
SD3Tokenizer
()
with
torch
.
no_grad
():
self
.
clip_g
=
SDXLClipG
(
CLIPG_CONFIG
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
self
.
clip_l
=
SDClipModel
(
layer
=
"hidden"
,
layer_idx
=-
2
,
device
=
"cpu"
,
dtype
=
torch
.
float32
,
layer_norm_hidden_state
=
False
,
return_projected_pooled
=
False
,
textmodel_json_config
=
CLIPL_CONFIG
)
self
.
t5xxl
=
T5XXLModel
(
T5_CONFIG
,
device
=
"cpu"
,
dtype
=
torch
.
float32
)
self
.
weights_loaded
=
False
def
forward
(
self
,
prompts
:
list
[
str
]):
res
=
[]
for
prompt
in
prompts
:
tokens
=
self
.
tokenizer
.
tokenize_with_weights
(
prompt
)
l_out
,
l_pooled
=
self
.
clip_l
.
encode_token_weights
(
tokens
[
"l"
])
g_out
,
g_pooled
=
self
.
clip_g
.
encode_token_weights
(
tokens
[
"g"
])
t5_out
,
t5_pooled
=
self
.
t5xxl
.
encode_token_weights
(
tokens
[
"t5xxl"
])
lg_out
=
torch
.
cat
([
l_out
,
g_out
],
dim
=-
1
)
lg_out
=
torch
.
nn
.
functional
.
pad
(
lg_out
,
(
0
,
4096
-
lg_out
.
shape
[
-
1
]))
lgt_out
=
torch
.
cat
([
lg_out
,
t5_out
],
dim
=-
2
)
vector_out
=
torch
.
cat
((
l_pooled
,
g_pooled
),
dim
=-
1
)
res
.
append
({
'crossattn'
:
lgt_out
[
0
]
.
to
(
devices
.
device
),
'vector'
:
vector_out
[
0
]
.
to
(
devices
.
device
),
})
return
res
def
load_weights
(
self
):
if
self
.
weights_loaded
:
return
clip_path
=
os
.
path
.
join
(
shared
.
models_path
,
"CLIP"
)
clip_g_file
=
modelloader
.
load_file_from_url
(
CLIPG_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_g.safetensors"
)
with
safetensors
.
safe_open
(
clip_g_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_g
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
))
clip_l_file
=
modelloader
.
load_file_from_url
(
CLIPL_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_l.safetensors"
)
with
safetensors
.
safe_open
(
clip_l_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_l
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
t5_file
=
modelloader
.
load_file_from_url
(
T5_URL
,
model_dir
=
clip_path
,
file_name
=
"t5xxl_fp16.safetensors"
)
with
safetensors
.
safe_open
(
t5_file
,
framework
=
"pt"
)
as
file
:
self
.
t5xxl
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
self
.
weights_loaded
=
True
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
return
torch
.
tensor
([[
0
]],
device
=
devices
.
device
)
# XXX
class
SD3Denoiser
(
k_diffusion
.
external
.
DiscreteSchedule
):
def
__init__
(
self
,
inner_model
,
sigmas
):
super
()
.
__init__
(
sigmas
,
quantize
=
shared
.
opts
.
enable_quantization
)
self
.
inner_model
=
inner_model
def
forward
(
self
,
input
,
sigma
,
**
kwargs
):
return
self
.
inner_model
.
apply_model
(
input
,
sigma
,
**
kwargs
)
class
SD3Inferencer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
state_dict
,
shift
=
3
,
use_ema
=
False
):
super
()
.
__init__
()
self
.
shift
=
shift
with
torch
.
no_grad
():
self
.
model
=
BaseModel
(
shift
=
shift
,
state_dict
=
state_dict
,
prefix
=
"model.diffusion_model."
,
device
=
"cpu"
,
dtype
=
devices
.
dtype
)
self
.
first_stage_model
=
SDVAE
(
device
=
"cpu"
,
dtype
=
devices
.
dtype_vae
)
self
.
first_stage_model
.
dtype
=
self
.
model
.
diffusion_model
.
dtype
self
.
alphas_cumprod
=
1
/
(
self
.
model
.
model_sampling
.
sigmas
**
2
+
1
)
self
.
cond_stage_model
=
SD3Cond
()
self
.
cond_stage_key
=
'txt'
self
.
parameterization
=
"eps"
self
.
model
.
conditioning_key
=
"crossattn"
self
.
latent_format
=
SD3LatentFormat
()
self
.
latent_channels
=
16
def
after_load_weights
(
self
):
self
.
cond_stage_model
.
load_weights
()
def
ema_scope
(
self
):
return
contextlib
.
nullcontext
()
def
get_learned_conditioning
(
self
,
batch
:
list
[
str
]):
return
self
.
cond_stage_model
(
batch
)
def
apply_model
(
self
,
x
,
t
,
cond
):
return
self
.
model
.
apply_model
(
x
,
t
,
c_crossattn
=
cond
[
'crossattn'
],
y
=
cond
[
'vector'
])
def
decode_first_stage
(
self
,
latent
):
latent
=
self
.
latent_format
.
process_out
(
latent
)
return
self
.
first_stage_model
.
decode
(
latent
)
def
encode_first_stage
(
self
,
image
):
latent
=
self
.
first_stage_model
.
encode
(
image
)
return
self
.
latent_format
.
process_in
(
latent
)
def
create_denoiser
(
self
):
return
SD3Denoiser
(
self
,
self
.
model
.
model_sampling
.
sigmas
)
modules/processing.py
View file @
5b2a60b8
...
@@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
...
@@ -942,7 +942,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p
.
seeds
=
p
.
all_seeds
[
n
*
p
.
batch_size
:(
n
+
1
)
*
p
.
batch_size
]
p
.
seeds
=
p
.
all_seeds
[
n
*
p
.
batch_size
:(
n
+
1
)
*
p
.
batch_size
]
p
.
subseeds
=
p
.
all_subseeds
[
n
*
p
.
batch_size
:(
n
+
1
)
*
p
.
batch_size
]
p
.
subseeds
=
p
.
all_subseeds
[
n
*
p
.
batch_size
:(
n
+
1
)
*
p
.
batch_size
]
p
.
rng
=
rng
.
ImageRNG
((
opt_C
,
p
.
height
//
opt_f
,
p
.
width
//
opt_f
),
p
.
seeds
,
subseeds
=
p
.
subseeds
,
subseed_strength
=
p
.
subseed_strength
,
seed_resize_from_h
=
p
.
seed_resize_from_h
,
seed_resize_from_w
=
p
.
seed_resize_from_w
)
latent_channels
=
getattr
(
shared
.
sd_model
,
'latent_channels'
,
opt_C
)
p
.
rng
=
rng
.
ImageRNG
((
latent_channels
,
p
.
height
//
opt_f
,
p
.
width
//
opt_f
),
p
.
seeds
,
subseeds
=
p
.
subseeds
,
subseed_strength
=
p
.
subseed_strength
,
seed_resize_from_h
=
p
.
seed_resize_from_h
,
seed_resize_from_w
=
p
.
seed_resize_from_w
)
if
p
.
scripts
is
not
None
:
if
p
.
scripts
is
not
None
:
p
.
scripts
.
before_process_batch
(
p
,
batch_number
=
n
,
prompts
=
p
.
prompts
,
seeds
=
p
.
seeds
,
subseeds
=
p
.
subseeds
)
p
.
scripts
.
before_process_batch
(
p
,
batch_number
=
n
,
prompts
=
p
.
prompts
,
seeds
=
p
.
seeds
,
subseeds
=
p
.
subseeds
)
...
...
modules/sd_models.py
View file @
5b2a60b8
import
collections
import
collections
import
importlib
import
os
import
os
import
sys
import
sys
import
threading
import
threading
import
enum
import
torch
import
torch
import
re
import
re
...
@@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
...
@@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
from
urllib
import
request
from
urllib
import
request
import
ldm.modules.midas
as
midas
import
ldm.modules.midas
as
midas
from
ldm.util
import
instantiate_from_config
from
modules
import
paths
,
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
,
sd_disable_initialization
,
errors
,
hashes
,
sd_models_config
,
sd_unet
,
sd_models_xl
,
cache
,
extra_networks
,
processing
,
lowvram
,
sd_hijack
,
patches
from
modules
import
paths
,
shared
,
modelloader
,
devices
,
script_callbacks
,
sd_vae
,
sd_disable_initialization
,
errors
,
hashes
,
sd_models_config
,
sd_unet
,
sd_models_xl
,
cache
,
extra_networks
,
processing
,
lowvram
,
sd_hijack
,
patches
from
modules.timer
import
Timer
from
modules.timer
import
Timer
from
modules.shared
import
opts
from
modules.shared
import
opts
...
@@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
...
@@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded
=
collections
.
OrderedDict
()
checkpoints_loaded
=
collections
.
OrderedDict
()
class
ModelType
(
enum
.
Enum
):
SD1
=
1
SD2
=
2
SDXL
=
3
SSD
=
4
SD3
=
5
def
replace_key
(
d
,
key
,
new_key
,
value
):
def
replace_key
(
d
,
key
,
new_key
,
value
):
keys
=
list
(
d
.
keys
())
keys
=
list
(
d
.
keys
())
...
@@ -368,6 +376,36 @@ def check_fp8(model):
...
@@ -368,6 +376,36 @@ def check_fp8(model):
return
enable_fp8
return
enable_fp8
def
set_model_type
(
model
,
state_dict
):
model
.
is_sd1
=
False
model
.
is_sd2
=
False
model
.
is_sdxl
=
False
model
.
is_ssd
=
False
model
.
is_ssd3
=
False
if
"model.diffusion_model.x_embedder.proj.weight"
in
state_dict
:
model
.
is_sd3
=
True
model
.
model_type
=
ModelType
.
SD3
elif
hasattr
(
model
,
'conditioner'
):
model
.
is_sdxl
=
True
if
'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight'
not
in
state_dict
.
keys
():
model
.
is_ssd
=
True
model
.
model_type
=
ModelType
.
SSD
else
:
model
.
model_type
=
ModelType
.
SDXL
elif
hasattr
(
model
.
cond_stage_model
,
'model'
):
model
.
is_sd2
=
True
model
.
model_type
=
ModelType
.
SD2
else
:
model
.
is_sd1
=
True
model
.
model_type
=
ModelType
.
SD1
def
set_model_fields
(
model
):
if
not
hasattr
(
model
,
'latent_channels'
):
model
.
latent_channels
=
4
def
load_model_weights
(
model
,
checkpoint_info
:
CheckpointInfo
,
state_dict
,
timer
):
def
load_model_weights
(
model
,
checkpoint_info
:
CheckpointInfo
,
state_dict
,
timer
):
sd_model_hash
=
checkpoint_info
.
calculate_shorthash
()
sd_model_hash
=
checkpoint_info
.
calculate_shorthash
()
timer
.
record
(
"calculate hash"
)
timer
.
record
(
"calculate hash"
)
...
@@ -382,10 +420,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
...
@@ -382,10 +420,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if
state_dict
is
None
:
if
state_dict
is
None
:
state_dict
=
get_checkpoint_state_dict
(
checkpoint_info
,
timer
)
state_dict
=
get_checkpoint_state_dict
(
checkpoint_info
,
timer
)
model
.
is_sdxl
=
hasattr
(
model
,
'conditioner'
)
set_model_type
(
model
,
state_dict
)
model
.
is_sd2
=
not
model
.
is_sdxl
and
hasattr
(
model
.
cond_stage_model
,
'model'
)
set_model_fields
(
model
)
model
.
is_sd1
=
not
model
.
is_sdxl
and
not
model
.
is_sd2
model
.
is_ssd
=
model
.
is_sdxl
and
'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight'
not
in
state_dict
.
keys
()
if
model
.
is_sdxl
:
if
model
.
is_sdxl
:
sd_models_xl
.
extend_sdxl
(
model
)
sd_models_xl
.
extend_sdxl
(
model
)
...
@@ -552,8 +589,7 @@ def patch_given_betas():
...
@@ -552,8 +589,7 @@ def patch_given_betas():
original_register_schedule
=
patches
.
patch
(
__name__
,
ldm
.
models
.
diffusion
.
ddpm
.
DDPM
,
'register_schedule'
,
patched_register_schedule
)
original_register_schedule
=
patches
.
patch
(
__name__
,
ldm
.
models
.
diffusion
.
ddpm
.
DDPM
,
'register_schedule'
,
patched_register_schedule
)
def
repair_config
(
sd_config
):
def
repair_config
(
sd_config
,
state_dict
=
None
):
if
not
hasattr
(
sd_config
.
model
.
params
,
"use_ema"
):
if
not
hasattr
(
sd_config
.
model
.
params
,
"use_ema"
):
sd_config
.
model
.
params
.
use_ema
=
False
sd_config
.
model
.
params
.
use_ema
=
False
...
@@ -563,8 +599,9 @@ def repair_config(sd_config):
...
@@ -563,8 +599,9 @@ def repair_config(sd_config):
elif
shared
.
cmd_opts
.
upcast_sampling
or
shared
.
cmd_opts
.
precision
==
"half"
:
elif
shared
.
cmd_opts
.
upcast_sampling
or
shared
.
cmd_opts
.
precision
==
"half"
:
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
True
sd_config
.
model
.
params
.
unet_config
.
params
.
use_fp16
=
True
if
getattr
(
sd_config
.
model
.
params
.
first_stage_config
.
params
.
ddconfig
,
"attn_type"
,
None
)
==
"vanilla-xformers"
and
not
shared
.
xformers_available
:
if
hasattr
(
sd_config
.
model
.
params
,
'first_stage_config'
):
sd_config
.
model
.
params
.
first_stage_config
.
params
.
ddconfig
.
attn_type
=
"vanilla"
if
getattr
(
sd_config
.
model
.
params
.
first_stage_config
.
params
.
ddconfig
,
"attn_type"
,
None
)
==
"vanilla-xformers"
and
not
shared
.
xformers_available
:
sd_config
.
model
.
params
.
first_stage_config
.
params
.
ddconfig
.
attn_type
=
"vanilla"
# For UnCLIP-L, override the hardcoded karlo directory
# For UnCLIP-L, override the hardcoded karlo directory
if
hasattr
(
sd_config
.
model
.
params
,
"noise_aug_config"
)
and
hasattr
(
sd_config
.
model
.
params
.
noise_aug_config
.
params
,
"clip_stats_path"
):
if
hasattr
(
sd_config
.
model
.
params
,
"noise_aug_config"
)
and
hasattr
(
sd_config
.
model
.
params
.
noise_aug_config
.
params
,
"clip_stats_path"
):
...
@@ -580,6 +617,7 @@ def repair_config(sd_config):
...
@@ -580,6 +617,7 @@ def repair_config(sd_config):
sd_config
.
model
.
params
.
unet_config
.
params
.
use_checkpoint
=
False
sd_config
.
model
.
params
.
unet_config
.
params
.
use_checkpoint
=
False
def
rescale_zero_terminal_snr_abar
(
alphas_cumprod
):
def
rescale_zero_terminal_snr_abar
(
alphas_cumprod
):
alphas_bar_sqrt
=
alphas_cumprod
.
sqrt
()
alphas_bar_sqrt
=
alphas_cumprod
.
sqrt
()
...
@@ -715,6 +753,25 @@ def send_model_to_trash(m):
...
@@ -715,6 +753,25 @@ def send_model_to_trash(m):
devices
.
torch_gc
()
devices
.
torch_gc
()
def
instantiate_from_config
(
config
,
state_dict
=
None
):
constructor
=
get_obj_from_str
(
config
[
"target"
])
params
=
{
**
config
.
get
(
"params"
,
{})}
if
state_dict
and
"state_dict"
in
params
and
params
[
"state_dict"
]
is
None
:
params
[
"state_dict"
]
=
state_dict
return
constructor
(
**
params
)
def
get_obj_from_str
(
string
,
reload
=
False
):
module
,
cls
=
string
.
rsplit
(
"."
,
1
)
if
reload
:
module_imp
=
importlib
.
import_module
(
module
)
importlib
.
reload
(
module_imp
)
return
getattr
(
importlib
.
import_module
(
module
,
package
=
None
),
cls
)
def
load_model
(
checkpoint_info
=
None
,
already_loaded_state_dict
=
None
):
def
load_model
(
checkpoint_info
=
None
,
already_loaded_state_dict
=
None
):
from
modules
import
sd_hijack
from
modules
import
sd_hijack
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
...
@@ -739,7 +796,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
...
@@ -739,7 +796,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer
.
record
(
"find config"
)
timer
.
record
(
"find config"
)
sd_config
=
OmegaConf
.
load
(
checkpoint_config
)
sd_config
=
OmegaConf
.
load
(
checkpoint_config
)
repair_config
(
sd_config
)
repair_config
(
sd_config
,
state_dict
)
timer
.
record
(
"load config"
)
timer
.
record
(
"load config"
)
...
@@ -749,7 +806,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
...
@@ -749,7 +806,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try
:
try
:
with
sd_disable_initialization
.
DisableInitialization
(
disable_clip
=
clip_is_included_into_sd
or
shared
.
cmd_opts
.
do_not_download_clip
):
with
sd_disable_initialization
.
DisableInitialization
(
disable_clip
=
clip_is_included_into_sd
or
shared
.
cmd_opts
.
do_not_download_clip
):
with
sd_disable_initialization
.
InitializeOnMeta
():
with
sd_disable_initialization
.
InitializeOnMeta
():
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
sd_model
=
instantiate_from_config
(
sd_config
.
model
,
state_dict
)
except
Exception
as
e
:
except
Exception
as
e
:
errors
.
display
(
e
,
"creating model quickly"
,
full_traceback
=
True
)
errors
.
display
(
e
,
"creating model quickly"
,
full_traceback
=
True
)
...
@@ -758,7 +815,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
...
@@ -758,7 +815,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
print
(
'Failed to create model quickly; will retry using slow method.'
,
file
=
sys
.
stderr
)
print
(
'Failed to create model quickly; will retry using slow method.'
,
file
=
sys
.
stderr
)
with
sd_disable_initialization
.
InitializeOnMeta
():
with
sd_disable_initialization
.
InitializeOnMeta
():
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
sd_model
=
instantiate_from_config
(
sd_config
.
model
,
state_dict
)
sd_model
.
used_config
=
checkpoint_config
sd_model
.
used_config
=
checkpoint_config
...
@@ -775,6 +832,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
...
@@ -775,6 +832,10 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with
sd_disable_initialization
.
LoadStateDictOnMeta
(
state_dict
,
device
=
model_target_device
(
sd_model
),
weight_dtype_conversion
=
weight_dtype_conversion
):
with
sd_disable_initialization
.
LoadStateDictOnMeta
(
state_dict
,
device
=
model_target_device
(
sd_model
),
weight_dtype_conversion
=
weight_dtype_conversion
):
load_model_weights
(
sd_model
,
checkpoint_info
,
state_dict
,
timer
)
load_model_weights
(
sd_model
,
checkpoint_info
,
state_dict
,
timer
)
if
hasattr
(
sd_model
,
"after_load_weights"
):
sd_model
.
after_load_weights
()
timer
.
record
(
"load weights from state dict"
)
timer
.
record
(
"load weights from state dict"
)
send_model_to_device
(
sd_model
)
send_model_to_device
(
sd_model
)
...
...
modules/sd_models_config.py
View file @
5b2a60b8
...
@@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
...
@@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
config_instruct_pix2pix
=
os
.
path
.
join
(
sd_configs_path
,
"instruct-pix2pix.yaml"
)
config_instruct_pix2pix
=
os
.
path
.
join
(
sd_configs_path
,
"instruct-pix2pix.yaml"
)
config_alt_diffusion
=
os
.
path
.
join
(
sd_configs_path
,
"alt-diffusion-inference.yaml"
)
config_alt_diffusion
=
os
.
path
.
join
(
sd_configs_path
,
"alt-diffusion-inference.yaml"
)
config_alt_diffusion_m18
=
os
.
path
.
join
(
sd_configs_path
,
"alt-diffusion-m18-inference.yaml"
)
config_alt_diffusion_m18
=
os
.
path
.
join
(
sd_configs_path
,
"alt-diffusion-m18-inference.yaml"
)
config_sd3
=
os
.
path
.
join
(
sd_configs_path
,
"sd3-inference.yaml"
)
def
is_using_v_parameterization_for_sd2
(
state_dict
):
def
is_using_v_parameterization_for_sd2
(
state_dict
):
"""
"""
...
@@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename):
...
@@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input
=
sd
.
get
(
'model.diffusion_model.input_blocks.0.0.weight'
,
None
)
diffusion_model_input
=
sd
.
get
(
'model.diffusion_model.input_blocks.0.0.weight'
,
None
)
sd2_variations_weight
=
sd
.
get
(
'embedder.model.ln_final.weight'
,
None
)
sd2_variations_weight
=
sd
.
get
(
'embedder.model.ln_final.weight'
,
None
)
if
"model.diffusion_model.x_embedder.proj.weight"
in
sd
:
return
config_sd3
if
sd
.
get
(
'conditioner.embedders.1.model.ln_final.weight'
,
None
)
is
not
None
:
if
sd
.
get
(
'conditioner.embedders.1.model.ln_final.weight'
,
None
)
is
not
None
:
if
diffusion_model_input
.
shape
[
1
]
==
9
:
if
diffusion_model_input
.
shape
[
1
]
==
9
:
return
config_sdxl_inpainting
return
config_sdxl_inpainting
else
:
else
:
return
config_sdxl
return
config_sdxl
if
sd
.
get
(
'conditioner.embedders.0.model.ln_final.weight'
,
None
)
is
not
None
:
if
sd
.
get
(
'conditioner.embedders.0.model.ln_final.weight'
,
None
)
is
not
None
:
return
config_sdxl_refiner
return
config_sdxl_refiner
elif
sd
.
get
(
'depth_model.model.pretrained.act_postprocess3.0.project.0.bias'
,
None
)
is
not
None
:
elif
sd
.
get
(
'depth_model.model.pretrained.act_postprocess3.0.project.0.bias'
,
None
)
is
not
None
:
...
@@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename):
...
@@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename):
if
diffusion_model_input
.
shape
[
1
]
==
8
:
if
diffusion_model_input
.
shape
[
1
]
==
8
:
return
config_instruct_pix2pix
return
config_instruct_pix2pix
if
sd
.
get
(
'cond_stage_model.roberta.embeddings.word_embeddings.weight'
,
None
)
is
not
None
:
if
sd
.
get
(
'cond_stage_model.roberta.embeddings.word_embeddings.weight'
,
None
)
is
not
None
:
if
sd
.
get
(
'cond_stage_model.transformation.weight'
)
.
size
()[
0
]
==
1024
:
if
sd
.
get
(
'cond_stage_model.transformation.weight'
)
.
size
()[
0
]
==
1024
:
return
config_alt_diffusion_m18
return
config_alt_diffusion_m18
...
...
modules/sd_models_types.py
View file @
5b2a60b8
...
@@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion):
...
@@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion):
is_sd1
:
bool
is_sd1
:
bool
"""True if the model's architecture is SD 1.x"""
"""True if the model's architecture is SD 1.x"""
is_sd3
:
bool
"""True if the model's architecture is SD 3"""
latent_channels
:
int
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""
modules/sd_samplers_common.py
View file @
5b2a60b8
...
@@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
...
@@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else
:
else
:
if
model
is
None
:
if
model
is
None
:
model
=
shared
.
sd_model
model
=
shared
.
sd_model
with
devices
.
without_autocast
():
# fixes an issue with unstable VAEs that are flaky even in fp32
with
torch
.
no_grad
(),
devices
.
without_autocast
():
# fixes an issue with unstable VAEs that are flaky even in fp32
x_sample
=
model
.
decode_first_stage
(
sample
.
to
(
model
.
first_stage_model
.
dtype
))
x_sample
=
model
.
decode_first_stage
(
sample
.
to
(
model
.
first_stage_model
.
dtype
))
return
x_sample
return
x_sample
...
@@ -246,7 +246,7 @@ class Sampler:
...
@@ -246,7 +246,7 @@ class Sampler:
self
.
eta_infotext_field
=
'Eta'
self
.
eta_infotext_field
=
'Eta'
self
.
eta_default
=
1.0
self
.
eta_default
=
1.0
self
.
conditioning_key
=
shared
.
sd_model
.
model
.
conditioning_key
self
.
conditioning_key
=
getattr
(
shared
.
sd_model
.
model
,
'conditioning_key'
,
'crossattn'
)
self
.
p
=
None
self
.
p
=
None
self
.
model_wrap_cfg
=
None
self
.
model_wrap_cfg
=
None
...
...
modules/sd_samplers_kdiffusion.py
View file @
5b2a60b8
...
@@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
...
@@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@
property
@
property
def
inner_model
(
self
):
def
inner_model
(
self
):
if
self
.
model_wrap
is
None
:
if
self
.
model_wrap
is
None
:
denoiser
=
k_diffusion
.
external
.
CompVisVDenoiser
if
shared
.
sd_model
.
parameterization
==
"v"
else
k_diffusion
.
external
.
CompVisDenoiser
denoiser_constructor
=
getattr
(
shared
.
sd_model
,
'create_denoiser'
,
None
)
self
.
model_wrap
=
denoiser
(
shared
.
sd_model
,
quantize
=
shared
.
opts
.
enable_quantization
)
if
denoiser_constructor
is
not
None
:
self
.
model_wrap
=
denoiser_constructor
()
else
:
denoiser
=
k_diffusion
.
external
.
CompVisVDenoiser
if
shared
.
sd_model
.
parameterization
==
"v"
else
k_diffusion
.
external
.
CompVisDenoiser
self
.
model_wrap
=
denoiser
(
shared
.
sd_model
,
quantize
=
shared
.
opts
.
enable_quantization
)
return
self
.
model_wrap
return
self
.
model_wrap
...
...
modules/sd_vae_approx.py
View file @
5b2a60b8
...
@@ -8,9 +8,9 @@ sd_vae_approx_models = {}
...
@@ -8,9 +8,9 @@ sd_vae_approx_models = {}
class
VAEApprox
(
nn
.
Module
):
class
VAEApprox
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
,
latent_channels
=
4
):
super
(
VAEApprox
,
self
)
.
__init__
()
super
(
VAEApprox
,
self
)
.
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
4
,
8
,
(
7
,
7
))
self
.
conv1
=
nn
.
Conv2d
(
latent_channels
,
8
,
(
7
,
7
))
self
.
conv2
=
nn
.
Conv2d
(
8
,
16
,
(
5
,
5
))
self
.
conv2
=
nn
.
Conv2d
(
8
,
16
,
(
5
,
5
))
self
.
conv3
=
nn
.
Conv2d
(
16
,
32
,
(
3
,
3
))
self
.
conv3
=
nn
.
Conv2d
(
16
,
32
,
(
3
,
3
))
self
.
conv4
=
nn
.
Conv2d
(
32
,
64
,
(
3
,
3
))
self
.
conv4
=
nn
.
Conv2d
(
32
,
64
,
(
3
,
3
))
...
@@ -40,7 +40,13 @@ def download_model(model_path, model_url):
...
@@ -40,7 +40,13 @@ def download_model(model_path, model_url):
def
model
():
def
model
():
model_name
=
"vaeapprox-sdxl.pt"
if
getattr
(
shared
.
sd_model
,
'is_sdxl'
,
False
)
else
"model.pt"
if
shared
.
sd_model
.
is_sd3
:
model_name
=
"vaeapprox-sd3.pt"
elif
shared
.
sd_model
.
is_sdxl
:
model_name
=
"vaeapprox-sdxl.pt"
else
:
model_name
=
"model.pt"
loaded_model
=
sd_vae_approx_models
.
get
(
model_name
)
loaded_model
=
sd_vae_approx_models
.
get
(
model_name
)
if
loaded_model
is
None
:
if
loaded_model
is
None
:
...
@@ -52,7 +58,7 @@ def model():
...
@@ -52,7 +58,7 @@ def model():
model_path
=
os
.
path
.
join
(
paths
.
models_path
,
"VAE-approx"
,
model_name
)
model_path
=
os
.
path
.
join
(
paths
.
models_path
,
"VAE-approx"
,
model_name
)
download_model
(
model_path
,
'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/'
+
model_name
)
download_model
(
model_path
,
'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/'
+
model_name
)
loaded_model
=
VAEApprox
()
loaded_model
=
VAEApprox
(
latent_channels
=
shared
.
sd_model
.
latent_channels
)
loaded_model
.
load_state_dict
(
torch
.
load
(
model_path
,
map_location
=
'cpu'
if
devices
.
device
.
type
!=
'cuda'
else
None
))
loaded_model
.
load_state_dict
(
torch
.
load
(
model_path
,
map_location
=
'cpu'
if
devices
.
device
.
type
!=
'cuda'
else
None
))
loaded_model
.
eval
()
loaded_model
.
eval
()
loaded_model
.
to
(
devices
.
device
,
devices
.
dtype
)
loaded_model
.
to
(
devices
.
device
,
devices
.
dtype
)
...
@@ -64,7 +70,18 @@ def model():
...
@@ -64,7 +70,18 @@ def model():
def
cheap_approximation
(
sample
):
def
cheap_approximation
(
sample
):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
if
shared
.
sd_model
.
is_sdxl
:
if
shared
.
sd_model
.
is_sd3
:
coeffs
=
[
[
-
0.0645
,
0.0177
,
0.1052
],
[
0.0028
,
0.0312
,
0.0650
],
[
0.1848
,
0.0762
,
0.0360
],
[
0.0944
,
0.0360
,
0.0889
],
[
0.0897
,
0.0506
,
-
0.0364
],
[
-
0.0020
,
0.1203
,
0.0284
],
[
0.0855
,
0.0118
,
0.0283
],
[
-
0.0539
,
0.0658
,
0.1047
],
[
-
0.0057
,
0.0116
,
0.0700
],
[
-
0.0412
,
0.0281
,
-
0.0039
],
[
0.1106
,
0.1171
,
0.1220
],
[
-
0.0248
,
0.0682
,
-
0.0481
],
[
0.0815
,
0.0846
,
0.1207
],
[
-
0.0120
,
-
0.0055
,
-
0.0867
],
[
-
0.0749
,
-
0.0634
,
-
0.0456
],
[
-
0.1418
,
-
0.1457
,
-
0.1259
],
]
elif
shared
.
sd_model
.
is_sdxl
:
coeffs
=
[
coeffs
=
[
[
0.3448
,
0.4168
,
0.4395
],
[
0.3448
,
0.4168
,
0.4395
],
[
-
0.1953
,
-
0.0290
,
0.0250
],
[
-
0.1953
,
-
0.0290
,
0.0250
],
...
...
modules/sd_vae_taesd.py
View file @
5b2a60b8
...
@@ -34,9 +34,9 @@ class Block(nn.Module):
...
@@ -34,9 +34,9 @@ class Block(nn.Module):
return
self
.
fuse
(
self
.
conv
(
x
)
+
self
.
skip
(
x
))
return
self
.
fuse
(
self
.
conv
(
x
)
+
self
.
skip
(
x
))
def
decoder
():
def
decoder
(
latent_channels
=
4
):
return
nn
.
Sequential
(
return
nn
.
Sequential
(
Clamp
(),
conv
(
4
,
64
),
nn
.
ReLU
(),
Clamp
(),
conv
(
latent_channels
,
64
),
nn
.
ReLU
(),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
nn
.
Upsample
(
scale_factor
=
2
),
conv
(
64
,
64
,
bias
=
False
),
...
@@ -44,13 +44,13 @@ def decoder():
...
@@ -44,13 +44,13 @@ def decoder():
)
)
def
encoder
():
def
encoder
(
latent_channels
=
4
):
return
nn
.
Sequential
(
return
nn
.
Sequential
(
conv
(
3
,
64
),
Block
(
64
,
64
),
conv
(
3
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
64
,
stride
=
2
,
bias
=
False
),
Block
(
64
,
64
),
Block
(
64
,
64
),
Block
(
64
,
64
),
conv
(
64
,
4
),
conv
(
64
,
latent_channels
),
)
)
...
@@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
...
@@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
latent_magnitude
=
3
latent_magnitude
=
3
latent_shift
=
0.5
latent_shift
=
0.5
def
__init__
(
self
,
decoder_path
=
"taesd_decoder.pth"
):
def
__init__
(
self
,
decoder_path
=
"taesd_decoder.pth"
,
latent_channels
=
None
):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super
()
.
__init__
()
super
()
.
__init__
()
self
.
decoder
=
decoder
()
if
latent_channels
is
None
:
latent_channels
=
16
if
"taesd3"
in
str
(
decoder_path
)
else
4
self
.
decoder
=
decoder
(
latent_channels
)
self
.
decoder
.
load_state_dict
(
self
.
decoder
.
load_state_dict
(
torch
.
load
(
decoder_path
,
map_location
=
'cpu'
if
devices
.
device
.
type
!=
'cuda'
else
None
))
torch
.
load
(
decoder_path
,
map_location
=
'cpu'
if
devices
.
device
.
type
!=
'cuda'
else
None
))
...
@@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
...
@@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
latent_magnitude
=
3
latent_magnitude
=
3
latent_shift
=
0.5
latent_shift
=
0.5
def
__init__
(
self
,
encoder_path
=
"taesd_encoder.pth"
):
def
__init__
(
self
,
encoder_path
=
"taesd_encoder.pth"
,
latent_channels
=
None
):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super
()
.
__init__
()
super
()
.
__init__
()
self
.
encoder
=
encoder
()
if
latent_channels
is
None
:
latent_channels
=
16
if
"taesd3"
in
str
(
encoder_path
)
else
4
self
.
encoder
=
encoder
(
latent_channels
)
self
.
encoder
.
load_state_dict
(
self
.
encoder
.
load_state_dict
(
torch
.
load
(
encoder_path
,
map_location
=
'cpu'
if
devices
.
device
.
type
!=
'cuda'
else
None
))
torch
.
load
(
encoder_path
,
map_location
=
'cpu'
if
devices
.
device
.
type
!=
'cuda'
else
None
))
...
@@ -87,7 +95,13 @@ def download_model(model_path, model_url):
...
@@ -87,7 +95,13 @@ def download_model(model_path, model_url):
def
decoder_model
():
def
decoder_model
():
model_name
=
"taesdxl_decoder.pth"
if
getattr
(
shared
.
sd_model
,
'is_sdxl'
,
False
)
else
"taesd_decoder.pth"
if
shared
.
sd_model
.
is_sd3
:
model_name
=
"taesd3_decoder.pth"
elif
shared
.
sd_model
.
is_sdxl
:
model_name
=
"taesdxl_decoder.pth"
else
:
model_name
=
"taesd_decoder.pth"
loaded_model
=
sd_vae_taesd_models
.
get
(
model_name
)
loaded_model
=
sd_vae_taesd_models
.
get
(
model_name
)
if
loaded_model
is
None
:
if
loaded_model
is
None
:
...
@@ -106,7 +120,13 @@ def decoder_model():
...
@@ -106,7 +120,13 @@ def decoder_model():
def
encoder_model
():
def
encoder_model
():
model_name
=
"taesdxl_encoder.pth"
if
getattr
(
shared
.
sd_model
,
'is_sdxl'
,
False
)
else
"taesd_encoder.pth"
if
shared
.
sd_model
.
is_sd3
:
model_name
=
"taesd3_encoder.pth"
elif
shared
.
sd_model
.
is_sdxl
:
model_name
=
"taesdxl_encoder.pth"
else
:
model_name
=
"taesd_encoder.pth"
loaded_model
=
sd_vae_taesd_models
.
get
(
model_name
)
loaded_model
=
sd_vae_taesd_models
.
get
(
model_name
)
if
loaded_model
is
None
:
if
loaded_model
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment