Commit ea32d948 authored by Wes Brown's avatar Wes Brown

hypertrain reporting fixes and improvements

parent 94c0ad6f
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import wandb
from torch.utils import data as torch_data
from torch.utils.checkpoint import checkpoint as ck
from basedformer import optimizer, lm_utils, dataset
from basedformer.utils import *
from transformers import AutoTokenizer
from basedformer import sampling
from termcolor import colored
from typing import Callable, List
import argparse
gpu = "cuda"
......@@ -24,9 +25,9 @@ prompts = ["<|endoftext|>",
"The tavern was full again, so I ended up sharing a table with three very different creatures: a",
"I had been hiking in the wilderness when suddenly a",
"She spread her",
"The mercurial and beautiful woman laughed",
"[ Author:",
"[ Tags:",
"The mercurial and beautiful",
"<|endoftext|>[ Author:",
"<|endoftext|>[ Genre:",
"***"]
......@@ -110,10 +111,13 @@ class HyperNetwork(nn.Module):
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
class HyperNetworkSingle(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
n_layers = config["n_layer"]
self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
self.activation = gelu_new
# self.linear.weight.data.normal_(mean=0.0, std=0.02)
......@@ -122,7 +126,7 @@ class HyperNetworkSingle(nn.Module):
for param in self.linear.parameters():
param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
std=(0.02 / math.sqrt(2 * n_layers)))
def forward(self, x):
x = x.float()
......@@ -135,7 +139,8 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2')
@torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0):
def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0, run_name="",
generate_vanilla=False):
torch.seed()
tokens = tokenizer.encode(prompt)
tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu)
......@@ -158,41 +163,80 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0):
ops_list=ops_list,
hypernetwork=hypernetwork,
non_deterministic=True)
vanilla_tokens_generated = sampling.generate(model.forward,
tokens,
n_tokens,
ops_list=ops_list,
hypernetwork=None)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
vanilla_tokens_generated = tokenizer.batch_decode(
vanilla_tokens_generated.cpu().numpy())
vanilla_tokens_generated = None
if generate_vanilla:
vanilla_tokens_generated = sampling.generate(model.forward,
tokens,
n_tokens,
ops_list=ops_list,
hypernetwork=None)
vanilla_tokens_generated = tokenizer.batch_decode(
vanilla_tokens_generated.cpu().numpy())
data = []
for x in range(len(tokens_generated)):
data.append([step,
prompt,
str(tokens_generated[x]),
str(vanilla_tokens_generated[x])])
entry = {"Run": run_name,
"Step": step,
"Prompt": prompt,
"Generated Text": str(tokens_generated[x])}
if vanilla_tokens_generated:
entry["Vanilla Model"] = vanilla_tokens_generated[x]
data.append(entry)
return data
def report_wandb(data):
columns = ["Step", "Prompt", "Generated Text", "Vanilla Model"]
wandb.log({"Generations": wandb.Table(data=data, columns=columns)})
columns = list(data[0].keys())
step = data[0]["Step"]
data_list = [x.values() for x in data]
wandb.log({"Generations": wandb.Table(data=data_list, columns=columns)},
step=step)
def report_console(data):
for gen in data[2]:
print(colored("======================================================",
"red"))
print(colored(gen, "green"))
def print_colored_bars(color):
print(colored("======================================================",
"red"))
color))
def report_console(data: List[dict]):
print_colored_bars("blue")
print(colored(data[0]['Prompt'], "white"))
print_colored_bars("blue")
for gen in data:
print_colored_bars("red")
print(colored(gen["Generated Text"], "green"))
def make_eval_function(hypernetwork: HyperNetworkSingle, config: dict) -> \
Callable[[int], None]:
sample_data = {'rows': []}
gen_vanilla = config.get('generate_vanilla', False)
run_name = config.get('run_name', '')
def make_hypernet_saver(train_config, hypernetwork):
def hypernet_saver(id: str):
save_folder = Path(train_config["save_path"]) / id
def eval_function(curr_step: int) -> None:
print()
print_colored_bars('yellow')
print(f"Step: {curr_step}")
for prompt in prompts:
sampled = sample(prompt, 500, 3,
run_name=run_name,
hypernetwork=hypernetwork,
step=curr_step,
generate_vanilla=gen_vanilla)
report_console(sampled)
sample_data['rows'].extend(sampled)
print_colored_bars("red")
report_wandb(sample_data['rows'])
return eval_function
def make_hypernet_saver(hypernetwork: HyperNetworkSingle, config: dict) \
-> Callable[[str], None]:
def hypernet_saver(id: str) -> None:
save_folder = Path(config["save_path"]) / id
save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt")
......@@ -239,11 +283,12 @@ parser.add_argument("--project_id", type=str, help="Project ID for reporting",
parser.add_argument("--logs", type=str, help="log directory location",
default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False)
parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False,
sample_vanilla=False)
args = parser.parse_args()
if args.output == '':
args.output = f'./{args.run_name}'
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": args.dataset,
......@@ -264,6 +309,7 @@ train_config = {
"loss_scale": args.loss_scale,
"eval_every": args.eval_every,
"context_size": args.context_size,
"sample_vanilla": args.sample_vanilla,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
......@@ -276,26 +322,25 @@ for param in model.parameters():
param.requires_grad = False
for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
if "ln" in name or "vocab_embed" in name:
p.requires_grad = True
hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
for param in hypernetwork.parameters():
param.requires_grad = True
hypernetwork_saver = make_hypernet_saver(train_config, hypernetwork)
hypernetwork_saver = make_hypernet_saver(hypernetwork, train_config)
eval_fn = make_eval_function(hypernetwork, train_config)
cp_list = sorted(os.listdir(train_config["save_path"]),
key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(
cp_list) > 0 else None
print(last_cp)
if last_cp and not args.no_resume:
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(),
last_cp / "opt")
else:
opt = optimizer.BasedOptimizer(hypernetwork.parameters(),
train_config,
......@@ -303,16 +348,15 @@ else:
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden
# states from the get_logits function.
print(opt.curr_step)
train_dataset = dataset.ShardedDataset(train_config["context_size"],
train_config["data_path"])
if last_cp:
train_dataset.skip = opt.curr_step * bs * gas
train_dataset.skip = opt.curr_step
train_loader = data.DataLoader(train_dataset,
batch_size=bs * gas,
shuffle=True,
num_workers=0)
train_loader = torch_data.DataLoader(train_dataset,
batch_size=bs * gas,
shuffle=True,
num_workers=0)
wandb.init(project="hypernetwork-tests",
name=train_config["run_name"],
config={**train_config, **model.config})
......@@ -323,7 +367,6 @@ else:
curr_step = 0
t = tqdm(train_loader, initial=curr_step)
sample_data = []
for input_ids, labels in t:
timex = time.perf_counter()
......@@ -383,15 +426,10 @@ for input_ids, labels in t:
hypernetwork_saver(f"step_{curr_step}")
print(f"\nSaved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
for prompt in prompts:
sampled = sample(prompt, 500, 3, hypernetwork=hypernetwork,
step=step)
print(f"PROMPT:\n{prompt}")
report_console(sampled)
sample_data = sample_data + sampled
report_wandb(sample_data)
if curr_step % train_config["eval_every"] == 0:
eval_fn(curr_step)
curr_step += 1
eval_fn(curr_step)
hypernetwork_saver("final")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment