Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
7e5cdaab
Commit
7e5cdaab
authored
Jul 15, 2024
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
SD3 lora support
parent
b2453d28
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
106 additions
and
24 deletions
+106
-24
extensions-builtin/Lora/network.py
extensions-builtin/Lora/network.py
+5
-1
extensions-builtin/Lora/network_lora.py
extensions-builtin/Lora/network_lora.py
+9
-1
extensions-builtin/Lora/networks.py
extensions-builtin/Lora/networks.py
+75
-21
modules/models/sd3/mmdit.py
modules/models/sd3/mmdit.py
+4
-1
modules/models/sd3/sd3_impls.py
modules/models/sd3/sd3_impls.py
+1
-0
modules/models/sd3/sd3_model.py
modules/models/sd3/sd3_model.py
+12
-0
No files found.
extensions-builtin/Lora/network.py
View file @
7e5cdaab
...
...
@@ -7,6 +7,7 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
modules
import
sd_models
,
cache
,
errors
,
hashes
,
shared
import
modules.models.sd3.mmdit
NetworkWeights
=
namedtuple
(
'NetworkWeights'
,
[
'network_key'
,
'sd_key'
,
'w'
,
'sd_module'
])
...
...
@@ -114,7 +115,10 @@ class NetworkModule:
self
.
sd_key
=
weights
.
sd_key
self
.
sd_module
=
weights
.
sd_module
if
hasattr
(
self
.
sd_module
,
'weight'
):
if
isinstance
(
self
.
sd_module
,
modules
.
models
.
sd3
.
mmdit
.
QkvLinear
):
s
=
self
.
sd_module
.
weight
.
shape
self
.
shape
=
(
s
[
0
]
//
3
,
s
[
1
])
elif
hasattr
(
self
.
sd_module
,
'weight'
):
self
.
shape
=
self
.
sd_module
.
weight
.
shape
elif
isinstance
(
self
.
sd_module
,
nn
.
MultiheadAttention
):
# For now, only self-attn use Pytorch's MHA
...
...
extensions-builtin/Lora/network_lora.py
View file @
7e5cdaab
import
torch
import
lyco_helpers
import
modules.models.sd3.mmdit
import
network
from
modules
import
devices
...
...
@@ -10,6 +11,13 @@ class ModuleTypeLora(network.ModuleType):
if
all
(
x
in
weights
.
w
for
x
in
[
"lora_up.weight"
,
"lora_down.weight"
]):
return
NetworkModuleLora
(
net
,
weights
)
if
all
(
x
in
weights
.
w
for
x
in
[
"lora_A.weight"
,
"lora_B.weight"
]):
w
=
weights
.
w
.
copy
()
weights
.
w
.
clear
()
weights
.
w
.
update
({
"lora_up.weight"
:
w
[
"lora_B.weight"
],
"lora_down.weight"
:
w
[
"lora_A.weight"
]})
return
NetworkModuleLora
(
net
,
weights
)
return
None
...
...
@@ -29,7 +37,7 @@ class NetworkModuleLora(network.NetworkModule):
if
weight
is
None
and
none_ok
:
return
None
is_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Linear
,
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
,
torch
.
nn
.
MultiheadAttention
]
is_linear
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Linear
,
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
,
torch
.
nn
.
MultiheadAttention
,
modules
.
models
.
sd3
.
mmdit
.
QkvLinear
]
is_conv
=
type
(
self
.
sd_module
)
in
[
torch
.
nn
.
Conv2d
]
if
is_linear
:
...
...
extensions-builtin/Lora/networks.py
View file @
7e5cdaab
...
...
@@ -20,6 +20,7 @@ from typing import Union
from
modules
import
shared
,
devices
,
sd_models
,
errors
,
scripts
,
sd_hijack
import
modules.textual_inversion.textual_inversion
as
textual_inversion
import
modules.models.sd3.mmdit
from
lora_logger
import
logger
...
...
@@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
keys_failed_to_match
=
{}
is_sd2
=
'model_transformer_resblocks'
in
shared
.
sd_model
.
network_layer_mapping
if
hasattr
(
shared
.
sd_model
,
'diffusers_weight_map'
):
diffusers_weight_map
=
shared
.
sd_model
.
diffusers_weight_map
elif
hasattr
(
shared
.
sd_model
,
'diffusers_weight_mapping'
):
diffusers_weight_map
=
{}
for
k
,
v
in
shared
.
sd_model
.
diffusers_weight_mapping
():
diffusers_weight_map
[
k
]
=
v
shared
.
sd_model
.
diffusers_weight_map
=
diffusers_weight_map
else
:
diffusers_weight_map
=
None
matched_networks
=
{}
bundle_embeddings
=
{}
for
key_network
,
weight
in
sd
.
items
():
key_network_without_network_parts
,
_
,
network_part
=
key_network
.
partition
(
"."
)
if
diffusers_weight_map
:
key_network_without_network_parts
,
network_name
,
network_weight
=
key_network
.
rsplit
(
"."
,
2
)
network_part
=
network_name
+
'.'
+
network_weight
else
:
key_network_without_network_parts
,
_
,
network_part
=
key_network
.
partition
(
"."
)
if
key_network_without_network_parts
==
"bundle_emb"
:
emb_name
,
vec_name
=
network_part
.
split
(
"."
,
1
)
...
...
@@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
emb_dict
[
vec_name
]
=
weight
bundle_embeddings
[
emb_name
]
=
emb_dict
key
=
convert_diffusers_name_to_compvis
(
key_network_without_network_parts
,
is_sd2
)
if
diffusers_weight_map
:
key
=
diffusers_weight_map
.
get
(
key_network_without_network_parts
,
key_network_without_network_parts
)
else
:
key
=
convert_diffusers_name_to_compvis
(
key_network_without_network_parts
,
is_sd2
)
sd_module
=
shared
.
sd_model
.
network_layer_mapping
.
get
(
key
,
None
)
if
sd_module
is
None
:
...
...
@@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
purge_networks_from_memory
()
def
allowed_layer_without_weight
(
layer
):
if
isinstance
(
layer
,
torch
.
nn
.
LayerNorm
)
and
not
layer
.
elementwise_affine
:
return
True
return
False
def
store_weights_backup
(
weight
):
if
weight
is
None
:
return
None
return
weight
.
to
(
devices
.
cpu
,
copy
=
True
)
def
restore_weights_backup
(
obj
,
field
,
weight
):
if
weight
is
None
:
setattr
(
obj
,
field
,
None
)
return
getattr
(
obj
,
field
)
.
copy_
(
weight
)
def
network_restore_weights_from_backup
(
self
:
Union
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
MultiheadAttention
]):
weights_backup
=
getattr
(
self
,
"network_weights_backup"
,
None
)
bias_backup
=
getattr
(
self
,
"network_bias_backup"
,
None
)
...
...
@@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
if
weights_backup
is
not
None
:
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
self
.
in_proj_weight
.
copy_
(
weights_backup
[
0
])
self
.
out_proj
.
weight
.
copy_
(
weights_backup
[
1
])
restore_weights_backup
(
self
,
'in_proj_weight'
,
weights_backup
[
0
])
restore_weights_backup
(
self
.
out_proj
,
'weight'
,
weights_backup
[
0
])
else
:
self
.
weight
.
copy_
(
weights_backup
)
restore_weights_backup
(
self
,
'weight'
,
weights_backup
)
if
bias_backup
is
not
None
:
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
self
.
out_proj
.
bias
.
copy_
(
bias_backup
)
else
:
self
.
bias
.
copy_
(
bias_backup
)
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
restore_weights_backup
(
self
.
out_proj
,
'bias'
,
bias_backup
)
else
:
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
self
.
out_proj
.
bias
=
None
else
:
self
.
bias
=
None
restore_weights_backup
(
self
,
'bias'
,
bias_backup
)
def
network_apply_weights
(
self
:
Union
[
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Linear
,
torch
.
nn
.
GroupNorm
,
torch
.
nn
.
LayerNorm
,
torch
.
nn
.
MultiheadAttention
]):
...
...
@@ -389,22 +424,22 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
weights_backup
=
getattr
(
self
,
"network_weights_backup"
,
None
)
if
weights_backup
is
None
and
wanted_names
!=
():
if
current_names
!=
():
raise
RuntimeError
(
"
no backup weights found and current weights are not unchanged"
)
if
current_names
!=
()
and
not
allowed_layer_without_weight
(
self
)
:
raise
RuntimeError
(
f
"{network_layer_name} -
no backup weights found and current weights are not unchanged"
)
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
weights_backup
=
(
s
elf
.
in_proj_weight
.
to
(
devices
.
cpu
,
copy
=
True
),
self
.
out_proj
.
weight
.
to
(
devices
.
cpu
,
copy
=
True
))
weights_backup
=
(
s
tore_weights_backup
(
self
.
in_proj_weight
),
store_weights_backup
(
self
.
out_proj
.
weight
))
else
:
weights_backup
=
s
elf
.
weight
.
to
(
devices
.
cpu
,
copy
=
True
)
weights_backup
=
s
tore_weights_backup
(
self
.
weight
)
self
.
network_weights_backup
=
weights_backup
bias_backup
=
getattr
(
self
,
"network_bias_backup"
,
None
)
if
bias_backup
is
None
and
wanted_names
!=
():
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
)
and
self
.
out_proj
.
bias
is
not
None
:
bias_backup
=
s
elf
.
out_proj
.
bias
.
to
(
devices
.
cpu
,
copy
=
True
)
bias_backup
=
s
tore_weights_backup
(
self
.
out_proj
)
elif
getattr
(
self
,
'bias'
,
None
)
is
not
None
:
bias_backup
=
s
elf
.
bias
.
to
(
devices
.
cpu
,
copy
=
True
)
bias_backup
=
s
tore_weights_backup
(
self
.
bias
)
else
:
bias_backup
=
None
...
...
@@ -412,6 +447,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
# Only report if bias is not None and current bias are not unchanged.
if
bias_backup
is
not
None
and
current_names
!=
():
raise
RuntimeError
(
"no backup bias found and current bias are not unchanged"
)
self
.
network_bias_backup
=
bias_backup
if
current_names
!=
wanted_names
:
...
...
@@ -419,7 +455,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
for
net
in
loaded_networks
:
module
=
net
.
modules
.
get
(
network_layer_name
,
None
)
if
module
is
not
None
and
hasattr
(
self
,
'weight'
):
if
module
is
not
None
and
hasattr
(
self
,
'weight'
)
and
not
isinstance
(
module
,
modules
.
models
.
sd3
.
mmdit
.
QkvLinear
)
:
try
:
with
torch
.
no_grad
():
if
getattr
(
self
,
'fp16_weight'
,
None
)
is
None
:
...
...
@@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
continue
if
isinstance
(
self
,
modules
.
models
.
sd3
.
mmdit
.
QkvLinear
)
and
module_q
and
module_k
and
module_v
:
try
:
with
torch
.
no_grad
():
# Send "real" orig_weight into MHA's lora module
qw
,
kw
,
vw
=
self
.
weight
.
chunk
(
3
,
0
)
updown_q
,
_
=
module_q
.
calc_updown
(
qw
)
updown_k
,
_
=
module_k
.
calc_updown
(
kw
)
updown_v
,
_
=
module_v
.
calc_updown
(
vw
)
del
qw
,
kw
,
vw
updown_qkv
=
torch
.
vstack
([
updown_q
,
updown_k
,
updown_v
])
self
.
weight
+=
updown_qkv
except
RuntimeError
as
e
:
logging
.
debug
(
f
"Network {net.name} layer {network_layer_name}: {e}"
)
extra_network_lora
.
errors
[
net
.
name
]
=
extra_network_lora
.
errors
.
get
(
net
.
name
,
0
)
+
1
continue
if
module
is
None
:
continue
...
...
modules/models/sd3/mmdit.py
View file @
7e5cdaab
...
...
@@ -175,6 +175,9 @@ class VectorEmbedder(nn.Module):
#################################################################################
class
QkvLinear
(
torch
.
nn
.
Linear
):
pass
def
split_qkv
(
qkv
,
head_dim
):
qkv
=
qkv
.
reshape
(
qkv
.
shape
[
0
],
qkv
.
shape
[
1
],
3
,
-
1
,
head_dim
)
.
movedim
(
2
,
0
)
return
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
...
...
@@ -202,7 +205,7 @@ class SelfAttention(nn.Module):
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
,
dtype
=
dtype
,
device
=
device
)
self
.
qkv
=
Qkv
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
,
dtype
=
dtype
,
device
=
device
)
if
not
pre_only
:
self
.
proj
=
nn
.
Linear
(
dim
,
dim
,
dtype
=
dtype
,
device
=
device
)
assert
attn_mode
in
self
.
ATTENTION_MODES
...
...
modules/models/sd3/sd3_impls.py
View file @
7e5cdaab
...
...
@@ -67,6 +67,7 @@ class BaseModel(torch.nn.Module):
}
self
.
diffusion_model
=
MMDiT
(
input_size
=
None
,
pos_embed_scaling_factor
=
None
,
pos_embed_offset
=
None
,
pos_embed_max_size
=
pos_embed_max_size
,
patch_size
=
patch_size
,
in_channels
=
16
,
depth
=
depth
,
num_patches
=
num_patches
,
adm_in_channels
=
adm_in_channels
,
context_embedder_config
=
context_embedder_config
,
device
=
device
,
dtype
=
dtype
)
self
.
model_sampling
=
ModelSamplingDiscreteFlow
(
shift
=
shift
)
self
.
depth
=
depth
def
apply_model
(
self
,
x
,
sigma
,
c_crossattn
=
None
,
y
=
None
):
dtype
=
self
.
get_dtype
()
...
...
modules/models/sd3/sd3_model.py
View file @
7e5cdaab
...
...
@@ -82,3 +82,15 @@ class SD3Inferencer(torch.nn.Module):
def
fix_dimensions
(
self
,
width
,
height
):
return
width
//
16
*
16
,
height
//
16
*
16
def
diffusers_weight_mapping
(
self
):
for
i
in
range
(
self
.
model
.
depth
):
yield
f
"transformer.transformer_blocks.{i}.attn.to_q"
,
f
"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.to_k"
,
f
"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.to_v"
,
f
"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.to_out.0"
,
f
"diffusion_model_joint_blocks_{i}_x_block_attn_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.add_q_proj"
,
f
"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.add_k_proj"
,
f
"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.add_v_proj"
,
f
"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj"
yield
f
"transformer.transformer_blocks.{i}.attn.add_out_proj.0"
,
f
"diffusion_model_joint_blocks_{i}_context_block_attn_proj"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment