Commit b0155c91 authored by novelailab's avatar novelailab

hypernetwork training

parent bc280afb
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -228,6 +228,7 @@ class GPTLayer(nn.Module): ...@@ -228,6 +228,7 @@ class GPTLayer(nn.Module):
def forward(self, x, hypernetwork=None, act_ck=False): def forward(self, x, hypernetwork=None, act_ck=False):
residual = x residual = x
if act_ck: if act_ck:
x = ck(self.ln_preattn, x) x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x) attn_out = ck(self.attn, x)
...@@ -236,12 +237,14 @@ class GPTLayer(nn.Module): ...@@ -236,12 +237,14 @@ class GPTLayer(nn.Module):
x = self.ln_preattn(x) x = self.ln_preattn(x)
attn_out = self.attn(x) attn_out = self.attn(x)
if hypernetwork:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck) ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here. #order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match. #x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork: if hypernetwork:
hyper_out = hypernetwork(x)
x = x + hyper_out x = x + hyper_out
return x return x
......
...@@ -24,19 +24,21 @@ class HyperNetwork(nn.Module): ...@@ -24,19 +24,21 @@ class HyperNetwork(nn.Module):
self.linear.weight.data.normal_(mean=0.0, std=0.02) self.linear.weight.data.normal_(mean=0.0, std=0.02)
for param in self.linear.parameters(): for param in self.linear.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"]))) param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.linear(hidden_states) hidden_states = self.linear(hidden_states.float())
hidden_states = hidden_states.mul(torch.sigmoid(hidden_states)) hidden_states = hidden_states.mul(torch.sigmoid(hidden_states))
return hidden_states return hidden_states.bfloat16()
model_config = { model_config = {
"n_layer": 28, "n_layer": 12,
"n_head": 16, "n_head": 12,
"hidden_dim": 4096, "hidden_dim": 768,
"vocab_dim": 50400, "vocab_dim": 50400,
"eps": 1e-5, "eps": 1e-5,
"activation": gelu_new, "activation": gelu_new,
...@@ -44,9 +46,9 @@ model_config = { ...@@ -44,9 +46,9 @@ model_config = {
} }
model_config = { model_config = {
"n_layer": 12, "n_layer": 28,
"n_head": 12, "n_head": 16,
"hidden_dim": 768, "hidden_dim": 4096,
"vocab_dim": 50400, "vocab_dim": 50400,
"eps": 1e-5, "eps": 1e-5,
"activation": gelu_new, "activation": gelu_new,
...@@ -55,19 +57,20 @@ model_config = { ...@@ -55,19 +57,20 @@ model_config = {
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
train_config = { train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map", "data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map", #"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/fixedj", "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/fixedj",
"run_name": "gpt-j-8bitopt-owt2-125m-fp16AMP-fixedj", "run_name": "gpt-j-owt2-6b-preattn",
"lr": 1e-4, "lr": 5e-4,
"end_lr": 1e-4 * 2, "end_lr": 5e-4,
"warmup_steps": 50, "warmup_steps": 50,
"bs": 12, "bs": 1,
"gas": 10, "gas": 16,
"seed": 69, "seed": 69,
"save_every": 500, "save_every": 500,
"amp": True, "amp": False,
"loss_scale": True, "loss_scale": False,
} }
torch.manual_seed(train_config["seed"]) torch.manual_seed(train_config["seed"])
bs = train_config["bs"] bs = train_config["bs"]
...@@ -75,25 +78,26 @@ gas = train_config["gas"] ...@@ -75,25 +78,26 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True) Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = GPTModel.gpt2_init(model_config).cuda().float() #model = GPTModel.gpt2_init(model_config).cuda().float()
#model = load_gpt_j().cuda().half() model = load_gpt_j().cuda().bfloat16()
#for param in model.parameters(): for param in model.parameters():
# param.requires_grad = False param.requires_grad = False
for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
#for name, p in model.named_parameters(): hypernetwork = HyperNetwork(model_config).cuda().float()
# if ("ln" in name or "vocab_embed" in name): for param in hypernetwork.parameters():
# p.requires_grad = True param.requires_grad = True
#hypernetwork = HyperNetwork(model_config).cuda().float() opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")
#for param in hypernetwork.parameters():
# param.requires_grad = True
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function. # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
train_dataset = utils.FbDataset(2049, train_config["data_path"]) train_dataset = utils.FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0) train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model_config}) wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model_config})
t = tqdm(train_loader) t = tqdm(train_loader)
curr_step = 0 curr_step = 0
...@@ -107,10 +111,10 @@ for input_ids, labels in t: ...@@ -107,10 +111,10 @@ for input_ids, labels in t:
loss = 0 loss = 0
for x in range(train_config["gas"]): for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16): with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :1024].cuda(), hypernetwork=None, act_ck=False) logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=False)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0])) #print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1]) logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :1024].contiguous() gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
gas_labels = gas_labels.view(-1) gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels) gas_loss = F.cross_entropy(logits, gas_labels)
...@@ -124,7 +128,7 @@ for input_ids, labels in t: ...@@ -124,7 +128,7 @@ for input_ids, labels in t:
loss = loss / gas loss = loss / gas
if train_config["loss_scale"]: if train_config["loss_scale"]:
scaler.unscale_(opt.optimizer) scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1) torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
if train_config["loss_scale"]: if train_config["loss_scale"]:
opt.step(scaler=scaler) opt.step(scaler=scaler)
else: else:
...@@ -140,5 +144,5 @@ for input_ids, labels in t: ...@@ -140,5 +144,5 @@ for input_ids, labels in t:
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()}) wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()})
curr_step += 1 curr_step += 1
if curr_step % train_config["save_every"] == 0: if curr_step % train_config["save_every"] == 0:
model.save(train_config["save_path"] + f"/{curr_step}") #model.save(train_config["save_path"] + f"/{curr_step}")
print(f"Saved model at step {curr_step}") print(f"Saved model at step {curr_step}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment