Commit 80b26d2a authored by AUTOMATIC's avatar AUTOMATIC

apply Lora by altering layer's weights instead of adding more calculations in forward()

parent 69eb2a9e
...@@ -131,7 +131,7 @@ def load_lora(name, filename): ...@@ -131,7 +131,7 @@ def load_lora(name, filename):
with torch.no_grad(): with torch.no_grad():
module.weight.copy_(weight) module.weight.copy_(weight)
module.to(device=devices.device, dtype=devices.dtype) module.to(device=devices.cpu, dtype=devices.dtype)
if lora_key == "lora_up.weight": if lora_key == "lora_up.weight":
lora_module.up = module lora_module.up = module
...@@ -177,29 +177,69 @@ def load_loras(names, multipliers=None): ...@@ -177,29 +177,69 @@ def load_loras(names, multipliers=None):
loaded_loras.append(lora) loaded_loras.append(lora)
def lora_forward(module, input, res): def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
input = devices.cond_cast_unet(input) """
if len(loaded_loras) == 0: Applies the currently selected set of Loras to the weight of torch layer self.
return res If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras.
"""
lora_layer_name = getattr(module, 'lora_layer_name', None) current_names = getattr(self, "lora_current_names", ())
for lora in loaded_loras: wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
module = lora.modules.get(lora_layer_name, None)
if module is not None: weights_backup = getattr(self, "lora_weights_backup", None)
if shared.opts.lora_apply_to_outputs and res.shape == input.shape: if weights_backup is None:
res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) weights_backup = self.weight.to(devices.cpu, copy=True)
else: self.lora_weights_backup = weights_backup
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
if current_names != wanted_names:
if weights_backup is not None:
self.weight.copy_(weights_backup)
lora_layer_name = getattr(self, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is None:
continue
return res with torch.no_grad():
up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
updown = up @ down
self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
setattr(self, "lora_current_names", wanted_names)
def lora_Linear_forward(self, input): def lora_Linear_forward(self, input):
return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
def lora_Conv2d_forward(self, input): def lora_Conv2d_forward(self, input):
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)
def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
def list_available_loras(): def list_available_loras():
......
...@@ -9,7 +9,9 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared ...@@ -9,7 +9,9 @@ from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
def before_ui(): def before_ui():
...@@ -20,11 +22,19 @@ def before_ui(): ...@@ -20,11 +22,19 @@ def before_ui():
if not hasattr(torch.nn, 'Linear_forward_before_lora'): if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
torch.nn.Linear.forward = lora.lora_Linear_forward torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
...@@ -33,6 +43,4 @@ script_callbacks.on_before_ui(before_ui) ...@@ -33,6 +43,4 @@ script_callbacks.on_before_ui(before_ui)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
"lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"),
})) }))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment