Commit 6c1a2d67 authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

Merge pull request #9 from NovelAI/os.changes

parents 91470a5a 9fc1cc21
import numpy as np import numpy as np
import torch import torch
import mmap import mmap
import pickle
import concurrent import concurrent
from torch.utils import data from torch.utils import data
from simplejpeg import decode_jpeg
import simplejpeg
import pickle import pickle
from pathlib import Path from pathlib import Path
from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from concurrent.futures import as_completed from concurrent.futures import as_completed
import requests import requests
...@@ -54,6 +50,9 @@ class ShardedDataset(data.Dataset): ...@@ -54,6 +50,9 @@ class ShardedDataset(data.Dataset):
class ShardedImageDataset(data.Dataset): class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None, def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"): outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
from simplejpeg import decode_jpeg
import simplejpeg
from PIL import Image
self.skip = skip self.skip = skip
self.threads = threads self.threads = threads
......
...@@ -87,11 +87,13 @@ def load_from_path(config_folder=None, strict=False): ...@@ -87,11 +87,13 @@ def load_from_path(config_folder=None, strict=False):
model = _load_dict_model(model_class, model_config, model_path, strict=strict) model = _load_dict_model(model_class, model_config, model_path, strict=strict)
return model return model
def _load_dict_model(model_class, config, path=None, state_dict=None, strict=False): def _load_dict_model(model_class, config, path=None, state_dict=None,
strict=False, device="cuda"):
# I am kinda sad that we will not have a load function in lm object itself. # I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope. # might be better to add load functions -- actually nope.
if path: if path:
state_dict = utils.SplitCheckpoint(path, device="cuda") state_dict = utils.SplitCheckpoint(path, device=device)
state_dict.device = device
model= utils.no_init(lambda: model_class(config)) model= utils.no_init(lambda: model_class(config))
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
......
from re import A
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from pathlib import Path
from torch.utils import data from torch.utils import data
import math
import sys
from tqdm import tqdm
import time
import wandb import wandb
import numpy as np
from torch.utils.checkpoint import checkpoint as ck from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil
from basedformer import optimizer, lm_utils, dataset from basedformer import optimizer, lm_utils, dataset
from basedformer.utils import * from basedformer.utils import *
import glob
from transformers import AutoTokenizer from transformers import AutoTokenizer
from basedformer import sampling from basedformer import sampling
from icecream import ic
from termcolor import colored from termcolor import colored
import argparse
gpu = "cuda"
amp = torch.cuda.amp
if gpu != "cuda":
amp = torch.amp
scaler = torch.cuda.amp.GradScaler()
prompts = ["<|endoftext|>",
"The year was",
"I grabbed my",
"She lifted the",
"He was known as the",
"The tavern was full again, so I ended up sharing a table with three very different creatures: a",
"I had been hiking in the wilderness when suddenly a",
"She spread her",
"The mercurial and beautiful woman laughed",
"[ Author:",
"[ Tags:",
"***"]
def _init_weights(module): def _init_weights(module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
...@@ -33,49 +43,21 @@ def _init_weights(module): ...@@ -33,49 +43,21 @@ def _init_weights(module):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
cumsum = x.cumsum(dim = 1)
*x, x_pass = x.chunk(amt + 1, dim = -1)
*x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)
amts = 2 ** torch.arange(amt)
amts = amts.tolist()
shifts = []
denom = torch.arange(n, device = device)
for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
normed_shifted_x = shifted_chunk / (shifted_denom + eps)
shifts.append(normed_shifted_x)
return torch.cat((*shifts, x_pass), dim = -1)
def discounted_cumsum(t, gamma):
try:
from torch_discounted_cumsum import discounted_cumsum_left
except ImportError:
print('unable to import torch_discounted_cumsum - please run `pip install torch-discounted-cumsum`')
b, n, d = t.shape def shift(x, amt, dim=-1):
t = rearrange(t, 'b n d -> (b d) n') return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value=0.)
t = discounted_cumsum_left(t, gamma)
t = rearrange(t, '(b d) n -> b n d', b = b)
return t
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
class HyperNetworkGRU(nn.Module): class HyperNetworkGRU(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
embed_dim = config["hidden_dim"] embed_dim = config["hidden_dim"]
self.linear1 = nn.Linear(embed_dim, embed_dim//8) self.linear1 = nn.Linear(embed_dim, embed_dim // 8)
self.gru = nn.GRU(embed_dim//8, embed_dim // 8, num_layers=1, bidirectional=False, batch_first=True) self.gru = nn.GRU(embed_dim // 8,
embed_dim // 8,
num_layers=1,
bidirectional=False,
batch_first=True)
self.linear2 = nn.Linear(embed_dim // 8, embed_dim) self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5) self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
self.activation = gelu_new self.activation = gelu_new
...@@ -84,42 +66,44 @@ class HyperNetworkGRU(nn.Module): ...@@ -84,42 +66,44 @@ class HyperNetworkGRU(nn.Module):
_init_weights(module) _init_weights(module)
for param in self.linear2.parameters(): for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"]))) param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
for param in self.gru.parameters(): for param in self.gru.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"]))) param.data.normal_(mean=0.0,
std=(0.02 / math.sqrt(2 * config["n_layer"])))
self.linear_gru = nn.Sequential(
self.linear1,
self.gru)
self.layernorm_linear = nn.Sequential(
self.ln_1,
self.linear2)
def forward(self, x): def forward(self, x):
x = x.float() x = x.float()
x = self.linear1(x) x = self.linear_gru.forward(x)[0]
x = self.gru(x)[0] x = ck(self.activation,
x = self.ln_1(x) self.layernorm_linear.forward(x))
x = self.linear2(x)
x = ck(self.activation, x)
return x.bfloat16() return x.bfloat16()
class HyperNetwork(nn.Module): class HyperNetwork(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
embed_dim = config["hidden_dim"] embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True) self.linear = nn.Linear(embed_dim, embed_dim // 4, bias=True)
self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True) self.linear2 = nn.Linear(embed_dim // 4, embed_dim, bias=True)
self.activation = torch.nn.functional.gelu self.activation = torch.nn.functional.gelu
self.num_shifts = ceil(log2(2048)) - 1
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules(): for module in self.modules():
_init_weights(module) _init_weights(module)
for param in self.linear2.parameters(): for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"]))) param.data.normal_(mean=0.0,
#state = self.state_dict() std=(0.02 / math.sqrt(2 * config["n_layer"])))
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, x): def forward(self, x):
x = x.float() x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x) x = self.linear(x)
x = ck(self.activation, x) x = ck(self.activation, x)
x = self.linear2(x) x = self.linear2(x)
...@@ -132,33 +116,29 @@ class HyperNetworkSingle(nn.Module): ...@@ -132,33 +116,29 @@ class HyperNetworkSingle(nn.Module):
embed_dim = config["hidden_dim"] embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim, bias=True) self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
self.activation = gelu_new self.activation = gelu_new
#self.linear.weight.data.normal_(mean=0.0, std=0.02) # self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules(): for module in self.modules():
_init_weights(module) _init_weights(module)
for param in self.linear.parameters(): for param in self.linear.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"]))) param.data.normal_(mean=0.0,
#state = self.state_dict() std=(0.02 / math.sqrt(2 * config["n_layer"])))
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, x): def forward(self, x):
x = x.float() x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x) x = self.linear(x)
x = x.mul(torch.sigmoid(x)) x = x.mul(torch.sigmoid(x))
return x.bfloat16() return x.bfloat16()
tokenizer = AutoTokenizer.from_pretrained('gpt2') tokenizer = AutoTokenizer.from_pretrained('gpt2')
@torch.no_grad() @torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None): def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0):
torch.seed() torch.seed()
tokens = tokenizer.encode(prompt) tokens = tokenizer.encode(prompt)
#print("Prompt:") tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu)
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = [tokens] * bsz tokens = [tokens] * bsz
tokens = torch.cat(tokens, dim=0) tokens = torch.cat(tokens, dim=0)
...@@ -172,41 +152,117 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None): ...@@ -172,41 +152,117 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None):
"temp": 0.8, "temp": 0.8,
} }
ops_list = [ops] * bsz ops_list = [ops] * bsz
tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=hypernetwork, non_deterministic=True) tokens_generated = sampling.generate(model.forward,
vanilla_tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=None) tokens,
n_tokens,
ops_list=ops_list,
hypernetwork=hypernetwork,
non_deterministic=True)
vanilla_tokens_generated = sampling.generate(model.forward,
tokens,
n_tokens,
ops_list=ops_list,
hypernetwork=None)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy()) tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
vanilla_tokens_generated = tokenizer.batch_decode(vanilla_tokens_generated.cpu().numpy()) vanilla_tokens_generated = tokenizer.batch_decode(
### send to wandb vanilla_tokens_generated.cpu().numpy())
columns = ["Prompt", "Generated Text", "Vanilla Model"]
data = [] data = []
for x in range(len(tokens_generated)): for x in range(len(tokens_generated)):
data.append([prompt, str(tokens_generated[x]), str(vanilla_tokens_generated[x])]) data.append([step,
prompt,
str(tokens_generated[x]),
str(vanilla_tokens_generated[x])])
for gen in tokens_generated: return data
print(colored("==========================================================", "red"))
print(colored(gen, "green"))
print(colored("==========================================================", "red")) def report_wandb(data):
columns = ["Step", "Prompt", "Generated Text", "Vanilla Model"]
wandb.log({"Generations": wandb.Table(data=data, columns=columns)}) wandb.log({"Generations": wandb.Table(data=data, columns=columns)})
def report_console(data):
for gen in data[2]:
print(colored("======================================================",
"red"))
print(colored(gen, "green"))
print(colored("======================================================",
"red"))
def make_hypernet_saver(train_config, hypernetwork):
def hypernet_saver(id: str):
save_folder = Path(train_config["save_path"]) / id
save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt")
return hypernet_saver
parser = argparse.ArgumentParser(description='Hypernetwork Finetuner')
parser.add_argument('--run_name', type=str, help='the run name to use',
required=True)
parser.add_argument('--model', type=str, help='the model to train against',
required=True)
parser.add_argument('--dataset', type=str, help='pre-tokenized dataset to use',
required=True)
parser.add_argument("--output", type=str, help='output path',
default='')
parser.add_argument('--optimizer', type=str, help='the optimizer to use',
default='adamw')
parser.add_argument('--lr', type=float, help='learning rate', default=2e-4)
parser.add_argument('--end_lr', type=float, help='end learning rate',
default=2e-4)
parser.add_argument('--warmup', type=int, help='warmup steps', default=10)
parser.add_argument('--bs', type=int, help='batch size', default=4)
parser.add_argument('--gas', type=int, help='gas', default=1)
parser.add_argument('--seed', type=int, help="Random seed value",
default=42)
parser.add_argument("--save_steps", type=int,
help='# of steps between checkpoint saves',
default=300)
parser.add_argument("--amp", type=bool, help='enable amp', default=False)
parser.add_argument('--loss_scale', type=bool, help='whether to scale loss',
default=False)
parser.add_argument("--eval_every", type=int,
help='evaluate hypernetwork every x steps',
default=100)
parser.add_argument('--output_path', type=str, help="Root path of all output",
default="./")
parser.add_argument('--no_resume', type=bool, default=False,
help="Do not resume from last checkpoint")
parser.add_argument("--context_size", type=int, help="Dataset context sizes",
default=2048)
parser.add_argument("--project_id", type=str, help="Project ID for reporting",
default="hypernetwork-training")
parser.add_argument("--logs", type=str, help="log directory location",
default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False)
args = parser.parse_args()
if args.output == '':
args.output = f'./{args.run_name}'
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
train_config = { train_config = {
"data_path": "/home/xuser/nvme1/dataset/enwik9-gpt2-2049.map", "data_path": args.dataset,
"save_path": "/home/xuser/models/enwik9-sigurdv4-hypernet2", "save_path": args.output,
"lm_path": "/home/xuser/nvme1/pretrained/sigurdv4", "lm_path": args.model,
"optimizer": "adamw", "optimizer": args.optimizer,
"masked_softmax_fusion": False, "masked_softmax_fusion": args.masked,
"do_save": True, "do_save": args.save_steps != 0,
"run_name": "gptj-6b-enwik9-6b-postln-bf16-2e-4-4bsz-every5layer", "run_name": args.run_name,
"lr": 2e-4, "lr": args.lr,
"end_lr": 2e-4, "end_lr": args.end_lr,
"warmup_steps": 50, "warmup_steps": args.warmup,
"bs": 4, "bs": args.bs,
"gas": 1, "gas": args.gas,
"seed": 69, "seed": args.seed,
"save_every": 300, "save_every": args.save_steps,
"amp": False, "amp": args.amp,
"loss_scale": False, "loss_scale": args.loss_scale,
"eval_every": 100, "eval_every": args.eval_every,
} }
torch.manual_seed(train_config["seed"]) torch.manual_seed(train_config["seed"])
bs = train_config["bs"] bs = train_config["bs"]
...@@ -214,8 +270,7 @@ gas = train_config["gas"] ...@@ -214,8 +270,7 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True) Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float() model = lm_utils.load_from_path(train_config["lm_path"]).to(gpu).bfloat16()
model = lm_utils.load_from_path("/home/xuser/nvme1/pretrained/sigurdv4").cuda().bfloat16()
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -223,32 +278,42 @@ for name, p in model.named_parameters(): ...@@ -223,32 +278,42 @@ for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name): if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True p.requires_grad = True
hypernetwork = HyperNetworkSingle(model.config).cuda().float() hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(2)])
for param in hypernetwork.parameters(): for param in hypernetwork.parameters():
param.requires_grad = True param.requires_grad = True
hypernetwork_saver = make_hypernet_saver(train_config, hypernetwork)
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1])) cp_list = sorted(os.listdir(train_config["save_path"]),
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(
cp_list) > 0 else None
print(last_cp) print(last_cp)
if last_cp: if last_cp and not args.no_resume:
print("Loading from step {}".format(cp_list[-1].split("_")[-1])) print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt")) hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(), last_cp / "opt") opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(),
last_cp / "opt")
else: else:
opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, train_config["optimizer"]) opt = optimizer.BasedOptimizer(hypernetwork.parameters(),
train_config,
train_config["optimizer"])
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function. # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden
# states from the get_logits function.
print(opt.curr_step) print(opt.curr_step)
train_dataset = dataset.ShardedDataset(2049, train_config["data_path"]) train_dataset = dataset.ShardedDataset(2049, train_config["data_path"])
if last_cp: if last_cp:
train_dataset.skip = opt.curr_step * bs * gas train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, ) train_loader = data.DataLoader(train_dataset,
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model.config}) batch_size=bs * gas,
shuffle=False,
num_workers=0)
wandb.init(project="hypernetwork-tests",
name=train_config["run_name"],
config={**train_config, **model.config})
if last_cp: if last_cp:
curr_step = opt.curr_step curr_step = opt.curr_step
...@@ -256,21 +321,22 @@ else: ...@@ -256,21 +321,22 @@ else:
curr_step = 0 curr_step = 0
t = tqdm(train_loader, initial=curr_step) t = tqdm(train_loader, initial=curr_step)
sample_data = []
scaler = torch.cuda.amp.GradScaler()
#sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
for input_ids, labels in t: for input_ids, labels in t:
timex = time.perf_counter() timex = time.perf_counter()
input_ids = input_ids.cuda() input_ids = input_ids.to(gpu)
labels = labels.cuda() labels = labels.to(gpu)
loss = 0 loss = 0
for x in range(train_config["gas"]): for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16): with amp.autocast(enabled=train_config["amp"],
logits, _ = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=True) dtype=torch.float16):
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0])) logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu),
hypernetwork=hypernetwork,
act_ck=True)
# print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1]) logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous() gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous()
gas_labels = gas_labels.view(-1) gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels) gas_loss = F.cross_entropy(logits, gas_labels)
...@@ -297,26 +363,34 @@ for input_ids, labels in t: ...@@ -297,26 +363,34 @@ for input_ids, labels in t:
sec_per_step = (time.perf_counter() - timex) sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas tokens_per_sec = (step_per_sec * 2048) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}") t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step,"
+ f"{tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log( wandb.log(
{ {
"train/loss": loss, "train/loss": loss,
"train/tokens_per_sec": tokens_per_sec, "train/tokens_per_sec": tokens_per_sec,
"train/sec_per_step": sec_per_step, "train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec, "train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr, "train/lr": opt.curr_lr,
"train/loss_scale": scaler.get_scale() "train/loss_scale": scaler.get_scale()
}, },
step=curr_step) step=curr_step)
if train_config["do_save"] and curr_step % train_config["save_every"] == 0 and curr_step != 0: if train_config["do_save"] and \
save_folder = Path(train_config["save_path"]) / f"step_{curr_step}" curr_step % train_config["save_every"] == 0 and \
save_folder.mkdir(parents=True, exist_ok=True) curr_step != 0:
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt") hypernetwork_saver(f"step_{curr_step}")
opt.save(save_folder / "opt") print(f"\nSaved model at step {curr_step}")
print(f"Saved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0 and curr_step != 0: if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork) for prompt in prompts:
sampled = sample(prompt, 500, 3, hypernetwork=hypernetwork,
step=step)
print(f"PROMPT:\n{prompt}")
report_console(sampled)
sample_data = sample_data + sampled
report_wandb(sample_data)
curr_step += 1 curr_step += 1
\ No newline at end of file hypernetwork_saver("final")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment