Commit 971ed5dc authored by novelailab's avatar novelailab

optimizer save load completed

parent 41b51369
...@@ -6,7 +6,7 @@ from basedformer import gptj ...@@ -6,7 +6,7 @@ from basedformer import gptj
import os import os
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
''' '''
BaseLM config dataclass: BaseLM config dataclass:
model_config = { model_config = {
...@@ -27,7 +27,6 @@ class BaseLMConfig(): ...@@ -27,7 +27,6 @@ class BaseLMConfig():
vocab_dim: int vocab_dim: int
eps: float eps: float
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense. #Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module): class BaseLM(nn.Module):
def __init__(self, config=None, lm=None): def __init__(self, config=None, lm=None):
......
...@@ -3,18 +3,29 @@ import numpy as np ...@@ -3,18 +3,29 @@ import numpy as np
import torch import torch
from dotmap import DotMap from dotmap import DotMap
import pickle import pickle
import os
from pathlib import Path
#Based Optimizer #Based Optimizer
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr): def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=False):
warmup_percent = np.clip(step, 0, warmup_steps) / warmup_steps warmup_percent = np.clip(step, 0, warmup_steps) / warmup_steps
anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps
#cosine schedule for annealing
return lr * warmup_percent - (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2
#kinda broken. doesn't start from 0
if cosine_warmup:
main_lr = lr * (1 - np.cos(np.pi * warmup_percent)) / 2
else:
main_lr = lr * warmup_percent
anneal_lr = (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2
return main_lr - anneal_lr
class BasedOptimizer: class BasedOptimizer:
def __init__(self, parameters, config, optimizer): def __init__(self, parameters, config, optimizer, init=True):
if init:
self.init_config(config)
self.init_optimizer(parameters, optimizer)
def init_config(self, config):
defaults = { defaults = {
"lr": 6e-4, "lr": 6e-4,
"end_lr": 6e-4, "end_lr": 6e-4,
...@@ -27,6 +38,9 @@ class BasedOptimizer: ...@@ -27,6 +38,9 @@ class BasedOptimizer:
"beta1": 0.9, "beta1": 0.9,
"beta2": 0.95, "beta2": 0.95,
"eps": 1e-4, "eps": 1e-4,
"max_lr": False,
"curr_step": 0,
"curr_lr": 0,
} }
for k, v in defaults.items(): for k, v in defaults.items():
...@@ -35,36 +49,35 @@ class BasedOptimizer: ...@@ -35,36 +49,35 @@ class BasedOptimizer:
for k, v in config.items(): for k, v in config.items():
setattr(self, k, v) setattr(self, k, v)
self.max_lr = False def init_optimizer(self, parameters, optimizer_name):
self.curr_step = 0 if optimizer_name == "adamw":
self.curr_lr = 0 self.optimizer = optim.AdamW(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
if optimizer == "adamw":
self.optimizer = optim.AdamW(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer == "adamw8bit": elif optimizer_name == "adamw8bit":
import bitsandbytes as bnb import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps) self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer == "adafactor": elif optimizer_name == "adafactor":
try: try:
from transformers.optimization import Adafactor from transformers.optimization import Adafactor
except ImportError: except ImportError:
raise ImportError("Please install transformers for Adafactor") raise ImportError("Please install transformers for Adafactor")
self.optimizer = Adafactor(params=parameters) self.optimizer = Adafactor(params=self.parameters)
def step(self, scaler=None): def step(self, dry_run=False, scaler=None):
if scaler: if not dry_run:
scaler.step(self.optimizer) if scaler:
scaler.step(self.optimizer)
else: else:
self.optimizer.step() self.optimizer.step()
self.curr_step = self.curr_step + 1
self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr) self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
self.curr_step = self.curr_step + 1
if not self.max_lr: if not self.max_lr:
if self.curr_lr == self.end_lr: if self.curr_lr == self.end_lr:
print("max lr reached.") print("max lr reached.")
...@@ -85,15 +98,21 @@ class BasedOptimizer: ...@@ -85,15 +98,21 @@ class BasedOptimizer:
if self.curr_step != 0: if self.curr_step != 0:
print(f"curr_lr: {str(self.get_current_lr())}") print(f"curr_lr: {str(self.get_current_lr())}")
def save(self, path): def save(self, path: Path):
torch.save(self.optimizer.state_dict(), path) path = path / "opt"
with open(path, 'wb') as f: path.mkdir(parents=True, exist_ok=True)
pickle.dump(self, f) torch.save(self.optimizer.state_dict(), path / "opt_states.pt")
del self.optimizer
metadata = self.__dict__
with open(path / "opt_metadata.pkl", 'wb') as f:
pickle.dump(metadata, f)
@classmethod @classmethod
def load(cls, path): def load(cls, parameters, path):
with open(path, 'rb') as f: path = path / "opt"
based_optimizer = pickle.load(f) with open(path / "opt_metadata.pkl", 'rb') as f:
metadata = pickle.load(f)
based_optimizer.optimizer.load_state_dict(torch.load(path))
based_optimizer = cls(parameters, metadata, metadata["optimizer_name"])
based_optimizer.optimizer.load_state_dict(torch.load(path / "opt_states.pt"))
return based_optimizer return based_optimizer
\ No newline at end of file
...@@ -146,7 +146,6 @@ class HyperNetworkSingle(nn.Module): ...@@ -146,7 +146,6 @@ class HyperNetworkSingle(nn.Module):
return x.bfloat16() return x.bfloat16()
model_config = { model_config = {
"model_class":
"n_layer": 28, "n_layer": 28,
"n_head": 16, "n_head": 16,
"hidden_dim": 4096, "hidden_dim": 4096,
......
from basedformer import optimizer
import torch
from tqdm import tqdm
import wandb
import os
from pathlib import Path
train_config = {
"lr": 5e-4,
"end_lr": 1e-4,
"warmup_steps": 100,
"anneal_steps": 90,
}
model = torch.nn.Linear(10, 100)
save_folder = "models/test_optimizer2"
if not os.path.isdir(save_folder + "/opt"):
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
else:
opt = optimizer.BasedOptimizer.load(model.parameters(), Path(save_folder))
wandb.init(project="opt-test", name="test")
for x in tqdm(range(opt.curr_step, 100)):
print(f"Step {opt.curr_step}: LR {opt.curr_lr}")
wandb.log({"lr": opt.curr_lr})
opt.step(dry_run=True)
#if x == 60:
#opt.save(Path(save_folder))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment