Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
650ddc9d
Commit
650ddc9d
authored
Mar 26, 2023
by
AUTOMATIC
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Lora support for SD2
parent
b705c9b7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
126 additions
and
39 deletions
+126
-39
extensions-builtin/Lora/lora.py
extensions-builtin/Lora/lora.py
+116
-39
extensions-builtin/Lora/scripts/lora_script.py
extensions-builtin/Lora/scripts/lora_script.py
+10
-0
No files found.
extensions-builtin/Lora/lora.py
View file @
650ddc9d
...
@@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors
...
@@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors
metadata_tags_order
=
{
"ss_sd_model_name"
:
1
,
"ss_resolution"
:
2
,
"ss_clip_skip"
:
3
,
"ss_num_train_images"
:
10
,
"ss_tag_frequency"
:
20
}
metadata_tags_order
=
{
"ss_sd_model_name"
:
1
,
"ss_resolution"
:
2
,
"ss_clip_skip"
:
3
,
"ss_num_train_images"
:
10
,
"ss_tag_frequency"
:
20
}
re_digits
=
re
.
compile
(
r"\d+"
)
re_digits
=
re
.
compile
(
r"\d+"
)
re_unet_down_blocks
=
re
.
compile
(
r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)"
)
re_x_proj
=
re
.
compile
(
r"(.*)_([qkv]_proj)$"
)
re_unet_mid_blocks
=
re
.
compile
(
r"lora_unet_mid_block_attentions_(\d+)_(.+)"
)
re_compiled
=
{}
re_unet_up_blocks
=
re
.
compile
(
r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)"
)
re_text_block
=
re
.
compile
(
r"lora_te_text_model_encoder_layers_(\d+)_(.+)"
)
suffix_conversion
=
{
"attentions"
:
{},
"resnets"
:
{
"conv1"
:
"in_layers_2"
,
"conv2"
:
"out_layers_3"
,
"time_emb_proj"
:
"emb_layers_1"
,
"conv_shortcut"
:
"skip_connection"
,
}
}
def
convert_diffusers_name_to_compvis
(
key
,
is_sd2
):
def
convert_diffusers_name_to_compvis
(
key
,
is_sd2
):
def
match
(
match_list
,
regex
):
def
match
(
match_list
,
regex_text
):
regex
=
re_compiled
.
get
(
regex_text
)
if
regex
is
None
:
regex
=
re
.
compile
(
regex_text
)
re_compiled
[
regex_text
]
=
regex
r
=
re
.
match
(
regex
,
key
)
r
=
re
.
match
(
regex
,
key
)
if
not
r
:
if
not
r
:
return
False
return
False
...
@@ -26,16 +39,25 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
...
@@ -26,16 +39,25 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
m
=
[]
m
=
[]
if
match
(
m
,
re_unet_down_blocks
):
if
match
(
m
,
r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"
):
return
f
"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
suffix
=
suffix_conversion
.
get
(
m
[
1
],
{})
.
get
(
m
[
3
],
m
[
3
])
return
f
"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if
match
(
m
,
r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"
):
suffix
=
suffix_conversion
.
get
(
m
[
0
],
{})
.
get
(
m
[
2
],
m
[
2
])
return
f
"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
if
match
(
m
,
re_unet_mid_blocks
):
if
match
(
m
,
r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"
):
return
f
"diffusion_model_middle_block_1_{m[1]}"
suffix
=
suffix_conversion
.
get
(
m
[
1
],
{})
.
get
(
m
[
3
],
m
[
3
])
return
f
"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
if
match
(
m
,
r
e_unet_up_blocks
):
if
match
(
m
,
r
"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"
):
return
f
"diffusion_model_
output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}
"
return
f
"diffusion_model_
input_blocks_{3 + m[0] * 3}_0_op
"
if
match
(
m
,
re_text_block
):
if
match
(
m
,
r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"
):
return
f
"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
if
match
(
m
,
r"lora_te_text_model_encoder_layers_(\d+)_(.+)"
):
if
is_sd2
:
if
is_sd2
:
if
'mlp_fc1'
in
m
[
1
]:
if
'mlp_fc1'
in
m
[
1
]:
return
f
"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
return
f
"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
...
@@ -109,16 +131,22 @@ def load_lora(name, filename):
...
@@ -109,16 +131,22 @@ def load_lora(name, filename):
sd
=
sd_models
.
read_state_dict
(
filename
)
sd
=
sd_models
.
read_state_dict
(
filename
)
keys_failed_to_match
=
[]
keys_failed_to_match
=
{}
is_sd2
=
'model_transformer_resblocks'
in
shared
.
sd_model
.
lora_layer_mapping
is_sd2
=
'model_transformer_resblocks'
in
shared
.
sd_model
.
lora_layer_mapping
for
key_diffusers
,
weight
in
sd
.
items
():
for
key_diffusers
,
weight
in
sd
.
items
():
fullkey
=
convert_diffusers_name_to_compvis
(
key_diffusers
,
is_sd2
)
key_diffusers_without_lora_parts
,
lora_key
=
key_diffusers
.
split
(
"."
,
1
)
key
,
lora_key
=
fullkey
.
split
(
"."
,
1
)
key
=
convert_diffusers_name_to_compvis
(
key_diffusers_without_lora_parts
,
is_sd2
)
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
key
,
None
)
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
key
,
None
)
if
sd_module
is
None
:
if
sd_module
is
None
:
keys_failed_to_match
.
append
(
key_diffusers
)
m
=
re_x_proj
.
match
(
key
)
if
m
:
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
m
.
group
(
1
),
None
)
if
sd_module
is
None
:
keys_failed_to_match
[
key_diffusers
]
=
key
continue
continue
lora_module
=
lora
.
modules
.
get
(
key
,
None
)
lora_module
=
lora
.
modules
.
get
(
key
,
None
)
...
@@ -133,7 +161,9 @@ def load_lora(name, filename):
...
@@ -133,7 +161,9 @@ def load_lora(name, filename):
if
type
(
sd_module
)
==
torch
.
nn
.
Linear
:
if
type
(
sd_module
)
==
torch
.
nn
.
Linear
:
module
=
torch
.
nn
.
Linear
(
weight
.
shape
[
1
],
weight
.
shape
[
0
],
bias
=
False
)
module
=
torch
.
nn
.
Linear
(
weight
.
shape
[
1
],
weight
.
shape
[
0
],
bias
=
False
)
elif
type
(
sd_module
)
==
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
:
elif
type
(
sd_module
)
==
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
:
module
=
torch
.
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
(
weight
.
shape
[
1
],
weight
.
shape
[
0
],
bias
=
False
)
module
=
torch
.
nn
.
Linear
(
weight
.
shape
[
1
],
weight
.
shape
[
0
],
bias
=
False
)
elif
type
(
sd_module
)
==
torch
.
nn
.
MultiheadAttention
:
module
=
torch
.
nn
.
Linear
(
weight
.
shape
[
1
],
weight
.
shape
[
0
],
bias
=
False
)
elif
type
(
sd_module
)
==
torch
.
nn
.
Conv2d
:
elif
type
(
sd_module
)
==
torch
.
nn
.
Conv2d
:
module
=
torch
.
nn
.
Conv2d
(
weight
.
shape
[
1
],
weight
.
shape
[
0
],
(
1
,
1
),
bias
=
False
)
module
=
torch
.
nn
.
Conv2d
(
weight
.
shape
[
1
],
weight
.
shape
[
0
],
(
1
,
1
),
bias
=
False
)
else
:
else
:
...
@@ -190,54 +220,94 @@ def load_loras(names, multipliers=None):
...
@@ -190,54 +220,94 @@ def load_loras(names, multipliers=None):
loaded_loras
.
append
(
lora
)
loaded_loras
.
append
(
lora
)
def
lora_apply_weights
(
self
:
torch
.
nn
.
Conv2d
|
torch
.
nn
.
Linear
):
def
lora_calc_updown
(
lora
,
module
,
target
):
with
torch
.
no_grad
():
up
=
module
.
up
.
weight
.
to
(
target
.
device
,
dtype
=
target
.
dtype
)
down
=
module
.
down
.
weight
.
to
(
target
.
device
,
dtype
=
target
.
dtype
)
if
up
.
shape
[
2
:]
==
(
1
,
1
)
and
down
.
shape
[
2
:]
==
(
1
,
1
):
updown
=
(
up
.
squeeze
(
2
)
.
squeeze
(
2
)
@
down
.
squeeze
(
2
)
.
squeeze
(
2
))
.
unsqueeze
(
2
)
.
unsqueeze
(
3
)
else
:
updown
=
up
@
down
updown
=
updown
*
lora
.
multiplier
*
(
module
.
alpha
/
module
.
up
.
weight
.
shape
[
1
]
if
module
.
alpha
else
1.0
)
return
updown
def
lora_apply_weights
(
self
:
torch
.
nn
.
Conv2d
|
torch
.
nn
.
Linear
|
torch
.
nn
.
MultiheadAttention
):
"""
"""
Applies the currently selected set of Loras to the weight of torch layer self.
Applies the currently selected set of Loras to the weight
s
of torch layer self.
If weights already have this particular set of loras applied, does nothing.
If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras.
If not, restores orginal weights from backup and alters weights according to loras.
"""
"""
lora_layer_name
=
getattr
(
self
,
'lora_layer_name'
,
None
)
if
lora_layer_name
is
None
:
return
current_names
=
getattr
(
self
,
"lora_current_names"
,
())
current_names
=
getattr
(
self
,
"lora_current_names"
,
())
wanted_names
=
tuple
((
x
.
name
,
x
.
multiplier
)
for
x
in
loaded_loras
)
wanted_names
=
tuple
((
x
.
name
,
x
.
multiplier
)
for
x
in
loaded_loras
)
weights_backup
=
getattr
(
self
,
"lora_weights_backup"
,
None
)
weights_backup
=
getattr
(
self
,
"lora_weights_backup"
,
None
)
if
weights_backup
is
None
:
if
weights_backup
is
None
:
weights_backup
=
self
.
weight
.
to
(
devices
.
cpu
,
copy
=
True
)
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
weights_backup
=
(
self
.
in_proj_weight
.
to
(
devices
.
cpu
,
copy
=
True
),
self
.
out_proj
.
weight
.
to
(
devices
.
cpu
,
copy
=
True
))
else
:
weights_backup
=
self
.
weight
.
to
(
devices
.
cpu
,
copy
=
True
)
self
.
lora_weights_backup
=
weights_backup
self
.
lora_weights_backup
=
weights_backup
if
current_names
!=
wanted_names
:
if
current_names
!=
wanted_names
:
if
weights_backup
is
not
None
:
if
weights_backup
is
not
None
:
self
.
weight
.
copy_
(
weights_backup
)
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
):
self
.
in_proj_weight
.
copy_
(
weights_backup
[
0
])
self
.
out_proj
.
weight
.
copy_
(
weights_backup
[
1
])
else
:
self
.
weight
.
copy_
(
weights_backup
)
lora_layer_name
=
getattr
(
self
,
'lora_layer_name'
,
None
)
for
lora
in
loaded_loras
:
for
lora
in
loaded_loras
:
module
=
lora
.
modules
.
get
(
lora_layer_name
,
None
)
module
=
lora
.
modules
.
get
(
lora_layer_name
,
None
)
if
module
is
None
:
if
module
is
not
None
and
hasattr
(
self
,
'weight'
):
self
.
weight
+=
lora_calc_updown
(
lora
,
module
,
self
.
weight
)
continue
continue
with
torch
.
no_grad
():
module_q
=
lora
.
modules
.
get
(
lora_layer_name
+
"_q_proj"
,
None
)
up
=
module
.
up
.
weight
.
to
(
self
.
weight
.
device
,
dtype
=
self
.
weight
.
dtype
)
module_k
=
lora
.
modules
.
get
(
lora_layer_name
+
"_k_proj"
,
None
)
down
=
module
.
down
.
weight
.
to
(
self
.
weight
.
device
,
dtype
=
self
.
weight
.
dtype
)
module_v
=
lora
.
modules
.
get
(
lora_layer_name
+
"_v_proj"
,
None
)
module_out
=
lora
.
modules
.
get
(
lora_layer_name
+
"_out_proj"
,
None
)
if
isinstance
(
self
,
torch
.
nn
.
MultiheadAttention
)
and
module_q
and
module_k
and
module_v
and
module_out
:
updown_q
=
lora_calc_updown
(
lora
,
module_q
,
self
.
in_proj_weight
)
updown_k
=
lora_calc_updown
(
lora
,
module_k
,
self
.
in_proj_weight
)
updown_v
=
lora_calc_updown
(
lora
,
module_v
,
self
.
in_proj_weight
)
updown_qkv
=
torch
.
vstack
([
updown_q
,
updown_k
,
updown_v
])
if
up
.
shape
[
2
:]
==
(
1
,
1
)
and
down
.
shape
[
2
:]
==
(
1
,
1
):
self
.
in_proj_weight
+=
updown_qkv
updown
=
(
up
.
squeeze
(
2
)
.
squeeze
(
2
)
@
down
.
squeeze
(
2
)
.
squeeze
(
2
))
.
unsqueeze
(
2
)
.
unsqueeze
(
3
)
self
.
out_proj
.
weight
+=
lora_calc_updown
(
lora
,
module_out
,
self
.
out_proj
.
weight
)
else
:
continue
updown
=
up
@
down
if
module
is
None
:
continue
self
.
weight
+=
updown
*
lora
.
multiplier
*
(
module
.
alpha
/
module
.
up
.
weight
.
shape
[
1
]
if
module
.
alpha
else
1.0
)
print
(
f
'failed to calculate lora weights for layer {lora_layer_name}'
)
setattr
(
self
,
"lora_current_names"
,
wanted_names
)
setattr
(
self
,
"lora_current_names"
,
wanted_names
)
def
lora_reset_cached_weight
(
self
:
torch
.
nn
.
Conv2d
|
torch
.
nn
.
Linear
):
setattr
(
self
,
"lora_current_names"
,
())
setattr
(
self
,
"lora_weights_backup"
,
None
)
def
lora_Linear_forward
(
self
,
input
):
def
lora_Linear_forward
(
self
,
input
):
lora_apply_weights
(
self
)
lora_apply_weights
(
self
)
return
torch
.
nn
.
Linear_forward_before_lora
(
self
,
input
)
return
torch
.
nn
.
Linear_forward_before_lora
(
self
,
input
)
def
lora_Linear_load_state_dict
(
self
:
torch
.
nn
.
Linear
,
*
args
,
**
kwargs
):
def
lora_Linear_load_state_dict
(
self
,
*
args
,
**
kwargs
):
setattr
(
self
,
"lora_current_names"
,
())
lora_reset_cached_weight
(
self
)
setattr
(
self
,
"lora_weights_backup"
,
None
)
return
torch
.
nn
.
Linear_load_state_dict_before_lora
(
self
,
*
args
,
**
kwargs
)
return
torch
.
nn
.
Linear_load_state_dict_before_lora
(
self
,
*
args
,
**
kwargs
)
...
@@ -248,15 +318,22 @@ def lora_Conv2d_forward(self, input):
...
@@ -248,15 +318,22 @@ def lora_Conv2d_forward(self, input):
return
torch
.
nn
.
Conv2d_forward_before_lora
(
self
,
input
)
return
torch
.
nn
.
Conv2d_forward_before_lora
(
self
,
input
)
def
lora_Conv2d_load_state_dict
(
self
:
torch
.
nn
.
Conv2d
,
*
args
,
**
kwargs
):
def
lora_Conv2d_load_state_dict
(
self
,
*
args
,
**
kwargs
):
setattr
(
self
,
"lora_current_names"
,
())
lora_reset_cached_weight
(
self
)
setattr
(
self
,
"lora_weights_backup"
,
None
)
return
torch
.
nn
.
Conv2d_load_state_dict_before_lora
(
self
,
*
args
,
**
kwargs
)
return
torch
.
nn
.
Conv2d_load_state_dict_before_lora
(
self
,
*
args
,
**
kwargs
)
def
lora_NonDynamicallyQuantizableLinear_forward
(
self
,
input
):
def
lora_MultiheadAttention_forward
(
self
,
*
args
,
**
kwargs
):
return
lora_forward
(
self
,
input
,
torch
.
nn
.
NonDynamicallyQuantizableLinear_forward_before_lora
(
self
,
input
))
lora_apply_weights
(
self
)
return
torch
.
nn
.
MultiheadAttention_forward_before_lora
(
self
,
*
args
,
**
kwargs
)
def
lora_MultiheadAttention_load_state_dict
(
self
,
*
args
,
**
kwargs
):
lora_reset_cached_weight
(
self
)
return
torch
.
nn
.
MultiheadAttention_load_state_dict_before_lora
(
self
,
*
args
,
**
kwargs
)
def
list_available_loras
():
def
list_available_loras
():
...
...
extensions-builtin/Lora/scripts/lora_script.py
View file @
650ddc9d
...
@@ -12,6 +12,8 @@ def unload():
...
@@ -12,6 +12,8 @@ def unload():
torch
.
nn
.
Linear
.
_load_from_state_dict
=
torch
.
nn
.
Linear_load_state_dict_before_lora
torch
.
nn
.
Linear
.
_load_from_state_dict
=
torch
.
nn
.
Linear_load_state_dict_before_lora
torch
.
nn
.
Conv2d
.
forward
=
torch
.
nn
.
Conv2d_forward_before_lora
torch
.
nn
.
Conv2d
.
forward
=
torch
.
nn
.
Conv2d_forward_before_lora
torch
.
nn
.
Conv2d
.
_load_from_state_dict
=
torch
.
nn
.
Conv2d_load_state_dict_before_lora
torch
.
nn
.
Conv2d
.
_load_from_state_dict
=
torch
.
nn
.
Conv2d_load_state_dict_before_lora
torch
.
nn
.
MultiheadAttention
.
forward
=
torch
.
nn
.
MultiheadAttention_forward_before_lora
torch
.
nn
.
MultiheadAttention
.
_load_from_state_dict
=
torch
.
nn
.
MultiheadAttention_load_state_dict_before_lora
def
before_ui
():
def
before_ui
():
...
@@ -31,10 +33,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
...
@@ -31,10 +33,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
if
not
hasattr
(
torch
.
nn
,
'Conv2d_load_state_dict_before_lora'
):
if
not
hasattr
(
torch
.
nn
,
'Conv2d_load_state_dict_before_lora'
):
torch
.
nn
.
Conv2d_load_state_dict_before_lora
=
torch
.
nn
.
Conv2d
.
_load_from_state_dict
torch
.
nn
.
Conv2d_load_state_dict_before_lora
=
torch
.
nn
.
Conv2d
.
_load_from_state_dict
if
not
hasattr
(
torch
.
nn
,
'MultiheadAttention_forward_before_lora'
):
torch
.
nn
.
MultiheadAttention_forward_before_lora
=
torch
.
nn
.
MultiheadAttention
.
forward
if
not
hasattr
(
torch
.
nn
,
'MultiheadAttention_load_state_dict_before_lora'
):
torch
.
nn
.
MultiheadAttention_load_state_dict_before_lora
=
torch
.
nn
.
MultiheadAttention
.
_load_from_state_dict
torch
.
nn
.
Linear
.
forward
=
lora
.
lora_Linear_forward
torch
.
nn
.
Linear
.
forward
=
lora
.
lora_Linear_forward
torch
.
nn
.
Linear
.
_load_from_state_dict
=
lora
.
lora_Linear_load_state_dict
torch
.
nn
.
Linear
.
_load_from_state_dict
=
lora
.
lora_Linear_load_state_dict
torch
.
nn
.
Conv2d
.
forward
=
lora
.
lora_Conv2d_forward
torch
.
nn
.
Conv2d
.
forward
=
lora
.
lora_Conv2d_forward
torch
.
nn
.
Conv2d
.
_load_from_state_dict
=
lora
.
lora_Conv2d_load_state_dict
torch
.
nn
.
Conv2d
.
_load_from_state_dict
=
lora
.
lora_Conv2d_load_state_dict
torch
.
nn
.
MultiheadAttention
.
forward
=
lora
.
lora_MultiheadAttention_forward
torch
.
nn
.
MultiheadAttention
.
_load_from_state_dict
=
lora
.
lora_MultiheadAttention_load_state_dict
script_callbacks
.
on_model_loaded
(
lora
.
assign_lora_names_to_compvis_modules
)
script_callbacks
.
on_model_loaded
(
lora
.
assign_lora_names_to_compvis_modules
)
script_callbacks
.
on_script_unloaded
(
unload
)
script_callbacks
.
on_script_unloaded
(
unload
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment