Commit 074d25c5 authored by novelailab's avatar novelailab

sampling working, gpt-neo(x) init

parent 9223ef70
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer import lm_base
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
cumsum = x.cumsum(dim = 1)
*x, x_pass = x.chunk(amt + 1, dim = -1)
*x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)
amts = 2 ** torch.arange(amt)
amts = amts.tolist()
shifts = []
denom = torch.arange(n, device = device)
for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
normed_shifted_x = shifted_chunk / (shifted_denom + eps)
shifts.append(normed_shifted_x)
return torch.cat((*shifts, x_pass), dim = -1)
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPTNeoLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
nn.Module.__init__(self)
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ln_postattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
residual = residual + attn_out
x = self.ln_postattn(x)
ff_out = self.ff(x, act_ck)
x = residual + ff_out
return x
class GPTNeoModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTNeoLayer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
class GPTNeoBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTNeoModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTNeoBaseLM.load(config, path, state_dict)
return model
from typing import KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer import lm_base
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
offset = kv[0].shape[-2]
else:
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, :self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key = torch.cat([k, key], dim=-2) # cat key
value = torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPTNeoxLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
nn.Module.__init__(self)
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x, kv
class GPTNeoXModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTNeoXLayer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(x)
if kv:
return x.float(), kv
else:
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
kv = [None] * self.n_layer
kv_new = []
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
kv_new.append(kvi)
x = self.ln_final(x)
if cache:
return x, kv_new
else:
return x, None
class GPTNeoXBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTNeoXModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTNeoXBaseLM.load(config, path, state_dict)
return model
......@@ -94,7 +94,7 @@ class SplitCheckpoint(MutableMapping):
def copy(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True, cuda_blocking=False):
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True, cuda_blocking=True):
precision = 'ns'
r_arr = np.empty([2, r]) # [0] = mean, [1] = std
if function:
......@@ -104,8 +104,10 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
n_arr = np.empty(n)
for k in range(n):
start = time.perf_counter_ns()
if cuda_blocking:
torch.cuda.synchronize()
func()
if cuda_blocking:
torch.cuda.synchronize()
n_arr[k] = time.perf_counter_ns() - start
......
......@@ -2,6 +2,7 @@ from basedformer import gptj
from basedformer.utils import *
from transformers import AutoTokenizer
from icecream import ic
import functorch
import time
import sys
......@@ -17,19 +18,24 @@ def apply_top_k(logits, k):
# filter the logits that are not in the top-k to -inf
# keep top_k_ind and filter the rest
top_k_values = logits.topk(k)[0]
remove_mask = logits < top_k_values[:, -1].unsqueeze(0)
remove_mask = logits < top_k_values[:, -1].unsqueeze(-1)
logits[remove_mask == True] = -float("inf")
return logits
def apply_top_p(logits, p):
logits = torch.softmax(logits, dim=-1)
sorted, indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(sorted, dim=-1)
cumulative_probs = cumulative_probs.scatter(dim=-1, index=indices, src=cumulative_probs)
remove_mask = cumulative_probs > p
logits[remove_mask == True] = -float("inf")
mask_tensor = cumulative_probs > p
# Shift the indices to the right to keep also the first token above the threshold
mask_tensor[..., 1:] = mask_tensor[..., :-1].clone()
mask_tensor[..., 0] = 0
mask_tensor = mask_tensor.scatter(dim=-1, index=indices, src=mask_tensor)
logits[mask_tensor == True] = -float("inf")
return logits
def apply_tfs(logits, tfs):
logits = torch.softmax(logits, dim=-1)
sorted, indices = torch.sort(logits, descending=True)
d = sorted
d = d[:, 1:] - d[:, :-1]
......@@ -37,52 +43,211 @@ def apply_tfs(logits, tfs):
d = d.abs()
d = d / d.sum(dim=-1).view(1, -1).T
cumulative_probs = torch.cumsum(d, dim=-1)
cumulative_probs = cumulative_probs.scatter(dim=-1, index=indices, src=cumulative_probs)
remove_mask = cumulative_probs > tfs
logits[remove_mask == True] = -float("inf")
mask_tensor = torch.empty(indices.shape).cuda()
mask_tensor[:, 1:-1] = (cumulative_probs > tfs)[:, :]
# Always remove last token
mask_tensor[:, -1:] = True
# Always keep the first token
mask_tensor[:, 0] = False
mask_tensor = mask_tensor.scatter(dim=-1, index=indices, src=mask_tensor)
logits[mask_tensor == True] = -float("inf")
return logits
def temperature(logits, temperature):
def apply_typical(logits, mass=0.9):
scores = logits
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, -float("inf"))
return scores
def apply_temp(logits, temperature):
logits = logits / temperature
return logits
def generate(forward, prompt_tokens, temperature, tokens_to_generate=50, ops_list=[{"temp": 0.9}]):
def rep_pen(input_ids, scores, penalty, m=3.33, penalize_last=250,
alpha_frequency=None, alpha_presence=None, whitelist=None,
):
scores = torch.log_softmax(scores, dim=-1)
penalty = 1.0 if penalty < 1.0 else penalty
raw_penalty = penalty
penalize_last = None
if not m is None and not penalize_last is None and penalize_last >= 1:
penalty = (torch.arange(penalize_last)/(penalize_last - 1)) * 2. - 1
penalty = (m * penalty) / (1 + torch.abs(penalty) * (m - 1))
penalty = 1 + ((penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
penalize_last = penalize_last
alpha_enable = alpha_frequency is not None or alpha_presence is not None
whitelist = None
whitelist_list = None
if whitelist is not None:
whitelist_list = whitelist
##########
if whitelist is None and whitelist_list is not None:
whitelist_list = list(filter(lambda x: x >= 0 and x < scores.shape[1], whitelist_list))
if len(whitelist_list) > 0:
whitelist = torch.tensor(whitelist_list).long().sort()[0]
whitelist = whitelist.to(input_ids.device)
if whitelist is not None:
unpenalized = scores.gather(1, whitelist.view(1, -1))
if raw_penalty > 1.0:
if not penalize_last is None:
penality_len = min(input_ids.shape[1], penalize_last)
input_ids = input_ids[:, -penality_len:]
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if not penalize_last is None:
penalty = penalty.type(score.dtype).to(score.device)
score = torch.where(score < 0, score * penalty[:, -penality_len:], score / penalty[:, -penality_len:])
else:
score = torch.where(score < 0, score * penalty, score / penalty)
scores.scatter_(1, input_ids, score)
if alpha_enable:
c = torch.zeros(scores.shape).long().to(input_ids.device)
# unique only returns counts for first item in batch, so manually iterate
for i in range(input_ids.shape[0]):
if penalize_last is not None:
token_input_ids, counts = torch.unique(input_ids[i,-penalize_last:], sorted=True, return_counts=True, dim=-1)
else:
token_input_ids, counts = torch.unique(input_ids[i], sorted=True, return_counts=True, dim=-1)
c[i].scatter_(0, token_input_ids, counts)
if alpha_frequency:
scores -= c * alpha_frequency
if alpha_presence:
scores[c > 0] -= alpha_presence
if whitelist is not None:
scores.scatter_(1, whitelist.view(1, -1), unpenalized)
return scores
def func_multinomial(x):
torch.manual_seed(69)
return torch.multinomial(x, 1)
def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0.9}]):
with torch.no_grad():
in_tokens = prompt_tokens
context = prompt_tokens
print(context.shape)
kv = None
fully_deterministic = False
tokens_generated = []
soft_required = ["top_k", "top_p"]
#soft_required = ["top_k", "top_p"]
op_map = {
"top_k": apply_top_k,
"top_p": apply_top_p,
"temp": temperature,
"tfs": apply_tfs
"typical": apply_typical,
"temp": apply_temp,
"tfs": apply_tfs,
"rep_pen": rep_pen,
}
funcnomial = functorch.vmap(func_multinomial, randomness="different")
for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv)
# always work on softmax logits to make sure all models
# behave similarly as logprobs can be quite different
# TODO: can break compatibility with novelai presets.
# logits should be the last token in the sequence
logits = logits[:, -1, :]
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
#can save one softmax here by not applying softmax for the first op,
#need to take the softmax out of the necessary functions though
for i, ops in enumerate(ops_list):
batch = []
for i, ops in enumerate(ops_list):
item = logits[i, ...].unsqueeze(0)
ctx = context[i, ...].unsqueeze(0)
ic("------")
for op, value in ops.items():
if op in soft_required:
item = torch.log_softmax(logits[i, :, :], dim=-1)
ic(op, value)
if op == "rep_pen":
item = op_map[op](ctx, item, **value)
else:
item = op_map[op](item, value)
batch.append(item)
logits = torch.cat(batch, dim=0)
logits = torch.softmax(logits, dim=-1)
#fully_deterministic makes it deterministic in the batch
if fully_deterministic:
logits = logits.split(1, dim=0)
logit_list = []
for logit in logits:
torch.manual_seed(69)
logit_list.append(torch.multinomial(logit, 1))
logits = torch.cat(logit_list, dim=0)
else:
torch.manual_seed(69)
logits = torch.multinomial(logits, 1)
context = torch.cat([context, logits], dim=1)
in_tokens = logits
return context
def generate_real_batched(forward, prompt_tokens, tokens_to_generate=50, ops={"temp": 0.9}):
with torch.no_grad():
in_tokens = prompt_tokens
kv = None
fully_deterministic = False
tokens_generated = []
op_map = {
"top_k": apply_top_k,
"top_p": apply_top_p,
"typical": apply_typical,
"temp": apply_temp,
"tfs": apply_tfs
}
for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
for op, value in ops.items():
logits = op_map[op](logits, value).float()
logits = torch.softmax(logits, dim=-1).float()
if fully_deterministic:
logits = logits.split(1, dim=0)
logit_list = []
for logit in logits:
torch.manual_seed(69)
logit_list.append(torch.multinomial(logit, 1))
logits = torch.cat(logit_list, dim=0)
else:
torch.manual_seed(69)
logits = torch.multinomial(logits, 1)
in_tokens = logits
tokens_generated.append(logits)
......@@ -90,30 +255,49 @@ def generate(forward, prompt_tokens, temperature, tokens_to_generate=50, ops_lis
return tokens_generated
def main():
bsz = 4
gen_len = 250
torch.manual_seed(69)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
prompt = """I fucked her with my huge donut, when she seen my donut she went"""
prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
tokens = tokenizer.encode(prompt)
print("Prompt:")
for x in range(len(tokens)):
print(tokenizer.decode([tokens[x]]), end=" | ")
print("\n Generation:")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = [tokens] * bsz
#tokens = torch.cat([tokens, tokens], dim=0)
tokens = torch.cat(tokens, dim=0)
t = time.perf_counter()
model = gptj.load_gpt_j().cuda().half().eval()
model = model.lm
ic(time.perf_counter() - t)
rep_pen = {
"penalty": 1000000,
}
ops = {
"top_k": 40,
"top_p": 0.9,
"temp": 0.9,
"rep_pen": rep_pen,
"top_k": 50,
"temp": 0.8,
}
ops_list = [ops] * bsz
tokens_generated = generate(model.forward, tokens, 40, ops=ops)
tokens_generated = tokenizer.decode(tokens_generated.squeeze().tolist())
tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape)
ic(prompt)
ic(tokens_generated)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
for gen in tokens_generated:
print(str(gen))
print("===========================================================")
#ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
#timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
#timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment