Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
6c5f83b1
Commit
6c5f83b1
authored
Jul 13, 2023
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add support for SDXL loras with te1/te2 modules
parent
ff73841c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
12 deletions
+33
-12
extensions-builtin/Lora/lora.py
extensions-builtin/Lora/lora.py
+31
-10
modules/sd_models.py
modules/sd_models.py
+2
-1
modules/sd_models_xl.py
modules/sd_models_xl.py
+0
-1
No files found.
extensions-builtin/Lora/lora.py
View file @
6c5f83b1
...
@@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
...
@@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
return
f
"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return
f
"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
if
match
(
m
,
r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"
):
if
'mlp_fc1'
in
m
[
1
]:
return
f
"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif
'mlp_fc2'
in
m
[
1
]:
return
f
"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else
:
return
f
"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
return
key
return
key
...
@@ -142,10 +150,20 @@ class LoraUpDownModule:
...
@@ -142,10 +150,20 @@ class LoraUpDownModule:
def
assign_lora_names_to_compvis_modules
(
sd_model
):
def
assign_lora_names_to_compvis_modules
(
sd_model
):
lora_layer_mapping
=
{}
lora_layer_mapping
=
{}
for
name
,
module
in
shared
.
sd_model
.
cond_stage_model
.
wrapped
.
named_modules
():
if
shared
.
sd_model
.
is_sdxl
:
lora_name
=
name
.
replace
(
"."
,
"_"
)
for
i
,
embedder
in
enumerate
(
shared
.
sd_model
.
conditioner
.
embedders
):
lora_layer_mapping
[
lora_name
]
=
module
if
not
hasattr
(
embedder
,
'wrapped'
):
module
.
lora_layer_name
=
lora_name
continue
for
name
,
module
in
embedder
.
wrapped
.
named_modules
():
lora_name
=
f
'{i}_{name.replace(".", "_")}'
lora_layer_mapping
[
lora_name
]
=
module
module
.
lora_layer_name
=
lora_name
else
:
for
name
,
module
in
shared
.
sd_model
.
cond_stage_model
.
wrapped
.
named_modules
():
lora_name
=
name
.
replace
(
"."
,
"_"
)
lora_layer_mapping
[
lora_name
]
=
module
module
.
lora_layer_name
=
lora_name
for
name
,
module
in
shared
.
sd_model
.
model
.
named_modules
():
for
name
,
module
in
shared
.
sd_model
.
model
.
named_modules
():
lora_name
=
name
.
replace
(
"."
,
"_"
)
lora_name
=
name
.
replace
(
"."
,
"_"
)
...
@@ -168,10 +186,10 @@ def load_lora(name, lora_on_disk):
...
@@ -168,10 +186,10 @@ def load_lora(name, lora_on_disk):
keys_failed_to_match
=
{}
keys_failed_to_match
=
{}
is_sd2
=
'model_transformer_resblocks'
in
shared
.
sd_model
.
lora_layer_mapping
is_sd2
=
'model_transformer_resblocks'
in
shared
.
sd_model
.
lora_layer_mapping
for
key_diffusers
,
weight
in
sd
.
items
():
for
key_lora
,
weight
in
sd
.
items
():
key_diffusers_without_lora_parts
,
lora_key
=
key_diffusers
.
split
(
"."
,
1
)
key_lora_without_lora_parts
,
lora_key
=
key_lora
.
split
(
"."
,
1
)
key
=
convert_diffusers_name_to_compvis
(
key_diffusers_without_lora_parts
,
is_sd2
)
key
=
convert_diffusers_name_to_compvis
(
key_lora_without_lora_parts
,
is_sd2
)
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
key
,
None
)
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
key
,
None
)
if
sd_module
is
None
:
if
sd_module
is
None
:
...
@@ -180,12 +198,15 @@ def load_lora(name, lora_on_disk):
...
@@ -180,12 +198,15 @@ def load_lora(name, lora_on_disk):
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
m
.
group
(
1
),
None
)
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
m
.
group
(
1
),
None
)
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if
sd_module
is
None
and
"lora_unet"
in
key_diffusers_without_lora_parts
:
if
sd_module
is
None
and
"lora_unet"
in
key_lora_without_lora_parts
:
key
=
key_diffusers_without_lora_parts
.
replace
(
"lora_unet"
,
"diffusion_model"
)
key
=
key_lora_without_lora_parts
.
replace
(
"lora_unet"
,
"diffusion_model"
)
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
key
,
None
)
elif
sd_module
is
None
and
"lora_te1_text_model"
in
key_lora_without_lora_parts
:
key
=
key_lora_without_lora_parts
.
replace
(
"lora_te1_text_model"
,
"0_transformer_text_model"
)
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
key
,
None
)
sd_module
=
shared
.
sd_model
.
lora_layer_mapping
.
get
(
key
,
None
)
if
sd_module
is
None
:
if
sd_module
is
None
:
keys_failed_to_match
[
key_
diffusers
]
=
key
keys_failed_to_match
[
key_
lora
]
=
key
continue
continue
lora_module
=
lora
.
modules
.
get
(
key
,
None
)
lora_module
=
lora
.
modules
.
get
(
key
,
None
)
...
...
modules/sd_models.py
View file @
6c5f83b1
...
@@ -289,7 +289,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
...
@@ -289,7 +289,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if
state_dict
is
None
:
if
state_dict
is
None
:
state_dict
=
get_checkpoint_state_dict
(
checkpoint_info
,
timer
)
state_dict
=
get_checkpoint_state_dict
(
checkpoint_info
,
timer
)
if
hasattr
(
model
,
'conditioner'
):
model
.
is_sdxl
=
hasattr
(
model
,
'conditioner'
)
if
model
.
is_sdxl
:
sd_models_xl
.
extend_sdxl
(
model
)
sd_models_xl
.
extend_sdxl
(
model
)
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
...
...
modules/sd_models_xl.py
View file @
6c5f83b1
...
@@ -48,7 +48,6 @@ def extend_sdxl(model):
...
@@ -48,7 +48,6 @@ def extend_sdxl(model):
discretization
=
sgm
.
modules
.
diffusionmodules
.
discretizer
.
LegacyDDPMDiscretization
()
discretization
=
sgm
.
modules
.
diffusionmodules
.
discretizer
.
LegacyDDPMDiscretization
()
model
.
alphas_cumprod
=
torch
.
asarray
(
discretization
.
alphas_cumprod
,
device
=
devices
.
device
,
dtype
=
dtype
)
model
.
alphas_cumprod
=
torch
.
asarray
(
discretization
.
alphas_cumprod
,
device
=
devices
.
device
,
dtype
=
dtype
)
model
.
is_sdxl
=
True
sgm
.
models
.
diffusion
.
DiffusionEngine
.
get_learned_conditioning
=
get_learned_conditioning
sgm
.
models
.
diffusion
.
DiffusionEngine
.
get_learned_conditioning
=
get_learned_conditioning
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment