Commit 86e815ab authored by novelailab's avatar novelailab

upstream

parent 1ab8bac6
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer import lm_base
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
cumsum = x.cumsum(dim = 1)
*x, x_pass = x.chunk(amt + 1, dim = -1)
*x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)
amts = 2 ** torch.arange(amt)
amts = amts.tolist()
shifts = []
denom = torch.arange(n, device = device)
for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
normed_shifted_x = shifted_chunk / (shifted_denom + eps)
shifts.append(normed_shifted_x)
return torch.cat((*shifts, x_pass), dim = -1)
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
def _split_heads(tensor, num_heads, attn_head_size, rotary):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x):
query = self.q_proj(x)
key = self.k_proj(x)
value = self.v_proj(x)
query = _split_heads(query, self.n_head, self.head_dim, True)
key = _split_heads(key, self.n_head, self.head_dim, True)
value = _split_heads(value, self.n_head, self.head_dim, False)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = _merge_heads(x, self.n_head, self.head_dim)
x = self.out_proj(x)
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPT2Layer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
nn.Module.__init__(self)
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ln_postattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
residual = residual + attn_out
x = self.ln_postattn(x)
ff_out = self.ff(x, act_ck)
x = residual + ff_out
return x
class GPT2Model(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPT2Layer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
class GPT2BaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPT2Model
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPT2BaseLM.load(config, path, state_dict)
return model
......@@ -51,7 +51,7 @@ class BaseLM(nn.Module):
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.config.n_layer)))
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.config["n_layer"])))
@classmethod
def init(cls, config):
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer import lm_base
def token_shift(x, window_size=1):
size = x.size()[-1] // (window_size + 1)
def shift(x, t, s):
return nn.functional.pad(x[:, :-t, (t - 1) * s:t * s], (0, 0, t, 0))[:, :x.size()[-2], :]
time_shifts = [shift(x, t, size) for t in range(1, window_size + 1)]
current_x = [x[:, :, len(time_shifts) * size:]]
x = torch.cat(time_shifts + current_x, dim=-1)
return x
def token_shift_no_mix(x, window_size=1):
size = x.size()[-1] // (window_size + 1)
def shift(x, t, s):
return nn.functional.pad(x[:, :-t, :s], (0, 0, t, 0))[:, :x.size()[-2], :]
time_shifts = [shift(x, t, size) for t in range(window_size, 0,-1)]
current_x = [x[:, :, len(time_shifts) * size:]]
x = torch.cat(time_shifts + current_x, dim=-1)
return x
def _split_heads(tensor, num_heads, attn_head_size, rotary):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x):
query = self.q_proj(x)
key = self.k_proj(x)
value = self.v_proj(x)
query = _split_heads(query, self.n_head, self.head_dim, True)
key = _split_heads(key, self.n_head, self.head_dim, True)
value = _split_heads(value, self.n_head, self.head_dim, False)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = _merge_heads(x, self.n_head, self.head_dim)
x = self.out_proj(x)
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = token_shift_no_mix(x)
x = self.ff2(x)
return x
class GPTNELayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
nn.Module.__init__(self)
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x
class GPTNEModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTNELayer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
class GPTNEBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTNEModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTNEBaseLM.load(config, path, state_dict)
return model
......@@ -4,39 +4,37 @@ import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from lm_train import utils
from torch.utils import data
from basedformer import lm_base, optimizer
from basedformer import optimizer, utils, gptj, noemblm, gpt2
import yaml
import sys
from tqdm import tqdm
import time
import wandb
import numpy as np
import os
model_config = {
"n_layer": 12,
"n_head": 12,
"hidden_dim": 768,
"n_layer": 3,
"n_head": 16,
"hidden_dim": 1024,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/fixedj",
"run_name": "gpt-j-owt2-6b-preattn",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshiftnomix-3L-16A-1024H",
"do_save": True,
"run_name": "gptj-nopos-tokenshiftnomix-owt2-72M-3L-16A-1024H-fp16AMP-512ctx-16bs-1e-4lrinit",
"lr": 1e-4,
"end_lr": 1e-4,
"warmup_steps": 50,
"bs": 12,
"gas": 10,
"warmup_steps": 100,
"bs": 16,
"gas": 1,
"seed": 69,
"save_every": 500,
"amp": True,
......@@ -48,17 +46,37 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = GPTModel.gpt2_init(model_config).cuda().float()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = noemblm.GPTNEBaseLM.init(model_config).cuda().float()
model.train()
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None
print(last_cp)
if last_cp:
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
model.load(model_config, last_cp / "lm", strict=True)
opt = optimizer.BasedOptimizer.load(model.parameters(), last_cp / "opt")
else:
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model_config})
if last_cp:
train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model_config})
if last_cp:
curr_step = opt.curr_step
else:
curr_step = 0
t = tqdm(train_loader)
curr_step = 0
t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler()
......@@ -69,9 +87,10 @@ for input_ids, labels in t:
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=None, act_ck=False)
logits = model.lm(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
gas_labels = labels[x*bs:(x+1)*bs, :512].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
......@@ -88,7 +107,6 @@ for input_ids, labels in t:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
if train_config["loss_scale"]:
opt.step(scaler=scaler)
else:
opt.step()
......@@ -96,12 +114,27 @@ for input_ids, labels in t:
scaler.update()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas)
sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * 1024
tokens_per_sec = (step_per_sec * 1024) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()})
curr_step += 1
if curr_step % train_config["save_every"] == 0:
model.save(train_config["save_path"] + f"/{curr_step}")
print(f"Saved model at step {curr_step}")
wandb.log(
{
"train/loss": loss,
"train/tokens_per_sec": tokens_per_sec,
"train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/loss_scale": scaler.get_scale()
},
step=curr_step)
if train_config["do_save"]:
if curr_step % train_config["save_every"] == 0:
save_folder = Path(train_config["save_path"]) / f"step_{curr_step}"
save_folder.mkdir(parents=True, exist_ok=True)
model.save(save_folder / "lm")
opt.save(save_folder / "opt")
print(f"Saved model at step {curr_step}")
curr_step += 1
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment