Commit da464a3f authored by AUTOMATIC1111's avatar AUTOMATIC1111

SDXL support

parent af081211
...@@ -224,6 +224,20 @@ def run_extensions_installers(settings_file): ...@@ -224,6 +224,20 @@ def run_extensions_installers(settings_file):
run_extension_installer(os.path.join(extensions_dir, dirname_extension)) run_extension_installer(os.path.join(extensions_dir, dirname_extension))
def mute_sdxl_imports():
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
import importlib
module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None))
module.LPIPS = None
sys.modules['taming.modules.losses.lpips'] = module
module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None))
module.StableDataModuleFromConfig = None
sys.modules['sgm.data'] = module
def prepare_environment(): def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
...@@ -319,11 +333,14 @@ def prepare_environment(): ...@@ -319,11 +333,14 @@ def prepare_environment():
if args.update_all_extensions: if args.update_all_extensions:
git_pull_recursive(extensions_dir) git_pull_recursive(extensions_dir)
mute_sdxl_imports()
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)
def configure_for_tests(): def configure_for_tests():
if "--api" not in sys.argv: if "--api" not in sys.argv:
sys.argv.append("--api") sys.argv.append("--api")
......
...@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
send_me_to_gpu(first_stage_model, None) send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z) return first_stage_model_decode(z)
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field to_remain_in_cpu = [
if hasattr(sd_model.cond_stage_model, 'model'): (sd_model, 'first_stage_model'),
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model (sd_model, 'depth_model'),
(sd_model, 'embedder'),
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then (sd_model, 'model'),
# send the model to GPU. Then put modules back. the modules will be in CPU. (sd_model, 'embedder'),
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model ]
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
if is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
else:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
stored = []
for obj, field in to_remain_in_cpu:
module = getattr(obj, field, None)
stored.append(module)
setattr(obj, field, None)
# send the model to GPU.
sd_model.to(devices.device) sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
# put modules back. the modules will be in CPU.
for (obj, field), module in zip(to_remain_in_cpu, stored):
setattr(obj, field, module)
# register hooks for those the first three models # register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) if is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap
...@@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'):
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
del sd_model.cond_stage_model.transformer
if use_medvram: if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu) sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else: else:
......
...@@ -20,7 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl ...@@ -20,7 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
...@@ -36,6 +36,13 @@ for d, must_exist, what, options in path_dirs: ...@@ -36,6 +36,13 @@ for d, must_exist, what, options in path_dirs:
d = os.path.abspath(d) d = os.path.abspath(d)
if "atstart" in options: if "atstart" in options:
sys.path.insert(0, d) sys.path.insert(0, d)
elif "sgm" in options:
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
sys.path.insert(0, d)
import sgm
sys.path.pop(0)
else: else:
sys.path.append(d) sys.path.append(d)
paths[what] = d paths[what] = d
...@@ -343,10 +343,13 @@ class StableDiffusionProcessing: ...@@ -343,10 +343,13 @@ class StableDiffusionProcessing:
return cache[1] return cache[1]
def setup_conds(self): def setup_conds(self):
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height)
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
......
from __future__ import annotations
import re import re
from collections import namedtuple from collections import namedtuple
from typing import List from typing import List
...@@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ...@@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
def get_learned_conditioning(model, prompts, steps): class SdConditioning(list):
"""
A list with prompts for stable diffusion's conditioner model.
Can also specify width and height of created image - SDXL needs it.
"""
def __init__(self, prompts, width=None, height=None):
super().__init__()
self.extend(prompts)
self.width = width or getattr(prompts, 'width', None)
self.height = height or getattr(prompts, 'height', None)
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one. and the sampling step at which this condition is to be replaced by the next one.
...@@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps): ...@@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps):
re_AND = re.compile(r"\bAND\b") re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
def get_multicond_prompt_list(prompts):
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = [] res_indexes = []
prompt_flat_list = []
prompt_indexes = {} prompt_indexes = {}
prompt_flat_list = SdConditioning(prompts)
prompt_flat_list.clear()
for prompt in prompts: for prompt in prompts:
subprompts = re_AND.split(prompt) subprompts = re_AND.split(prompt)
...@@ -201,6 +217,7 @@ class MulticondLearnedConditioning: ...@@ -201,6 +217,7 @@ class MulticondLearnedConditioning:
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator. For each prompt, the list is obtained by splitting the prompt using the AND separator.
......
...@@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim ...@@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
import ldm.modules.encoders.modules import ldm.modules.encoders.modules
import sgm.modules.attention
import sgm.modules.diffusionmodules.model
import sgm.modules.diffusionmodules.openaimodel
import sgm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
...@@ -56,6 +61,9 @@ def apply_optimizations(option=None): ...@@ -56,6 +61,9 @@ def apply_optimizations(option=None):
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
sgm.modules.diffusionmodules.model.nonlinearity = silu
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
if current_optimizer is not None: if current_optimizer is not None:
current_optimizer.undo() current_optimizer.undo()
current_optimizer = None current_optimizer = None
...@@ -89,6 +97,10 @@ def undo_optimizations(): ...@@ -89,6 +97,10 @@ def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def fix_checkpoint(): def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
...@@ -170,10 +182,19 @@ class StableDiffusionModelHijack: ...@@ -170,10 +182,19 @@ class StableDiffusionModelHijack:
if conditioner: if conditioner:
for i in range(len(conditioner.embedders)): for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i] embedder = conditioner.embedders[i]
if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': typename = type(embedder).__name__
if typename == 'FrozenOpenCLIPEmbedder':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
conditioner.embedders[i] = m.cond_stage_model conditioner.embedders[i] = m.cond_stage_model
if typename == 'FrozenCLIPEmbedder':
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self)
conditioner.embedders[i] = m.cond_stage_model
if typename == 'FrozenOpenCLIPEmbedder2':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings = m.cond_stage_model.roberta.embeddings
......
...@@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75 self.chunk_length = 75
self.is_trainable = getattr(wrapped, 'is_trainable', False)
self.input_key = getattr(wrapped, 'input_key', 'txt')
self.legacy_ucg_val = None
def empty_chunk(self): def empty_chunk(self):
"""creates an empty PromptChunk and returns it""" """creates an empty PromptChunk and returns it"""
...@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
""" """
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
An example shape returned by this function can be: (2, 77, 768). An example shape returned by this function can be: (2, 77, 768).
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
""" """
...@@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
self.hijack.comments.append(f"Used embeddings: {embeddings_list}") self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
return torch.hstack(zs) if getattr(self.wrapped, 'return_pooled', False):
return torch.hstack(zs), zs[0].pooled
else:
return torch.hstack(zs)
def process_tokens(self, remade_batch_tokens, batch_multipliers): def process_tokens(self, remade_batch_tokens, batch_multipliers):
""" """
...@@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean() original_mean = z.mean()
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean() new_mean = z.mean()
z = z * (original_mean / new_mean) z *= (original_mean / new_mean)
return z return z
......
...@@ -16,10 +16,6 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit ...@@ -16,10 +16,6 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
self.id_end = tokenizer.encoder["<end_of_text>"] self.id_end = tokenizer.encoder["<end_of_text>"]
self.id_pad = 0 self.id_pad = 0
self.is_trainable = getattr(wrapped, 'is_trainable', False)
self.input_key = getattr(wrapped, 'input_key', 'txt')
self.legacy_ucg_val = None
def tokenize(self, texts): def tokenize(self, texts):
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
...@@ -39,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit ...@@ -39,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
return embedded return embedded
class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
self.id_start = tokenizer.encoder["<start_of_text>"]
self.id_end = tokenizer.encoder["<end_of_text>"]
self.id_pad = 0
def tokenize(self, texts):
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
tokenized = [tokenizer.encode(text) for text in texts]
return tokenized
def encode_with_transformers(self, tokens):
d = self.wrapped.encode_with_transformer(tokens)
z = d[self.wrapped.layer]
pooled = d.get("pooled")
if pooled is not None:
z.pooled = pooled
return z
def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
return embedded
This diff is collapsed.
...@@ -411,6 +411,7 @@ def repair_config(sd_config): ...@@ -411,6 +411,7 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
class SdModelData: class SdModelData:
...@@ -445,6 +446,15 @@ class SdModelData: ...@@ -445,6 +446,15 @@ class SdModelData:
model_data = SdModelData() model_data = SdModelData()
def get_empty_cond(sd_model):
if hasattr(sd_model, 'conditioner'):
d = sd_model.get_learned_conditioning([""])
return d['crossattn']
else:
return sd_model.cond_stage_model([""])
def load_model(checkpoint_info=None, already_loaded_state_dict=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
...@@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict
timer.record("find config") timer.record("find config")
...@@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks") timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad(): with devices.autocast(), torch.no_grad():
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""]) sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt") timer.record("calculate empty prompt")
......
...@@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") ...@@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
...@@ -70,7 +71,9 @@ def guess_model_config_from_state_dict(sd, filename): ...@@ -70,7 +71,9 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
return config_sdxl
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip return config_unclip
......
from __future__ import annotations from __future__ import annotations
import sys
import torch import torch
import sgm.models.diffusion import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer import sgm.modules.diffusionmodules.discretizer
from modules import devices from modules import devices, shared, prompt_parser
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
for embedder in self.conditioner.embedders: for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0 embedder.ucg_rate = 0.0
c = self.conditioner({'txt': batch}) width = getattr(self, 'target_width', 1024)
height = getattr(self, 'target_height', 1024)
sdxl_conds = {
"txt": batch,
"original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype),
"target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
}
c = self.conditioner(sdxl_conds)
return c return c
...@@ -26,7 +38,7 @@ def extend_sdxl(model): ...@@ -26,7 +38,7 @@ def extend_sdxl(model):
model.model.diffusion_model.dtype = dtype model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn' model.model.conditioning_key = 'crossattn'
model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0]
model.cond_stage_key = model.cond_stage_model.input_key model.cond_stage_key = model.cond_stage_model.input_key
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
...@@ -34,7 +46,14 @@ def extend_sdxl(model): ...@@ -34,7 +46,14 @@ def extend_sdxl(model):
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
model.is_xl = True
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
sgm.modules.attention.print = lambda *args: None
sgm.modules.diffusionmodules.model.print = lambda *args: None
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
sgm.modules.encoders.modules.print = lambda *args: None
...@@ -186,7 +186,7 @@ class CFGDenoiser(torch.nn.Module): ...@@ -186,7 +186,7 @@ class CFGDenoiser(torch.nn.Module):
for batch_offset in range(0, x_out.shape[0], batch_size): for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset a = batch_offset
b = a + batch_size b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
......
...@@ -428,6 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -428,6 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
"sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"),
"sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"),
})) }))
options_templates.update(options_section(('optimizations', "Optimizations"), { options_templates.update(options_section(('optimizations', "Optimizations"), {
......
...@@ -15,6 +15,7 @@ kornia==0.6.7 ...@@ -15,6 +15,7 @@ kornia==0.6.7
lark==1.1.2 lark==1.1.2
numpy==1.23.5 numpy==1.23.5
omegaconf==2.2.3 omegaconf==2.2.3
open-clip-torch==2.20.0
piexif==1.1.3 piexif==1.1.3
psutil~=5.9.5 psutil~=5.9.5
pytorch_lightning==1.9.4 pytorch_lightning==1.9.4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment