Commit 482a7bae authored by novelailab's avatar novelailab

update

parent b39e6d1b
......@@ -28,7 +28,7 @@ def init_weights(model, n_layer):
def init(model_class, config):
model = model_class(config)
model.init_weights()
init_weights(model, config["n_layer"])
return model
def no_init(model_class, config):
......
......@@ -2,12 +2,15 @@ from . import gptj
from . import gpt2
from . import fairseq
from . import gptneo
from . import alibi
from . import fast
MODEL_MAP = {
"gptj": gptj.GPTJModel,
"gpt2": gpt2.GPT2Model,
"gpt-fairseq": fairseq.GPTFairModel,
"gpt-neo": gptneo.GPTNeoModel
"gpt-neo": gptneo.GPTNeoModel,
"alibi": alibi.AlibiModel,
}
def get_model(model_name: str):
......
from typing import Callable, KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer.models import base_lm
def get_slopes(n):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n) #In the paper, we only train models that have 2^a heads for some a. This function has
else: #some good properties that only occur when the input is a power of 2. To maintain that even
closest_power_of_2 = 2**math.floor(math.log2(n)) #when the number of heads is not a power of 2, we use this workaround.
return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.register_buffer("slopes", torch.Tensor(get_slopes(config.n_head)))
#In the next line, the part after the * is what constructs the diagonal matrix (right matrix in Figure 3 in the paper).
#If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3, but one where all rows are identical.
#This works because the softmax operation is invariant to translation, and our bias functions are always linear.
print(self.slopes.shape)
self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_positions).unsqueeze(0).unsqueeze(0).expand(config.n_head, -1, -1)
self.alibi = self.alibi.view(config.n_head, 1, max_positions)
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
self.q_only = config.q_only
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
if config.q_only:
self.k_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if self.q_only:
key = self.k_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
else:
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
offset = kv[0].shape[-2]
else:
offset = 0
self.alibi = self.alibi.repeat(B, 1, 1) # batch_size, 1, 1
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key = torch.cat([k, key], dim=-2) # cat key
value = torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
print(causal_mask.shape)
print(self.alibi.shape)
x = _attn(
query, key, value, causal_mask+self.alibi, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class AlibiLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self.ff = ff(config)
self.attn = attn(config)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv, cache)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x, kv
class AlibiModel(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
self.default_config = {
'n_layer': 6,
'n_head': 8,
'n_tokens': 2048,
'hidden_dim': 512,
'vocab_dim': 50400,
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': AlibiLayer,
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
}
base_lm.BaseModel.__init__(self, user_config, **kwargs)
This diff is collapsed.
......@@ -59,12 +59,18 @@ class SelfAttention(nn.Module):
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
self.q_only = config.q_only
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
if config.q_only:
self.k_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
......@@ -76,8 +82,12 @@ class SelfAttention(nn.Module):
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if self.q_only:
key = self.k_proj(x).view(B, S, 1, self.head_dim)
value = self.v_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
else:
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
offset = kv[0].shape[-2]
......@@ -147,6 +157,7 @@ class GPTJLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self.ff = ff(config)
self.attn = attn(config)
self.tick = True
......
This diff is collapsed.
......@@ -5,7 +5,7 @@ from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
from basedformer.hypernet import *
from basedformer.models.hypernet import *
import sys
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
......
from basedformer import models, utils
import torch
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
}
config = {
"n_layer": 40,
"n_head": 40,
"hidden_dim": 5120,
}
config_q = {**config, "q_only":True}
#init param matched GPT
gpt = models.fairseq.GPTFairModel(config).cuda().half()
utils.print_parameters(gpt)
bsz = 3
cached_seq = 1000
y = torch.randint(0, 50256, (bsz, cached_seq)).long().cuda()
x = torch.randint(0, 50256, (bsz, 1)).long().cuda()
cache_f = torch.rand(bsz, config["n_head"], cached_seq, config["hidden_dim"]//config["n_head"]).cuda().half()
cache_f = (cache_f, cache_f)
cache_f = [cache_f for _ in range(config["n_layer"])]
print(len(cache_f))
print(cache_f[0][1].shape)
######
cache_q = torch.rand(bsz, 1, cached_seq, config["hidden_dim"]//config["n_head"]).cuda().half()
cache_q = (cache_q, cache_q)
cache_q = [cache_q for _ in range(config["n_layer"])]
print(cache_q[0][0].shape)
with torch.no_grad():
#print("Initial Context GPT:")
#utils.timeit(func=lambda: gpt(y), r=10, n=10)
out = gpt(y, cache=True)
print(out[1][0][0].shape)
print("GPT")
utils.timeit(func=lambda: gpt(x, kv=cache_f), r=10, n=10)
'''
del gpt
#init param matched Q-Only
gpt_q = models.gptj.GPTJModel(config_q).cuda().half()
utils.print_parameters(gpt_q)
with torch.no_grad():
#print("Initial Context GPT-Q:")
#utils.timeit(func=lambda: gpt_q(y), r=10, n=10)
out_q = gpt_q(y, cache=True)
print("GPT-Q:")
utils.timeit(func=lambda: gpt_q(x, kv=cache_q), r=10, n=10)
'''
\ No newline at end of file
......@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from torch.utils import data
from basedformer import optimizer, utils, gptj, noemblm, gpt2
from basedformer import optimizer, utils, models, lm_utils
import yaml
import sys
from tqdm import tqdm
......@@ -14,12 +14,17 @@ import wandb
import numpy as np
import os
def softmax_activation(x):
return F.log_softmax(x, dim=-1)
model_config = {
"n_layer": 3,
"n_head": 16,
"hidden_dim": 1024,
"n_layer": 12,
"n_head": 12,
"hidden_dim": 768,
"vocab_dim": 50400,
"eps": 1e-5,
"q_only": True,
"activation": torch.nn.GELU(),
}
# we need 250 batch size to train the small GPT.
......@@ -29,9 +34,9 @@ train_config = {
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2gptj-nopos-tokenshift-superhighlr-residualgate-3L-16A-1024H",
"do_save": False,
"run_name": "gptj-nopos-tokenshift-superhighlr-owt2-72M-3L-16A-1024H-fp16AMP-512ctx-16bs-1e-4lrinit",
"lr": 1e-3,
"end_lr": 1e-3,
"run_name": "gptj-owt2-512ctx-12L-12H-768H-16bs-1e-4lr-q-only-smallattneveryotherlayer",
"lr": 1e-4,
"end_lr": 1e-4,
"warmup_steps": 100,
"bs": 16,
"gas": 1,
......@@ -47,7 +52,8 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = noemblm.GPTNEBaseLM.init(model_config).cuda().float()
model = lm_utils.init(models.fast.GPTJModel, model_config).cuda().float()
utils.print_parameters(model)
model.train()
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
......@@ -87,7 +93,7 @@ for input_ids, labels in t:
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model.lm(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False)
logits = model(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#roll down the sequence
logits = logits.view(-1, logits.shape[-1])
......@@ -117,7 +123,7 @@ for input_ids, labels in t:
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 1024) * bs * gas
tokens_per_sec = (step_per_sec * 512) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log(
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment