Commit 8b26deda authored by Wes Brown's avatar Wes Brown

Revert mostly to `x=` assignment form.

parent 8073ccfc
...@@ -29,6 +29,7 @@ prompts = ["<|endoftext|>", ...@@ -29,6 +29,7 @@ prompts = ["<|endoftext|>",
"[ Tags:", "[ Tags:",
"***"] "***"]
def _init_weights(module): def _init_weights(module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02) module.weight.data.normal_(mean=0.0, std=0.02)
...@@ -72,13 +73,19 @@ class HyperNetworkGRU(nn.Module): ...@@ -72,13 +73,19 @@ class HyperNetworkGRU(nn.Module):
param.data.normal_(mean=0.0, param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"]))) std=(0.02 / math.sqrt(2 * config["n_layer"])))
self.linear_gru = nn.Sequential(
self.linear1,
self.gru)
self.layernorm_linear = nn.Sequential(
self.ln_1,
self.linear2)
def forward(self, x): def forward(self, x):
return ck(self.activation, x = x.float()
self.linear2( x = self.linear_gru.forward(x)[0]
self.ln_1( x = ck(self.activation,
self.gru( self.layernorm_linear.forward(x))
self.linear1( return x.bfloat16()
x.float()))[0]))).bfloat16()
class HyperNetwork(nn.Module): class HyperNetwork(nn.Module):
...@@ -96,11 +103,12 @@ class HyperNetwork(nn.Module): ...@@ -96,11 +103,12 @@ class HyperNetwork(nn.Module):
std=(0.02 / math.sqrt(2 * config["n_layer"]))) std=(0.02 / math.sqrt(2 * config["n_layer"])))
def forward(self, x): def forward(self, x):
x = self.linear2( x = x.float()
ck(self.activation, x = self.linear(x)
self.linear(x.float()))) x = ck(self.activation, x)
return x.mul(torch.sigmoid(x)).bfloat16() x = self.linear2(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
class HyperNetworkSingle(nn.Module): class HyperNetworkSingle(nn.Module):
def __init__(self, config): def __init__(self, config):
...@@ -115,14 +123,12 @@ class HyperNetworkSingle(nn.Module): ...@@ -115,14 +123,12 @@ class HyperNetworkSingle(nn.Module):
for param in self.linear.parameters(): for param in self.linear.parameters():
param.data.normal_(mean=0.0, param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"]))) std=(0.02 / math.sqrt(2 * config["n_layer"])))
# state = self.state_dict()
# for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
# self.load_state_dict(state)
def forward(self, x): def forward(self, x):
x = self.linear(x.float()) x = x.float()
return x.mul(torch.sigmoid(x)).bfloat16() x = self.linear(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
tokenizer = AutoTokenizer.from_pretrained('gpt2') tokenizer = AutoTokenizer.from_pretrained('gpt2')
...@@ -183,14 +189,17 @@ def report_console(data): ...@@ -183,14 +189,17 @@ def report_console(data):
print(colored("======================================================", print(colored("======================================================",
"red")) "red"))
def make_hypernet_saver(train_config, hypernetwork): def make_hypernet_saver(train_config, hypernetwork):
def hypernet_saver(id: str): def hypernet_saver(id: str):
save_folder = Path(train_config["save_path"]) / id save_folder = Path(train_config["save_path"]) / id
save_folder.mkdir(parents=True, exist_ok=True) save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt") torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt") opt.save(save_folder / "opt")
return hypernet_saver return hypernet_saver
parser = argparse.ArgumentParser(description='Hypernetwork Finetuner') parser = argparse.ArgumentParser(description='Hypernetwork Finetuner')
parser.add_argument('--run_name', type=str, help='the run name to use', parser.add_argument('--run_name', type=str, help='the run name to use',
required=True) required=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment