Commit 4141d527 authored by novelailab's avatar novelailab

a

parent 34739983
...@@ -202,7 +202,7 @@ class GPTJLayer(nn.Module): ...@@ -202,7 +202,7 @@ class GPTJLayer(nn.Module):
return x return x
class GPTJModel(nn.Module): class GPTJModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTJLayer, device="cuda", dtype=torch.float16): def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTJLayer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self) nn.Module.__init__(self)
self.n_layer = n_layer self.n_layer = n_layer
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
from torch import nn from torch import nn
from basedformer import gptj from basedformer import gptj
import os import os
import json
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense. #Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module): class BaseLM(nn.Module):
...@@ -32,7 +33,7 @@ class BaseLM(nn.Module): ...@@ -32,7 +33,7 @@ class BaseLM(nn.Module):
@classmethod @classmethod
def init(cls, config): def init(cls, config):
lm = config.model_class(**config) lm = config["model_class"](**config)
model = cls(config, lm) model = cls(config, lm)
model.init_weights() model.init_weights()
#make this modular later #make this modular later
...@@ -46,13 +47,13 @@ class BaseLM(nn.Module): ...@@ -46,13 +47,13 @@ class BaseLM(nn.Module):
return model return model
@classmethod @classmethod
def load(cls, model_class, config, path=None, state_dict=None, strict=False): def load(cls, config, path=None, state_dict=None, strict=False):
# I am kinda sad that we will not have a load function in lm object itself. # I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions to that as well but not sure. # might be better to add load functions -- actually nope.
if path: if path:
state_dict = utils.SplitCheckpoint(path, device="cuda") state_dict = utils.SplitCheckpoint(path, device="cuda")
lm = model_class(**config) lm = config["model_class"](**config)
model = cls(config, lm) model = cls(config, lm)
model.lm.load_state_dict(state_dict, strict=strict) model.lm.load_state_dict(state_dict, strict=strict)
return model return model
...@@ -73,11 +74,12 @@ class BaseLM(nn.Module): ...@@ -73,11 +74,12 @@ class BaseLM(nn.Module):
def load_gpt_j(path="models/6b", state_dict=None): def load_gpt_j(path="models/6b", state_dict=None):
config = { config = {
"model_class": gptj.GPTJModel,
"n_layer": 28, "n_layer": 28,
"n_head": 16, "n_head": 16,
"hidden_dim": 4096, "hidden_dim": 4096,
"vocab_dim": 50400, "vocab_dim": 50400,
"eps": 1e-5 "eps": 1e-5
} }
model = BaseLM.load(gptj.GPTJModel, config, path, state_dict) model = BaseLM.load(config, path, state_dict)
return model return model
...@@ -10,6 +10,7 @@ def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr): ...@@ -10,6 +10,7 @@ def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps
#cosine schedule for annealing #cosine schedule for annealing
return lr * warmup_percent - (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2 return lr * warmup_percent - (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2
#kinda broken. doesn't start from 0
class BasedOptimizer: class BasedOptimizer:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment