Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
726769da
Commit
726769da
authored
Oct 31, 2022
by
Muhammad Rizqi Nur
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Checkpoint cache by combination key of checkpoint and vae
parent
b96d0c4e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
12 deletions
+23
-12
modules/sd_models.py
modules/sd_models.py
+16
-11
modules/sd_vae.py
modules/sd_vae.py
+7
-1
No files found.
modules/sd_models.py
View file @
726769da
...
...
@@ -160,11 +160,15 @@ def get_state_dict_from_checkpoint(pl_sd):
vae_ignore_keys
=
{
"model_ema.decay"
,
"model_ema.num_updates"
}
def
load_model_weights
(
model
,
checkpoint_info
,
force
=
False
):
def
load_model_weights
(
model
,
checkpoint_info
,
vae_file
=
"auto"
):
checkpoint_file
=
checkpoint_info
.
filename
sd_model_hash
=
checkpoint_info
.
hash
if
force
or
checkpoint_info
not
in
checkpoints_loaded
:
vae_file
=
sd_vae
.
resolve_vae
(
checkpoint_file
,
vae_file
=
vae_file
)
checkpoint_key
=
(
checkpoint_info
,
vae_file
)
if
checkpoint_key
not
in
checkpoints_loaded
:
print
(
f
"Loading weights [{sd_model_hash}] from {checkpoint_file}"
)
pl_sd
=
torch
.
load
(
checkpoint_file
,
map_location
=
shared
.
weight_load_location
)
...
...
@@ -185,24 +189,25 @@ def load_model_weights(model, checkpoint_info, force=False):
devices
.
dtype
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
else
torch
.
float16
devices
.
dtype_vae
=
torch
.
float32
if
shared
.
cmd_opts
.
no_half
or
shared
.
cmd_opts
.
no_half_vae
else
torch
.
float16
sd_vae
.
load_vae
(
model
,
checkpoint
_file
)
sd_vae
.
load_vae
(
model
,
vae
_file
)
model
.
first_stage_model
.
to
(
devices
.
dtype_vae
)
if
shared
.
opts
.
sd_checkpoint_cache
>
0
:
checkpoints_loaded
[
checkpoint_
info
]
=
model
.
state_dict
()
.
copy
()
checkpoints_loaded
[
checkpoint_
key
]
=
model
.
state_dict
()
.
copy
()
while
len
(
checkpoints_loaded
)
>
shared
.
opts
.
sd_checkpoint_cache
:
checkpoints_loaded
.
popitem
(
last
=
False
)
# LRU
else
:
print
(
f
"Loading weights [{sd_model_hash}] from cache"
)
checkpoints_loaded
.
move_to_end
(
checkpoint_info
)
model
.
load_state_dict
(
checkpoints_loaded
[
checkpoint_info
])
vae_name
=
sd_vae
.
get_filename
(
vae_file
)
print
(
f
"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache"
)
checkpoints_loaded
.
move_to_end
(
checkpoint_key
)
model
.
load_state_dict
(
checkpoints_loaded
[
checkpoint_key
])
model
.
sd_model_hash
=
sd_model_hash
model
.
sd_model_checkpoint
=
checkpoint_file
model
.
sd_checkpoint_info
=
checkpoint_info
def
load_model
(
checkpoint_info
=
None
,
force
=
False
):
def
load_model
(
checkpoint_info
=
None
):
from
modules
import
lowvram
,
sd_hijack
checkpoint_info
=
checkpoint_info
or
select_checkpoint
()
...
...
@@ -223,7 +228,7 @@ def load_model(checkpoint_info=None, force=False):
do_inpainting_hijack
()
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
load_model_weights
(
sd_model
,
checkpoint_info
,
force
=
force
)
load_model_weights
(
sd_model
,
checkpoint_info
)
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
lowvram
.
setup_for_low_vram
(
sd_model
,
shared
.
cmd_opts
.
medvram
)
...
...
@@ -250,7 +255,7 @@ def reload_model_weights(sd_model, info=None, force=False):
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
):
checkpoints_loaded
.
clear
()
load_model
(
checkpoint_info
,
force
=
force
)
load_model
(
checkpoint_info
)
return
shared
.
sd_model
if
shared
.
cmd_opts
.
lowvram
or
shared
.
cmd_opts
.
medvram
:
...
...
@@ -260,7 +265,7 @@ def reload_model_weights(sd_model, info=None, force=False):
sd_hijack
.
model_hijack
.
undo_hijack
(
sd_model
)
load_model_weights
(
sd_model
,
checkpoint_info
,
force
=
force
)
load_model_weights
(
sd_model
,
checkpoint_info
)
sd_hijack
.
model_hijack
.
hijack
(
sd_model
)
script_callbacks
.
model_loaded_callback
(
sd_model
)
...
...
modules/sd_vae.py
View file @
726769da
...
...
@@ -43,7 +43,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
vae_dict
.
update
(
res
)
return
vae_list
def
load_vae
(
model
,
checkpoint_file
,
vae_file
=
"auto"
):
def
resolve_vae
(
checkpoint_file
,
vae_file
=
"auto"
):
global
first_load
,
vae_dict
,
vae_list
# save_settings = False
...
...
@@ -94,6 +94,12 @@ def load_vae(model, checkpoint_file, vae_file="auto"):
if
vae_file
and
not
os
.
path
.
exists
(
vae_file
):
vae_file
=
None
return
vae_file
def
load_vae
(
model
,
vae_file
):
global
first_load
,
vae_dict
,
vae_list
# save_settings = False
if
vae_file
:
print
(
f
"Loading VAE weights from: {vae_file}"
)
vae_ckpt
=
torch
.
load
(
vae_file
,
map_location
=
shared
.
weight_load_location
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment