Commit c23f666d authored by AUTOMATIC's avatar AUTOMATIC

a more strict check for activation type and a more reasonable check for type of layer in hypernets

parent a26fc283
...@@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module): ...@@ -32,10 +32,16 @@ class HypernetworkModule(torch.nn.Module):
linears = [] linears = []
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu": if activation_func == "relu":
linears.append(torch.nn.ReLU()) linears.append(torch.nn.ReLU())
if activation_func == "leakyrelu": elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU()) linears.append(torch.nn.LeakyReLU())
elif activation_func == 'linear' or activation_func is None:
pass
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
...@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): ...@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
else: else:
for layer in self.linear: for layer in self.linear:
if not "ReLU" in layer.__str__(): if type(layer) == torch.nn.Linear:
layer.weight.data.normal_(mean=0.0, std=0.01) layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_() layer.bias.data.zero_()
...@@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): ...@@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self): def trainables(self):
layer_structure = [] layer_structure = []
for layer in self.linear: for layer in self.linear:
if not "ReLU" in layer.__str__(): if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias] layer_structure += [layer.weight, layer.bias]
return layer_structure return layer_structure
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment