Commit 704947b4 authored by Wes Brown's avatar Wes Brown

Clean up, better reporting.

parent cc02ad48
import numpy as np
import torch
import mmap
import pickle
import concurrent
from torch.utils import data
import pickle
......
from re import A
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch.utils import data
import math
import sys
from tqdm import tqdm
import time
import wandb
import numpy as np
from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil
from basedformer import optimizer, lm_utils, dataset
from basedformer.utils import *
import glob
from transformers import AutoTokenizer
from basedformer import sampling
from icecream import ic
from termcolor import colored
gpu = "cuda"
......@@ -26,6 +15,9 @@ if gpu != "cuda":
amp = torch.amp
scaler = torch.cuda.amp.GradScaler()
prompts = ["<|endoftext|>"]
def _init_weights(module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
......@@ -39,49 +31,21 @@ def _init_weights(module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
cumsum = x.cumsum(dim = 1)
*x, x_pass = x.chunk(amt + 1, dim = -1)
*x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)
amts = 2 ** torch.arange(amt)
amts = amts.tolist()
shifts = []
denom = torch.arange(n, device = device)
for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
normed_shifted_x = shifted_chunk / (shifted_denom + eps)
shifts.append(normed_shifted_x)
return torch.cat((*shifts, x_pass), dim = -1)
def discounted_cumsum(t, gamma):
try:
from torch_discounted_cumsum import discounted_cumsum_left
except ImportError:
print('unable to import torch_discounted_cumsum - please run `pip install torch-discounted-cumsum`')
def shift(x, amt, dim=-1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value=0.)
b, n, d = t.shape
t = rearrange(t, 'b n d -> (b d) n')
t = discounted_cumsum_left(t, gamma)
t = rearrange(t, '(b d) n -> b n d', b = b)
return t
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
class HyperNetworkGRU(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear1 = nn.Linear(embed_dim, embed_dim//8)
self.gru = nn.GRU(embed_dim//8, embed_dim // 8, num_layers=1, bidirectional=False, batch_first=True)
self.linear1 = nn.Linear(embed_dim, embed_dim // 8)
self.gru = nn.GRU(embed_dim // 8,
embed_dim // 8,
num_layers=1,
bidirectional=False,
batch_first=True)
self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
self.activation = gelu_new
......@@ -90,47 +54,42 @@ class HyperNetworkGRU(nn.Module):
_init_weights(module)
for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
for param in self.gru.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
def forward(self, x):
x = x.float()
x = self.linear1(x)
x = self.gru(x)[0]
x = self.ln_1(x)
x = self.linear2(x)
x = ck(self.activation, x)
return x.bfloat16()
return ck(self.activation,
self.linear2(
self.ln_1(
self.gru(
self.linear1(
x.float()))[0]))).bfloat16()
class HyperNetwork(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
self.linear = nn.Linear(embed_dim, embed_dim // 4, bias=True)
self.linear2 = nn.Linear(embed_dim // 4, embed_dim, bias=True)
self.activation = torch.nn.functional.gelu
self.num_shifts = ceil(log2(2048)) - 1
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
_init_weights(module)
for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
def forward(self, x):
x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = ck(self.activation, x)
x = self.linear2(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
x = self.linear2(
ck(self.activation,
self.linear(x.float())))
return x.mul(torch.sigmoid(x)).bfloat16()
class HyperNetworkSingle(nn.Module):
def __init__(self, config):
......@@ -138,32 +97,30 @@ class HyperNetworkSingle(nn.Module):
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
self.activation = gelu_new
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
# self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
_init_weights(module)
for param in self.linear.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
# state = self.state_dict()
# for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
# self.load_state_dict(state)
def forward(self, x):
x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
x = self.linear(x.float())
return x.mul(torch.sigmoid(x)).bfloat16()
tokenizer = AutoTokenizer.from_pretrained('gpt2')
@torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None):
def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0):
torch.seed()
tokens = tokenizer.encode(prompt)
#print("Prompt:")
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu)
tokens = [tokens] * bsz
tokens = torch.cat(tokens, dim=0)
......@@ -178,31 +135,52 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None):
"temp": 0.8,
}
ops_list = [ops] * bsz
tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=hypernetwork, non_deterministic=True)
vanilla_tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=None)
tokens_generated = sampling.generate(model.forward,
tokens,
n_tokens,
ops_list=ops_list,
hypernetwork=hypernetwork,
non_deterministic=True)
vanilla_tokens_generated = sampling.generate(model.forward,
tokens,
n_tokens,
ops_list=ops_list,
hypernetwork=None)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
vanilla_tokens_generated = tokenizer.batch_decode(vanilla_tokens_generated.cpu().numpy())
### send to wandb
columns = ["Prompt", "Generated Text", "Vanilla Model"]
vanilla_tokens_generated = tokenizer.batch_decode(
vanilla_tokens_generated.cpu().numpy())
data = []
for x in range(len(tokens_generated)):
data.append([prompt, str(tokens_generated[x]), str(vanilla_tokens_generated[x])])
data.append([step,
prompt,
str(tokens_generated[x]),
str(vanilla_tokens_generated[x])])
for gen in tokens_generated:
print(colored("==========================================================", "red"))
print(colored(gen, "green"))
print(colored("==========================================================", "red"))
return data
def report_wandb(data):
columns = ["Step", "Prompt", "Generated Text", "Vanilla Model"]
wandb.log({"Generations": wandb.Table(data=data, columns=columns)})
def report_console(data):
for gen in data[3]:
print(colored("======================================================",
"red"))
print(colored(gen, "green"))
print(colored("======================================================",
"red"))
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "dataset/enwik9-gpt2-2049.map",
"save_path": "models/enwik9-sigurdv4-hypernet2",
"data_path": "dataset/cassandra.map",
"save_path": "models/sigurdv4-cassandra-hypernet2",
"lm_path": "pretrained/sigurdv4",
"optimizer": "adamw",
"masked_softmax_fusion": False,
"do_save": True,
"run_name": "gptj-6b-enwik9-6b-postln-bf16-2e-4-4bsz-every5layer",
"run_name": "sigurdv4-cassandra-6b-postln-bf16-2e-4-4bsz-every5layer",
"lr": 2e-4,
"end_lr": 2e-4,
"warmup_steps": 50,
......@@ -220,7 +198,6 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = lm_utils.load_from_path("pretrained/sigurdv4").to(gpu).bfloat16()
for param in model.parameters():
param.requires_grad = False
......@@ -233,26 +210,37 @@ hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
for param in hypernetwork.parameters():
param.requires_grad = True
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None
cp_list = sorted(os.listdir(train_config["save_path"]),
key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(
cp_list) > 0 else None
print(last_cp)
if last_cp:
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(), last_cp / "opt")
opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(),
last_cp / "opt")
else:
opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, train_config["optimizer"])
opt = optimizer.BasedOptimizer(hypernetwork.parameters(),
train_config,
train_config["optimizer"])
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden
# states from the get_logits function.
print(opt.curr_step)
train_dataset = dataset.ShardedDataset(2049, train_config["data_path"])
if last_cp:
train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model.config})
train_loader = data.DataLoader(train_dataset,
batch_size=bs * gas,
shuffle=False,
num_workers=0)
wandb.init(project="hypernetwork-tests",
name=train_config["run_name"],
config={**train_config, **model.config})
if last_cp:
curr_step = opt.curr_step
......@@ -260,21 +248,22 @@ else:
curr_step = 0
t = tqdm(train_loader, initial=curr_step)
sample_data = []
#sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.to(gpu)
labels = labels.to(gpu)
loss = 0
for x in range(train_config["gas"]):
with amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits, _ = model(input_ids[x*bs:(x+1)*bs, :].to(gpu), hypernetwork=hypernetwork, act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
with amp.autocast(enabled=train_config["amp"],
dtype=torch.float16):
logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu),
hypernetwork=hypernetwork,
act_ck=True)
# print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
......@@ -301,27 +290,33 @@ for input_ids, labels in t:
sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step,"
+ f"{tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log(
{
"train/loss": loss,
"train/tokens_per_sec": tokens_per_sec,
"train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/loss_scale": scaler.get_scale()
},
step=curr_step)
if train_config["do_save"] and curr_step % train_config["save_every"] == 0 and curr_step != 0:
if train_config["do_save"] and curr_step % train_config[
"save_every"] == 0 and curr_step != 0:
save_folder = Path(train_config["save_path"]) / f"step_{curr_step}"
save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt")
print(f"Saved model at step {curr_step}")
print(f"\nSaved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
print("")
sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
for prompt in prompts:
sampled = sample(prompt, 500, 3, hypernetwork=hypernetwork)
print(f"PROMPT:\n{prompt}")
report_console(sampled)
sample_data = sample_data + sampled
report_wandb(sample_data)
curr_step += 1
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment