Commit fa2a771c authored by novelailab's avatar novelailab

update gpt2 to base_lm

parent 835709ca
......@@ -11,7 +11,7 @@ except ImportError:
import os
from pathlib import Path
import math
from basedformer import lm_base
from basedformer.models import base_lm
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
......@@ -147,43 +147,20 @@ class GPT2Layer(nn.Module):
return x
class GPT2Model(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPT2Layer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
class GPT2BaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPT2Model
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPT2BaseLM.load(config, path, state_dict)
return model
class GPTJModel(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
self.default_config = {
'n_layer': 6,
'n_head': 8,
'n_tokens': 2048,
'hidden_dim': 512,
'vocab_dim': 50400,
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': GPTJLayer,
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
}
base_lm.BaseModel.__init__(self, user_config, **kwargs)
\ No newline at end of file
......@@ -13,11 +13,7 @@ except ImportError:
import os
from pathlib import Path
import math
from basedformer import lm_utils
from basedformer.models import base_lm
from dataclasses import dataclass
#import dotmap
from dotmap import DotMap
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment