Commit 40e90836 authored by novelailab's avatar novelailab

fix fairseq by taking out eot and newline mapping

parent aa6444e9
......@@ -171,7 +171,7 @@ class GPTNeoModel(base_lm.BaseModel):
base_lm.BaseModel.__init__(self, user_config, **kwargs)
self.pos_embed = nn.Embedding(self.config.n_tokens, self.config.hidden_dim)
self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False)
#bias=False for fairseq models
#bias=False for neo models
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
......
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer.models import base_lm
from typing import Optional, Any
def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(torch.tensor(50257, requires_grad=False)).int()
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx if padding_idx is not None else 0
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
)
self.onnx_trace = False
self.register_buffer("_float_tensor",
torch.tensor(1.0, requires_grad=False).float())
self.max_positions = int(1e5)
# print(embedding_dim, padding_idx, init_size)
def prepare_for_onnx_export_(self):
self.onnx_trace = True
@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int,
padding_idx: Optional[int] = None
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(
self,
input,
incremental_state: Optional[Any] = None,
timestep: Optional[torch.Tensor] = None,
positions: Optional[Any] = None,
offset: Optional[int] = 0
):
"""Input is expected to be of size [bsz x seqlen]."""
bspair = input.shape
bsz, seq_len = bspair[0], bspair[1]
max_pos = self.padding_idx + 1 + seq_len + offset
# print("max_pos: " + str(max_pos))
if self.weights is None or max_pos > self.weights.size(0):
# print("recomputing embeddings")
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx + offset
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
if self.onnx_trace:
return (
self.weights.index_select(
index=self.padding_idx + pos + offset, dim=0)
.unsqueeze(1)
.repeat(bsz, 1, 1)
)
return self.weights[self.padding_idx + pos + offset, :].expand(bsz,
1,
-1)
positions = make_positions(
input, self.padding_idx + offset, onnx_trace=self.onnx_trace
)
if self.onnx_trace:
flat_embeddings = self.weights.detach().index_select(0,
positions.view(
-1))
embedding_shape = torch.cat(
(bsz.view(1), seq_len.view(1),
torch.tensor([-1], dtype=torch.long))
)
embeddings = torch.onnx.operators.reshape_from_tensor_shape(
flat_embeddings, embedding_shape
)
return embeddings
return (
self.weights.index_select(0, positions.view(-1))
.view(bsz, seq_len, -1)
.detach()
)
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
):
m = SinusoidalPositionalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings + padding_idx + 1,
)
return m
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None, fp32_attn=True):
if fp32_attn:
attn_weights = torch.matmul(query.float(), key.transpose(-1, -2).float())
else:
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config):
nn.Module.__init__(self)
self.config = config
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
device = config.device
dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = True #fairseq has attn_bias
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn, self.config.fp32_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim * 4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim * 4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPTFairLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.hidden_dim = config.hidden_dim
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = ff(config)
self.attn = attn(config)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
x = residual + attn_out
residual = x
x = self.ln_postattn(x)
ff_out = self.ff(x, act_ck)
x = residual + ff_out
return x, kv
class GPTFairModel(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
self.default_config = {
'n_layer': 6,
'n_head': 8,
'n_tokens': 2049,
'hidden_dim': 512,
'vocab_dim': 50400,
'fp32_attn': True, #fairseq models are trained with fp32 attn
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': GPTFairLayer,
'activation': F.gelu,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
}
base_lm.BaseModel.__init__(self, user_config, **kwargs)
# returns sinusoidal embeddings of shape: (1, n_tokens, 768)
self.register_buffer("embed_scale", torch.sqrt(torch.tensor(self.config.hidden_dim, requires_grad=False)))
self.pos_embed = PositionalEmbedding(self.config.n_tokens, self.config.hidden_dim, 1)
self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False)
#bias=False for fairseq models
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
kv = [None] * self.n_layer
past_length = 0
else:
past_length = kv[0][0].size(-2) #get sequence dim of key
kv_new = []
position_embeds = self.pos_embed(x, offset=past_length)
input_embeds = self.vocab_embed(x) * self.embed_scale
x = position_embeds + input_embeds
for layer_id, layer in enumerate(self.layers):
x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
kv_new.append(kvi)
x = self.ln_final(x)
if cache:
return x, kv_new
else:
return x, None
\ No newline at end of file
......@@ -14,7 +14,7 @@ bash = False
config_obj = KubeConfig()
config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.RTX_A6000, amount=1)
config_obj.set_ram(16)
config_obj.set_ram(64)
config_obj.set_cpu(4)
config_obj.dry_run(dry)
config_obj.print_information()
......@@ -31,18 +31,23 @@ if False:
env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl')
env1.sh('pip install einops numpy')
env1.sh('pip install tqdm')
env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
#env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
env1.sh('pip3 install einops==0.4.1 pyyaml wandb')
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap icecream')
path.sh("pip3 install --editable .")
#path.sh("pip3 uninstall torch")
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
with always_rerun():
if True:
#env1.sh('pip3 uninstall transformers')
#env1.sh('pip3 install transformers')
path.sh("python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gpt-neo-125m-ported --device 0 --tasks lambada --no_cache")
#path.sh('pip3 install --editable ../lm-evaluation-harness/.')
#env1.sh('pip3 install pytest')
#env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl')
path.sh('pip3 uninstall huggingface_hub')
path.sh('pip3 install huggingface_hub')
#path.sh('pip3 uninstall transformers')
#path.sh('pip3 install transformers')
#path.sh("python3 ../lm-evaluation-harness/main.py --model gpt2 --batch_size 8 --model_args pretrained=EleutherAI/gpt-neo-125M --device 0 --tasks lambada --no_cache")
path.sh("python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/fairseq_125m --device 0 --tasks lambada --no_cache")
#path.sh("python3 ../lm-evaluation-harness/main.py --batch_size 8")
else:
......
......@@ -11,67 +11,15 @@ from contextlib import contextmanager
import torch.nn.functional as F
from transformers import GPTNeoForCausalLM
from icecream import ic
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
precision = 'ns'
r_arr = np.empty([2, r]) # [0] = mean, [1] = std
if function:
func.__name__ = function.__name__
for i in tqdm(range(r)) if do_tqdm else range(r):
n_arr = np.empty(n)
for k in range(n):
start = perf_counter_ns()
func()
n_arr[k] = perf_counter_ns() - start
if not first:
# delete the first element from n_arr numpy array
n_arr = np.delete(n_arr, 0)
r_arr[0, i] = np.mean(n_arr)
r_arr[1, i] = np.std(n_arr)
best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
#check if best[0] bigger than 1ms in numpy
if best[0] < 1e3:
precision = 'ns'
elif best[0] >= 1e9:
print('b')
best[0] = best[0] * 1e-9
best[1] = best[1] * 1e-9
precision = 's'
elif best[0] >= 1e6:
best[0] = best[0] * 1e-6
best[1] = best[1] * 1e-6
precision = 'ms'
elif best[0] >= 1e3:
precision = 'μs'
best[0] = best[0] * 1e-3
best[1] = best[1] * 1e-3
if not quiet:
if precision == 'ns':
print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
if precision == 'μs':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 'ms':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 's':
print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
with torch.no_grad():
model_dir = '/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/hf_125m/'
model_dir = '/home/xuser/diffusionstorage/models/fairseq/converted/en_dense_lm_125m/'
hf_model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(model_dir)).cuda().half().eval()
print("Loaded hf model")
path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/fairseq_125m"
based_model = lmu.load_from_path(path).cuda().half().eval()
print("Loaded based model")
x = torch.randint(0, 50256, (1, 2048)).cuda().long()
x = torch.randint(0, 51200, (1, 300)).cuda().long()
assert torch.allclose(hf_model.transformer.wte(x), based_model.vocab_embed(x))
hidden = hf_model.transformer.wte(x)
......@@ -85,7 +33,7 @@ with torch.no_grad():
ic(hf_model.transformer.h[layer].attn(hidden)[0].abs().mean())
ic(based_model.layers[layer].attn(hidden)[0].abs().mean())
ic((hf_model.transformer.h[layer].attn(hidden)[0] - based_model.layers[layer].attn(hidden)[0]).abs().mean())
#assert torch.allclose(hf_model.transformer.h[layer].attn(hidden)[0], based_model.layers[layer].attn(hidden)[0], rtol=1e-6)
assert torch.allclose(hf_model.transformer.h[layer].attn(hidden)[0], based_model.layers[layer].attn(hidden)[0], rtol=1e-6)
attn_out = hf_model.transformer.h[layer].attn(hidden)[0]
hidden = residual + attn_out
residual = hidden
......
......@@ -110,12 +110,12 @@ with torch.no_grad():
wte = fairdict["decoder.embed_tokens.weight"].clone()
for i in range(50260):
wte[mapping[i]] = fairdict["decoder.embed_tokens.weight"][i]
hack_embs(wte)
#hack_embs(wte)
save(wte.half(), "vocab_embed.weight")
lm_head = fairdict["decoder.output_projection.weight"].clone()
for i in range(50260):
lm_head[mapping[i]] = fairdict["decoder.output_projection.weight"][i]
hack_embs(lm_head)
#hack_embs(lm_head)
save(lm_head.half(), "lm_head.weight")
save(torch.FloatTensor(1), "pos_embed._float_tensor")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment