Commit 8d44445e authored by novelailab's avatar novelailab

config dataclass, not sure about this structure

i am sad
parent a1b9e387
...@@ -14,6 +14,7 @@ import os ...@@ -14,6 +14,7 @@ import os
from pathlib import Path from pathlib import Path
import math import math
from basedformer import lm_base from basedformer import lm_base
from dataclasses import dataclass
def fixed_pos_embedding(dim=None, seq_len=None, x=None): def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None: if x is None:
...@@ -51,23 +52,23 @@ def _attn(query, key, value, causal_mask, masked_bias, ...@@ -51,23 +52,23 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later. # Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype): def __init__(self, config):
nn.Module.__init__(self) nn.Module.__init__(self)
max_positions = 2049 max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view( bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool() 1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4 self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim self.hidden_dim = config.hidden_dim
self.n_head = n_head self.n_head = config.n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float())) self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias) self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses. self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions) sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin) self.register_buffer("sin", sin)
self.register_buffer("cos", cos) self.register_buffer("cos", cos)
...@@ -129,11 +130,11 @@ class SelfAttention(nn.Module): ...@@ -129,11 +130,11 @@ class SelfAttention(nn.Module):
return x, None return x, None
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype): def __init__(self, config):
nn.Module.__init__(self) nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype) self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype) self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = activation self.activation = config.activation
def forward(self, x, act_ck=False): def forward(self, x, act_ck=False):
x = self.ff1(x) x = self.ff1(x)
...@@ -145,12 +146,11 @@ class FeedForward(nn.Module): ...@@ -145,12 +146,11 @@ class FeedForward(nn.Module):
return x return x
class GPTJLayer(nn.Module): class GPTJLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype): def __init__(self, attn, ff, config):
nn.Module.__init__(self) nn.Module.__init__(self)
self.hidden_dim = hidden_dim self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.type)
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype) self.ff = ff(config)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype) self.attn = attn(config)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None): def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
...@@ -190,16 +190,22 @@ class GPTJLayer(nn.Module): ...@@ -190,16 +190,22 @@ class GPTJLayer(nn.Module):
return x, kv return x, kv
class GPTJModel(nn.Module): class GPTJModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTJLayer, device="cuda", dtype=torch.float16, **kwargs): def __init__(self, config, **kwargs):
nn.Module.__init__(self) nn.Module.__init__(self)
self.n_layer = n_layer self.n_layer = config.n_layer
self.hidden_dim = hidden_dim self.hidden_dim = config.hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype) self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype) self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True) self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for _ in range(n_layer): for _ in range(config.n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype)) self.layers.append(
config.Layer(
attn=SelfAttention,
ff=FeedForward,
config=config,
)
)
def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False): def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache) x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
...@@ -238,6 +244,22 @@ class GPTJModel(nn.Module): ...@@ -238,6 +244,22 @@ class GPTJModel(nn.Module):
else: else:
return x, None return x, None
@dataclass
class GPTJConfig:
n_layer: int = 6
n_head: int = 8
hidden_dim: int = 512
vocab_dim: int = 50400
eps: float = 1e-5
device: torch.device = torch.device('cuda')
dtype: torch.dtype = torch.float16
Layer = GPTJLayer
activation = gelu_new
def from_dict(self, config_dict):
for k, v in config_dict.items():
setattr(self, k, v)
class GPTJBaseLM(lm_base.BaseLM): class GPTJBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None): def __init__(self, config=None, lm=None):
nn.Module.__init__(self) nn.Module.__init__(self)
...@@ -252,6 +274,6 @@ def load_gpt_j(path="models/6b", state_dict=None): ...@@ -252,6 +274,6 @@ def load_gpt_j(path="models/6b", state_dict=None):
"vocab_dim": 50400, "vocab_dim": 50400,
"eps": 1e-5 "eps": 1e-5
} }
config = DotMap(config) config = GPTJConfig(**config)
model = GPTJBaseLM.load(config, path, state_dict) model = GPTJBaseLM.load(config, path, state_dict)
return model return model
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment