Commit 1f8d04b2 authored by novelailab's avatar novelailab

implement cached generation

parent 3a52a64a
......@@ -30,34 +30,6 @@ def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
def _split_heads(tensor, num_heads, attn_head_size, rotary):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
......@@ -75,6 +47,54 @@ def _attn(query, key, value, causal_mask, masked_bias,
return attn_output
class CrossAttentionMod(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x, y, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2) #latent query
key = self.k_proj(y).view(B, S, self.n_head, self.head_dim).transpose(1, 2) #context key
value = self.v_proj(y).view(B, S, self.n_head, self.head_dim).transpose(1, 2) #context value
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
......@@ -98,14 +118,13 @@ class SelfAttention(nn.Module):
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x):
query = self.q_proj(x)
key = self.k_proj(x)
value = self.v_proj(x)
query = _split_heads(query, self.n_head, self.head_dim, True)
key = _split_heads(key, self.n_head, self.head_dim, True)
value = _split_heads(value, self.n_head, self.head_dim, False)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
offset = 0
if self.rotary_dim < self.head_dim:
......@@ -124,21 +143,32 @@ class SelfAttention(nn.Module):
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = _merge_heads(x, self.n_head, self.head_dim)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
return x
if cache:
return x, (key, value)
else:
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
......@@ -156,7 +186,7 @@ class FeedForward(nn.Module):
x = self.ff2(x)
return x
class CrossGPTLayer(nn.Module):
class GPTJLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
nn.Module.__init__(self)
self.hidden_dim = hidden_dim
......@@ -202,8 +232,8 @@ class CrossGPTLayer(nn.Module):
return x
class CrossGPTModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=CrossGPTLayer, device="cuda", dtype=torch.float16, **kwargs):
class GPTJModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTJLayer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
......@@ -226,11 +256,11 @@ class CrossGPTModel(nn.Module):
x = self.ln_final(x)
return x
class CrossGPTBaseLM(lm_base.BaseLM):
class GPTJBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=CrossGPTModel
self.model_class=GPTJModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
......@@ -240,5 +270,5 @@ def load_gpt_j(path="models/6b", state_dict=None):
"vocab_dim": 50400,
"eps": 1e-5
}
model = CrossGPTBaseLM.load(config, path, state_dict)
model = GPTJBaseLM.load(config, path, state_dict)
return model
from typing import KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -78,7 +79,11 @@ class SelfAttention(nn.Module):
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
offset = 0
if kv:
offset = kv[0].shape[-2]
else:
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
......@@ -103,8 +108,8 @@ class SelfAttention(nn.Module):
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
key = torch.cat([k, key], dim=-2) # cat key
value = torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
......@@ -120,7 +125,7 @@ class SelfAttention(nn.Module):
if cache:
return x, (key, value)
else:
return x
return x, None
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
......@@ -147,16 +152,15 @@ class GPTJLayer(nn.Module):
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5):
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
attn_out, kv = ck(self.attn, x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
if hypernetwork:
if diff_hypernets:
......@@ -182,7 +186,7 @@ class GPTJLayer(nn.Module):
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x
return x, kv
class GPTJModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTJLayer, device="cuda", dtype=torch.float16, **kwargs):
......@@ -196,17 +200,30 @@ class GPTJModel(nn.Module):
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
def forward(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(x)
return x.float()
if kv:
return x.float(), kv
else:
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False):
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
kv = [None] * self.n_layer
kv_new = []
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
kv_new.append(kvi)
x = self.ln_final(x)
return x
if cache:
return x, kv_new
else:
return x, None
class GPTJBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
......
......@@ -76,7 +76,7 @@ class BaseLM(nn.Module):
state_dict = utils.SplitCheckpoint(path, device="cuda")
model = cls(config)
model.lm = model.model_class(**config)
model.lm = utils.no_init(lambda: model.model_class(**config))
model.lm.load_state_dict(state_dict, strict=strict)
return model
......
......@@ -31,7 +31,7 @@ env1.sh('pip install tqdm')
env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
env1.sh('pip3 install einops==0.4.1 pyyaml wandb')
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap')
env1.sh('pip3 install dotmap icecream')
with always_rerun():
if bash:
path.sh("bash")
......
......@@ -81,11 +81,11 @@ with torch.no_grad():
hidden = hf_model.transformer.h[layer].ln_1(hidden)
assert torch.allclose(hf_model.transformer.h[layer].mlp(hidden), based_model.layers[layer].ff(hidden))
hidden = hf_model.transformer.h[layer].mlp(hidden)
assert torch.allclose(hf_model.transformer.h[layer].attn(hidden)[0], based_model.layers[layer].attn(hidden))
assert torch.allclose(hf_model.transformer.h[layer].attn(hidden)[0], based_model.layers[layer].attn(hidden)[0])
hidden = hf_model.transformer.h[layer].attn(hidden)[0]
assert torch.allclose(hf_model.transformer.h[layer](hidden)[0], based_model.layers[layer](hidden))
assert torch.allclose(hf_model.transformer.h[layer](hidden)[0], based_model.layers[layer](hidden)[0])
assert torch.allclose(hf_model.transformer.ln_f(hidden), based_model.ln_final(hidden))
hidden = hf_model.transformer.ln_f(hidden)
assert torch.allclose(hf_model.transformer(x)["last_hidden_state"], based_model.get_embeds(x))
assert torch.allclose(hf_model.transformer(x)["last_hidden_state"], based_model.get_embeds(x)[0])
assert torch.allclose(hf_model(x)["logits"], based_model(x))
\ No newline at end of file
from basedformer import gptj
from basedformer.utils import *
from transformers import AutoTokenizer
from icecream import ic
import time
import sys
def print_top_k(logits, tokenizer, k):
topk_ind = logits.topk(k)[1]
for x in range(topk_ind.shape[0]):
for y in range(topk_ind.shape[1]):
print("\nToken " + str(y))
for token in topk_ind[x, y, :].tolist():
print(tokenizer.decode([token]), end=" | ")
def main():
tokenizer = AutoTokenizer.from_pretrained('gpt2')
prompt = """I fucked her with my huge donut, when she seen my donut she went"""
tokens = tokenizer.encode(prompt)
print("Prompt:")
for x in range(len(tokens)):
print(tokenizer.decode([tokens[x]]), end=" | ")
print("\n Generation:")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
t = time.perf_counter()
model = gptj.load_gpt_j().cuda().half().eval()
model = model.lm
ic(time.perf_counter() - t)
with torch.no_grad():
kv = None
tokens_to_generate = 50
in_tokens = tokens
accum_tokens = []
for x in range(tokens_to_generate):
logits, kv = model(in_tokens, cache=True, kv=kv)
in_tokens = logits[:, -1, :].topk(1)[1]
#in_tokens = torch.cat([in_tokens, logits[:, -1, :].topk(1)[1]], dim=1)
print(tokenizer.decode(in_tokens.squeeze(1).tolist()[-1]), end=" | ")
#accum_tokens = torch.cat(accum_tokens, dim=1)
#accum_tokens = accum_tokens.squeeze(0).tolist()
#print("\n Final token list")
#print(tokenizer.decode(accum_tokens))
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment