Commit da39346c authored by novelailab's avatar novelailab

initial commit

parent 9287a251
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
#TODO: Might change with non einsum functions?
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), axis=-1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
class FeedForward(nn.Module):
def __init__(self, dim=768, hidden_dim=768*4, activation=nn.GELU):
self.ff1 = nn.Linear(dim, hidden_dim)
self.ff2 = nn.Linear(hidden_dim, dim)
self.activation = activation()
def forward(self, x):
x = self.ff1(x)
x = self.activation(x)
x = self.ff2(x)
return x
def _split_heads(self, tensor, num_heads, attn_head_size, rotary):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head):
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False))
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias)
self.rotary_dim = self.head_dim
# TODO: handle rotary
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x):
query = self.q_proj(x)
key = self.k_proj(x)
value = self.v_proj(x)
query = _split_heads(query, self.n_head, self.head_dim, True)
key = _split_heads(key, self.n_head, self.head_dim, True)
value = _split_heads(value, self.n_head, self.head_dim, False)
offset = 0
key = self.apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = self.apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, self.scale_attn
)
x = _merge_heads(x, self.num_heads, self.head_dim)
x = self.out_proj(x)
return x # a, present, (attentions)
class GPTLayer(nn.Module):
def __init__(self, attn=SelfAttention, ff=FeedForward, hidden_dim=768, n_head=4, eps=1e-6, activation=nn.GELU):
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps)
#self.ln_postattn = nn.LayerNorm(hidden_dim, eps=eps)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head)
def forward(self, x, hypernetwork):
residual = x
x = self.ln_preattn(x)
if hypernetwork:
hyper_out = hypernetwork(x)
attn_out = self.attn(x)
ff_out = self.ff(x)
x = residual + attn_out + ff_out + (hyper_out if hyper_out is not None else 0)
return x
# Can access and change every module from here, as both Layer class and ff and attn classes are passed from GPTModel.
class GPTModel(nn.Module):
def __init__(self, hidden_dim=512, n_layer=12, n_head=4, vocab_dim=50400, eps=1e-4, activation=nn.GELU(), Layer=GPTLayer):
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps)
self.layers = nn.ModuleList([])
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation))
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere?
def forward(self, x, hypernetwork=None):
x = self.vocab_embed(x)
for layer in self.layers:
x = layer(x, hypernetwork)
x = self.ln_final(x)
return x
def load(self, path):
state_dict = torch.load(path)
self.load_state_dict(state_dict)
#TODO: Get SplitCheckpoint support
def save(self, path):
torch.save(self.state_dict(), path)
#TODO: Get SplitCheckpoint support
# TODO: Do we want to have the LM head as a seperate Class? Or just a function? I think we might be better off with a function here and maybe
# also for the self attention, we can just write a function that gets fed in the q, k, v.
class GPTLM(nn.Module):
def __init__(self):
return
def forward(self, x):
return
def load_gpt_j(path):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-4,
"activation": nn.GELU,
"Layer": GPTLayer
}
model = no_init(lambda: GPTModel(**config))
model.load(path)
return model
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment