Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
a8fba9af
Commit
a8fba9af
authored
Jun 24, 2024
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
medvram support for SD3
parent
a65dd315
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
35 additions
and
8 deletions
+35
-8
modules/lowvram.py
modules/lowvram.py
+23
-5
modules/models/sd3/mmdit.py
modules/models/sd3/mmdit.py
+0
-1
modules/models/sd3/sd3_model.py
modules/models/sd3/sd3_model.py
+11
-1
modules/sd_samplers_common.py
modules/sd_samplers_common.py
+1
-1
No files found.
modules/lowvram.py
View file @
a8fba9af
from
collections
import
namedtuple
import
torch
from
modules
import
devices
,
shared
module_in_gpu
=
None
cpu
=
torch
.
device
(
"cpu"
)
ModuleWithParent
=
namedtuple
(
'ModuleWithParent'
,
[
'module'
,
'parent'
],
defaults
=
[
'None'
])
def
send_everything_to_cpu
():
global
module_in_gpu
...
...
@@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
(
sd_model
,
'depth_model'
),
(
sd_model
,
'embedder'
),
(
sd_model
,
'model'
),
(
sd_model
,
'embedder'
),
]
is_sdxl
=
hasattr
(
sd_model
,
'conditioner'
)
is_sd2
=
not
is_sdxl
and
hasattr
(
sd_model
.
cond_stage_model
,
'model'
)
if
is_sdxl
:
if
hasattr
(
sd_model
,
'medvram_fields'
):
to_remain_in_cpu
=
sd_model
.
medvram_fields
()
elif
is_sdxl
:
to_remain_in_cpu
.
append
((
sd_model
,
'conditioner'
))
elif
is_sd2
:
to_remain_in_cpu
.
append
((
sd_model
.
cond_stage_model
,
'model'
))
...
...
@@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
setattr
(
obj
,
field
,
module
)
# register hooks for those the first three models
if
is_sdxl
:
if
hasattr
(
sd_model
.
cond_stage_model
,
"medvram_modules"
):
for
module
in
sd_model
.
cond_stage_model
.
medvram_modules
():
if
isinstance
(
module
,
ModuleWithParent
):
parent
=
module
.
parent
module
=
module
.
module
else
:
parent
=
None
if
module
:
module
.
register_forward_pre_hook
(
send_me_to_gpu
)
if
parent
:
parents
[
module
]
=
parent
elif
is_sdxl
:
sd_model
.
conditioner
.
register_forward_pre_hook
(
send_me_to_gpu
)
elif
is_sd2
:
sd_model
.
cond_stage_model
.
model
.
register_forward_pre_hook
(
send_me_to_gpu
)
...
...
@@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model
.
first_stage_model
.
register_forward_pre_hook
(
send_me_to_gpu
)
sd_model
.
first_stage_model
.
encode
=
first_stage_model_encode_wrap
sd_model
.
first_stage_model
.
decode
=
first_stage_model_decode_wrap
if
sd_model
.
depth_model
:
if
hasattr
(
sd_model
,
'depth_model'
)
:
sd_model
.
depth_model
.
register_forward_pre_hook
(
send_me_to_gpu
)
if
sd_model
.
embedder
:
if
hasattr
(
sd_model
,
'embedder'
)
:
sd_model
.
embedder
.
register_forward_pre_hook
(
send_me_to_gpu
)
if
use_medvram
:
...
...
modules/models/sd3/mmdit.py
View file @
a8fba9af
...
...
@@ -492,7 +492,6 @@ class MMDiT(nn.Module):
device
=
None
,
):
super
()
.
__init__
()
print
(
f
"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}"
)
self
.
dtype
=
dtype
self
.
learn_sigma
=
learn_sigma
self
.
in_channels
=
in_channels
...
...
modules/models/sd3/sd3_model.py
View file @
a8fba9af
...
...
@@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module):
def
encode_embedding_init_text
(
self
,
init_text
,
nvpt
):
return
torch
.
tensor
([[
0
]],
device
=
devices
.
device
)
# XXX
def
medvram_modules
(
self
):
return
[
self
.
clip_g
,
self
.
clip_l
,
self
.
t5xxl
]
class
SD3Denoiser
(
k_diffusion
.
external
.
DiscreteSchedule
):
def
__init__
(
self
,
inner_model
,
sigmas
):
...
...
@@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module):
return
self
.
cond_stage_model
(
batch
)
def
apply_model
(
self
,
x
,
t
,
cond
):
return
self
.
model
.
apply_model
(
x
,
t
,
c_crossattn
=
cond
[
'crossattn'
],
y
=
cond
[
'vector'
])
return
self
.
model
(
x
,
t
,
c_crossattn
=
cond
[
'crossattn'
],
y
=
cond
[
'vector'
])
def
decode_first_stage
(
self
,
latent
):
latent
=
self
.
latent_format
.
process_out
(
latent
)
...
...
@@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module):
def
create_denoiser
(
self
):
return
SD3Denoiser
(
self
,
self
.
model
.
model_sampling
.
sigmas
)
def
medvram_fields
(
self
):
return
[
(
self
,
'first_stage_model'
),
(
self
,
'cond_stage_model'
),
(
self
,
'model'
),
]
modules/sd_samplers_common.py
View file @
a8fba9af
...
...
@@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
else
:
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
try
:
timestep
=
torch
.
argmin
(
torch
.
abs
(
cfg_denoiser
.
inner_model
.
sigmas
-
torch
.
max
(
sigma
)))
timestep
=
torch
.
argmin
(
torch
.
abs
(
cfg_denoiser
.
inner_model
.
sigmas
.
to
(
sigma
.
device
)
-
torch
.
max
(
sigma
)))
except
AttributeError
:
# for samplers that don't use sigmas (DDIM) sigma is actually the timestep
timestep
=
torch
.
max
(
sigma
)
.
to
(
dtype
=
int
)
completed_ratio
=
(
999
-
timestep
)
/
1000
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment