Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
d1324810
Commit
d1324810
authored
Apr 02, 2023
by
space-nuko
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Embed model merge metadata in .safetensors file
parent
22bcc7be
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
55 additions
and
4 deletions
+55
-4
modules/extras.py
modules/extras.py
+42
-2
modules/sd_models.py
modules/sd_models.py
+10
-1
modules/ui.py
modules/ui.py
+3
-1
No files found.
modules/extras.py
View file @
d1324810
import
os
import
os
import
re
import
re
import
shutil
import
shutil
import
json
import
torch
import
torch
...
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
...
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
return
tensor
return
tensor
def
run_modelmerger
(
id_task
,
primary_model_name
,
secondary_model_name
,
tertiary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
,
config_source
,
bake_in_vae
,
discard_weights
):
def
run_modelmerger
(
id_task
,
primary_model_name
,
secondary_model_name
,
tertiary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
,
config_source
,
bake_in_vae
,
discard_weights
,
save_metadata
):
shared
.
state
.
begin
()
shared
.
state
.
begin
()
shared
.
state
.
job
=
'model-merge'
shared
.
state
.
job
=
'model-merge'
...
@@ -241,13 +242,52 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
...
@@ -241,13 +242,52 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared
.
state
.
textinfo
=
"Saving"
shared
.
state
.
textinfo
=
"Saving"
print
(
f
"Saving to {output_modelname}..."
)
print
(
f
"Saving to {output_modelname}..."
)
metadata
=
{
"format"
:
"pt"
,
"models"
:
{},
"merge_recipe"
:
None
}
if
save_metadata
:
merge_recipe
=
{
"primary_model_hash"
:
primary_model_info
.
sha256
,
"secondary_model_hash"
:
secondary_model_info
.
sha256
if
secondary_model_info
else
None
,
"tertiary_model_hash"
:
tertiary_model_info
.
sha256
if
tertiary_model_info
else
None
,
"interp_method"
:
interp_method
,
"multiplier"
:
multiplier
,
"save_as_half"
:
save_as_half
,
"custom_name"
:
custom_name
,
"config_source"
:
config_source
,
"bake_in_vae"
:
bake_in_vae
,
"discard_weights"
:
discard_weights
,
"is_inpainting"
:
result_is_inpainting_model
,
"is_instruct_pix2pix"
:
result_is_instruct_pix2pix_model
}
metadata
[
"merge_recipe"
]
=
json
.
dumps
(
merge_recipe
)
def
add_model_metadata
(
checkpoint_info
):
metadata
[
"models"
][
checkpoint_info
.
sha256
]
=
{
"name"
:
checkpoint_info
.
name
,
"legacy_hash"
:
checkpoint_info
.
hash
,
"merge_recipe"
:
checkpoint_info
.
metadata
.
get
(
"merge_recipe"
,
None
)
}
metadata
[
"models"
]
.
update
(
checkpoint_info
.
metadata
.
get
(
"models"
,
{}))
add_model_metadata
(
primary_model_info
)
if
secondary_model_info
:
add_model_metadata
(
secondary_model_info
)
if
tertiary_model_info
:
add_model_metadata
(
tertiary_model_info
)
metadata
[
"models"
]
=
json
.
dumps
(
metadata
[
"models"
])
_
,
extension
=
os
.
path
.
splitext
(
output_modelname
)
_
,
extension
=
os
.
path
.
splitext
(
output_modelname
)
if
extension
.
lower
()
==
".safetensors"
:
if
extension
.
lower
()
==
".safetensors"
:
safetensors
.
torch
.
save_file
(
theta_0
,
output_modelname
,
metadata
=
{
"format"
:
"pt"
}
)
safetensors
.
torch
.
save_file
(
theta_0
,
output_modelname
,
metadata
=
metadata
)
else
:
else
:
torch
.
save
(
theta_0
,
output_modelname
)
torch
.
save
(
theta_0
,
output_modelname
)
sd_models
.
list_models
()
sd_models
.
list_models
()
created_model
=
next
((
ckpt
for
ckpt
in
sd_models
.
checkpoints_list
.
values
()
if
ckpt
.
name
==
filename
),
None
)
if
created_model
:
created_model
.
calculate_shorthash
()
create_config
(
output_modelname
,
config_source
,
primary_model_info
,
secondary_model_info
,
tertiary_model_info
)
create_config
(
output_modelname
,
config_source
,
primary_model_info
,
secondary_model_info
,
tertiary_model_info
)
...
...
modules/sd_models.py
View file @
d1324810
...
@@ -52,6 +52,15 @@ class CheckpointInfo:
...
@@ -52,6 +52,15 @@ class CheckpointInfo:
self
.
ids
=
[
self
.
hash
,
self
.
model_name
,
self
.
title
,
name
,
f
'{name} [{self.hash}]'
]
+
([
self
.
shorthash
,
self
.
sha256
,
f
'{self.name} [{self.shorthash}]'
]
if
self
.
shorthash
else
[])
self
.
ids
=
[
self
.
hash
,
self
.
model_name
,
self
.
title
,
name
,
f
'{name} [{self.hash}]'
]
+
([
self
.
shorthash
,
self
.
sha256
,
f
'{self.name} [{self.shorthash}]'
]
if
self
.
shorthash
else
[])
self
.
metadata
=
{}
_
,
ext
=
os
.
path
.
splitext
(
self
.
filename
)
if
ext
.
lower
()
==
".safetensors"
:
try
:
self
.
metadata
=
read_metadata_from_safetensors
(
filename
)
except
Exception
as
e
:
errors
.
display
(
e
,
f
"reading checkpoint metadata: {filename}"
)
def
register
(
self
):
def
register
(
self
):
checkpoints_list
[
self
.
title
]
=
self
checkpoints_list
[
self
.
title
]
=
self
for
id
in
self
.
ids
:
for
id
in
self
.
ids
:
...
...
modules/ui.py
View file @
d1324810
...
@@ -1019,8 +1019,9 @@ def create_ui():
...
@@ -1019,8 +1019,9 @@ def create_ui():
interp_method
.
change
(
fn
=
update_interp_description
,
inputs
=
[
interp_method
],
outputs
=
[
interp_description
])
interp_method
.
change
(
fn
=
update_interp_description
,
inputs
=
[
interp_method
],
outputs
=
[
interp_description
])
with
FormRow
():
with
FormRow
():
checkpoint_format
=
gr
.
Radio
(
choices
=
[
"ckpt"
,
"safetensors"
],
value
=
"
ckpt
"
,
label
=
"Checkpoint format"
,
elem_id
=
"modelmerger_checkpoint_format"
)
checkpoint_format
=
gr
.
Radio
(
choices
=
[
"ckpt"
,
"safetensors"
],
value
=
"
safetensors
"
,
label
=
"Checkpoint format"
,
elem_id
=
"modelmerger_checkpoint_format"
)
save_as_half
=
gr
.
Checkbox
(
value
=
False
,
label
=
"Save as float16"
,
elem_id
=
"modelmerger_save_as_half"
)
save_as_half
=
gr
.
Checkbox
(
value
=
False
,
label
=
"Save as float16"
,
elem_id
=
"modelmerger_save_as_half"
)
save_metadata
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Save metadata (.safetensors only)"
,
elem_id
=
"modelmerger_save_metadata"
)
with
FormRow
():
with
FormRow
():
with
gr
.
Column
():
with
gr
.
Column
():
...
@@ -1658,6 +1659,7 @@ def create_ui():
...
@@ -1658,6 +1659,7 @@ def create_ui():
config_source
,
config_source
,
bake_in_vae
,
bake_in_vae
,
discard_weights
,
discard_weights
,
save_metadata
,
],
],
outputs
=
[
outputs
=
[
primary_model_name
,
primary_model_name
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment