Commit c02e3a55 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #16030 from AUTOMATIC1111/sd3

Stable Diffusion 3 support
parents a30b19dd 9e404c31
...@@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h ...@@ -150,7 +150,7 @@ For the purposes of getting Google and other search engines to crawl the wiki, h
## Credits ## Credits
Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.
- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers - Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref
- k-diffusion - https://github.com/crowsonkb/k-diffusion.git - k-diffusion - https://github.com/crowsonkb/k-diffusion.git
- Spandrel - https://github.com/chaiNNer-org/spandrel implementing - Spandrel - https://github.com/chaiNNer-org/spandrel implementing
- GFPGAN - https://github.com/TencentARC/GFPGAN.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git
......
model:
target: modules.models.sd3.sd3_model.SD3Inferencer
params:
shift: 3
state_dict: null
...@@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model): ...@@ -130,7 +130,9 @@ def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping[network_name] = module network_layer_mapping[network_name] = module
module.network_layer_name = network_name module.network_layer_name = network_name
else: else:
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)
for name, module in cond_stage_model.named_modules():
network_name = name.replace(".", "_") network_name = name.replace(".", "_")
network_layer_mapping[network_name] = module network_layer_mapping[network_name] = module
module.network_layer_name = network_name module.network_layer_name = network_name
......
...@@ -57,7 +57,7 @@ class DeepDanbooru: ...@@ -57,7 +57,7 @@ class DeepDanbooru:
a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
with torch.no_grad(), devices.autocast(): with torch.no_grad(), devices.autocast():
x = torch.from_numpy(a).to(devices.device) x = torch.from_numpy(a).to(devices.device, devices.dtype)
y = self.model(x)[0].detach().cpu().numpy() y = self.model(x)[0].detach().cpu().numpy()
probability_dict = {} probability_dict = {}
......
from collections import namedtuple
import torch import torch
from modules import devices, shared from modules import devices, shared
module_in_gpu = None module_in_gpu = None
cpu = torch.device("cpu") cpu = torch.device("cpu")
ModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])
def send_everything_to_cpu(): def send_everything_to_cpu():
global module_in_gpu global module_in_gpu
...@@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -75,13 +78,14 @@ def setup_for_low_vram(sd_model, use_medvram):
(sd_model, 'depth_model'), (sd_model, 'depth_model'),
(sd_model, 'embedder'), (sd_model, 'embedder'),
(sd_model, 'model'), (sd_model, 'model'),
(sd_model, 'embedder'),
] ]
is_sdxl = hasattr(sd_model, 'conditioner') is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
if is_sdxl: if hasattr(sd_model, 'medvram_fields'):
to_remain_in_cpu = sd_model.medvram_fields()
elif is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner')) to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2: elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
...@@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -103,7 +107,21 @@ def setup_for_low_vram(sd_model, use_medvram):
setattr(obj, field, module) setattr(obj, field, module)
# register hooks for those the first three models # register hooks for those the first three models
if is_sdxl: if hasattr(sd_model, "cond_stage_model") and hasattr(sd_model.cond_stage_model, "medvram_modules"):
for module in sd_model.cond_stage_model.medvram_modules():
if isinstance(module, ModuleWithParent):
parent = module.parent
module = module.module
else:
parent = None
if module:
module.register_forward_pre_hook(send_me_to_gpu)
if parent:
parents[module] = parent
elif is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2: elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
...@@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -117,9 +135,9 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap
if sd_model.depth_model: if getattr(sd_model, 'depth_model', None) is not None:
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder: if getattr(sd_model, 'embedder', None) is not None:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if use_medvram: if use_medvram:
......
This diff is collapsed.
This diff is collapsed.
import os
import safetensors
import torch
import typing
from transformers import CLIPTokenizer, T5TokenizerFast
from modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser
from modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer
class SafetensorsMapping(typing.Mapping):
def __init__(self, file):
self.file = file
def __len__(self):
return len(self.file.keys())
def __iter__(self):
for key in self.file.keys():
yield key
def __getitem__(self, key):
return self.file.get_tensor(key)
CLIPL_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors"
CLIPL_CONFIG = {
"hidden_act": "quick_gelu",
"hidden_size": 768,
"intermediate_size": 3072,
"num_attention_heads": 12,
"num_hidden_layers": 12,
}
CLIPG_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors"
CLIPG_CONFIG = {
"hidden_act": "gelu",
"hidden_size": 1280,
"intermediate_size": 5120,
"num_attention_heads": 20,
"num_hidden_layers": 32,
}
T5_URL = "https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors"
T5_CONFIG = {
"d_ff": 10240,
"d_model": 4096,
"num_heads": 64,
"num_layers": 24,
"vocab_size": 32128,
}
class Sd3ClipLG(sd_hijack_clip.TextConditionalModel):
def __init__(self, clip_l, clip_g):
super().__init__()
self.clip_l = clip_l
self.clip_g = clip_g
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
empty = self.tokenizer('')["input_ids"]
self.id_start = empty[0]
self.id_end = empty[1]
self.id_pad = empty[1]
self.return_pooled = True
def tokenize(self, texts):
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
def encode_with_transformers(self, tokens):
tokens_g = tokens.clone()
for batch_pos in range(tokens_g.shape[0]):
index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)
tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0
l_out, l_pooled = self.clip_l(tokens)
g_out, g_pooled = self.clip_g(tokens_g)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
vector_out = torch.cat((l_pooled, g_pooled), dim=-1)
lg_out.pooled = vector_out
return lg_out
def encode_embedding_init_text(self, init_text, nvpt):
return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX
class Sd3T5(torch.nn.Module):
def __init__(self, t5xxl):
super().__init__()
self.t5xxl = t5xxl
self.tokenizer = T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl")
empty = self.tokenizer('', padding='max_length', max_length=2)["input_ids"]
self.id_end = empty[0]
self.id_pad = empty[1]
def tokenize(self, texts):
return self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
def tokenize_line(self, line, *, target_token_count=None):
if shared.opts.emphasis != "None":
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.tokenize([text for text, _ in parsed])
tokens = []
multipliers = []
for text_tokens, (text, weight) in zip(tokenized, parsed):
if text == 'BREAK' and weight == -1:
continue
tokens += text_tokens
multipliers += [weight] * len(text_tokens)
tokens += [self.id_end]
multipliers += [1.0]
if target_token_count is not None:
if len(tokens) < target_token_count:
tokens += [self.id_pad] * (target_token_count - len(tokens))
multipliers += [1.0] * (target_token_count - len(tokens))
else:
tokens = tokens[0:target_token_count]
multipliers = multipliers[0:target_token_count]
return tokens, multipliers
def forward(self, texts, *, token_count):
if not self.t5xxl or not shared.opts.sd3_enable_t5:
return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)
tokens_batch = []
for text in texts:
tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)
tokens_batch.append(tokens)
t5_out, t5_pooled = self.t5xxl(tokens_batch)
return t5_out
def encode_embedding_init_text(self, init_text, nvpt):
return torch.zeros((nvpt, 4096), device=devices.device) # XXX
class SD3Cond(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer = SD3Tokenizer()
with torch.no_grad():
self.clip_g = SDXLClipG(CLIPG_CONFIG, device="cpu", dtype=devices.dtype)
self.clip_l = SDClipModel(layer="hidden", layer_idx=-2, device="cpu", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)
if shared.opts.sd3_enable_t5:
self.t5xxl = T5XXLModel(T5_CONFIG, device="cpu", dtype=devices.dtype)
else:
self.t5xxl = None
self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)
self.model_t5 = Sd3T5(self.t5xxl)
def forward(self, prompts: list[str]):
with devices.without_autocast():
lg_out, vector_out = self.model_lg(prompts)
t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])
lgt_out = torch.cat([lg_out, t5_out], dim=-2)
return {
'crossattn': lgt_out,
'vector': vector_out,
}
def before_load_weights(self, state_dict):
clip_path = os.path.join(shared.models_path, "CLIP")
if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name="clip_g.safetensors")
with safetensors.safe_open(clip_g_file, framework="pt") as file:
self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))
if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:
clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name="clip_l.safetensors")
with safetensors.safe_open(clip_l_file, framework="pt") as file:
self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:
t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name="t5xxl_fp16.safetensors")
with safetensors.safe_open(t5_file, framework="pt") as file:
self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)
def encode_embedding_init_text(self, init_text, nvpt):
return torch.tensor([[0]], device=devices.device) # XXX
def medvram_modules(self):
return [self.clip_g, self.clip_l, self.t5xxl]
def get_token_count(self, text):
_, token_count = self.model_lg.process_texts([text])
return token_count
def get_target_prompt_token_count(self, token_count):
return self.model_lg.get_target_prompt_token_count(token_count)
This diff is collapsed.
import contextlib
import torch
import k_diffusion
from modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat
from modules.models.sd3.sd3_cond import SD3Cond
from modules import shared, devices
class SD3Denoiser(k_diffusion.external.DiscreteSchedule):
def __init__(self, inner_model, sigmas):
super().__init__(sigmas, quantize=shared.opts.enable_quantization)
self.inner_model = inner_model
def forward(self, input, sigma, **kwargs):
return self.inner_model.apply_model(input, sigma, **kwargs)
class SD3Inferencer(torch.nn.Module):
def __init__(self, state_dict, shift=3, use_ema=False):
super().__init__()
self.shift = shift
with torch.no_grad():
self.model = BaseModel(shift=shift, state_dict=state_dict, prefix="model.diffusion_model.", device="cpu", dtype=devices.dtype)
self.first_stage_model = SDVAE(device="cpu", dtype=devices.dtype_vae)
self.first_stage_model.dtype = self.model.diffusion_model.dtype
self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)
self.text_encoders = SD3Cond()
self.cond_stage_key = 'txt'
self.parameterization = "eps"
self.model.conditioning_key = "crossattn"
self.latent_format = SD3LatentFormat()
self.latent_channels = 16
@property
def cond_stage_model(self):
return self.text_encoders
def before_load_weights(self, state_dict):
self.cond_stage_model.before_load_weights(state_dict)
def ema_scope(self):
return contextlib.nullcontext()
def get_learned_conditioning(self, batch: list[str]):
return self.cond_stage_model(batch)
def apply_model(self, x, t, cond):
return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])
def decode_first_stage(self, latent):
latent = self.latent_format.process_out(latent)
return self.first_stage_model.decode(latent)
def encode_first_stage(self, image):
latent = self.first_stage_model.encode(image)
return self.latent_format.process_in(latent)
def get_first_stage_encoding(self, x):
return x
def create_denoiser(self):
return SD3Denoiser(self, self.model.model_sampling.sigmas)
def medvram_fields(self):
return [
(self, 'first_stage_model'),
(self, 'text_encoders'),
(self, 'model'),
]
def add_noise_to_latent(self, x, noise, amount):
return x * (1 - amount) + noise * amount
def fix_dimensions(self, width, height):
return width // 16 * 16, height // 16 * 16
...@@ -884,6 +884,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -884,6 +884,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.refiner_checkpoint_info is None: if p.refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}') raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')
if hasattr(shared.sd_model, 'fix_dimensions'):
p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)
p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra
p.sd_model_hash = shared.sd_model.sd_model_hash p.sd_model_hash = shared.sd_model.sd_model_hash
p.sd_vae_name = sd_vae.get_loaded_vae_name() p.sd_vae_name = sd_vae.get_loaded_vae_name()
...@@ -942,7 +945,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -942,7 +945,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)
p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
if p.scripts is not None: if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
...@@ -1736,10 +1740,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -1736,10 +1740,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latmask = latmask[0] latmask = latmask[0]
if self.mask_round: if self.mask_round:
latmask = np.around(latmask) latmask = np.around(latmask)
latmask = np.tile(latmask[None], (4, 1, 1)) latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))
self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype) self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)
self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype) self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)
# this needs to be fixed to be done in sample() using actual seeds for batches # this needs to be fixed to be done in sample() using actual seeds for batches
if self.inpainting_fill == 2: if self.inpainting_fill == 2:
......
...@@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, ...@@ -268,7 +268,7 @@ def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None,
class DictWithShape(dict): class DictWithShape(dict):
def __init__(self, x, shape): def __init__(self, x, shape=None):
super().__init__() super().__init__()
self.update(x) self.update(x)
......
...@@ -325,7 +325,10 @@ class StableDiffusionModelHijack: ...@@ -325,7 +325,10 @@ class StableDiffusionModelHijack:
if self.clip is None: if self.clip is None:
return "-", "-" return "-", "-"
_, token_count = self.clip.process_texts([text]) if hasattr(self.clip, 'get_token_count'):
token_count = self.clip.get_token_count(text)
else:
_, token_count = self.clip.process_texts([text])
return token_count, self.clip.get_target_prompt_token_count(token_count) return token_count, self.clip.get_target_prompt_token_count(token_count)
......
...@@ -27,24 +27,21 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC ...@@ -27,24 +27,21 @@ chunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenC
are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): class TextConditionalModel(torch.nn.Module):
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to def __init__(self):
have unlimited prompt length and assign weights to tokens in prompt.
"""
def __init__(self, wrapped, hijack):
super().__init__() super().__init__()
self.wrapped = wrapped self.hijack = sd_hijack.model_hijack
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
depending on model."""
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75 self.chunk_length = 75
self.is_trainable = getattr(wrapped, 'is_trainable', False) self.is_trainable = False
self.input_key = getattr(wrapped, 'input_key', 'txt') self.input_key = 'txt'
self.legacy_ucg_val = None self.return_pooled = False
self.comma_token = None
self.id_start = None
self.id_end = None
self.id_pad = None
def empty_chunk(self): def empty_chunk(self):
"""creates an empty PromptChunk and returns it""" """creates an empty PromptChunk and returns it"""
...@@ -210,10 +207,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -210,10 +207,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
""" """
if opts.use_old_emphasis_implementation:
import modules.sd_hijack_clip_old
return modules.sd_hijack_clip_old.forward_old(self, texts)
batch_chunks, token_count = self.process_texts(texts) batch_chunks, token_count = self.process_texts(texts)
used_embeddings = {} used_embeddings = {}
...@@ -252,7 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -252,7 +245,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
if getattr(self.wrapped, 'return_pooled', False): if self.return_pooled:
return torch.hstack(zs), zs[0].pooled return torch.hstack(zs), zs[0].pooled
else: else:
return torch.hstack(zs) return torch.hstack(zs)
...@@ -292,6 +285,34 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -292,6 +285,34 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
return z return z
class FrozenCLIPEmbedderWithCustomWordsBase(TextConditionalModel):
"""A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
have unlimited prompt length and assign weights to tokens in prompt.
"""
def __init__(self, wrapped, hijack):
super().__init__()
self.hijack = hijack
self.wrapped = wrapped
"""Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
depending on model."""
self.is_trainable = getattr(wrapped, 'is_trainable', False)
self.input_key = getattr(wrapped, 'input_key', 'txt')
self.return_pooled = getattr(self.wrapped, 'return_pooled', False)
self.legacy_ucg_val = None # for sgm codebase
def forward(self, texts):
if opts.use_old_emphasis_implementation:
import modules.sd_hijack_clip_old
return modules.sd_hijack_clip_old.forward_old(self, texts)
return super().forward(texts)
class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack) super().__init__(wrapped, hijack)
......
import collections import collections
import importlib
import os import os
import sys import sys
import threading import threading
import enum
import torch import torch
import re import re
...@@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig ...@@ -10,8 +12,6 @@ from omegaconf import OmegaConf, ListConfig
from urllib import request from urllib import request
import ldm.modules.midas as midas import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer from modules.timer import Timer
from modules.shared import opts from modules.shared import opts
...@@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name ...@@ -27,6 +27,14 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name
checkpoints_loaded = collections.OrderedDict() checkpoints_loaded = collections.OrderedDict()
class ModelType(enum.Enum):
SD1 = 1
SD2 = 2
SDXL = 3
SSD = 4
SD3 = 5
def replace_key(d, key, new_key, value): def replace_key(d, key, new_key, value):
keys = list(d.keys()) keys = list(d.keys())
...@@ -368,6 +376,37 @@ def check_fp8(model): ...@@ -368,6 +376,37 @@ def check_fp8(model):
return enable_fp8 return enable_fp8
def set_model_type(model, state_dict):
model.is_sd1 = False
model.is_sd2 = False
model.is_sdxl = False
model.is_ssd = False
model.is_sd3 = False
if "model.diffusion_model.x_embedder.proj.weight" in state_dict:
model.is_sd3 = True
model.model_type = ModelType.SD3
elif hasattr(model, 'conditioner'):
model.is_sdxl = True
if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():
model.is_ssd = True
model.model_type = ModelType.SSD
else:
model.model_type = ModelType.SDXL
elif hasattr(model.cond_stage_model, 'model'):
model.is_sd2 = True
model.model_type = ModelType.SD2
else:
model.is_sd1 = True
model.model_type = ModelType.SD1
def set_model_fields(model):
if not hasattr(model, 'latent_channels'):
model.latent_channels = 4
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash() sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash") timer.record("calculate hash")
...@@ -382,10 +421,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer ...@@ -382,10 +421,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if state_dict is None: if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
model.is_sdxl = hasattr(model, 'conditioner') set_model_type(model, state_dict)
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') set_model_fields(model)
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
if model.is_sdxl: if model.is_sdxl:
sd_models_xl.extend_sdxl(model) sd_models_xl.extend_sdxl(model)
...@@ -396,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer ...@@ -396,9 +434,15 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
# cache newly loaded model # cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy() checkpoints_loaded[checkpoint_info] = state_dict.copy()
if hasattr(model, "before_load_weights"):
model.before_load_weights(state_dict)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
timer.record("apply weights to model") timer.record("apply weights to model")
if hasattr(model, "after_load_weights"):
model.after_load_weights(state_dict)
del state_dict del state_dict
# Set is_sdxl_inpaint flag. # Set is_sdxl_inpaint flag.
...@@ -552,8 +596,7 @@ def patch_given_betas(): ...@@ -552,8 +596,7 @@ def patch_given_betas():
original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)
def repair_config(sd_config): def repair_config(sd_config, state_dict=None):
if not hasattr(sd_config.model.params, "use_ema"): if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False sd_config.model.params.use_ema = False
...@@ -563,8 +606,9 @@ def repair_config(sd_config): ...@@ -563,8 +606,9 @@ def repair_config(sd_config):
elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half": elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == "half":
sd_config.model.params.unet_config.params.use_fp16 = True sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: if hasattr(sd_config.model.params, 'first_stage_config'):
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
# For UnCLIP-L, override the hardcoded karlo directory # For UnCLIP-L, override the hardcoded karlo directory
if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"): if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
...@@ -580,6 +624,7 @@ def repair_config(sd_config): ...@@ -580,6 +624,7 @@ def repair_config(sd_config):
sd_config.model.params.unet_config.params.use_checkpoint = False sd_config.model.params.unet_config.params.use_checkpoint = False
def rescale_zero_terminal_snr_abar(alphas_cumprod): def rescale_zero_terminal_snr_abar(alphas_cumprod):
alphas_bar_sqrt = alphas_cumprod.sqrt() alphas_bar_sqrt = alphas_cumprod.sqrt()
...@@ -679,11 +724,15 @@ def get_empty_cond(sd_model): ...@@ -679,11 +724,15 @@ def get_empty_cond(sd_model):
p = processing.StableDiffusionProcessingTxt2Img() p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {}) extra_networks.activate(p, {})
if hasattr(sd_model, 'conditioner'): if hasattr(sd_model, 'get_learned_conditioning'):
d = sd_model.get_learned_conditioning([""]) d = sd_model.get_learned_conditioning([""])
return d['crossattn']
else: else:
return sd_model.cond_stage_model([""]) d = sd_model.cond_stage_model([""])
if isinstance(d, dict):
d = d['crossattn']
return d
def send_model_to_cpu(m): def send_model_to_cpu(m):
...@@ -715,6 +764,25 @@ def send_model_to_trash(m): ...@@ -715,6 +764,25 @@ def send_model_to_trash(m):
devices.torch_gc() devices.torch_gc()
def instantiate_from_config(config, state_dict=None):
constructor = get_obj_from_str(config["target"])
params = {**config.get("params", {})}
if state_dict and "state_dict" in params and params["state_dict"] is None:
params["state_dict"] = state_dict
return constructor(**params)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_model(checkpoint_info=None, already_loaded_state_dict=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
...@@ -739,7 +807,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -739,7 +807,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("find config") timer.record("find config")
sd_config = OmegaConf.load(checkpoint_config) sd_config = OmegaConf.load(checkpoint_config)
repair_config(sd_config) repair_config(sd_config, state_dict)
timer.record("load config") timer.record("load config")
...@@ -749,7 +817,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -749,7 +817,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
try: try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
with sd_disable_initialization.InitializeOnMeta(): with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model, state_dict)
except Exception as e: except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True) errors.display(e, "creating model quickly", full_traceback=True)
...@@ -758,7 +826,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -758,7 +826,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
with sd_disable_initialization.InitializeOnMeta(): with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model, state_dict)
sd_model.used_config = checkpoint_config sd_model.used_config = checkpoint_config
...@@ -775,6 +843,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -775,6 +843,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer) load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict") timer.record("load weights from state dict")
send_model_to_device(sd_model) send_model_to_device(sd_model)
......
...@@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml" ...@@ -23,6 +23,8 @@ config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml"
config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml")
config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml")
config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml")
config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml")
def is_using_v_parameterization_for_sd2(state_dict): def is_using_v_parameterization_for_sd2(state_dict):
""" """
...@@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename): ...@@ -71,11 +73,15 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if "model.diffusion_model.x_embedder.proj.weight" in sd:
return config_sd3
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
if diffusion_model_input.shape[1] == 9: if diffusion_model_input.shape[1] == 9:
return config_sdxl_inpainting return config_sdxl_inpainting
else: else:
return config_sdxl return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
...@@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename): ...@@ -99,7 +105,6 @@ def guess_model_config_from_state_dict(sd, filename):
if diffusion_model_input.shape[1] == 8: if diffusion_model_input.shape[1] == 8:
return config_instruct_pix2pix return config_instruct_pix2pix
if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:
if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:
return config_alt_diffusion_m18 return config_alt_diffusion_m18
......
...@@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion): ...@@ -32,3 +32,9 @@ class WebuiSdModel(LatentDiffusion):
is_sd1: bool is_sd1: bool
"""True if the model's architecture is SD 1.x""" """True if the model's architecture is SD 1.x"""
is_sd3: bool
"""True if the model's architecture is SD 3"""
latent_channels: int
"""number of layer in latent image representation; will be 16 in SD3 and 4 in other version"""
...@@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): ...@@ -54,7 +54,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else: else:
if model is None: if model is None:
model = shared.sd_model model = shared.sd_model
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
return x_sample return x_sample
...@@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None): ...@@ -163,7 +163,7 @@ def apply_refiner(cfg_denoiser, sigma=None):
else: else:
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
try: try:
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma))) timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))
except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep except AttributeError: # for samplers that don't use sigmas (DDIM) sigma is actually the timestep
timestep = torch.max(sigma).to(dtype=int) timestep = torch.max(sigma).to(dtype=int)
completed_ratio = (999 - timestep) / 1000 completed_ratio = (999 - timestep) / 1000
...@@ -246,7 +246,7 @@ class Sampler: ...@@ -246,7 +246,7 @@ class Sampler:
self.eta_infotext_field = 'Eta' self.eta_infotext_field = 'Eta'
self.eta_default = 1.0 self.eta_default = 1.0
self.conditioning_key = shared.sd_model.model.conditioning_key self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')
self.p = None self.p = None
self.model_wrap_cfg = None self.model_wrap_cfg = None
......
...@@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): ...@@ -53,8 +53,13 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):
@property @property
def inner_model(self): def inner_model(self):
if self.model_wrap is None: if self.model_wrap is None:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
if denoiser_constructor is not None:
self.model_wrap = denoiser_constructor()
else:
denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)
return self.model_wrap return self.model_wrap
...@@ -120,7 +125,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler): ...@@ -120,7 +125,7 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
if discard_next_to_last_sigma: if discard_next_to_last_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas return sigmas.cpu()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
...@@ -128,7 +133,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler): ...@@ -128,7 +133,10 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
sigmas = self.get_sigmas(p, steps) sigmas = self.get_sigmas(p, steps)
sigma_sched = sigmas[steps - t_enc - 1:] sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0] if hasattr(shared.sd_model, 'add_noise_to_latent'):
xi = shared.sd_model.add_noise_to_latent(x, noise, sigma_sched[0])
else:
xi = x + noise * sigma_sched[0]
if opts.img2img_extra_noise > 0: if opts.img2img_extra_noise > 0:
p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise
......
...@@ -8,9 +8,9 @@ sd_vae_approx_models = {} ...@@ -8,9 +8,9 @@ sd_vae_approx_models = {}
class VAEApprox(nn.Module): class VAEApprox(nn.Module):
def __init__(self): def __init__(self, latent_channels=4):
super(VAEApprox, self).__init__() super(VAEApprox, self).__init__()
self.conv1 = nn.Conv2d(4, 8, (7, 7)) self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7))
self.conv2 = nn.Conv2d(8, 16, (5, 5)) self.conv2 = nn.Conv2d(8, 16, (5, 5))
self.conv3 = nn.Conv2d(16, 32, (3, 3)) self.conv3 = nn.Conv2d(16, 32, (3, 3))
self.conv4 = nn.Conv2d(32, 64, (3, 3)) self.conv4 = nn.Conv2d(32, 64, (3, 3))
...@@ -40,7 +40,13 @@ def download_model(model_path, model_url): ...@@ -40,7 +40,13 @@ def download_model(model_path, model_url):
def model(): def model():
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" if shared.sd_model.is_sd3:
model_name = "vaeapprox-sd3.pt"
elif shared.sd_model.is_sdxl:
model_name = "vaeapprox-sdxl.pt"
else:
model_name = "model.pt"
loaded_model = sd_vae_approx_models.get(model_name) loaded_model = sd_vae_approx_models.get(model_name)
if loaded_model is None: if loaded_model is None:
...@@ -52,7 +58,7 @@ def model(): ...@@ -52,7 +58,7 @@ def model():
model_path = os.path.join(paths.models_path, "VAE-approx", model_name) model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name) download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
loaded_model = VAEApprox() loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None)) loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
loaded_model.eval() loaded_model.eval()
loaded_model.to(devices.device, devices.dtype) loaded_model.to(devices.device, devices.dtype)
...@@ -64,7 +70,18 @@ def model(): ...@@ -64,7 +70,18 @@ def model():
def cheap_approximation(sample): def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
if shared.sd_model.is_sdxl: if shared.sd_model.is_sd3:
coeffs = [
[-0.0645, 0.0177, 0.1052], [ 0.0028, 0.0312, 0.0650],
[ 0.1848, 0.0762, 0.0360], [ 0.0944, 0.0360, 0.0889],
[ 0.0897, 0.0506, -0.0364], [-0.0020, 0.1203, 0.0284],
[ 0.0855, 0.0118, 0.0283], [-0.0539, 0.0658, 0.1047],
[-0.0057, 0.0116, 0.0700], [-0.0412, 0.0281, -0.0039],
[ 0.1106, 0.1171, 0.1220], [-0.0248, 0.0682, -0.0481],
[ 0.0815, 0.0846, 0.1207], [-0.0120, -0.0055, -0.0867],
[-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],
]
elif shared.sd_model.is_sdxl:
coeffs = [ coeffs = [
[ 0.3448, 0.4168, 0.4395], [ 0.3448, 0.4168, 0.4395],
[-0.1953, -0.0290, 0.0250], [-0.1953, -0.0290, 0.0250],
......
...@@ -34,9 +34,9 @@ class Block(nn.Module): ...@@ -34,9 +34,9 @@ class Block(nn.Module):
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))
def decoder(): def decoder(latent_channels=4):
return nn.Sequential( return nn.Sequential(
Clamp(), conv(4, 64), nn.ReLU(), Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
...@@ -44,13 +44,13 @@ def decoder(): ...@@ -44,13 +44,13 @@ def decoder():
) )
def encoder(): def encoder(latent_channels=4):
return nn.Sequential( return nn.Sequential(
conv(3, 64), Block(64, 64), conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 4), conv(64, latent_channels),
) )
...@@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module): ...@@ -58,10 +58,14 @@ class TAESDDecoder(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
latent_shift = 0.5 latent_shift = 0.5
def __init__(self, decoder_path="taesd_decoder.pth"): def __init__(self, decoder_path="taesd_decoder.pth", latent_channels=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.decoder = decoder()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(decoder_path) else 4
self.decoder = decoder(latent_channels)
self.decoder.load_state_dict( self.decoder.load_state_dict(
torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
...@@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module): ...@@ -70,10 +74,14 @@ class TAESDEncoder(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
latent_shift = 0.5 latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth"): def __init__(self, encoder_path="taesd_encoder.pth", latent_channels=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.encoder = encoder()
if latent_channels is None:
latent_channels = 16 if "taesd3" in str(encoder_path) else 4
self.encoder = encoder(latent_channels)
self.encoder.load_state_dict( self.encoder.load_state_dict(
torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))
...@@ -87,7 +95,13 @@ def download_model(model_path, model_url): ...@@ -87,7 +95,13 @@ def download_model(model_path, model_url):
def decoder_model(): def decoder_model():
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" if shared.sd_model.is_sd3:
model_name = "taesd3_decoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_decoder.pth"
else:
model_name = "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name) loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None: if loaded_model is None:
...@@ -106,7 +120,13 @@ def decoder_model(): ...@@ -106,7 +120,13 @@ def decoder_model():
def encoder_model(): def encoder_model():
model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" if shared.sd_model.is_sd3:
model_name = "taesd3_encoder.pth"
elif shared.sd_model.is_sdxl:
model_name = "taesdxl_encoder.pth"
else:
model_name = "taesd_encoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name) loaded_model = sd_vae_taesd_models.get(model_name)
if loaded_model is None: if loaded_model is None:
......
...@@ -191,6 +191,10 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"), ...@@ -191,6 +191,10 @@ options_templates.update(options_section(('sdxl', "Stable Diffusion XL", "sd"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
})) }))
options_templates.update(options_section(('sd3', "Stable Diffusion 3", "sd"), {
"sd3_enable_t5": OptionInfo(False, "Enable T5").info("load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply"),
}))
options_templates.update(options_section(('vae', "VAE", "sd"), { options_templates.update(options_section(('vae', "VAE", "sd"), {
"sd_vae_explanation": OptionHTML(""" "sd_vae_explanation": OptionHTML("""
<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr> <abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>
......
...@@ -18,6 +18,7 @@ numpy==1.26.2 ...@@ -18,6 +18,7 @@ numpy==1.26.2
omegaconf==2.2.3 omegaconf==2.2.3
open-clip-torch==2.20.0 open-clip-torch==2.20.0
piexif==1.1.3 piexif==1.1.3
protobuf==3.20.0
psutil==5.9.5 psutil==5.9.5
pytorch_lightning==1.9.4 pytorch_lightning==1.9.4
resize-right==0.0.2 resize-right==0.0.2
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment