Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
ea903819
Commit
ea903819
authored
Jul 20, 2024
by
AUTOMATIC1111
Committed by
GitHub
Jul 20, 2024
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #16212 from AUTOMATIC1111/sd3_lora
SD3 Lora support
parents
b2453d28
2b50233f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
106 additions
and
24 deletions
+106
-24
extensions-builtin/Lora/network.py
extensions-builtin/Lora/network.py
+5
-1
extensions-builtin/Lora/network_lora.py
extensions-builtin/Lora/network_lora.py
+9
-1
extensions-builtin/Lora/networks.py
extensions-builtin/Lora/networks.py
+75
-21
modules/models/sd3/mmdit.py
modules/models/sd3/mmdit.py
+4
-1
modules/models/sd3/sd3_impls.py
modules/models/sd3/sd3_impls.py
+1
-0
modules/models/sd3/sd3_model.py
modules/models/sd3/sd3_model.py
+12
-0
No files found.
extensions-builtin/Lora/network.py
View file @
ea903819
...
@@ -7,6 +7,7 @@ import torch.nn as nn
...
@@ -7,6 +7,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
modules
import
sd_models
,
cache
,
errors
,
hashes
,
shared
from
modules
import
sd_models
,
cache
,
errors
,
hashes
,
shared
import
modules.models.sd3.mmdit
NetworkWeights
=
namedtuple
(
'NetworkWeights'
,
[
'network_key'
,
'sd_key'
,
'w'
,
'sd_module'
])
NetworkWeights
=
namedtuple
(
'NetworkWeights'
,
[
'network_key'
,
'sd_key'
,
'w'
,
'sd_module'
])
...
@@ -114,7 +115,10 @@ class NetworkModule:
...
@@ -114,7 +115,10 @@ class NetworkModule:
self
.
sd_key
=
weights
.
sd_key
self
.
sd_key
=
weights
.
sd_key
self
.
sd_module
=
weights
.
sd_module
self
.
sd_module
=
weights
.
sd_module
if
hasattr
(
self
.
sd_module
,
'weight'
):
if
isinstance
(
self
.
sd_module
,
modules
.
models
.
sd3
.
mmdit
.
QkvLinear
):
s
=
self
.
sd_module
.
weight
.
shape
self
.
shape
=
(
s
[
0
]
//
3
,
s
[
1
])
elif
hasattr
(
self
.
sd_module
,
'weight'
):
self
.
shape
=
self
.
sd_module
.
weight
.
shape
self
.
shape
=
self
.
sd_module
.
weight
.
shape
elif
isinstance
(
self
.
sd_module
,
nn
.
MultiheadAttention
):
elif
isinstance
(
self
.
sd_module
,
nn
.
MultiheadAttention
):
# For now, only self-attn use Pytorch's MHA
# For now, only self-attn use Pytorch's MHA
...
...
extensions-builtin/Lora/network_lora.py
View file @
ea903819
import
torch
import
torch
import
lyco_helpers
import
lyco_helpers
import
modules.models.sd3.mmdit
import
network
import
network
from
modules
import
devices
from
modules
import
devices
...
@@ -10,6 +11,13 @@ class ModuleTypeLora(network.ModuleType):
...
@@ -10,6 +11,13 @@ class ModuleTypeLora(network.ModuleType):
if
all
(
x
in
weights
.
w
for
x
in
[
"lora_up.weight"
,
"lora_down.weight"
]):
if
all
(
x
in
weights
.
w
for
x
in
[
"lora_up.weight"
,
"lora_down.weight"
]):
return
NetworkModuleLora
(
net
,
weights
)
return
NetworkModuleLora
(
net
,
weights
)
if
all
(
x
in
weights
.
w
for
x
in
[
"lora_A.weight"
,
"lora_B.weight"
]):
w
=
weights
.
w
.
copy
()
weights
.
w
.
clear
()
weights
.
w
.
update
({
"lora_up.weight"
:
w
[
"lora_B.weight"
],
"lora_down.weight"
:
w
[
"lora_A.weight"
]})
return
NetworkModuleLora
(
net
,
weights
)
return
None
return
None
...
@@ -29,7 +37,7 @@ class NetworkModuleLora(network.NetworkModule):
...
@@ -29,7 +37,7 @@ class NetworkModuleLora(network.NetworkModule):
if
weight
is
None
and
none_ok
:
if
weight
is
None
and
none_ok
:
return
None
return
None
is_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Linear
,
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
,
torch
.
nn
.
MultiheadAttention
]
is_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Linear
,
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
,
torch
.
nn
.
MultiheadAttention
,
modules
.
models
.
sd3
.
mmdit
.
QkvLinear
]
is_conv
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Conv2d
]
is_conv
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Conv2d
]
if
is_linear
:
if
is_linear
:
...
...
extensions-builtin/Lora/networks.py
View file @
ea903819
...
@@ -20,6 +20,7 @@ from typing import Union
...
@@ -20,6 +20,7 @@ from typing import Union
from
modules
import
shared
,
devices
,
sd_models
,
errors
,
scripts
,
sd_hijack
from
modules
import
shared
,
devices
,
sd_models
,
errors
,
scripts
,
sd_hijack
import
modules.textual_inversion.textual_inversion
as
textual_inversion
import
modules.textual_inversion.textual_inversion
as
textual_inversion
import
modules.models.sd3.mmdit
from
lora_logger
import
logger
from
lora_logger
import
logger
...
@@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
...
@@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
keys_failed_to_match
=
{}
keys_failed_to_match
=
{}
is_sd2
=
'model_transformer_resblocks'
in
shared
.
sd_model
.
network_layer_mapping
is_sd2
=
'model_transformer_resblocks'
in
shared
.
sd_model
.
network_layer_mapping
if
hasattr
(
shared
.
sd_model
,
'diffusers_weight_map'
):
diffusers_weight_map
=
shared
.
sd_model
.
diffusers_weight_map
elif
hasattr
(
shared
.
sd_model
,
'diffusers_weight_mapping'
):
diffusers_weight_map
=
{}
for
k
,
v
in
shared
.
sd_model
.
diffusers_weight_mapping
():
diffusers_weight_map
[
k
]
=
v
shared
.
sd_model
.
diffusers_weight_map
=
diffusers_weight_map
else
:
diffusers_weight_map
=
None
matched_networks
=
{}
matched_networks
=
{}
bundle_embeddings
=
{}
bundle_embeddings
=
{}
for
key_network
,
weight
in
sd
.
items
():
for
key_network
,
weight
in
sd
.
items
():
key_network_without_network_parts
,
_
,
network_part
=
key_network
.
partition
(
"."
)
if
diffusers_weight_map
:
key_network_without_network_parts
,
network_name
,
network_weight
=
key_network
.
rsplit
(
"."
,
2
)
network_part
=
network_name
+
'.'
+
network_weight
else
:
key_network_without_network_parts
,
_
,
network_part
=
key_network
.
partition
(
"."
)
if
key_network_without_network_parts
==
"bundle_emb"
:
if
key_network_without_network_parts
==
"bundle_emb"
:
emb_name
,
vec_name
=
network_part
.
split
(
"."
,
1
)
emb_name
,
vec_name
=
network_part
.
split
(
"."
,
1
)
...
@@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
...
@@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
emb_dict
[
vec_name
]
=
weight
emb_dict
[
vec_name
]
=
weight
bundle_embeddings
[
emb_name
]
=
emb_dict
bundle_embeddings
[
emb_name
]
=
emb_dict
key
=
convert_diffusers_name_to_compvis
(
key_network_without_network_parts
,
is_sd2
)
if
diffusers_weight_map
:
key
=
diffusers_weight_map
.
get
(
key_network_without_network_parts
,
key_network_without_network_parts
)
else
:
key
=
convert_diffusers_name_to_compvis
(
key_network_without_network_parts
,
is_sd2
)
sd_module
=
shared
.
sd_model
.
network_layer_mapping
.
get
(
key
,
None
)
sd_module
=
shared
.
sd_model
.
network_layer_mapping
.
get
(
key
,
None
)
if
sd_module
is
None
:
if
sd_module
is
None
:
...
@@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
...
@@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory
()
purge_networks_from_memory
()
def
allowed_layer_without_weight
(
layer
):
if
isinstance
(
layer
,
torch
.
nn
.
LayerNorm
)
and
not
layer
.
elementwise_affine
:
return
True
return
False
def
store_weights_backup
(
weight
):
if
weight
is
None
:
return
None
return
weight
.
to
(
devices
.
cpu
,
copy
=
True
)
def
restore_weights_backup
(
obj
,
field
,
weight
):
if
weight
is
None
:
setattr
(
obj
,
field
,
None
)
return
getattr
(
obj
,
field
)
.
copy_
(
weight
)
def
network_restore_weights_from_backup
(
self
:
Union
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
MultiheadAttention
]):
def
network_restore_weights_from_backup
(
self
:
Union
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
MultiheadAttention
]):
weights_backup
=
getattr
(
self
,
"network_weights_backup"
,
None
)
weights_backup
=
getattr
(
self
,
"network_weights_backup"
,
None
)
bias_backup
=
getattr
(
self
,
"network_bias_backup"
,
None
)
bias_backup
=
getattr
(
self
,
"network_bias_backup"
,
None
)
...
@@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
...
@@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
if
weights_backup
is
not
None
:
if
weights_backup
is
not
None
:
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
self
.
in_proj_weight
.
copy_
(
weights_backup
[
0
])
restore_weights_backup
(
self
,
'in_proj_weight'
,
weights_backup
[
0
])
self
.
out_proj
.
weight
.
copy_
(
weights_backup
[
1
])
restore_weights_backup
(
self
.
out_proj
,
'weight'
,
weights_backup
[
1
])
else
:
else
:
self
.
weight
.
copy_
(
weights_backup
)
restore_weights_backup
(
self
,
'weight'
,
weights_backup
)
if
bias_backup
is
not
None
:
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
restore_weights_backup
(
self
.
out_proj
,
'bias'
,
bias_backup
)
self
.
out_proj
.
bias
.
copy_
(
bias_backup
)
else
:
self
.
bias
.
copy_
(
bias_backup
)
else
:
else
:
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
restore_weights_backup
(
self
,
'bias'
,
bias_backup
)
self
.
out_proj
.
bias
=
None
else
:
self
.
bias
=
None
def
network_apply_weights
(
self
:
Union
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
MultiheadAttention
]):
def
network_apply_weights
(
self
:
Union
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
MultiheadAttention
]):
...
@@ -389,22 +424,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
...
@@ -389,22 +424,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
weights_backup
=
getattr
(
self
,
"network_weights_backup"
,
None
)
weights_backup
=
getattr
(
self
,
"network_weights_backup"
,
None
)
if
weights_backup
is
None
and
wanted_names
!=
():
if
weights_backup
is
None
and
wanted_names
!=
():
if
current_names
!=
():
if
current_names
!=
()
and
not
allowed_layer_without_weight
(
self
)
:
raise
RuntimeError
(
"
no backup weights found and current weights are not unchanged"
)
raise
RuntimeError
(
f
"{network_layer_name} -
no backup weights found and current weights are not unchanged"
)
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
weights_backup
=
(
s
elf
.
in_proj_weight
.
to
(
devices
.
cpu
,
copy
=
True
),
self
.
out_proj
.
weight
.
to
(
devices
.
cpu
,
copy
=
True
))
weights_backup
=
(
s
tore_weights_backup
(
self
.
in_proj_weight
),
store_weights_backup
(
self
.
out_proj
.
weight
))
else
:
else
:
weights_backup
=
s
elf
.
weight
.
to
(
devices
.
cpu
,
copy
=
True
)
weights_backup
=
s
tore_weights_backup
(
self
.
weight
)
self
.
network_weights_backup
=
weights_backup
self
.
network_weights_backup
=
weights_backup
bias_backup
=
getattr
(
self
,
"network_bias_backup"
,
None
)
bias_backup
=
getattr
(
self
,
"network_bias_backup"
,
None
)
if
bias_backup
is
None
and
wanted_names
!=
():
if
bias_backup
is
None
and
wanted_names
!=
():
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
)
and
self
.
out_proj
.
bias
is
not
None
:
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
)
and
self
.
out_proj
.
bias
is
not
None
:
bias_backup
=
s
elf
.
out_proj
.
bias
.
to
(
devices
.
cpu
,
copy
=
True
)
bias_backup
=
s
tore_weights_backup
(
self
.
out_proj
.
bias
)
elif
getattr
(
self
,
'bias'
,
None
)
is
not
None
:
elif
getattr
(
self
,
'bias'
,
None
)
is
not
None
:
bias_backup
=
s
elf
.
bias
.
to
(
devices
.
cpu
,
copy
=
True
)
bias_backup
=
s
tore_weights_backup
(
self
.
bias
)
else
:
else
:
bias_backup
=
None
bias_backup
=
None
...
@@ -412,6 +447,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
...
@@ -412,6 +447,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
# Only report if bias is not None and current bias are not unchanged.
# Only report if bias is not None and current bias are not unchanged.
if
bias_backup
is
not
None
and
current_names
!=
():
if
bias_backup
is
not
None
and
current_names
!=
():
raise
RuntimeError
(
"no backup bias found and current bias are not unchanged"
)
raise
RuntimeError
(
"no backup bias found and current bias are not unchanged"
)
self
.
network_bias_backup
=
bias_backup
self
.
network_bias_backup
=
bias_backup
if
current_names
!=
wanted_names
:
if
current_names
!=
wanted_names
:
...
@@ -419,7 +455,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
...
@@ -419,7 +455,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
for
net
in
loaded_networks
:
for
net
in
loaded_networks
:
module
=
net
.
modules
.
get
(
network_layer_name
,
None
)
module
=
net
.
modules
.
get
(
network_layer_name
,
None
)
if
module
is
not
None
and
hasattr
(
self
,
'weight'
):
if
module
is
not
None
and
hasattr
(
self
,
'weight'
)
and
not
isinstance
(
module
,
modules
.
models
.
sd3
.
mmdit
.
QkvLinear
)
:
try
:
try
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
getattr
(
self
,
'fp16_weight'
,
None
)
is
None
:
if
getattr
(
self
,
'fp16_weight'
,
None
)
is
None
:
...
@@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
...
@@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
continue
continue
if
isinstance
(
self
,
modules
.
models
.
sd3
.
mmdit
.
QkvLinear
)
and
module_q
and
module_k
and
module_v
:
try
:
with
torch
.
no_grad
():
# Send "real" orig_weight into MHA's lora module
qw
,
kw
,
vw
=
self
.
weight
.
chunk
(
3
,
0
)
updown_q
,
_
=
module_q
.
calc_updown
(
qw
)
updown_k
,
_
=
module_k
.
calc_updown
(
kw
)
updown_v
,
_
=
module_v
.
calc_updown
(
vw
)
del
qw
,
kw
,
vw
updown_qkv
=
torch
.
vstack
([
updown_q
,
updown_k
,
updown_v
])
self
.
weight
+=
updown_qkv
except
RuntimeError
as
e
:
logging
.
debug
(
f
"Network {net.name} layer {network_layer_name}: {e}"
)
extra_network_lora
.
errors
[
net
.
name
]
=
extra_network_lora
.
errors
.
get
(
net
.
name
,
0
)
+
1
continue
if
module
is
None
:
if
module
is
None
:
continue
continue
...
...
modules/models/sd3/mmdit.py
View file @
ea903819
...
@@ -175,6 +175,9 @@ class VectorEmbedder(nn.Module):
...
@@ -175,6 +175,9 @@ class VectorEmbedder(nn.Module):
#################################################################################
#################################################################################
class
QkvLinear
(
torch
.
nn
.
Linear
):
pass
def
split_qkv
(
qkv
,
head_dim
):
def
split_qkv
(
qkv
,
head_dim
):
qkv
=
qkv
.
reshape
(
qkv
.
shape
[
0
],
qkv
.
shape
[
1
],
3
,
-
1
,
head_dim
)
.
movedim
(
2
,
0
)
qkv
=
qkv
.
reshape
(
qkv
.
shape
[
0
],
qkv
.
shape
[
1
],
3
,
-
1
,
head_dim
)
.
movedim
(
2
,
0
)
return
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
return
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
...
@@ -202,7 +205,7 @@ class SelfAttention(nn.Module):
...
@@ -202,7 +205,7 @@ class SelfAttention(nn.Module):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
,
dtype
=
dtype
,
device
=
device
)
self
.
qkv
=
Qkv
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
,
dtype
=
dtype
,
device
=
device
)
if
not
pre_only
:
if
not
pre_only
:
self
.
proj
=
nn
.
Linear
(
dim
,
dim
,
dtype
=
dtype
,
device
=
device
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
,
dtype
=
dtype
,
device
=
device
)
assert
attn_mode
in
self
.
ATTENTION_MODES
assert
attn_mode
in
self
.
ATTENTION_MODES
...
...
modules/models/sd3/sd3_impls.py
View file @
ea903819
...
@@ -67,6 +67,7 @@ class BaseModel(torch.nn.Module):
...
@@ -67,6 +67,7 @@ class BaseModel(torch.nn.Module):
}
}
self
.
diffusion_model
=
MMDiT
(
input_size
=
None
,
pos_embed_scaling_factor
=
None
,
pos_embed_offset
=
None
,
pos_embed_max_size
=
pos_embed_max_size
,
patch_size
=
patch_size
,
in_channels
=
16
,
depth
=
depth
,
num_patches
=
num_patches
,
adm_in_channels
=
adm_in_channels
,
context_embedder_config
=
context_embedder_config
,
device
=
device
,
dtype
=
dtype
)
self
.
diffusion_model
=
MMDiT
(
input_size
=
None
,
pos_embed_scaling_factor
=
None
,
pos_embed_offset
=
None
,
pos_embed_max_size
=
pos_embed_max_size
,
patch_size
=
patch_size
,
in_channels
=
16
,
depth
=
depth
,
num_patches
=
num_patches
,
adm_in_channels
=
adm_in_channels
,
context_embedder_config
=
context_embedder_config
,
device
=
device
,
dtype
=
dtype
)
self
.
model_sampling
=
ModelSamplingDiscreteFlow
(
shift
=
shift
)
self
.
model_sampling
=
ModelSamplingDiscreteFlow
(
shift
=
shift
)
self
.
depth
=
depth
def
apply_model
(
self
,
x
,
sigma
,
c_crossattn
=
None
,
y
=
None
):
def
apply_model
(
self
,
x
,
sigma
,
c_crossattn
=
None
,
y
=
None
):
dtype
=
self
.
get_dtype
()
dtype
=
self
.
get_dtype
()
...
...
modules/models/sd3/sd3_model.py
View file @
ea903819
...
@@ -82,3 +82,15 @@ class SD3Inferencer(torch.nn.Module):
...
@@ -82,3 +82,15 @@ class SD3Inferencer(torch.nn.Module):
def
fix_dimensions
(
self
,
width
,
height
):
def
fix_dimensions
(
self
,
width
,
height
):
return
width
//
16
*
16
,
height
//
16
*
16
return
width
//
16
*
16
,
height
//
16
*
16
def
diffusers_weight_mapping
(
self
):
for
i
in
range
(
self
.
model
.
depth
):
yield
f
"transformer.transformer_blocks.{i}.attn.to_q"
,
f
"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.to_k"
,
f
"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.to_v"
,
f
"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.to_out.0"
,
f
"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.add_q_proj"
,
f
"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.add_k_proj"
,
f
"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.add_v_proj"
,
f
"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.add_out_proj.0"
,
f
"diffusion_model_joint_blocks_{i}_context_block_attn_proj"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment