Commit 074d25c5 authored by novelailab's avatar novelailab

sampling working, gpt-neo(x) init

parent 9223ef70
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer import lm_base
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
cumsum = x.cumsum(dim = 1)
*x, x_pass = x.chunk(amt + 1, dim = -1)
*x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)
amts = 2 ** torch.arange(amt)
amts = amts.tolist()
shifts = []
denom = torch.arange(n, device = device)
for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
normed_shifted_x = shifted_chunk / (shifted_denom + eps)
shifts.append(normed_shifted_x)
return torch.cat((*shifts, x_pass), dim = -1)
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPTNeoLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
nn.Module.__init__(self)
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ln_postattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
residual = residual + attn_out
x = self.ln_postattn(x)
ff_out = self.ff(x, act_ck)
x = residual + ff_out
return x
class GPTNeoModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTNeoLayer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
class GPTNeoBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTNeoModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTNeoBaseLM.load(config, path, state_dict)
return model
from typing import KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer import lm_base
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
offset = kv[0].shape[-2]
else:
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, :self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key = torch.cat([k, key], dim=-2) # cat key
value = torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPTNeoxLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
nn.Module.__init__(self)
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x, kv
class GPTNeoXModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTNeoXLayer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(x)
if kv:
return x.float(), kv
else:
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
kv = [None] * self.n_layer
kv_new = []
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
kv_new.append(kvi)
x = self.ln_final(x)
if cache:
return x, kv_new
else:
return x, None
class GPTNeoXBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTNeoXModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTNeoXBaseLM.load(config, path, state_dict)
return model
...@@ -94,7 +94,7 @@ class SplitCheckpoint(MutableMapping): ...@@ -94,7 +94,7 @@ class SplitCheckpoint(MutableMapping):
def copy(self): def copy(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device) return SplitCheckpoint(self.chkpt_dir, device=self.device)
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True, cuda_blocking=False): def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True, cuda_blocking=True):
precision = 'ns' precision = 'ns'
r_arr = np.empty([2, r]) # [0] = mean, [1] = std r_arr = np.empty([2, r]) # [0] = mean, [1] = std
if function: if function:
...@@ -104,9 +104,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True ...@@ -104,9 +104,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
n_arr = np.empty(n) n_arr = np.empty(n)
for k in range(n): for k in range(n):
start = time.perf_counter_ns() start = time.perf_counter_ns()
torch.cuda.synchronize() if cuda_blocking:
torch.cuda.synchronize()
func() func()
torch.cuda.synchronize() if cuda_blocking:
torch.cuda.synchronize()
n_arr[k] = time.perf_counter_ns() - start n_arr[k] = time.perf_counter_ns() - start
if not first: if not first:
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment