Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
7e4b06fc
Commit
7e4b06fc
authored
Jun 29, 2024
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
support loading clip/t5 from the main model checkpoint
parent
d67348a0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
25 deletions
+24
-25
modules/models/sd3/sd3_cond.py
modules/models/sd3/sd3_cond.py
+11
-19
modules/models/sd3/sd3_model.py
modules/models/sd3/sd3_model.py
+7
-3
modules/sd_models.py
modules/sd_models.py
+6
-3
No files found.
modules/models/sd3/sd3_cond.py
View file @
7e4b06fc
...
...
@@ -174,15 +174,10 @@ class SD3Cond(torch.nn.Module):
self
.
model_lg
=
Sd3ClipLG
(
self
.
clip_l
,
self
.
clip_g
)
self
.
model_t5
=
Sd3T5
(
self
.
t5xxl
)
self
.
weights_loaded
=
False
def
forward
(
self
,
prompts
:
list
[
str
]):
with
devices
.
without_autocast
():
lg_out
,
vector_out
=
self
.
model_lg
(
prompts
)
token_count
=
lg_out
.
shape
[
1
]
t5_out
=
self
.
model_t5
(
prompts
,
token_count
=
token_count
)
t5_out
=
self
.
model_t5
(
prompts
,
token_count
=
lg_out
.
shape
[
1
])
lgt_out
=
torch
.
cat
([
lg_out
,
t5_out
],
dim
=-
2
)
return
{
...
...
@@ -190,27 +185,24 @@ class SD3Cond(torch.nn.Module):
'vector'
:
vector_out
,
}
def
load_weights
(
self
):
if
self
.
weights_loaded
:
return
def
before_load_weights
(
self
,
state_dict
):
clip_path
=
os
.
path
.
join
(
shared
.
models_path
,
"CLIP"
)
clip_g_file
=
modelloader
.
load_file_from_url
(
CLIPG_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_g.safetensors"
)
with
safetensors
.
safe_open
(
clip_g_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_g
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
))
if
'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight'
not
in
state_dict
:
clip_g_file
=
modelloader
.
load_file_from_url
(
CLIPG_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_g.safetensors"
)
with
safetensors
.
safe_open
(
clip_g_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_g
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
))
clip_l_file
=
modelloader
.
load_file_from_url
(
CLIPL_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_l.safetensors"
)
with
safetensors
.
safe_open
(
clip_l_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_l
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
if
'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight'
not
in
state_dict
:
clip_l_file
=
modelloader
.
load_file_from_url
(
CLIPL_URL
,
model_dir
=
clip_path
,
file_name
=
"clip_l.safetensors"
)
with
safetensors
.
safe_open
(
clip_l_file
,
framework
=
"pt"
)
as
file
:
self
.
clip_l
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
if
self
.
t5xxl
:
if
self
.
t5xxl
and
'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight'
not
in
state_dict
:
t5_file
=
modelloader
.
load_file_from_url
(
T5_URL
,
model_dir
=
clip_path
,
file_name
=
"t5xxl_fp16.safetensors"
)
with
safetensors
.
safe_open
(
t5_file
,
framework
=
"pt"
)
as
file
:
self
.
t5xxl
.
transformer
.
load_state_dict
(
SafetensorsMapping
(
file
),
strict
=
False
)
self
.
weights_loaded
=
True
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
return
torch
.
tensor
([[
0
]],
device
=
devices
.
device
)
# XXX
...
...
modules/models/sd3/sd3_model.py
View file @
7e4b06fc
...
...
@@ -31,7 +31,7 @@ class SD3Inferencer(torch.nn.Module):
self
.
alphas_cumprod
=
1
/
(
self
.
model
.
model_sampling
.
sigmas
**
2
+
1
)
self
.
cond_stage_model
=
SD3Cond
()
self
.
text_encoders
=
SD3Cond
()
self
.
cond_stage_key
=
'txt'
self
.
parameterization
=
"eps"
...
...
@@ -40,8 +40,12 @@ class SD3Inferencer(torch.nn.Module):
self
.
latent_format
=
SD3LatentFormat
()
self
.
latent_channels
=
16
def
after_load_weights
(
self
):
self
.
cond_stage_model
.
load_weights
()
@
property
def
cond_stage_model
(
self
):
return
self
.
text_encoders
def
before_load_weights
(
self
,
state_dict
):
self
.
cond_stage_model
.
before_load_weights
(
state_dict
)
def
ema_scope
(
self
):
return
contextlib
.
nullcontext
()
...
...
modules/sd_models.py
View file @
7e4b06fc
...
...
@@ -434,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
# cache newly loaded model
checkpoints_loaded
[
checkpoint_info
]
=
state_dict
.
copy
()
if
hasattr
(
model
,
"before_load_weights"
):
model
.
before_load_weights
(
state_dict
)
model
.
load_state_dict
(
state_dict
,
strict
=
False
)
timer
.
record
(
"apply weights to model"
)
if
hasattr
(
model
,
"after_load_weights"
):
model
.
after_load_weights
(
state_dict
)
del
state_dict
# Set is_sdxl_inpaint flag.
...
...
@@ -838,9 +844,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with
sd_disable_initialization
.
LoadStateDictOnMeta
(
state_dict
,
device
=
model_target_device
(
sd_model
),
weight_dtype_conversion
=
weight_dtype_conversion
):
load_model_weights
(
sd_model
,
checkpoint_info
,
state_dict
,
timer
)
if
hasattr
(
sd_model
,
"after_load_weights"
):
sd_model
.
after_load_weights
()
timer
.
record
(
"load weights from state dict"
)
send_model_to_device
(
sd_model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment