Commit 971ed5dc authored by novelailab's avatar novelailab

optimizer save load completed

parent 41b51369
......@@ -6,7 +6,7 @@ from basedformer import gptj
import os
import json
from dataclasses import dataclass
from pathlib import Path
'''
BaseLM config dataclass:
model_config = {
......@@ -27,7 +27,6 @@ class BaseLMConfig():
vocab_dim: int
eps: float
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module):
def __init__(self, config=None, lm=None):
......
......@@ -3,18 +3,29 @@ import numpy as np
import torch
from dotmap import DotMap
import pickle
import os
from pathlib import Path
#Based Optimizer
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=False):
warmup_percent = np.clip(step, 0, warmup_steps) / warmup_steps
anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps
#cosine schedule for annealing
return lr * warmup_percent - (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2
#kinda broken. doesn't start from 0
if cosine_warmup:
main_lr = lr * (1 - np.cos(np.pi * warmup_percent)) / 2
else:
main_lr = lr * warmup_percent
anneal_lr = (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2
return main_lr - anneal_lr
class BasedOptimizer:
def __init__(self, parameters, config, optimizer):
def __init__(self, parameters, config, optimizer, init=True):
if init:
self.init_config(config)
self.init_optimizer(parameters, optimizer)
def init_config(self, config):
defaults = {
"lr": 6e-4,
"end_lr": 6e-4,
......@@ -27,6 +38,9 @@ class BasedOptimizer:
"beta1": 0.9,
"beta2": 0.95,
"eps": 1e-4,
"max_lr": False,
"curr_step": 0,
"curr_lr": 0,
}
for k, v in defaults.items():
......@@ -35,35 +49,34 @@ class BasedOptimizer:
for k, v in config.items():
setattr(self, k, v)
self.max_lr = False
self.curr_step = 0
self.curr_lr = 0
if optimizer == "adamw":
self.optimizer = optim.AdamW(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
def init_optimizer(self, parameters, optimizer_name):
if optimizer_name == "adamw":
self.optimizer = optim.AdamW(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer == "adamw8bit":
elif optimizer_name == "adamw8bit":
import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer == "adafactor":
elif optimizer_name == "adafactor":
try:
from transformers.optimization import Adafactor
except ImportError:
raise ImportError("Please install transformers for Adafactor")
self.optimizer = Adafactor(params=parameters)
self.optimizer = Adafactor(params=self.parameters)
def step(self, scaler=None):
def step(self, dry_run=False, scaler=None):
if not dry_run:
if scaler:
scaler.step(self.optimizer)
else:
self.optimizer.step()
self.curr_step = self.curr_step + 1
self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
self.curr_step = self.curr_step + 1
if not self.max_lr:
if self.curr_lr == self.end_lr:
......@@ -85,15 +98,21 @@ class BasedOptimizer:
if self.curr_step != 0:
print(f"curr_lr: {str(self.get_current_lr())}")
def save(self, path):
torch.save(self.optimizer.state_dict(), path)
with open(path, 'wb') as f:
pickle.dump(self, f)
def save(self, path: Path):
path = path / "opt"
path.mkdir(parents=True, exist_ok=True)
torch.save(self.optimizer.state_dict(), path / "opt_states.pt")
del self.optimizer
metadata = self.__dict__
with open(path / "opt_metadata.pkl", 'wb') as f:
pickle.dump(metadata, f)
@classmethod
def load(cls, path):
with open(path, 'rb') as f:
based_optimizer = pickle.load(f)
def load(cls, parameters, path):
path = path / "opt"
with open(path / "opt_metadata.pkl", 'rb') as f:
metadata = pickle.load(f)
based_optimizer.optimizer.load_state_dict(torch.load(path))
based_optimizer = cls(parameters, metadata, metadata["optimizer_name"])
based_optimizer.optimizer.load_state_dict(torch.load(path / "opt_states.pt"))
return based_optimizer
\ No newline at end of file
......@@ -146,7 +146,6 @@ class HyperNetworkSingle(nn.Module):
return x.bfloat16()
model_config = {
"model_class":
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
......
from basedformer import optimizer
import torch
from tqdm import tqdm
import wandb
import os
from pathlib import Path
train_config = {
"lr": 5e-4,
"end_lr": 1e-4,
"warmup_steps": 100,
"anneal_steps": 90,
}
model = torch.nn.Linear(10, 100)
save_folder = "models/test_optimizer2"
if not os.path.isdir(save_folder + "/opt"):
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
else:
opt = optimizer.BasedOptimizer.load(model.parameters(), Path(save_folder))
wandb.init(project="opt-test", name="test")
for x in tqdm(range(opt.curr_step, 100)):
print(f"Step {opt.curr_step}: LR {opt.curr_lr}")
wandb.log({"lr": opt.curr_lr})
opt.step(dry_run=True)
#if x == 60:
#opt.save(Path(save_folder))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment