Commit 347ef912 authored by novelailab's avatar novelailab

pickle optimizer save/load

parent b07251f0
from torch import optim from torch import optim
import numpy as np import numpy as np
import torch
from dotmap import DotMap
import pickle
#Based Optimizer #Based Optimizer
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr): def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
...@@ -61,4 +64,16 @@ class BasedOptimizer: ...@@ -61,4 +64,16 @@ class BasedOptimizer:
print(f"weight_decay: {str(self.weight_decay)}") print(f"weight_decay: {str(self.weight_decay)}")
print(f"step: {str(self.curr_step)}") print(f"step: {str(self.curr_step)}")
if self.curr_step != 0: if self.curr_step != 0:
print(f"curr_lr: {str(self.get_current_lr())}") print(f"curr_lr: {str(self.get_current_lr())}")
\ No newline at end of file
def save(self, path):
torch.save(self.optimizer.state_dict(), path)
with open(path, 'wb') as f:
pickle.dump(self, f)
@classmethod
def load(cls, path):
with open(path, 'rb') as f:
based_optimizer = pickle.load(f)
based_optimizer.optimizer.load_state_dict(torch.load(path))
return based_optimizer
\ No newline at end of file
...@@ -33,19 +33,21 @@ train_config = { ...@@ -33,19 +33,21 @@ train_config = {
"bs": 16, "bs": 16,
"gas": 16, "gas": 16,
"seed": 69, "seed": 69,
"save_every": 50,
} }
bs = train_config["bs"] bs = train_config["bs"]
gas = train_config["gas"] gas = train_config["gas"]
model = GPTModel.neox_init(model_config).cuda().bfloat16() model = GPTModel.neox_init(model_config).cuda().bfloat16()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw") opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel. # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
train_dataset = utils.FbDataset(2049, train_config["data_path"]) train_dataset = utils.FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0) train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model_config}) wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model_config})
t = tqdm(train_loader) t = tqdm(train_loader)
curr_step = 0
for input_ids, labels in t: for input_ids, labels in t:
timex = time.perf_counter() timex = time.perf_counter()
input_ids = input_ids.cuda() input_ids = input_ids.cuda()
...@@ -59,7 +61,7 @@ for input_ids, labels in t: ...@@ -59,7 +61,7 @@ for input_ids, labels in t:
gas_loss = F.cross_entropy(logits, gas_labels) gas_loss = F.cross_entropy(logits, gas_labels)
gas_loss.backward() gas_loss.backward()
loss += gas_loss.item() loss += gas_loss.item()
loss = loss / gas loss = loss / gas
opt.step() opt.step()
opt.zero_grad() opt.zero_grad()
...@@ -67,4 +69,8 @@ for input_ids, labels in t: ...@@ -67,4 +69,8 @@ for input_ids, labels in t:
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * 2048 tokens_per_sec = step_per_sec * 2048
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}") t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr}) wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr})
\ No newline at end of file curr_step += 1
if curr_step % train_config["save_every"] == 0:
model.save(train_config["save_path"])
print(f"Saved model at step {curr_step}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment