Commit e7c03ccd authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf

Merge branch 'dev' into extra-norm-module

parents d9cc27cb 007ecfbb
...@@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -6,9 +6,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self): def __init__(self):
super().__init__('lora') super().__init__('lora')
self.errors = {}
"""mapping of network names to the number of errors the network had during operation"""
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
self.errors.clear()
if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
...@@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -56,4 +61,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p): def deactivate(self, p):
pass if self.errors:
p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items()))
self.errors.clear()
import logging
import os import os
import re import re
...@@ -194,7 +195,7 @@ def load_network(name, network_on_disk): ...@@ -194,7 +195,7 @@ def load_network(name, network_on_disk):
net.modules[key] = net_module net.modules[key] = net_module
if keys_failed_to_match: if keys_failed_to_match:
print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
return net return net
...@@ -207,7 +208,6 @@ def purge_networks_from_memory(): ...@@ -207,7 +208,6 @@ def purge_networks_from_memory():
devices.torch_gc() devices.torch_gc()
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {} already_loaded = {}
...@@ -248,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No ...@@ -248,7 +248,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if net is None: if net is None:
failed_to_load_networks.append(name) failed_to_load_networks.append(name)
print(f"Couldn't find network with name {name}") logging.info(f"Couldn't find network with name {name}")
continue continue
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0
...@@ -257,7 +257,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No ...@@ -257,7 +257,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
loaded_networks.append(net) loaded_networks.append(net)
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory() purge_networks_from_memory()
...@@ -327,6 +327,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn ...@@ -327,6 +327,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
for net in loaded_networks: for net in loaded_networks:
module = net.modules.get(network_layer_name, None) module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad(): with torch.no_grad():
updown, ex_bias = module.calc_updown(self.weight) updown, ex_bias = module.calc_updown(self.weight)
...@@ -340,6 +341,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn ...@@ -340,6 +341,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.bias = torch.nn.Parameter(ex_bias) self.bias = torch.nn.Parameter(ex_bias)
else: else:
self.bias += ex_bias self.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue continue
module_q = net.modules.get(network_layer_name + "_q_proj", None) module_q = net.modules.get(network_layer_name + "_q_proj", None)
...@@ -348,6 +353,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn ...@@ -348,6 +353,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
module_out = net.modules.get(network_layer_name + "_out_proj", None) module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
try:
with torch.no_grad(): with torch.no_grad():
updown_q, _ = module_q.calc_updown(self.in_proj_weight) updown_q, _ = module_q.calc_updown(self.in_proj_weight)
updown_k, _ = module_k.calc_updown(self.in_proj_weight) updown_k, _ = module_k.calc_updown(self.in_proj_weight)
...@@ -362,12 +368,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn ...@@ -362,12 +368,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.out_proj.bias = torch.nn.Parameter(ex_bias) self.out_proj.bias = torch.nn.Parameter(ex_bias)
else: else:
self.out_proj.bias += ex_bias self.out_proj.bias += ex_bias
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
continue continue
if module is None: if module is None:
continue continue
print(f'failed to calculate network weights for layer {network_layer_name}') logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
self.network_current_names = wanted_names self.network_current_names = wanted_names
...@@ -540,6 +552,7 @@ def infotext_pasted(infotext, params): ...@@ -540,6 +552,7 @@ def infotext_pasted(infotext, params):
if added: if added:
params["Prompt"] += "\n" + "".join(added) params["Prompt"] += "\n" + "".join(added)
extra_network_lora = None
available_networks = {} available_networks = {}
available_network_aliases = {} available_network_aliases = {}
......
...@@ -23,9 +23,9 @@ def unload(): ...@@ -23,9 +23,9 @@ def unload():
def before_ui(): def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
extra_network = extra_networks_lora.ExtraNetworkLora() networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(extra_network) extra_networks.register_extra_network(networks.extra_network_lora)
extra_networks.register_extra_network_alias(extra_network, "lyco") extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
if not hasattr(torch.nn, 'Linear_forward_before_network'): if not hasattr(torch.nn, 'Linear_forward_before_network'):
......
...@@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -25,9 +25,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
item = { item = {
"name": name, "name": name,
"filename": lora_on_disk.filename, "filename": lora_on_disk.filename,
"shorthash": lora_on_disk.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename), "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""),
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata, "metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
......
...@@ -173,9 +173,12 @@ def git_clone(url, dir, name, commithash=None): ...@@ -173,9 +173,12 @@ def git_clone(url, dir, name, commithash=None):
if current_hash == commithash: if current_hash == commithash:
return return
run_git('fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False) if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url:
run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False)
run_git('checkout', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) run_git(dir, name, 'fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}", autofix=False)
run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True)
return return
...@@ -319,12 +322,12 @@ def prepare_environment(): ...@@ -319,12 +322,12 @@ def prepare_environment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
try: try:
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution # the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart")) os.remove(os.path.join(script_path, "tmp", "restart"))
os.environ.setdefault('SD_WEBUI_RESTARTING', '1') os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
except OSError: except OSError:
......
...@@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): ...@@ -52,9 +52,6 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
if has_mps: if has_mps:
# MPS fix for randn in torchsde
CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps')
if platform.mac_ver()[0].startswith("13.2."): if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)
......
This diff is collapsed.
...@@ -38,18 +38,12 @@ class ScriptRefiner(scripts.Script): ...@@ -38,18 +38,12 @@ class ScriptRefiner(scripts.Script):
return enable_refiner, refiner_checkpoint, refiner_switch_at return enable_refiner, refiner_checkpoint, refiner_switch_at
def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at): def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):
# the actual implementation is in sd_samplers_common.py, apply_refiner # the actual implementation is in sd_samplers_common.py, apply_refiner
if not enable_refiner or refiner_checkpoint in (None, "", "None"):
p.refiner_checkpoint_info = None p.refiner_checkpoint_info = None
p.refiner_switch_at = None p.refiner_switch_at = None
else:
if not enable_refiner or refiner_checkpoint in (None, "", "None"): p.refiner_checkpoint = refiner_checkpoint
return
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint)
if refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}')
p.refiner_checkpoint_info = refiner_checkpoint_info
p.refiner_switch_at = refiner_switch_at p.refiner_switch_at = refiner_switch_at
...@@ -58,7 +58,7 @@ class ScriptSeed(scripts.ScriptBuiltin): ...@@ -58,7 +58,7 @@ class ScriptSeed(scripts.ScriptBuiltin):
return self.seed, subseed, subseed_strength return self.seed, subseed, subseed_strength
def before_process(self, p, seed, subseed, subseed_strength): def setup(self, p, seed, subseed, subseed_strength):
p.seed = seed p.seed = seed
if subseed_strength > 0: if subseed_strength > 0:
......
...@@ -106,9 +106,16 @@ class Script: ...@@ -106,9 +106,16 @@ class Script:
pass pass
def setup(self, p, *args):
"""For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.
args contains all values returned by components from ui().
"""
pass
def before_process(self, p, *args): def before_process(self, p, *args):
""" """
This function is called very early before processing begins for AlwaysVisible scripts. This function is called very early during processing begins for AlwaysVisible scripts.
You can modify the processing object (p) here, inject hooks, etc. You can modify the processing object (p) here, inject hooks, etc.
args contains all values returned by components from ui() args contains all values returned by components from ui()
""" """
...@@ -706,6 +713,14 @@ class ScriptRunner: ...@@ -706,6 +713,14 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True) errors.report(f"Error running before_hr: {script.filename}", exc_info=True)
def setup_scrips(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.setup(p, *script_args)
except Exception:
errors.report(f"Error running setup: {script.filename}", exc_info=True)
scripts_txt2img: ScriptRunner = None scripts_txt2img: ScriptRunner = None
scripts_img2img: ScriptRunner = None scripts_img2img: ScriptRunner = None
......
from __future__ import annotations from __future__ import annotations
import math import math
import psutil import psutil
import platform
import torch import torch
from torch import einsum from torch import einsum
...@@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): ...@@ -94,7 +95,10 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
class SdOptimizationSubQuad(SdOptimization): class SdOptimizationSubQuad(SdOptimization):
name = "sub-quadratic" name = "sub-quadratic"
cmd_opt = "opt_sub_quad_attention" cmd_opt = "opt_sub_quad_attention"
priority = 10
@property
def priority(self):
return 1000 if shared.device.type == 'mps' else 10
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
...@@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): ...@@ -120,7 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
@property @property
def priority(self): def priority(self):
return 1000 if not torch.cuda.is_available() else 10 return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10
def apply(self): def apply(self):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
...@@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ ...@@ -427,7 +431,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
if chunk_threshold is None: if chunk_threshold is None:
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) if q.device.type == 'mps':
chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
else:
chunk_threshold_bytes = int(get_available_vram() * 0.7)
elif chunk_threshold == 0: elif chunk_threshold == 0:
chunk_threshold_bytes = None chunk_threshold_bytes = None
else: else:
......
...@@ -92,6 +92,14 @@ def images_tensor_to_samples(image, approximation=None, model=None): ...@@ -92,6 +92,14 @@ def images_tensor_to_samples(image, approximation=None, model=None):
model = shared.sd_model model = shared.sd_model
image = image.to(shared.device, dtype=devices.dtype_vae) image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1 image = image * 2 - 1
if len(image) > 1:
x_latent = torch.stack([
model.get_first_stage_encoding(
model.encode_first_stage(torch.unsqueeze(img, 0))
)[0]
for img in image
])
else:
x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))
return x_latent return x_latent
...@@ -145,7 +153,7 @@ def apply_refiner(cfg_denoiser): ...@@ -145,7 +153,7 @@ def apply_refiner(cfg_denoiser):
refiner_switch_at = cfg_denoiser.p.refiner_switch_at refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
if refiner_switch_at is not None and completed_ratio <= refiner_switch_at: if refiner_switch_at is not None and completed_ratio < refiner_switch_at:
return False return False
if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:
...@@ -276,19 +284,19 @@ class Sampler: ...@@ -276,19 +284,19 @@ class Sampler:
s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
s_noise = getattr(opts, 's_noise', p.s_noise) s_noise = getattr(opts, 's_noise', p.s_noise)
if s_churn != self.s_churn: if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:
extra_params_kwargs['s_churn'] = s_churn extra_params_kwargs['s_churn'] = s_churn
p.s_churn = s_churn p.s_churn = s_churn
p.extra_generation_params['Sigma churn'] = s_churn p.extra_generation_params['Sigma churn'] = s_churn
if s_tmin != self.s_tmin: if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:
extra_params_kwargs['s_tmin'] = s_tmin extra_params_kwargs['s_tmin'] = s_tmin
p.s_tmin = s_tmin p.s_tmin = s_tmin
p.extra_generation_params['Sigma tmin'] = s_tmin p.extra_generation_params['Sigma tmin'] = s_tmin
if s_tmax != self.s_tmax: if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:
extra_params_kwargs['s_tmax'] = s_tmax extra_params_kwargs['s_tmax'] = s_tmax
p.s_tmax = s_tmax p.s_tmax = s_tmax
p.extra_generation_params['Sigma tmax'] = s_tmax p.extra_generation_params['Sigma tmax'] = s_tmax
if s_noise != self.s_noise: if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:
extra_params_kwargs['s_noise'] = s_noise extra_params_kwargs['s_noise'] = s_noise
p.s_noise = s_noise p.s_noise = s_noise
p.extra_generation_params['Sigma noise'] = s_noise p.extra_generation_params['Sigma noise'] = s_noise
...@@ -305,5 +313,8 @@ class Sampler: ...@@ -305,5 +313,8 @@ class Sampler:
current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
raise NotImplementedError()
...@@ -22,6 +22,9 @@ samplers_k_diffusion = [ ...@@ -22,6 +22,9 @@ samplers_k_diffusion = [
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}),
('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Karras', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM++ 3M SDE Exponential', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde_exp'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, "brownian_noise": True}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
...@@ -42,6 +45,12 @@ sampler_extra_params = { ...@@ -42,6 +45,12 @@ sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
'sample_dpm_fast': ['s_noise'],
'sample_dpm_2_ancestral': ['s_noise'],
'sample_dpmpp_2s_ancestral': ['s_noise'],
'sample_dpmpp_sde': ['s_noise'],
'sample_dpmpp_2m_sde': ['s_noise'],
'sample_dpmpp_3m_sde': ['s_noise'],
} }
k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}
...@@ -67,6 +76,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler): ...@@ -67,6 +76,8 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
def __init__(self, funcname, sd_model, options=None): def __init__(self, funcname, sd_model, options=None):
super().__init__(funcname) super().__init__(funcname)
self.extra_params = sampler_extra_params.get(funcname, [])
self.options = options or {} self.options = options or {}
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
......
...@@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc ...@@ -11,7 +11,7 @@ from modules.models.diffusion.uni_pc import uni_pc
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
...@@ -42,7 +42,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta= ...@@ -42,7 +42,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps] alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas) sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
......
...@@ -285,12 +285,12 @@ options_templates.update(options_section(('ui', "Live previews"), { ...@@ -285,12 +285,12 @@ options_templates.update(options_section(('ui', "Live previews"), {
options_templates.update(options_section(('sampler-params', "Sampler parameters"), { options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
"hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(), "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),
"eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"), "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unperdictable results"),
"eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; applies to Euler a and other samplers that have a in them"), "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"),
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),
's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"),
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling; only applies to Euler, Heun, and DPM2'), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),
'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"),
'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"),
'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"),
......
...@@ -58,7 +58,7 @@ def _summarize_chunk( ...@@ -58,7 +58,7 @@ def _summarize_chunk(
scale: float, scale: float,
) -> AttnChunk: ) -> AttnChunk:
attn_weights = torch.baddbmm( attn_weights = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query, query,
key.transpose(1,2), key.transpose(1,2),
alpha=scale, alpha=scale,
...@@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking( ...@@ -121,7 +121,7 @@ def _get_attention_scores_no_kv_chunking(
scale: float, scale: float,
) -> Tensor: ) -> Tensor:
attn_scores = torch.baddbmm( attn_scores = torch.baddbmm(
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),
query, query,
key.transpose(1,2), key.transpose(1,2),
alpha=scale, alpha=scale,
......
...@@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): ...@@ -19,6 +19,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
return { return {
"name": checkpoint.name_for_extra, "name": checkpoint.name_for_extra,
"filename": checkpoint.filename, "filename": checkpoint.filename,
"shorthash": checkpoint.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
from modules import shared, ui_extra_networks from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js from modules.ui_extra_networks import quote_js
from modules.hashes import sha256_from_cache
class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
...@@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): ...@@ -14,13 +15,16 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def create_item(self, name, index=None, enable_filter=True): def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name] full_path = shared.hypernetworks[name]
path, ext = os.path.splitext(full_path) path, ext = os.path.splitext(full_path)
sha256 = sha256_from_cache(full_path, f'hypernet/{name}')
shorthash = sha256[0:10] if sha256 else None
return { return {
"name": name, "name": name,
"filename": full_path, "filename": full_path,
"shorthash": shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(path), "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""),
"prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"), "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"),
"local_preview": f"{path}.preview.{shared.opts.samples_format}", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, "sort_keys": {'default': index, **self.get_sort_keys(path + ext)},
......
...@@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): ...@@ -19,9 +19,10 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
return { return {
"name": name, "name": name,
"filename": embedding.filename, "filename": embedding.filename,
"shorthash": embedding.shorthash,
"preview": self.find_preview(path), "preview": self.find_preview(path),
"description": self.find_description(path), "description": self.find_description(path),
"search_term": self.search_terms_from_path(embedding.filename), "search_term": self.search_terms_from_path(embedding.filename) + " " + (embedding.hash or ""),
"prompt": quote_js(embedding.name), "prompt": quote_js(embedding.name),
"local_preview": f"{path}.preview.{shared.opts.samples_format}", "local_preview": f"{path}.preview.{shared.opts.samples_format}",
"sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)},
......
...@@ -93,11 +93,13 @@ class UserMetadataEditor: ...@@ -93,11 +93,13 @@ class UserMetadataEditor:
item = self.page.items.get(name, {}) item = self.page.items.get(name, {})
try: try:
filename = item["filename"] filename = item["filename"]
shorthash = item.get("shorthash", None)
stats = os.stat(filename) stats = os.stat(filename)
params = [ params = [
('Filename: ', os.path.basename(filename)), ('Filename: ', os.path.basename(filename)),
('File size: ', sysinfo.pretty_bytes(stats.st_size)), ('File size: ', sysinfo.pretty_bytes(stats.st_size)),
('Hash: ', shorthash),
('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),
] ]
...@@ -115,7 +117,7 @@ class UserMetadataEditor: ...@@ -115,7 +117,7 @@ class UserMetadataEditor:
errors.display(e, f"reading metadata info for {name}") errors.display(e, f"reading metadata info for {name}")
params = [] params = []
table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>' table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>'
return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')
......
This diff is collapsed.
...@@ -12,8 +12,6 @@ fi ...@@ -12,8 +12,6 @@ fi
export install_dir="$HOME" export install_dir="$HOME"
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2"
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
export PYTORCH_ENABLE_MPS_FALLBACK=1 export PYTORCH_ENABLE_MPS_FALLBACK=1
#################################################################### ####################################################################
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment