Commit a8fba9af authored by AUTOMATIC1111's avatar AUTOMATIC1111

medvram support for SD3

parent a65dd315
from collections import namedtuple
import torch import torch
from modules import devices, shared from modules import devices, shared
module_in_gpu = None module_in_gpu = None
cpu = torch.device("cpu") cpu = torch.device("cpu")
ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
def send_everything_to_cpu(): def send_everything_to_cpu():
global module_in_gpu global module_in_gpu
...@@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
(sd_model, 'depth_model'), (sd_model, 'depth_model'),
(sd_model, 'embedder'), (sd_model, 'embedder'),
(sd_model, 'model'), (sd_model, 'model'),
(sd_model, 'embedder'),
] ]
is_sdxl = hasattr(sd_model, 'conditioner') is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
if is_sdxl: if hasattr(sd_model, 'medvram_fields'):
to_remain_in_cpu = sd_model.medvram_fields()
elif is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner')) to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2: elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
...@@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
setattr(obj, field, module) setattr(obj, field, module)
# register hooks for those the first three models # register hooks for those the first three models
if is_sdxl: if hasattr(sd_model.cond_stage_model, "medvram_modules"):
for module in sd_model.cond_stage_model.medvram_modules():
if isinstance(module, ModuleWithParent):
parent = module.parent
module = module.module
else:
parent = None
if module:
module.register_forward_pre_hook(send_me_to_gpu)
if parent:
parents[module] = parent
elif is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2: elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
...@@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model: if hasattr(sd_model, 'depth_model'):
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder: if hasattr(sd_model, 'embedder'):
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if use_medvram: if use_medvram:
......
...@@ -492,7 +492,6 @@ class MMDiT(nn.Module): ...@@ -492,7 +492,6 @@ class MMDiT(nn.Module):
device = None, device = None,
): ):
super().__init__() super().__init__()
print(f"mmdit initializing with: {input_size=}, {patch_size=}, {in_channels=}, {depth=}, {mlp_ratio=}, {learn_sigma=}, {adm_in_channels=}, {context_embedder_config=}, {register_length=}, {attn_mode=}, {rmsnorm=}, {scale_mod_only=}, {swiglu=}, {out_channels=}, {pos_embed_scaling_factor=}, {pos_embed_offset=}, {pos_embed_max_size=}, {num_patches=}, {qk_norm=}, {qkv_bias=}, {dtype=}, {device=}")
self.dtype = dtype self.dtype = dtype
self.learn_sigma = learn_sigma self.learn_sigma = learn_sigma
self.in_channels = in_channels self.in_channels = in_channels
......
...@@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module): ...@@ -120,6 +120,9 @@ class SD3Cond(torch.nn.Module):
def encode_embedding_init_text(self, init_text, nvpt): def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX return torch.tensor([[0]], device=devices.device) # XXX
def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]
class SD3Denoiser(k_diffusion.external.DiscreteSchedule): class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas): def __init__(self, inner_model, sigmas):
...@@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module): ...@@ -163,7 +166,7 @@ class SD3Inferencer(torch.nn.Module):
return self.cond_stage_model(batch) return self.cond_stage_model(batch)
def apply_model(self, x, t, cond): def apply_model(self, x, t, cond):
return self.model.apply_model(x, t, c_crossattn=cond['crossattn'], y=cond['vector']) return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
def decode_first_stage(self, latent): def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent) latent = self.latent_format.process_out(latent)
...@@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module): ...@@ -175,3 +178,10 @@ class SD3Inferencer(torch.nn.Module):
def create_denoiser(self): def create_denoiser(self):
return SD3Denoiser(self, self.model.model_sampling.sigmas) return SD3Denoiser(self, self.model.model_sampling.sigmas)
def medvram_fields(self):
return [
(self, 'first_stage_model'),
(self, 'cond_stage_model'),
(self, 'model'),
]
...@@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None): ...@@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
else: else:
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
try: try:
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma))) timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
timestep = torch.max(sigma).to(dtype=int) timestep = torch.max(sigma).to(dtype=int)
completed_ratio = (999 - timestep) / 1000 completed_ratio = (999 - timestep) / 1000
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment