Commit 41f39980 authored by novelailab's avatar novelailab

make lm_base functions functional

parent 8d44445e
...@@ -7,36 +7,9 @@ import os ...@@ -7,36 +7,9 @@ import os
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
'''
BaseLM config dataclass:
model_config = {
"model_class":
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
}
'''
@dataclass
class BaseLMConfig():
model_class: type
n_layer: int
n_head: int
hidden_dim: int
vocab_dim: int
eps: float
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense. def init_weights(model, n_layer):
class BaseLM(nn.Module): for module in model.modules():
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
self.config = config
self.lm = lm
self.model_class = None
def init_weights(self):
for module in self.lm.modules():
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02) module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None: if module.bias is not None:
...@@ -51,44 +24,37 @@ class BaseLM(nn.Module): ...@@ -51,44 +24,37 @@ class BaseLM(nn.Module):
for name, p in module.named_parameters(): for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name: if ("ff2" in name or "out_proj" in name) and "weight" in name:
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.config["n_layer"]))) p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * n_layer)))
@classmethod @classmethod
def init(cls, config): def init(model_class, config):
model = cls(config) model = model_class(config)
model.lm = model.model_class(**config)
model.init_weights() model.init_weights()
#make this modular later
return model return model
@classmethod @classmethod
def no_init(cls, config): def no_init(model_class, config):
model = cls(config) model = utils.no_init(lambda: model_class(config))
model.lm = utils.no_init(lambda: model.model_class(**config))
return model return model
@classmethod @classmethod
def load(cls, config, path=None, state_dict=None, strict=False): def load(config, model_class, path=None, state_dict=None, strict=False):
# I am kinda sad that we will not have a load function in lm object itself. # I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope. # might be better to add load functions -- actually nope.
if path: if path:
state_dict = utils.SplitCheckpoint(path, device="cuda") state_dict = utils.SplitCheckpoint(path, device="cuda")
model = cls(config) model= utils.no_init(lambda: model_class(**config))
model.lm = utils.no_init(lambda: model.model_class(**config)) model.load_state_dict(state_dict, strict=strict)
model.lm.load_state_dict(state_dict, strict=strict)
return model return model
def save(self, path): def save(model, path):
if self.lm is None:
print("No LM object to save. Please first init a model.")
return
try: os.mkdir(path) try: os.mkdir(path)
except: pass except: pass
checkpoint = {} checkpoint = {}
for i, x in enumerate(self.lm.state_dict().items()): for i, x in enumerate(model.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt" checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt") torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt") torch.save(checkpoint, f"{path}/m.pt")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment