Commit 64179c32 authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf Committed by GitHub

Update network_oft.py

parent 591470d8
......@@ -72,7 +72,7 @@ class NetworkModuleOFT(network.NetworkModule):
eye = torch.eye(self.block_size, device=oft_blocks.device)
if not self.is_R:
block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix
block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix
norm_Q = torch.norm(block_Q.flatten())
new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment