Commit 347ef912 authored by novelailab's avatar novelailab

pickle optimizer save/load

parent b07251f0
from torch import optim
import numpy as np
import torch
from dotmap import DotMap
import pickle
#Based Optimizer
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
......@@ -61,4 +64,16 @@ class BasedOptimizer:
print(f"weight_decay: {str(self.weight_decay)}")
print(f"step: {str(self.curr_step)}")
if self.curr_step != 0:
print(f"curr_lr: {str(self.get_current_lr())}")
\ No newline at end of file
print(f"curr_lr: {str(self.get_current_lr())}")
def save(self, path):
torch.save(self.optimizer.state_dict(), path)
with open(path, 'wb') as f:
pickle.dump(self, f)
@classmethod
def load(cls, path):
with open(path, 'rb') as f:
based_optimizer = pickle.load(f)
based_optimizer.optimizer.load_state_dict(torch.load(path))
return based_optimizer
\ No newline at end of file
......@@ -33,19 +33,21 @@ train_config = {
"bs": 16,
"gas": 16,
"seed": 69,
"save_every": 50,
}
bs = train_config["bs"]
gas = train_config["gas"]
model = GPTModel.neox_init(model_config).cuda().bfloat16()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel.
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
train_dataset = utils.FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model_config})
t = tqdm(train_loader)
curr_step = 0
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
......@@ -59,7 +61,7 @@ for input_ids, labels in t:
gas_loss = F.cross_entropy(logits, gas_labels)
gas_loss.backward()
loss += gas_loss.item()
loss = loss / gas
opt.step()
opt.zero_grad()
......@@ -67,4 +69,8 @@ for input_ids, labels in t:
step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * 2048
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr})
\ No newline at end of file
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr})
curr_step += 1
if curr_step % train_config["save_every"] == 0:
model.save(train_config["save_path"])
print(f"Saved model at step {curr_step}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment