Commit b75b004f authored by AUTOMATIC1111's avatar AUTOMATIC1111

lora extension rework to include other types of networks

parent 7d26c479
from modules import extra_networks, shared from modules import extra_networks, shared
import lora import networks
class ExtraNetworkLora(extra_networks.ExtraNetwork): class ExtraNetworkLora(extra_networks.ExtraNetwork):
...@@ -9,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -9,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
...@@ -21,12 +21,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -21,12 +21,12 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
names.append(params.items[0]) names.append(params.items[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
lora.load_loras(names, multipliers) networks.load_networks(names, multipliers)
if shared.opts.lora_add_hashes_to_infotext: if shared.opts.lora_add_hashes_to_infotext:
lora_hashes = [] network_hashes = []
for item in lora.loaded_loras: for item in networks.loaded_networks:
shorthash = item.lora_on_disk.shorthash shorthash = item.network_on_disk.shorthash
if not shorthash: if not shorthash:
continue continue
...@@ -36,10 +36,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -36,10 +36,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
alias = alias.replace(":", "").replace(",", "") alias = alias.replace(":", "").replace(",", "")
lora_hashes.append(f"{alias}: {shorthash}") network_hashes.append(f"{alias}: {shorthash}")
if lora_hashes: if network_hashes:
p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes) p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p): def deactivate(self, p):
pass pass
import torch
def make_weight_cp(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
def rebuild_conventional(up, down, shape, dyn_dim=None):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
if dyn_dim is not None:
up = up[:, :dyn_dim]
down = down[:dyn_dim, :]
return (up @ down).reshape(shape)
import os
from collections import namedtuple
import torch
from modules import devices, sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
class NetworkOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename)
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
return metadata
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
self.shorthash = None
self.set_hash(
self.metadata.get('sshs_model_hash') or
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
''
)
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
if self.shorthash:
import networks
networks.available_network_hash_lookup[self.shorthash] = self
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
import networks
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
return self.name
else:
return self.alias
class Network: # LoraModule
def __init__(self, name, network_on_disk: NetworkOnDisk):
self.name = name
self.network_on_disk = network_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
self.mentioned_name = None
"""the text that was used to add the network to prompt - can be either name or an alias"""
class ModuleType:
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
return None
class NetworkModule:
def __init__(self, net: Network, weights: NetworkWeights):
self.network = net
self.network_key = weights.network_key
self.sd_key = weights.sd_key
self.sd_module = weights.sd_module
def calc_updown(self, target):
raise NotImplementedError()
def forward(self, x, y):
raise NotImplementedError()
import lyco_helpers
import network
import network_lyco
class ModuleTypeHada(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
return NetworkModuleHada(net, weights)
return None
class NetworkModuleHada(network_lyco.NetworkModuleLyco):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.w1a = weights.w["hada_w1_a"]
self.w1b = weights.w["hada_w1_b"]
self.dim = self.w1b.shape[0]
self.w2a = weights.w["hada_w2_a"]
self.w2b = weights.w["hada_w2_b"]
self.t1 = weights.w.get("hada_t1")
self.t2 = weights.w.get("hada_t2")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
if len(w1b.shape) == 4:
output_shape += w1b.shape[2:]
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
updown = updown1 * updown2
return self.finalize_updown(updown, orig_weight, output_shape)
import torch
import network
from modules import devices
class ModuleTypeLora(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
return NetworkModuleLora(net, weights)
return None
class NetworkModuleLora(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.up = self.create_module(weights.w["lora_up.weight"])
self.down = self.create_module(weights.w["lora_down.weight"])
self.alpha = weights.w["alpha"] if "alpha" in weights.w else None
def create_module(self, weight, none_ok=False):
if weight is None and none_ok:
return None
if type(self.sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
return None
with torch.no_grad():
module.weight.copy_(weight)
module.to(device=devices.cpu, dtype=devices.dtype)
module.weight.requires_grad_(False)
return module
def calc_updown(self, target):
up = self.up.weight.to(target.device, dtype=target.dtype)
down = self.down.weight.to(target.device, dtype=target.dtype)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
return updown
def forward(self, x, y):
self.up.to(device=devices.device)
self.down.to(device=devices.device)
return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0)
import torch
import lyco_helpers
import network
from modules import devices
class NetworkModuleLyco(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
def finalize_updown(self, updown, orig_weight, output_shape):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
scale = (
self.scale if self.scale is not None
else self.alpha / self.dim if self.dim is not None and self.alpha is not None
else 1.0
)
return updown * scale * self.network.multiplier
import os import os
import re import re
import network
import network_lora
import network_hada
import torch import torch
from typing import Union from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache from modules import shared, devices, sd_models, errors, scripts, sd_hijack
module_types = [
network_lora.ModuleTypeLora(),
network_hada.ModuleTypeHada(),
]
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+") re_digits = re.compile(r"\d+")
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
...@@ -79,81 +88,8 @@ def convert_diffusers_name_to_compvis(key, is_sd2): ...@@ -79,81 +88,8 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
return key return key
class LoraOnDisk: def assign_network_names_to_compvis_modules(sd_model):
def __init__(self, name, filename): network_layer_mapping = {}
self.name = name
self.filename = filename
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename)
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
return metadata
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
self.shorthash = None
self.set_hash(
self.metadata.get('sshs_model_hash') or
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
''
)
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
if self.shorthash:
available_lora_hash_lookup[self.shorthash] = self
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases:
return self.name
else:
return self.alias
class LoraModule:
def __init__(self, name, lora_on_disk: LoraOnDisk):
self.name = name
self.lora_on_disk = lora_on_disk
self.multiplier = 1.0
self.modules = {}
self.mtime = None
self.mentioned_name = None
"""the text that was used to add lora to prompt - can be either name or an alias"""
class LoraUpDownModule:
def __init__(self):
self.up = None
self.down = None
self.alpha = None
def assign_lora_names_to_compvis_modules(sd_model):
lora_layer_mapping = {}
if shared.sd_model.is_sdxl: if shared.sd_model.is_sdxl:
for i, embedder in enumerate(shared.sd_model.conditioner.embedders): for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
...@@ -161,166 +97,132 @@ def assign_lora_names_to_compvis_modules(sd_model): ...@@ -161,166 +97,132 @@ def assign_lora_names_to_compvis_modules(sd_model):
continue continue
for name, module in embedder.wrapped.named_modules(): for name, module in embedder.wrapped.named_modules():
lora_name = f'{i}_{name.replace(".", "_")}' network_name = f'{i}_{name.replace(".", "_")}'
lora_layer_mapping[lora_name] = module network_layer_mapping[network_name] = module
module.lora_layer_name = lora_name module.network_layer_name = network_name
else: else:
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
lora_name = name.replace(".", "_") network_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module network_layer_mapping[network_name] = module
module.lora_layer_name = lora_name module.network_layer_name = network_name
for name, module in shared.sd_model.model.named_modules(): for name, module in shared.sd_model.model.named_modules():
lora_name = name.replace(".", "_") network_name = name.replace(".", "_")
lora_layer_mapping[lora_name] = module network_layer_mapping[network_name] = module
module.lora_layer_name = lora_name module.network_layer_name = network_name
sd_model.lora_layer_mapping = lora_layer_mapping sd_model.network_layer_mapping = network_layer_mapping
def load_lora(name, lora_on_disk): def load_network(name, network_on_disk):
lora = LoraModule(name, lora_on_disk) net = network.Network(name, network_on_disk)
lora.mtime = os.path.getmtime(lora_on_disk.filename) net.mtime = os.path.getmtime(network_on_disk.filename)
sd = sd_models.read_state_dict(lora_on_disk.filename) sd = sd_models.read_state_dict(network_on_disk.filename)
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
if not hasattr(shared.sd_model, 'lora_layer_mapping'): if not hasattr(shared.sd_model, 'network_layer_mapping'):
assign_lora_names_to_compvis_modules(shared.sd_model) assign_network_names_to_compvis_modules(shared.sd_model)
keys_failed_to_match = {} keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
for key_lora, weight in sd.items(): for key_network, weight in sd.items():
key_lora_without_lora_parts, lora_key = key_lora.split(".", 1) key_network_without_network_parts, network_part = key_network.split(".", 1)
key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2) key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
m = re_x_proj.match(key) m = re_x_proj.match(key)
if m: if m:
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
if sd_module is None and "lora_unet" in key_lora_without_lora_parts: if sd_module is None and "lora_unet" in key_network_without_network_parts:
key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model") key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")
sd_module = shared.sd_model.lora_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts: elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:
key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model") key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.lora_layer_mapping.get(key, None) sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None: if sd_module is None:
keys_failed_to_match[key_lora] = key keys_failed_to_match[key_network] = key
continue continue
lora_module = lora.modules.get(key, None) if key not in matched_networks:
if lora_module is None: matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
lora_module = LoraUpDownModule()
lora.modules[key] = lora_module
if lora_key == "alpha": matched_networks[key].w[network_part] = weight
lora_module.alpha = weight.item()
continue
if type(sd_module) == torch.nn.Linear: for key, weights in matched_networks.items():
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) net_module = None
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: for nettype in module_types:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) net_module = nettype.create_module(net, weights)
elif type(sd_module) == torch.nn.MultiheadAttention: if net_module is not None:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) break
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad(): if net_module is None:
module.weight.copy_(weight) raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")
module.to(device=devices.cpu, dtype=devices.dtype)
if lora_key == "lora_up.weight": net.modules[key] = net_module
lora_module.up = module
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")
if keys_failed_to_match: if keys_failed_to_match:
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")
return lora return net
def load_loras(names, multipliers=None): def load_networks(names, multipliers=None):
already_loaded = {} already_loaded = {}
for lora in loaded_loras: for net in loaded_networks:
if lora.name in names: if net.name in names:
already_loaded[lora.name] = lora already_loaded[net.name] = net
loaded_loras.clear() loaded_networks.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names] networks_on_disk = [available_network_aliases.get(name, None) for name in names]
if any(x is None for x in loras_on_disk): if any(x is None for x in networks_on_disk):
list_available_loras() list_available_networks()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names] networks_on_disk = [available_network_aliases.get(name, None) for name in names]
failed_to_load_loras = [] failed_to_load_networks = []
for i, name in enumerate(names): for i, name in enumerate(names):
lora = already_loaded.get(name, None) net = already_loaded.get(name, None)
lora_on_disk = loras_on_disk[i] network_on_disk = networks_on_disk[i]
if lora_on_disk is not None: if network_on_disk is not None:
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
try: try:
lora = load_lora(name, lora_on_disk) net = load_network(name, network_on_disk)
except Exception as e: except Exception as e:
errors.display(e, f"loading Lora {lora_on_disk.filename}") errors.display(e, f"loading network {network_on_disk.filename}")
continue continue
lora.mentioned_name = name net.mentioned_name = name
lora_on_disk.read_hash() network_on_disk.read_hash()
if lora is None: if net is None:
failed_to_load_loras.append(name) failed_to_load_networks.append(name)
print(f"Couldn't find Lora with name {name}") print(f"Couldn't find network with name {name}")
continue continue
lora.multiplier = multipliers[i] if multipliers else 1.0 net.multiplier = multipliers[i] if multipliers else 1.0
loaded_loras.append(lora) loaded_networks.append(net)
if failed_to_load_loras:
sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))
def lora_calc_updown(lora, module, target): if failed_to_load_networks:
with torch.no_grad(): sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
up = module.up.weight.to(target.device, dtype=target.dtype)
down = module.down.weight.to(target.device, dtype=target.dtype)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return updown def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None: if weights_backup is None:
return return
...@@ -332,144 +234,148 @@ def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linea ...@@ -332,144 +234,148 @@ def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linea
self.weight.copy_(weights_backup) self.weight.copy_(weights_backup)
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
""" """
Applies the currently selected set of Loras to the weights of torch layer self. Applies the currently selected set of networks to the weights of torch layer self.
If weights already have this particular set of loras applied, does nothing. If weights already have this particular set of networks applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras. If not, restores orginal weights from backup and alters weights according to networks.
""" """
lora_layer_name = getattr(self, 'lora_layer_name', None) network_layer_name = getattr(self, 'network_layer_name', None)
if lora_layer_name is None: if network_layer_name is None:
return return
current_names = getattr(self, "lora_current_names", ()) current_names = getattr(self, "network_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks)
weights_backup = getattr(self, "lora_weights_backup", None) weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None: if weights_backup is None:
if isinstance(self, torch.nn.MultiheadAttention): if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else: else:
weights_backup = self.weight.to(devices.cpu, copy=True) weights_backup = self.weight.to(devices.cpu, copy=True)
self.lora_weights_backup = weights_backup self.network_weights_backup = weights_backup
if current_names != wanted_names: if current_names != wanted_names:
lora_restore_weights_from_backup(self) network_restore_weights_from_backup(self)
for lora in loaded_loras: for net in loaded_networks:
module = lora.modules.get(lora_layer_name, None) module = net.modules.get(network_layer_name, None)
if module is not None and hasattr(self, 'weight'): if module is not None and hasattr(self, 'weight'):
self.weight += lora_calc_updown(lora, module, self.weight) with torch.no_grad():
continue updown = module.calc_updown(self.weight)
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
self.weight += updown
module_q = lora.modules.get(lora_layer_name + "_q_proj", None) module_q = net.modules.get(network_layer_name + "_q_proj", None)
module_k = lora.modules.get(lora_layer_name + "_k_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None)
module_v = lora.modules.get(lora_layer_name + "_v_proj", None) module_v = net.modules.get(network_layer_name + "_v_proj", None)
module_out = lora.modules.get(lora_layer_name + "_out_proj", None) module_out = net.modules.get(network_layer_name + "_out_proj", None)
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) with torch.no_grad():
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) updown_q = module_q.calc_updown(self.in_proj_weight)
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) updown_k = module_k.calc_updown(self.in_proj_weight)
updown_v = module_v.calc_updown(self.in_proj_weight)
updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
self.in_proj_weight += updown_qkv self.in_proj_weight += updown_qkv
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) self.out_proj.weight += module_out.calc_updown(self.out_proj.weight)
continue continue
if module is None: if module is None:
continue continue
print(f'failed to calculate lora weights for layer {lora_layer_name}') print(f'failed to calculate network weights for layer {network_layer_name}')
self.lora_current_names = wanted_names self.network_current_names = wanted_names
def lora_forward(module, input, original_forward): def network_forward(module, input, original_forward):
""" """
Old way of applying Lora by executing operations during layer's forward. Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation. Stacking many loras this way results in big performance degradation.
""" """
if len(loaded_loras) == 0: if len(loaded_networks) == 0:
return original_forward(module, input) return original_forward(module, input)
input = devices.cond_cast_unet(input) input = devices.cond_cast_unet(input)
lora_restore_weights_from_backup(module) network_restore_weights_from_backup(module)
lora_reset_cached_weight(module) network_reset_cached_weight(module)
res = original_forward(module, input) y = original_forward(module, input)
lora_layer_name = getattr(module, 'lora_layer_name', None) network_layer_name = getattr(module, 'network_layer_name', None)
for lora in loaded_loras: for lora in loaded_networks:
module = lora.modules.get(lora_layer_name, None) module = lora.modules.get(network_layer_name, None)
if module is None: if module is None:
continue continue
module.up.to(device=devices.device) y = module.forward(y, input)
module.down.to(device=devices.device)
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res return y
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
self.lora_current_names = () self.network_current_names = ()
self.lora_weights_backup = None self.network_weights_backup = None
def lora_Linear_forward(self, input): def network_Linear_forward(self, input):
if shared.opts.lora_functional: if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Linear_forward_before_lora) return network_forward(self, input, torch.nn.Linear_forward_before_network)
lora_apply_weights(self) network_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input) return torch.nn.Linear_forward_before_network(self, input)
def lora_Linear_load_state_dict(self, *args, **kwargs): def network_Linear_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self) network_reset_cached_weight(self)
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
def lora_Conv2d_forward(self, input): def network_Conv2d_forward(self, input):
if shared.opts.lora_functional: if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora) return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
lora_apply_weights(self) network_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input) return torch.nn.Conv2d_forward_before_network(self, input)
def lora_Conv2d_load_state_dict(self, *args, **kwargs): def network_Conv2d_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self) network_reset_cached_weight(self)
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
def lora_MultiheadAttention_forward(self, *args, **kwargs): def network_MultiheadAttention_forward(self, *args, **kwargs):
lora_apply_weights(self) network_apply_weights(self)
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
lora_reset_cached_weight(self) network_reset_cached_weight(self)
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
def list_available_loras(): def list_available_networks():
available_loras.clear() available_networks.clear()
available_lora_aliases.clear() available_network_aliases.clear()
forbidden_lora_aliases.clear() forbidden_network_aliases.clear()
available_lora_hash_lookup.clear() available_network_hash_lookup.clear()
forbidden_lora_aliases.update({"none": 1, "Addams": 1}) forbidden_network_aliases.update({"none": 1, "Addams": 1})
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
...@@ -480,21 +386,21 @@ def list_available_loras(): ...@@ -480,21 +386,21 @@ def list_available_loras():
name = os.path.splitext(os.path.basename(filename))[0] name = os.path.splitext(os.path.basename(filename))[0]
try: try:
entry = LoraOnDisk(name, filename) entry = network.NetworkOnDisk(name, filename)
except OSError: # should catch FileNotFoundError and PermissionError etc. except OSError: # should catch FileNotFoundError and PermissionError etc.
errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True) errors.report(f"Failed to load network {name} from {filename}", exc_info=True)
continue continue
available_loras[name] = entry available_networks[name] = entry
if entry.alias in available_lora_aliases: if entry.alias in available_network_aliases:
forbidden_lora_aliases[entry.alias.lower()] = 1 forbidden_network_aliases[entry.alias.lower()] = 1
available_lora_aliases[name] = entry available_network_aliases[name] = entry
available_lora_aliases[entry.alias] = entry available_network_aliases[entry.alias] = entry
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def infotext_pasted(infotext, params): def infotext_pasted(infotext, params):
...@@ -516,7 +422,7 @@ def infotext_pasted(infotext, params): ...@@ -516,7 +422,7 @@ def infotext_pasted(infotext, params):
if name is None: if name is None:
continue continue
m = re_lora_name.match(name) m = re_network_name.match(name)
if m: if m:
name = m.group(1) name = m.group(1)
...@@ -528,10 +434,10 @@ def infotext_pasted(infotext, params): ...@@ -528,10 +434,10 @@ def infotext_pasted(infotext, params):
params["Prompt"] += "\n" + "".join(added) params["Prompt"] += "\n" + "".join(added)
available_loras = {} available_networks = {}
available_lora_aliases = {} available_network_aliases = {}
available_lora_hash_lookup = {} loaded_networks = []
forbidden_lora_aliases = {} available_network_hash_lookup = {}
loaded_loras = [] forbidden_network_aliases = {}
list_available_loras() list_available_networks()
...@@ -4,18 +4,19 @@ import torch ...@@ -4,18 +4,19 @@ import torch
import gradio as gr import gradio as gr
from fastapi import FastAPI from fastapi import FastAPI
import lora import network
import networks
import extra_networks_lora import extra_networks_lora
import ui_extra_networks_lora import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
def before_ui(): def before_ui():
...@@ -23,50 +24,50 @@ def before_ui(): ...@@ -23,50 +24,50 @@ def before_ui():
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
if not hasattr(torch.nn, 'Linear_forward_before_lora'): if not hasattr(torch.nn, 'Linear_forward_before_network'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui) script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(lora.infotext_pasted) script_callbacks.on_infotext_pasted(networks.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
})) }))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
})) }))
def create_lora_json(obj: lora.LoraOnDisk): def create_lora_json(obj: network.NetworkOnDisk):
return { return {
"name": obj.name, "name": obj.name,
"alias": obj.alias, "alias": obj.alias,
...@@ -75,17 +76,17 @@ def create_lora_json(obj: lora.LoraOnDisk): ...@@ -75,17 +76,17 @@ def create_lora_json(obj: lora.LoraOnDisk):
} }
def api_loras(_: gr.Blocks, app: FastAPI): def api_networks(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras") @app.get("/sdapi/v1/loras")
async def get_loras(): async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()] return [create_lora_json(obj) for obj in networks.available_networks.values()]
@app.post("/sdapi/v1/refresh-loras") @app.post("/sdapi/v1/refresh-loras")
async def refresh_loras(): async def refresh_loras():
return lora.list_available_loras() return networks.list_available_networks()
script_callbacks.on_app_started(api_loras) script_callbacks.on_app_started(api_networks)
re_lora = re.compile("<lora:([^:]+):") re_lora = re.compile("<lora:([^:]+):")
...@@ -98,19 +99,19 @@ def infotext_pasted(infotext, d): ...@@ -98,19 +99,19 @@ def infotext_pasted(infotext, d):
hashes = [x.strip().split(':', 1) for x in hashes.split(",")] hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes} hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
def lora_replacement(m): def network_replacement(m):
alias = m.group(1) alias = m.group(1)
shorthash = hashes.get(alias) shorthash = hashes.get(alias)
if shorthash is None: if shorthash is None:
return m.group(0) return m.group(0)
lora_on_disk = lora.available_lora_hash_lookup.get(shorthash) network_on_disk = networks.available_network_hash_lookup.get(shorthash)
if lora_on_disk is None: if network_on_disk is None:
return m.group(0) return m.group(0)
return f'<lora:{lora_on_disk.get_alias()}:' return f'<lora:{network_on_disk.get_alias()}:'
d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"]) d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
script_callbacks.on_infotext_pasted(infotext_pasted) script_callbacks.on_infotext_pasted(infotext_pasted)
import os import os
import lora import networks
from modules import shared, ui_extra_networks from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js from modules.ui_extra_networks import quote_js
...@@ -11,10 +11,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -11,10 +11,10 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
super().__init__('Lora') super().__init__('Lora')
def refresh(self): def refresh(self):
lora.list_available_loras() networks.list_available_networks()
def create_item(self, name, index=None): def create_item(self, name, index=None):
lora_on_disk = lora.available_loras.get(name) lora_on_disk = networks.available_networks.get(name)
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
...@@ -43,7 +43,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -43,7 +43,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
return item return item
def list_items(self): def list_items(self):
for index, name in enumerate(lora.available_loras): for index, name in enumerate(networks.available_networks):
item = self.create_item(name, index) item = self.create_item(name, index)
yield item yield item
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment