Commit aa6444e9 authored by novelailab's avatar novelailab

gpt neo working

parent 88f0a90e
This source diff could not be displayed because it is too large. You can view the blob instead.
from . import gptj from . import gptj
from . import gpt2 from . import gpt2
from . import fairseq from . import fairseq
from . import gptneo
MODEL_MAP = { MODEL_MAP = {
"gptj": gptj.GPTJModel, "gptj": gptj.GPTJModel,
"gpt2": gpt2.GPT2Model, "gpt2": gpt2.GPT2Model,
"gpt-fairseq": fairseq.GPTFairModel "gpt-fairseq": fairseq.GPTFairModel,
"gpt-neo": gptneo.GPTNeoModel
} }
def get_model(model_name: str): def get_model(model_name: str):
......
...@@ -18,12 +18,12 @@ class BaseModel(nn.Module): ...@@ -18,12 +18,12 @@ class BaseModel(nn.Module):
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True) self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for i in range(config.n_layer): for i in range(config.n_layer):
config.layer_idx = i
self.layers.append( self.layers.append(
config.Layer( config.Layer(
attn=config.SelfAttention, attn=config.SelfAttention,
ff=config.FeedForward, ff=config.FeedForward,
config=config, config=config,
layer_idx=i,
) )
) )
......
...@@ -13,6 +13,7 @@ from pathlib import Path ...@@ -13,6 +13,7 @@ from pathlib import Path
import math import math
from basedformer.models import base_lm from basedformer.models import base_lm
from typing import Optional, Any from typing import Optional, Any
from icecream import ic
def _attn(query, key, value, causal_mask, masked_bias, def _attn(query, key, value, causal_mask, masked_bias,
...@@ -24,7 +25,8 @@ def _attn(query, key, value, causal_mask, masked_bias, ...@@ -24,7 +25,8 @@ def _attn(query, key, value, causal_mask, masked_bias,
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype)) attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn.to(attn_weights.dtype) if scale_attn:
attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
...@@ -38,14 +40,15 @@ def _attn(query, key, value, causal_mask, masked_bias, ...@@ -38,14 +40,15 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later. # Code copied from HF, might want to sanity check later.
def __init__(self, config, attention_type): def __init__(self, config, attn_type):
ic(attn_type)
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
max_positions = 2049 max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view( bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool() 1, 1, max_positions, max_positions).bool()
if attention_type == "local": if attn_type == "local":
self.register_buffer( self.register_buffer(
"bias", "bias",
bias ^ torch.tril(bias, -config.window_size), bias ^ torch.tril(bias, -config.window_size),
...@@ -63,13 +66,13 @@ class SelfAttention(nn.Module): ...@@ -63,13 +66,13 @@ class SelfAttention(nn.Module):
device = config.device device = config.device
dtype = config.dtype dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float())) self.scale_attn = None
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses. self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = True #fairseq has attn_bias attn_bias = False #fairseq has attn_bias
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=True, device=device, dtype=dtype)
def forward(self, x, kv=None, cache=False): def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim B, S, H = x.shape # batch, sequence, hidden_dim
...@@ -116,13 +119,13 @@ class FeedForward(nn.Module): ...@@ -116,13 +119,13 @@ class FeedForward(nn.Module):
return x return x
class GPTNeoLayer(nn.Module): class GPTNeoLayer(nn.Module):
def __init__(self, attn, ff, config, layer_idx): def __init__(self, attn, ff, config):
nn.Module.__init__(self) nn.Module.__init__(self)
self.hidden_dim = config.hidden_dim self.hidden_dim = config.hidden_dim
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype) self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype) self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = ff(config) self.ff = ff(config)
if layer_idx % 2 == 0: if config.layer_idx % 2 == 0:
attn_type = "global" attn_type = "global"
else: else:
attn_type = "local" attn_type = "local"
...@@ -154,8 +157,8 @@ class GPTNeoModel(base_lm.BaseModel): ...@@ -154,8 +157,8 @@ class GPTNeoModel(base_lm.BaseModel):
'n_head': 8, 'n_head': 8,
'n_tokens': 2049, 'n_tokens': 2049,
'hidden_dim': 512, 'hidden_dim': 512,
'vocab_dim': 50400, 'vocab_dim': 50257,
'fp32_attn': True, #fairseq models are trained with fp32 attn 'fp32_attn': False,
'eps': 1e-5, 'eps': 1e-5,
'device': torch.device('cuda'), 'device': torch.device('cuda'),
'dtype': torch.float16, 'dtype': torch.float16,
...@@ -165,29 +168,7 @@ class GPTNeoModel(base_lm.BaseModel): ...@@ -165,29 +168,7 @@ class GPTNeoModel(base_lm.BaseModel):
'FeedForward': FeedForward, 'FeedForward': FeedForward,
'window_size': 256, 'window_size': 256,
} }
def __init__(self, user_config, **kwargs): base_lm.BaseModel.__init__(self, user_config, **kwargs)
nn.Module.__init__(self)
#configuration
self.user_config = user_config
self.config = self.configure_model()
config = self.config
#modeling
self.n_layer = config.n_layer
self.hidden_dim = config.hidden_dim
self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for _ in range(config.n_layer):
self.layers.append(
config.Layer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
)
)
# returns sinusoidal embeddings of shape: (1, n_tokens, 768)
self.register_buffer("embed_scale", torch.sqrt(torch.tensor(self.config.hidden_dim, requires_grad=False)))
self.pos_embed = nn.Embedding(self.config.n_tokens, self.config.hidden_dim) self.pos_embed = nn.Embedding(self.config.n_tokens, self.config.hidden_dim)
self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False) self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False)
#bias=False for fairseq models #bias=False for fairseq models
...@@ -202,8 +183,10 @@ class GPTNeoModel(base_lm.BaseModel): ...@@ -202,8 +183,10 @@ class GPTNeoModel(base_lm.BaseModel):
kv_new = [] kv_new = []
position_ids = torch.arange(past_length, x[-1] + past_length, dtype=torch.long, device=x.device) position_ids = torch.arange(past_length,
position_ids = position_ids.unsqueeze(0).view(-1, x[-1]) x.shape[-1] + past_length,
dtype=torch.long, device=x.device)
position_ids = position_ids.unsqueeze(0).view(-1, x.shape[-1])
x = self.vocab_embed(x) x = self.vocab_embed(x)
x = x + self.pos_embed(position_ids) x = x + self.pos_embed(position_ids)
......
...@@ -40,9 +40,9 @@ if False: ...@@ -40,9 +40,9 @@ if False:
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113") #path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
with always_rerun(): with always_rerun():
if True: if True:
env1.sh('pip3 uninstall transformers') #env1.sh('pip3 uninstall transformers')
env1.sh('pip3 install transformers') #env1.sh('pip3 install transformers')
path.sh("python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/fairseq_125m --device 0 --tasks lambada --no_cache") path.sh("python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gpt-neo-125m-ported --device 0 --tasks lambada --no_cache")
#path.sh("python3 ../lm-evaluation-harness/main.py --batch_size 8") #path.sh("python3 ../lm-evaluation-harness/main.py --batch_size 8")
else: else:
......
import torch
import transformers
import sys
from icecream import ic
import os
from pathlib import Path
"""
Original:
ln_f.weight
ln_f.bias
wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias
attn has biases unlike GPT-J. QKV Matrices are also merged instead of separate. what is the order though? probably just QKV.
"""
x = torch.load("pretrained/gpt-neo-125m/pytorch_model.bin")
state_dict = x
ic(x.keys())
new_state_dict = {}
module_map = {
"ln_1": "ln_preattn",
"ln_2": "ln_postattn",
"mlp.c_proj": "ff.ff2",
"mlp.c_fc": "ff.ff1",
"attn.attention.q_proj": "attn.q_proj",
"attn.attention.k_proj": "attn.k_proj",
"attn.attention.v_proj": "attn.v_proj",
"attn.attention.out_proj": "attn.out_proj",
"wte": "vocab_embed",
"wpe": "pos_embed",
'ln_f': 'ln_final',
}
print(type(state_dict))
for key in state_dict.keys():
dotlist = key.split('.')
if len(dotlist) > 3:
layer = dotlist[2]
for x in module_map:
if x in key:
new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
else:
for x in module_map:
if x in key:
new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> {module_map[x]}.{dotlist[-1]}")
new_state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for k, v in new_state_dict.items():
print(f"{k} -> {v.shape}")
def save(state_dict, path):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
checkpoint = {}
for i, x in enumerate(state_dict.items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
save(new_state_dict, "pretrained/gpt-neo-125m-ported/lm")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment