Commit 5a8dd0c5 authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf Committed by GitHub

Fix rescale

parent 90441294
......@@ -40,7 +40,9 @@ class NetworkModuleOFT(network.NetworkModule):
self.is_boft = False
if weights.w["oft_diag"].dim() == 4:
self.is_boft = True
self.rescale = weight.w.get('rescale', None)
self.rescale = weights.w.get('rescale', None)
if self.rescale is not None:
self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment