Commit ea32d948 authored by Wes Brown's avatar Wes Brown

hypertrain reporting fixes and improvements

parent 94c0ad6f
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils import data
import wandb import wandb
from torch.utils import data as torch_data
from torch.utils.checkpoint import checkpoint as ck from torch.utils.checkpoint import checkpoint as ck
from basedformer import optimizer, lm_utils, dataset from basedformer import optimizer, lm_utils, dataset
from basedformer.utils import * from basedformer.utils import *
from transformers import AutoTokenizer from transformers import AutoTokenizer
from basedformer import sampling from basedformer import sampling
from termcolor import colored from termcolor import colored
from typing import Callable, List
import argparse import argparse
gpu = "cuda" gpu = "cuda"
...@@ -24,9 +25,9 @@ prompts = ["<|endoftext|>", ...@@ -24,9 +25,9 @@ prompts = ["<|endoftext|>",
"The tavern was full again, so I ended up sharing a table with three very different creatures: a", "The tavern was full again, so I ended up sharing a table with three very different creatures: a",
"I had been hiking in the wilderness when suddenly a", "I had been hiking in the wilderness when suddenly a",
"She spread her", "She spread her",
"The mercurial and beautiful woman laughed", "The mercurial and beautiful",
"[ Author:", "<|endoftext|>[ Author:",
"[ Tags:", "<|endoftext|>[ Genre:",
"***"] "***"]
...@@ -110,10 +111,13 @@ class HyperNetwork(nn.Module): ...@@ -110,10 +111,13 @@ class HyperNetwork(nn.Module):
x = x.mul(torch.sigmoid(x)) x = x.mul(torch.sigmoid(x))
return x.bfloat16() return x.bfloat16()
class HyperNetworkSingle(nn.Module): class HyperNetworkSingle(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
embed_dim = config["hidden_dim"] embed_dim = config["hidden_dim"]
n_layers = config["n_layer"]
self.linear = nn.Linear(embed_dim, embed_dim, bias=True) self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
self.activation = gelu_new self.activation = gelu_new
# self.linear.weight.data.normal_(mean=0.0, std=0.02) # self.linear.weight.data.normal_(mean=0.0, std=0.02)
...@@ -122,7 +126,7 @@ class HyperNetworkSingle(nn.Module): ...@@ -122,7 +126,7 @@ class HyperNetworkSingle(nn.Module):
for param in self.linear.parameters(): for param in self.linear.parameters():
param.data.normal_(mean=0.0, param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"]))) std=(0.02 / math.sqrt(2 * n_layers)))
def forward(self, x): def forward(self, x):
x = x.float() x = x.float()
...@@ -135,7 +139,8 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2') ...@@ -135,7 +139,8 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2')
@torch.no_grad() @torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0): def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0, run_name="",
generate_vanilla=False):
torch.seed() torch.seed()
tokens = tokenizer.encode(prompt) tokens = tokenizer.encode(prompt)
tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu) tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu)
...@@ -158,41 +163,80 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0): ...@@ -158,41 +163,80 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0):
ops_list=ops_list, ops_list=ops_list,
hypernetwork=hypernetwork, hypernetwork=hypernetwork,
non_deterministic=True) non_deterministic=True)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
vanilla_tokens_generated = None
if generate_vanilla:
vanilla_tokens_generated = sampling.generate(model.forward, vanilla_tokens_generated = sampling.generate(model.forward,
tokens, tokens,
n_tokens, n_tokens,
ops_list=ops_list, ops_list=ops_list,
hypernetwork=None) hypernetwork=None)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
vanilla_tokens_generated = tokenizer.batch_decode( vanilla_tokens_generated = tokenizer.batch_decode(
vanilla_tokens_generated.cpu().numpy()) vanilla_tokens_generated.cpu().numpy())
data = [] data = []
for x in range(len(tokens_generated)): for x in range(len(tokens_generated)):
data.append([step, entry = {"Run": run_name,
prompt, "Step": step,
str(tokens_generated[x]), "Prompt": prompt,
str(vanilla_tokens_generated[x])]) "Generated Text": str(tokens_generated[x])}
if vanilla_tokens_generated:
entry["Vanilla Model"] = vanilla_tokens_generated[x]
data.append(entry)
return data return data
def report_wandb(data): def report_wandb(data):
columns = ["Step", "Prompt", "Generated Text", "Vanilla Model"] columns = list(data[0].keys())
wandb.log({"Generations": wandb.Table(data=data, columns=columns)}) step = data[0]["Step"]
data_list = [x.values() for x in data]
wandb.log({"Generations": wandb.Table(data=data_list, columns=columns)},
step=step)
def report_console(data): def print_colored_bars(color):
for gen in data[2]:
print(colored("======================================================", print(colored("======================================================",
"red")) color))
print(colored(gen, "green"))
print(colored("======================================================",
"red")) def report_console(data: List[dict]):
print_colored_bars("blue")
print(colored(data[0]['Prompt'], "white"))
print_colored_bars("blue")
for gen in data:
print_colored_bars("red")
print(colored(gen["Generated Text"], "green"))
def make_hypernet_saver(train_config, hypernetwork): def make_eval_function(hypernetwork: HyperNetworkSingle, config: dict) -> \
def hypernet_saver(id: str): Callable[[int], None]:
save_folder = Path(train_config["save_path"]) / id sample_data = {'rows': []}
gen_vanilla = config.get('generate_vanilla', False)
run_name = config.get('run_name', '')
def eval_function(curr_step: int) -> None:
print()
print_colored_bars('yellow')
print(f"Step: {curr_step}")
for prompt in prompts:
sampled = sample(prompt, 500, 3,
run_name=run_name,
hypernetwork=hypernetwork,
step=curr_step,
generate_vanilla=gen_vanilla)
report_console(sampled)
sample_data['rows'].extend(sampled)
print_colored_bars("red")
report_wandb(sample_data['rows'])
return eval_function
def make_hypernet_saver(hypernetwork: HyperNetworkSingle, config: dict) \
-> Callable[[str], None]:
def hypernet_saver(id: str) -> None:
save_folder = Path(config["save_path"]) / id
save_folder.mkdir(parents=True, exist_ok=True) save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt") torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt") opt.save(save_folder / "opt")
...@@ -239,11 +283,12 @@ parser.add_argument("--project_id", type=str, help="Project ID for reporting", ...@@ -239,11 +283,12 @@ parser.add_argument("--project_id", type=str, help="Project ID for reporting",
parser.add_argument("--logs", type=str, help="log directory location", parser.add_argument("--logs", type=str, help="log directory location",
default="./logs") default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion") parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False) parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False,
sample_vanilla=False)
args = parser.parse_args() args = parser.parse_args()
if args.output == '': if args.output == '':
args.output = f'./{args.run_name}' args.output = f'./{args.run_name}'
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
train_config = { train_config = {
"data_path": args.dataset, "data_path": args.dataset,
...@@ -264,6 +309,7 @@ train_config = { ...@@ -264,6 +309,7 @@ train_config = {
"loss_scale": args.loss_scale, "loss_scale": args.loss_scale,
"eval_every": args.eval_every, "eval_every": args.eval_every,
"context_size": args.context_size, "context_size": args.context_size,
"sample_vanilla": args.sample_vanilla,
} }
torch.manual_seed(train_config["seed"]) torch.manual_seed(train_config["seed"])
bs = train_config["bs"] bs = train_config["bs"]
...@@ -276,26 +322,25 @@ for param in model.parameters(): ...@@ -276,26 +322,25 @@ for param in model.parameters():
param.requires_grad = False param.requires_grad = False
for name, p in model.named_parameters(): for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name): if "ln" in name or "vocab_embed" in name:
p.requires_grad = True p.requires_grad = True
hypernetwork = HyperNetworkSingle(model.config).to(gpu).float() hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
for param in hypernetwork.parameters(): for param in hypernetwork.parameters():
param.requires_grad = True param.requires_grad = True
hypernetwork_saver = make_hypernet_saver(train_config, hypernetwork) hypernetwork_saver = make_hypernet_saver(hypernetwork, train_config)
eval_fn = make_eval_function(hypernetwork, train_config)
cp_list = sorted(os.listdir(train_config["save_path"]), cp_list = sorted(os.listdir(train_config["save_path"]),
key=lambda x: int(x.split("_")[-1])) key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len( last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(
cp_list) > 0 else None cp_list) > 0 else None
print(last_cp)
if last_cp and not args.no_resume: if last_cp and not args.no_resume:
print("Loading from step {}".format(cp_list[-1].split("_")[-1])) print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt")) hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(), opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(),
last_cp / "opt") last_cp / "opt")
else: else:
opt = optimizer.BasedOptimizer(hypernetwork.parameters(), opt = optimizer.BasedOptimizer(hypernetwork.parameters(),
train_config, train_config,
...@@ -303,13 +348,12 @@ else: ...@@ -303,13 +348,12 @@ else:
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden
# states from the get_logits function. # states from the get_logits function.
print(opt.curr_step)
train_dataset = dataset.ShardedDataset(train_config["context_size"], train_dataset = dataset.ShardedDataset(train_config["context_size"],
train_config["data_path"]) train_config["data_path"])
if last_cp: if last_cp:
train_dataset.skip = opt.curr_step * bs * gas train_dataset.skip = opt.curr_step
train_loader = data.DataLoader(train_dataset, train_loader = torch_data.DataLoader(train_dataset,
batch_size=bs * gas, batch_size=bs * gas,
shuffle=True, shuffle=True,
num_workers=0) num_workers=0)
...@@ -323,7 +367,6 @@ else: ...@@ -323,7 +367,6 @@ else:
curr_step = 0 curr_step = 0
t = tqdm(train_loader, initial=curr_step) t = tqdm(train_loader, initial=curr_step)
sample_data = []
for input_ids, labels in t: for input_ids, labels in t:
timex = time.perf_counter() timex = time.perf_counter()
...@@ -383,15 +426,10 @@ for input_ids, labels in t: ...@@ -383,15 +426,10 @@ for input_ids, labels in t:
hypernetwork_saver(f"step_{curr_step}") hypernetwork_saver(f"step_{curr_step}")
print(f"\nSaved model at step {curr_step}") print(f"\nSaved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0 and curr_step != 0: if curr_step % train_config["eval_every"] == 0:
for prompt in prompts: eval_fn(curr_step)
sampled = sample(prompt, 500, 3, hypernetwork=hypernetwork,
step=step)
print(f"PROMPT:\n{prompt}")
report_console(sampled)
sample_data = sample_data + sampled
report_wandb(sample_data)
curr_step += 1 curr_step += 1
eval_fn(curr_step)
hypernetwork_saver("final") hypernetwork_saver("final")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment