Commit 482a7bae authored by novelailab's avatar novelailab

update

parent b39e6d1b
...@@ -28,7 +28,7 @@ def init_weights(model, n_layer): ...@@ -28,7 +28,7 @@ def init_weights(model, n_layer):
def init(model_class, config): def init(model_class, config):
model = model_class(config) model = model_class(config)
model.init_weights() init_weights(model, config["n_layer"])
return model return model
def no_init(model_class, config): def no_init(model_class, config):
......
...@@ -2,12 +2,15 @@ from . import gptj ...@@ -2,12 +2,15 @@ from . import gptj
from . import gpt2 from . import gpt2
from . import fairseq from . import fairseq
from . import gptneo from . import gptneo
from . import alibi
from . import fast
MODEL_MAP = { MODEL_MAP = {
"gptj": gptj.GPTJModel, "gptj": gptj.GPTJModel,
"gpt2": gpt2.GPT2Model, "gpt2": gpt2.GPT2Model,
"gpt-fairseq": fairseq.GPTFairModel, "gpt-fairseq": fairseq.GPTFairModel,
"gpt-neo": gptneo.GPTNeoModel "gpt-neo": gptneo.GPTNeoModel,
"alibi": alibi.AlibiModel,
} }
def get_model(model_name: str): def get_model(model_name: str):
......
from typing import Callable, KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer.models import base_lm
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n) #In the paper, we only train models that have 2^a heads for some a. This function has
else: #some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2 = 2**math.floor(math.log2(n)) #when the number of heads is not a power of 2, we use this workaround.
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.register_buffer("slopes", torch.Tensor(get_slopes(config.n_head)))
#In the next line, the part after the * is what constructs the diagonal matrix (right matrix in Figure 3 in the paper).
#If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3, but one where all rows are identical.
#This works because the softmax operation is invariant to translation, and our bias functions are always linear.
print(self.slopes.shape)
self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_positions).unsqueeze(0).unsqueeze(0).expand(config.n_head, -1, -1)
self.alibi = self.alibi.view(config.n_head, 1, max_positions)
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
self.q_only = config.q_only
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
if config.q_only:
self.k_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if self.q_only:
key = self.k_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
else:
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
offset = kv[0].shape[-2]
else:
offset = 0
self.alibi = self.alibi.repeat(B, 1, 1) # batch_size, 1, 1
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key = torch.cat([k, key], dim=-2) # cat key
value = torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
print(causal_mask.shape)
print(self.alibi.shape)
x = _attn(
query, key, value, causal_mask+self.alibi, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class AlibiLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self.ff = ff(config)
self.attn = attn(config)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv, cache)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x, kv
class AlibiModel(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
self.default_config = {
'n_layer': 6,
'n_head': 8,
'n_tokens': 2048,
'hidden_dim': 512,
'vocab_dim': 50400,
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': AlibiLayer,
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
}
base_lm.BaseModel.__init__(self, user_config, **kwargs)
from typing import Callable, KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer.models import base_lm
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
def token_shift_fast(x, n_tokens=1):
size = x.size()[-1] // (n_tokens + 1)
seq_len = x.size()[-2]
padded_x = nn.functional.pad(x[:, :, -size:], (0, 0, n_tokens, 0))
token_shifts = [padded_x[:, offset:(offset + seq_len)] for offset in range(n_tokens)]
current_x = [x[:, :, len(token_shifts) * size:]]
x = torch.cat(token_shifts + current_x, dim=-1)
return x
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config, small_attn=False):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
self.q_only = config.q_only
self.small_attn = small_attn
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
if config.q_only:
self.k_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
if small_attn:
self.q_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
if small_attn:
self.out_proj = nn.Linear(self.head_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
if self.small_attn:
query = self.q_proj(x).view(B, S, 1, self.head_dim)
else:
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
if self.q_only:
key = self.k_proj(x).view(B, S, 1, self.head_dim)
value = self.v_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
else:
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
offset = kv[0].shape[-2]
else:
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, :self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key = torch.cat([k, key], dim=-2) # cat key
value = torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.config = config
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
if self.config.token_shift:
x = token_shift_fast(x)
x = self.ff2(x)
return x
class GPTJLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.config = config
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self.ff = ff(config)
self.attn = attn(config)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv, cache)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x, kv
class GPTJnoattnLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.config = config
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = ff(config)
self.attn = attn(config, small_attn=True)
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv, cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
return x, kv
class GPTJModel(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
nn.Module.__init__(self)
self.default_config = {
'n_layer': 6,
'n_head': 8,
'n_tokens': 2048,
'hidden_dim': 512,
'vocab_dim': 50400,
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': GPTJLayer,
'AlternateLayer': GPTJnoattnLayer,
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
}
#configuration
self.user_config = user_config
self.config = self.configure_model()
config = self.config
#modeling
self.n_layer = config.n_layer
self.hidden_dim = config.hidden_dim
self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for i in range(config.n_layer // 2):
config.layer_idx = i
self.layers.append(
config.Layer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
)
)
self.layers.append(
config.AlternateLayer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
)
)
...@@ -59,12 +59,18 @@ class SelfAttention(nn.Module): ...@@ -59,12 +59,18 @@ class SelfAttention(nn.Module):
self.rotary_dim = self.head_dim // 4 self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim self.hidden_dim = config.hidden_dim
self.n_head = config.n_head self.n_head = config.n_head
self.q_only = config.q_only
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float())) self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias) self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses. self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype) if config.q_only:
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype) self.k_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype) self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype) self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions) sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
...@@ -76,8 +82,12 @@ class SelfAttention(nn.Module): ...@@ -76,8 +82,12 @@ class SelfAttention(nn.Module):
# split heads into: [batch, head, sequence, head_dim] # split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d] # transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim) query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim) if self.q_only:
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2) key = self.k_proj(x).view(B, S, 1, self.head_dim)
value = self.v_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
else:
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv: if kv:
offset = kv[0].shape[-2] offset = kv[0].shape[-2]
...@@ -147,6 +157,7 @@ class GPTJLayer(nn.Module): ...@@ -147,6 +157,7 @@ class GPTJLayer(nn.Module):
def __init__(self, attn, ff, config): def __init__(self, attn, ff, config):
nn.Module.__init__(self) nn.Module.__init__(self)
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype) self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self.ff = ff(config) self.ff = ff(config)
self.attn = attn(config) self.attn = attn(config)
self.tick = True self.tick = True
......
from typing import Callable, KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer.models import base_lm
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class CrossAttentionMod(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x, y, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2) #latent query
key = self.k_proj(y).view(B, S, self.n_head, self.head_dim).transpose(1, 2) #context key
value = self.v_proj(y).view(B, S, self.n_head, self.head_dim).transpose(1, 2) #context value
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
self.q_only = config.q_only
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
if config.q_only:
self.k_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
if self.q_only:
key = self.k_proj(x).view(B, S, 1, self.head_dim)
value = self.v_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
else:
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
offset = kv[0].shape[-2]
else:
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, :self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key = torch.cat([k, key], dim=-2) # cat key
value = torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class PerceiverARLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = ff(config)
self.attn = attn(config)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv, cache)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x, kv
class PerceiverARModel(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
self.default_config = {
'n_layer': 6,
'n_head': 8,
'n_tokens': 2048,
'hidden_dim': 512,
'vocab_dim': 50400,
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': PerceiverARLayer,
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
}
base_lm.BaseModel.__init__(self, user_config, **kwargs)
...@@ -5,7 +5,7 @@ from time import perf_counter, perf_counter_ns ...@@ -5,7 +5,7 @@ from time import perf_counter, perf_counter_ns
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager from contextlib import contextmanager
from basedformer.hypernet import * from basedformer.models.hypernet import *
import sys import sys
#replicating timeit magic function of ipython #replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True): def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
......
from basedformer import models, utils
import torch
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
}
config = {
"n_layer": 40,
"n_head": 40,
"hidden_dim": 5120,
}
config_q = {**config, "q_only":True}
#init param matched GPT
gpt = models.fairseq.GPTFairModel(config).cuda().half()
utils.print_parameters(gpt)
bsz = 3
cached_seq = 1000
y = torch.randint(0, 50256, (bsz, cached_seq)).long().cuda()
x = torch.randint(0, 50256, (bsz, 1)).long().cuda()
cache_f = torch.rand(bsz, config["n_head"], cached_seq, config["hidden_dim"]//config["n_head"]).cuda().half()
cache_f = (cache_f, cache_f)
cache_f = [cache_f for _ in range(config["n_layer"])]
print(len(cache_f))
print(cache_f[0][1].shape)
######
cache_q = torch.rand(bsz, 1, cached_seq, config["hidden_dim"]//config["n_head"]).cuda().half()
cache_q = (cache_q, cache_q)
cache_q = [cache_q for _ in range(config["n_layer"])]
print(cache_q[0][0].shape)
with torch.no_grad():
#print("Initial Context GPT:")
#utils.timeit(func=lambda: gpt(y), r=10, n=10)
out = gpt(y, cache=True)
print(out[1][0][0].shape)
print("GPT")
utils.timeit(func=lambda: gpt(x, kv=cache_f), r=10, n=10)
'''
del gpt
#init param matched Q-Only
gpt_q = models.gptj.GPTJModel(config_q).cuda().half()
utils.print_parameters(gpt_q)
with torch.no_grad():
#print("Initial Context GPT-Q:")
#utils.timeit(func=lambda: gpt_q(y), r=10, n=10)
out_q = gpt_q(y, cache=True)
print("GPT-Q:")
utils.timeit(func=lambda: gpt_q(x, kv=cache_q), r=10, n=10)
'''
\ No newline at end of file
...@@ -5,7 +5,7 @@ import torch.cuda.amp as amp ...@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
import torch.optim as optim import torch.optim as optim
from pathlib import Path from pathlib import Path
from torch.utils import data from torch.utils import data
from basedformer import optimizer, utils, gptj, noemblm, gpt2 from basedformer import optimizer, utils, models, lm_utils
import yaml import yaml
import sys import sys
from tqdm import tqdm from tqdm import tqdm
...@@ -14,12 +14,17 @@ import wandb ...@@ -14,12 +14,17 @@ import wandb
import numpy as np import numpy as np
import os import os
def softmax_activation(x):
return F.log_softmax(x, dim=-1)
model_config = { model_config = {
"n_layer": 3, "n_layer": 12,
"n_head": 16, "n_head": 12,
"hidden_dim": 1024, "hidden_dim": 768,
"vocab_dim": 50400, "vocab_dim": 50400,
"eps": 1e-5, "eps": 1e-5,
"q_only": True,
"activation": torch.nn.GELU(),
} }
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
...@@ -29,9 +34,9 @@ train_config = { ...@@ -29,9 +34,9 @@ train_config = {
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map", #"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshift-superhighlr-residualgate-3L-16A-1024H", "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshift-superhighlr-residualgate-3L-16A-1024H",
"do_save": False, "do_save": False,
"run_name": "gptj-nopos-tokenshift-superhighlr-owt2-72M-3L-16A-1024H-fp16AMP-512ctx-16bs-1e-4lrinit", "run_name": "gptj-owt2-512ctx-12L-12H-768H-16bs-1e-4lr-q-only-smallattneveryotherlayer",
"lr": 1e-3, "lr": 1e-4,
"end_lr": 1e-3, "end_lr": 1e-4,
"warmup_steps": 100, "warmup_steps": 100,
"bs": 16, "bs": 16,
"gas": 1, "gas": 1,
...@@ -47,7 +52,8 @@ gas = train_config["gas"] ...@@ -47,7 +52,8 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True) Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float() #model = GPTModel.gpt2_init(model_config).cuda().float()
model = noemblm.GPTNEBaseLM.init(model_config).cuda().float() model = lm_utils.init(models.fast.GPTJModel, model_config).cuda().float()
utils.print_parameters(model)
model.train() model.train()
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1])) cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
...@@ -87,7 +93,7 @@ for input_ids, labels in t: ...@@ -87,7 +93,7 @@ for input_ids, labels in t:
loss = 0 loss = 0
for x in range(train_config["gas"]): for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16): with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model.lm(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False) logits = model(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0])) #print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#roll down the sequence #roll down the sequence
logits = logits.view(-1, logits.shape[-1]) logits = logits.view(-1, logits.shape[-1])
...@@ -117,7 +123,7 @@ for input_ids, labels in t: ...@@ -117,7 +123,7 @@ for input_ids, labels in t:
opt.zero_grad() opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 1024) * bs * gas tokens_per_sec = (step_per_sec * 512) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}") t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log( wandb.log(
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment