Commit 0198eaec authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #11757 from AUTOMATIC1111/sdxl

SD XL support
parents 9d3dd64f 14cf434b
...@@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2): ...@@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
if 'mlp_fc1' in m[1]:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
elif 'mlp_fc2' in m[1]:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
else:
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
return key return key
...@@ -147,10 +155,20 @@ class LoraUpDownModule: ...@@ -147,10 +155,20 @@ class LoraUpDownModule:
def assign_lora_names_to_compvis_modules(sd_model): def assign_lora_names_to_compvis_modules(sd_model):
lora_layer_mapping = {} lora_layer_mapping = {}
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): if shared.sd_model.is_sdxl:
lora_name = name.replace(".", "_") for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
lora_layer_mapping[lora_name] = module if not hasattr(embedder, 'wrapped'):
module.lora_layer_name = lora_name continue
for name, module in embedder.wrapped.named_modules():
lora_name = f'{i}_{name.replace(".", "_")}'
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
else:
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
lora_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module
module.lora_layer_name = lora_name
for name, module in shared.sd_model.model.named_modules(): for name, module in shared.sd_model.model.named_modules():
lora_name = name.replace(".", "_") lora_name = name.replace(".", "_")
...@@ -173,10 +191,10 @@ def load_lora(name, lora_on_disk): ...@@ -173,10 +191,10 @@ def load_lora(name, lora_on_disk):
keys_failed_to_match = {} keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
for key_diffusers, weight in sd.items(): for key_lora, weight in sd.items():
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None) sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
...@@ -184,8 +202,16 @@ def load_lora(name, lora_on_disk): ...@@ -184,8 +202,16 @@ def load_lora(name, lora_on_disk):
if m: if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match[key_diffusers] = key keys_failed_to_match[key_lora] = key
continue continue
lora_module = lora.modules.get(key, None) lora_module = lora.modules.get(key, None)
...@@ -208,9 +234,9 @@ def load_lora(name, lora_on_disk): ...@@ -208,9 +234,9 @@ def load_lora(name, lora_on_disk):
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else: else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
continue continue
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}") raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad(): with torch.no_grad():
module.weight.copy_(weight) module.weight.copy_(weight)
...@@ -222,7 +248,7 @@ def load_lora(name, lora_on_disk): ...@@ -222,7 +248,7 @@ def load_lora(name, lora_on_disk):
elif lora_key == "lora_down.weight": elif lora_key == "lora_down.weight":
lora_module.down = module lora_module.down = module
else: else:
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")
if keys_failed_to_match: if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
......
...@@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None): ...@@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None):
return context_k, context_v return context_k, context_v
def attention_CrossAttention_forward(self, x, context=None, mask=None): def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
......
...@@ -237,11 +237,13 @@ def prepare_environment(): ...@@ -237,11 +237,13 @@ def prepare_environment():
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
...@@ -299,6 +301,7 @@ def prepare_environment(): ...@@ -299,6 +301,7 @@ def prepare_environment():
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
...@@ -323,6 +326,7 @@ def prepare_environment(): ...@@ -323,6 +326,7 @@ def prepare_environment():
exit(0) exit(0)
def configure_for_tests(): def configure_for_tests():
if "--api" not in sys.argv: if "--api" not in sys.argv:
sys.argv.append("--api") sys.argv.append("--api")
......
...@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
send_me_to_gpu(first_stage_model, None) send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z) return first_stage_model_decode(z)
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field to_remain_in_cpu = [
if hasattr(sd_model.cond_stage_model, 'model'): (sd_model, 'first_stage_model'),
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model (sd_model, 'depth_model'),
(sd_model, 'embedder'),
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then (sd_model, 'model'),
# send the model to GPU. Then put modules back. the modules will be in CPU. (sd_model, 'embedder'),
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model ]
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
is_sdxl = hasattr(sd_model, 'conditioner')
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
if is_sdxl:
to_remain_in_cpu.append((sd_model, 'conditioner'))
elif is_sd2:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
else:
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
stored = []
for obj, field in to_remain_in_cpu:
module = getattr(obj, field, None)
stored.append(module)
setattr(obj, field, None)
# send the model to GPU.
sd_model.to(devices.device) sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
# put modules back. the modules will be in CPU.
for (obj, field), module in zip(to_remain_in_cpu, stored):
setattr(obj, field, module)
# register hooks for those the first three models # register hooks for those the first three models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) if is_sdxl:
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = first_stage_model_decode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap
...@@ -73,11 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -73,11 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
if sd_model.embedder: if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if hasattr(sd_model.cond_stage_model, 'model'): if hasattr(sd_model, 'cond_stage_model'):
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
del sd_model.cond_stage_model.transformer
if use_medvram: if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu) sd_model.model.register_forward_pre_hook(send_me_to_gpu)
......
...@@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio ...@@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio
import modules.safe # noqa: F401 import modules.safe # noqa: F401
def mute_sdxl_imports():
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
class Dummy:
pass
module = Dummy()
module.LPIPS = None
sys.modules['taming.modules.losses.lpips'] = module
module = Dummy()
module.StableDataModuleFromConfig = None
sys.modules['sgm.data'] = module
# data_path = cmd_opts_pre.data # data_path = cmd_opts_pre.data
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
...@@ -18,8 +33,11 @@ for possible_sd_path in possible_sd_paths: ...@@ -18,8 +33,11 @@ for possible_sd_path in possible_sd_paths:
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
mute_sdxl_imports()
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
...@@ -35,6 +53,13 @@ for d, must_exist, what, options in path_dirs: ...@@ -35,6 +53,13 @@ for d, must_exist, what, options in path_dirs:
d = os.path.abspath(d) d = os.path.abspath(d)
if "atstart" in options: if "atstart" in options:
sys.path.insert(0, d) sys.path.insert(0, d)
elif "sgm" in options:
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
sys.path.insert(0, d)
import sgm # noqa: F401
sys.path.pop(0)
else: else:
sys.path.append(d) sys.path.append(d)
paths[what] = d paths[what] = d
...@@ -330,8 +330,21 @@ class StableDiffusionProcessing: ...@@ -330,8 +330,21 @@ class StableDiffusionProcessing:
caches is a list with items described above. caches is a list with items described above.
""" """
cached_params = (
required_prompts,
steps,
opts.CLIP_stop_at_last_layers,
shared.sd_model.sd_checkpoint_info,
extra_network_data,
opts.sdxl_crop_left,
opts.sdxl_crop_top,
self.width,
self.height,
)
for cache in caches: for cache in caches:
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: if cache[0] is not None and cached_params == cache[0]:
return cache[1] return cache[1]
cache = caches[0] cache = caches[0]
...@@ -339,14 +352,17 @@ class StableDiffusionProcessing: ...@@ -339,14 +352,17 @@ class StableDiffusionProcessing:
with devices.autocast(): with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps) cache[1] = function(shared.sd_model, required_prompts, steps)
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) cache[0] = cached_params
return cache[1] return cache[1]
def setup_conds(self): def setup_conds(self):
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
...@@ -523,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see ...@@ -523,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
def decode_first_stage(model, x): def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae): x = model.decode_first_stage(x.to(devices.dtype_vae))
x = model.decode_first_stage(x)
return x return x
......
from __future__ import annotations
import re import re
from collections import namedtuple from collections import namedtuple
from typing import List from typing import List
...@@ -109,7 +111,25 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ...@@ -109,7 +111,25 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
def get_learned_conditioning(model, prompts, steps): class SdConditioning(list):
"""
A list with prompts for stable diffusion's conditioner model.
Can also specify width and height of created image - SDXL needs it.
"""
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
super().__init__()
self.extend(prompts)
if copy_from is None:
copy_from = prompts
self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
self.width = width or getattr(copy_from, 'width', None)
self.height = height or getattr(copy_from, 'height', None)
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one. and the sampling step at which this condition is to be replaced by the next one.
...@@ -139,12 +159,17 @@ def get_learned_conditioning(model, prompts, steps): ...@@ -139,12 +159,17 @@ def get_learned_conditioning(model, prompts, steps):
res.append(cached) res.append(cached)
continue continue
texts = [x[1] for x in prompt_schedule] texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
conds = model.get_learned_conditioning(texts) conds = model.get_learned_conditioning(texts)
cond_schedule = [] cond_schedule = []
for i, (end_at_step, _) in enumerate(prompt_schedule): for i, (end_at_step, _) in enumerate(prompt_schedule):
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i])) if isinstance(conds, dict):
cond = {k: v[i] for k, v in conds.items()}
else:
cond = conds[i]
cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
cache[prompt] = cond_schedule cache[prompt] = cond_schedule
res.append(cond_schedule) res.append(cond_schedule)
...@@ -155,11 +180,13 @@ def get_learned_conditioning(model, prompts, steps): ...@@ -155,11 +180,13 @@ def get_learned_conditioning(model, prompts, steps):
re_AND = re.compile(r"\bAND\b") re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
def get_multicond_prompt_list(prompts):
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
res_indexes = [] res_indexes = []
prompt_flat_list = []
prompt_indexes = {} prompt_indexes = {}
prompt_flat_list = SdConditioning(prompts)
prompt_flat_list.clear()
for prompt in prompts: for prompt in prompts:
subprompts = re_AND.split(prompt) subprompts = re_AND.split(prompt)
...@@ -196,6 +223,7 @@ class MulticondLearnedConditioning: ...@@ -196,6 +223,7 @@ class MulticondLearnedConditioning:
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator. For each prompt, the list is obtained by splitting the prompt using the AND separator.
...@@ -214,20 +242,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne ...@@ -214,20 +242,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
class DictWithShape(dict):
def __init__(self, x, shape):
super().__init__()
self.update(x)
@property
def shape(self):
return self["crossattn"].shape
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond param = c[0][0].cond
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) is_dict = isinstance(param, dict)
if is_dict:
dict_cond = param
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
else:
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
for i, cond_schedule in enumerate(c): for i, cond_schedule in enumerate(c):
target_index = 0 target_index = 0
for current, entry in enumerate(cond_schedule): for current, entry in enumerate(cond_schedule):
if current_step <= entry.end_at_step: if current_step <= entry.end_at_step:
target_index = current target_index = current
break break
res[i] = cond_schedule[target_index].cond
if is_dict:
for k, param in cond_schedule[target_index].cond.items():
res[k][i] = param
else:
res[i] = cond_schedule[target_index].cond
return res return res
def stack_conds(tensors):
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
# and won't be able to torch.stack them. So this fixes that.
token_count = max([x.shape[0] for x in tensors])
for i in range(len(tensors)):
if tensors[i].shape[0] != token_count:
last_vector = tensors[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
return torch.stack(tensors)
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
param = c.batch[0][0].schedules[0].cond param = c.batch[0][0].schedules[0].cond
...@@ -249,16 +314,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): ...@@ -249,16 +314,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch) conds_list.append(conds_for_batch)
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes if isinstance(tensors[0], dict):
# and won't be able to torch.stack them. So this fixes that. keys = list(tensors[0].keys())
token_count = max([x.shape[0] for x in tensors]) stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
for i in range(len(tensors)): stacked = DictWithShape(stacked, stacked['crossattn'].shape)
if tensors[i].shape[0] != token_count: else:
last_vector = tensors[i][-1:] stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) return conds_list, stacked
re_attention = re.compile(r""" re_attention = re.compile(r"""
......
...@@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim ...@@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
import ldm.modules.encoders.modules import ldm.modules.encoders.modules
import sgm.modules.attention
import sgm.modules.diffusionmodules.model
import sgm.modules.diffusionmodules.openaimodel
import sgm.modules.encoders.modules
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
...@@ -56,6 +61,9 @@ def apply_optimizations(option=None): ...@@ -56,6 +61,9 @@ def apply_optimizations(option=None):
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
sgm.modules.diffusionmodules.model.nonlinearity = silu
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
if current_optimizer is not None: if current_optimizer is not None:
current_optimizer.undo() current_optimizer.undo()
current_optimizer = None current_optimizer = None
...@@ -89,6 +97,10 @@ def undo_optimizations(): ...@@ -89,6 +97,10 @@ def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def fix_checkpoint(): def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
...@@ -168,6 +180,32 @@ class StableDiffusionModelHijack: ...@@ -168,6 +180,32 @@ class StableDiffusionModelHijack:
undo_optimizations() undo_optimizations()
def hijack(self, m): def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
text_cond_models = []
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
typename = type(embedder).__name__
if typename == 'FrozenOpenCLIPEmbedder':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenCLIPEmbedder':
model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if len(text_cond_models) == 1:
m.cond_stage_model = text_cond_models[0]
else:
m.cond_stage_model = conditioner
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
......
...@@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
self.chunk_length = 75 self.chunk_length = 75
self.is_trainable = getattr(wrapped, 'is_trainable', False)
self.input_key = getattr(wrapped, 'input_key', 'txt')
self.legacy_ucg_val = None
def empty_chunk(self): def empty_chunk(self):
"""creates an empty PromptChunk and returns it""" """creates an empty PromptChunk and returns it"""
...@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
""" """
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
An example shape returned by this function can be: (2, 77, 768). An example shape returned by this function can be: (2, 77, 768).
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
""" """
...@@ -242,7 +247,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -242,7 +247,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
if hashes: if hashes:
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
return torch.hstack(zs) if getattr(self.wrapped, 'return_pooled', False):
return torch.hstack(zs), zs[0].pooled
else:
return torch.hstack(zs)
def process_tokens(self, remade_batch_tokens, batch_multipliers): def process_tokens(self, remade_batch_tokens, batch_multipliers):
""" """
...@@ -265,9 +273,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -265,9 +273,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean() original_mean = z.mean()
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean() new_mean = z.mean()
z = z * (original_mean / new_mean) z *= (original_mean / new_mean)
return z return z
...@@ -324,3 +332,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): ...@@ -324,3 +332,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded return embedded
class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
def encode_with_transformers(self, tokens):
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
if self.wrapped.layer == "last":
z = outputs.last_hidden_state
else:
z = outputs.hidden_states[self.wrapped.layer_idx]
return z
...@@ -32,6 +32,40 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit ...@@ -32,6 +32,40 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
def encode_embedding_init_text(self, init_text, nvpt): def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text) ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int) ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded
class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
self.id_start = tokenizer.encoder["<start_of_text>"]
self.id_end = tokenizer.encoder["<end_of_text>"]
self.id_pad = 0
def tokenize(self, texts):
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
tokenized = [tokenizer.encode(text) for text in texts]
return tokenized
def encode_with_transformers(self, tokens):
d = self.wrapped.encode_with_transformer(tokens)
z = d[self.wrapped.layer]
pooled = d.get("pooled")
if pooled is not None:
z.pooled = pooled
return z
def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
return embedded return embedded
This diff is collapsed.
...@@ -14,7 +14,7 @@ import ldm.modules.midas as midas ...@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd import tomesd
...@@ -289,6 +289,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer ...@@ -289,6 +289,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
if state_dict is None: if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
model.is_sdxl = hasattr(model, 'conditioner')
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
del state_dict del state_dict
timer.record("apply weights to model") timer.record("apply weights to model")
...@@ -334,7 +338,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer ...@@ -334,7 +338,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
model.logvar = model.logvar.to(devices.device) # fix for training if hasattr(model, 'logvar'):
model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae() sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae() sd_vae.clear_loaded_vae()
...@@ -391,10 +396,11 @@ def repair_config(sd_config): ...@@ -391,10 +396,11 @@ def repair_config(sd_config):
if not hasattr(sd_config.model.params, "use_ema"): if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False sd_config.model.params.use_ema = False
if shared.cmd_opts.no_half: if hasattr(sd_config.model.params, 'unet_config'):
sd_config.model.params.unet_config.params.use_fp16 = False if shared.cmd_opts.no_half:
elif shared.cmd_opts.upcast_sampling: sd_config.model.params.unet_config.params.use_fp16 = False
sd_config.model.params.unet_config.params.use_fp16 = True elif shared.cmd_opts.upcast_sampling:
sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
...@@ -407,6 +413,8 @@ def repair_config(sd_config): ...@@ -407,6 +413,8 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData: class SdModelData:
...@@ -441,6 +449,15 @@ class SdModelData: ...@@ -441,6 +449,15 @@ class SdModelData:
model_data = SdModelData() model_data = SdModelData()
def get_empty_cond(sd_model):
if hasattr(sd_model, 'conditioner'):
d = sd_model.get_learned_conditioning([""])
return d['crossattn']
else:
return sd_model.cond_stage_model([""])
def load_model(checkpoint_info=None, already_loaded_state_dict=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
...@@ -461,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -461,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
timer.record("find config") timer.record("find config")
...@@ -513,7 +530,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -513,7 +530,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("scripts callbacks") timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad(): with devices.autocast(), torch.no_grad():
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""]) sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt") timer.record("calculate empty prompt")
......
...@@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization ...@@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization
sd_configs_path = shared.sd_configs_path sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
config_default = shared.sd_default_config config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
...@@ -68,7 +71,11 @@ def guess_model_config_from_state_dict(sd, filename): ...@@ -68,7 +71,11 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
return config_sdxl
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
return config_sdxl_refiner
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip return config_unclip
......
from __future__ import annotations
import torch
import sgm.models.diffusion
import sgm.modules.diffusionmodules.denoiser_scaling
import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0
width = getattr(self, 'target_width', 1024)
height = getattr(self, 'target_height', 1024)
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
devices_args = dict(device=devices.device, dtype=devices.dtype)
sdxl_conds = {
"txt": batch,
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
}
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
return c
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
return self.model(x, t, cond)
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
return x
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
res = []
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
res.append(encoded)
return torch.cat(res, dim=1)
def process_texts(self, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
return embedder.process_texts(texts)
def get_target_prompt_token_count(self, token_count):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
return embedder.get_target_prompt_token_count(token_count)
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
sgm.modules.GeneralConditioner.process_texts = process_texts
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
def extend_sdxl(model):
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
dtype = next(model.model.diffusion_model.parameters()).dtype
model.model.diffusion_model.dtype = dtype
model.model.conditioning_key = 'crossattn'
model.cond_stage_key = 'txt'
# model.cond_stage_model will be set in sd_hijack
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
model.conditioner.wrapped = torch.nn.Module()
sgm.modules.attention.print = lambda *args: None
sgm.modules.diffusionmodules.model.print = lambda *args: None
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
sgm.modules.encoders.modules.print = lambda *args: None
# this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
...@@ -28,6 +28,9 @@ def create_sampler(name, model): ...@@ -28,6 +28,9 @@ def create_sampler(name, model):
assert config is not None, f'bad sampler name: {name}' assert config is not None, f'bad sampler name: {name}'
if model.is_sdxl and config.options.get("no_sdxl", False):
raise Exception(f"Sampler {config.name} is not supported for SDXL")
sampler = config.constructor(model) sampler = config.constructor(model)
sampler.config = config sampler.config = config
......
...@@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc ...@@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc
samplers_data_compvis = [ samplers_data_compvis = [
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}), sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}), sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
] ]
......
...@@ -53,6 +53,28 @@ k_diffusion_scheduler = { ...@@ -53,6 +53,28 @@ k_diffusion_scheduler = {
} }
def catenate_conds(conds):
if not isinstance(conds[0], dict):
return torch.cat(conds)
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
def subscript_cond(cond, a, b):
if not isinstance(cond, dict):
return cond[a:b]
return {key: vec[a:b] for key, vec in cond.items()}
def pad_cond(tensor, repeats, empty):
if not isinstance(tensor, dict):
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
return tensor
class CFGDenoiser(torch.nn.Module): class CFGDenoiser(torch.nn.Module):
""" """
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
...@@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module): ...@@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module):
if shared.sd_model.model.conditioning_key == "crossattn-adm": if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond) image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
else: else:
image_uncond = image_cond image_uncond = image_cond
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} if isinstance(uncond, dict):
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
else:
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
if not is_edit_model: if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
...@@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module): ...@@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module):
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
if num_repeats < 0: if num_repeats < 0:
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1) tensor = pad_cond(tensor, -num_repeats, empty)
self.padded_cond_uncond = True self.padded_cond_uncond = True
elif num_repeats > 0: elif num_repeats > 0:
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True self.padded_cond_uncond = True
if tensor.shape[1] == uncond.shape[1] or skip_uncond: if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model: if is_edit_model:
cond_in = torch.cat([tensor, uncond, uncond]) cond_in = catenate_conds([tensor, uncond, uncond])
elif skip_uncond: elif skip_uncond:
cond_in = tensor cond_in = tensor
else: else:
cond_in = torch.cat([tensor, uncond]) cond_in = catenate_conds([tensor, uncond])
if shared.batch_cond_uncond: if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in)) x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size): for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset a = batch_offset
b = a + batch_size b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b])) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else: else:
x_out = torch.zeros_like(x_in) x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
...@@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module): ...@@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module):
b = min(a + batch_size, tensor.shape[0]) b = min(a + batch_size, tensor.shape[0])
if not is_edit_model: if not is_edit_model:
c_crossattn = [tensor[a:b]] c_crossattn = subscript_cond(tensor, a, b)
else: else:
c_crossattn = torch.cat([tensor[a:b]], uncond) c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
if not skip_uncond: if not skip_uncond:
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:])) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
denoised_image_indexes = [x[0][0] for x in conds_list] denoised_image_indexes = [x[0][0] for x in conds_list]
if skip_uncond: if skip_uncond:
......
...@@ -2,9 +2,9 @@ import os ...@@ -2,9 +2,9 @@ import os
import torch import torch
from torch import nn from torch import nn
from modules import devices, paths from modules import devices, paths, shared
sd_vae_approx_model = None sd_vae_approx_models = {}
class VAEApprox(nn.Module): class VAEApprox(nn.Module):
...@@ -31,30 +31,55 @@ class VAEApprox(nn.Module): ...@@ -31,30 +31,55 @@ class VAEApprox(nn.Module):
return x return x
def download_model(model_path, model_url):
if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True)
print(f'Downloading VAEApprox model to: {model_path}')
torch.hub.download_url_to_file(model_url, model_path)
def model(): def model():
global sd_vae_approx_model model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
loaded_model = sd_vae_approx_models.get(model_name)
if sd_vae_approx_model is None: if loaded_model is None:
model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt") model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
sd_vae_approx_model = VAEApprox()
if not os.path.exists(model_path): if not os.path.exists(model_path):
model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt") model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
sd_vae_approx_model.eval() if not os.path.exists(model_path):
sd_vae_approx_model.to(devices.device, devices.dtype) model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
loaded_model = VAEApprox()
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
loaded_model.eval()
loaded_model.to(devices.device, devices.dtype)
sd_vae_approx_models[model_name] = loaded_model
return sd_vae_approx_model return loaded_model
def cheap_approximation(sample): def cheap_approximation(sample):
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
coefs = torch.tensor([ if shared.sd_model.is_sdxl:
[0.298, 0.207, 0.208], coeffs = [
[0.187, 0.286, 0.173], [ 0.3448, 0.4168, 0.4395],
[-0.158, 0.189, 0.264], [-0.1953, -0.0290, 0.0250],
[-0.184, -0.271, -0.473], [ 0.1074, 0.0886, -0.0163],
]).to(sample.device) [-0.3730, -0.2499, -0.2088],
]
else:
coeffs = [
[ 0.298, 0.207, 0.208],
[ 0.187, 0.286, 0.173],
[-0.158, 0.189, 0.264],
[-0.184, -0.271, -0.473],
]
coefs = torch.tensor(coeffs).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
......
...@@ -8,9 +8,9 @@ import os ...@@ -8,9 +8,9 @@ import os
import torch import torch
import torch.nn as nn import torch.nn as nn
from modules import devices, paths_internal from modules import devices, paths_internal, shared
sd_vae_taesd = None sd_vae_taesd_models = {}
def conv(n_in, n_out, **kwargs): def conv(n_in, n_out, **kwargs):
...@@ -61,9 +61,7 @@ class TAESD(nn.Module): ...@@ -61,9 +61,7 @@ class TAESD(nn.Module):
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def download_model(model_path): def download_model(model_path, model_url):
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.makedirs(os.path.dirname(model_path), exist_ok=True) os.makedirs(os.path.dirname(model_path), exist_ok=True)
...@@ -72,17 +70,19 @@ def download_model(model_path): ...@@ -72,17 +70,19 @@ def download_model(model_path):
def model(): def model():
global sd_vae_taesd model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
loaded_model = sd_vae_taesd_models.get(model_name)
if sd_vae_taesd is None: if loaded_model is None:
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
download_model(model_path) download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
if os.path.exists(model_path): if os.path.exists(model_path):
sd_vae_taesd = TAESD(model_path) loaded_model = TAESD(model_path)
sd_vae_taesd.eval() loaded_model.eval()
sd_vae_taesd.to(devices.device, devices.dtype) loaded_model.to(devices.device, devices.dtype)
sd_vae_taesd_models[model_name] = loaded_model
else: else:
raise FileNotFoundError('TAESD model not found') raise FileNotFoundError('TAESD model not found')
return sd_vae_taesd.decoder return loaded_model.decoder
...@@ -429,9 +429,16 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -429,9 +429,16 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
})) }))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
}))
options_templates.update(options_section(('optimizations', "Optimizations"), { options_templates.update(options_section(('optimizations', "Optimizations"), {
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
......
...@@ -15,6 +15,7 @@ kornia==0.6.7 ...@@ -15,6 +15,7 @@ kornia==0.6.7
lark==1.1.2 lark==1.1.2
numpy==1.23.5 numpy==1.23.5
omegaconf==2.2.3 omegaconf==2.2.3
open-clip-torch==2.20.0
piexif==1.1.3 piexif==1.1.3
psutil~=5.9.5 psutil~=5.9.5
pytorch_lightning==1.9.4 pytorch_lightning==1.9.4
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment