Commit fd387a42 authored by novelailab's avatar novelailab

update

parent 482a7bae
......@@ -31,7 +31,8 @@ def init(model_class, config):
init_weights(model, config["n_layer"])
return model
def no_init(model_class, config):
def no_init(config):
model_class = models.get_model(config["model_class"])
model = utils.no_init(lambda: model_class(config))
return model
......
......@@ -133,7 +133,7 @@ class SelfAttention(nn.Module):
x = self.out_proj(x)
if cache:
return x, (key, value)
return x, [key, value]
else:
return x, None
......
from basedformer import gptj
from basedformer.utils import *
from basedformer import lm_utils
from transformers import AutoTokenizer
from icecream import ic
import time
......@@ -173,12 +174,13 @@ def generate_greedy(forward, prompt_tokens, tokens_to_generate=50, hypernetwork=
return generated
@torch.no_grad()
def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0.9}], hypernetwork=None):
def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
in_tokens = prompt_tokens
context = prompt_tokens
generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device)
generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
kv = None
fully_deterministic = False
if non_deterministic:
torch.seed()
#soft_required = ["top_k", "top_p"]
op_map = {
"top_k": apply_top_k,
......@@ -193,6 +195,20 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
#if kv[0][0].shape[0] == 1 and (kv[0][0].shape[0] != len(ops_list)):
# t = time.perf_counter()
# for layer in kv:
# for i in range(len(layer)):
# layer[i] = layer[i].repeat(len(ops_list), 1, 1, 1)
# ic("replicated kv")
# ic(time.perf_counter() - t)
#if logits.shape[0] == 1 and (logits.shape[0] != len(ops_list)):
# logits = logits.repeat(len(ops_list), 1)
# ic("replicated logits")
#if context.shape[0] == 1 and (context.shape[0] != len(ops_list)):
# context = context.repeat(len(ops_list), 1)
# ic("replicated context")
#can save one softmax here by not applying softmax for the first op,
#need to take the softmax out of the necessary functions though
batch = []
......@@ -222,7 +238,6 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
logits = torch.cat(logit_list, dim=0)
else:
#torch.manual_seed(69)
logits = torch.multinomial(logits, 1)
generated = torch.cat([generated, logits], dim=-1)
......@@ -231,7 +246,7 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
return generated
def generate_real_batched(forward, prompt_tokens, tokens_to_generate=50, ops={"temp": 0.9}):
def generate_real_batched(forward, prompt_tokens, tokens_to_generate=50, ops_list={"temp": 0.9}, hypernetwork=None):
with torch.no_grad():
in_tokens = prompt_tokens
kv = None
......@@ -246,11 +261,11 @@ def generate_real_batched(forward, prompt_tokens, tokens_to_generate=50, ops={"t
}
for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv)
logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
for op, value in ops.items():
for op, value in ops_list.items():
logits = op_map[op](logits, value).float()
logits = torch.softmax(logits, dim=-1).float()
......@@ -274,25 +289,142 @@ def generate_real_batched(forward, prompt_tokens, tokens_to_generate=50, ops={"t
tokens_generated = torch.cat(tokens_generated, dim=-1)
return tokens_generated
if __name__ == "__main__":
#model = lm_utils.load_from_path("/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gptj-6b").half().cuda().eval()
params = {
"model_class": "gptj",
'n_layer': 44,
'n_head': 64,
'n_tokens': 2048,
'hidden_dim': 6144,
}
model = lm_utils.no_init(params).half().cuda().eval()
tokenizer = AutoTokenizer.from_pretrained('gpt2')
print_parameters(model)
long_prompt = """_3 May. Bistritz._--Left Munich at 8:35 P. M., on 1st May, arriving at
Vienna early next morning; should have arrived at 6:46, but train was an
hour late. Buda-Pesth seems a wonderful place, from the glimpse which I
got of it from the train and the little I could walk through the
streets. I feared to go very far from the station, as we had arrived
late and would start as near the correct time as possible. The
impression I had was that we were leaving the West and entering the
East; the most western of splendid bridges over the Danube, which is
here of noble width and depth, took us among the traditions of Turkish
rule.
We left in pretty good time, and came after nightfall to Klausenburgh.
Here I stopped for the night at the Hotel Royale. I had for dinner, or
rather supper, a chicken done up some way with red pepper, which was
very good but thirsty. (_Mem._, get recipe for Mina.) I asked the
waiter, and he said it was called "paprika hendl," and that, as it was a
national dish, I should be able to get it anywhere along the
Carpathians. I found my smattering of German very useful here; indeed, I
don't know how I should be able to get on without it.
Having had some time at my disposal when in London, I had visited the
British Museum, and made search among the books and maps in the library
regarding Transylvania; it had struck me that some foreknowledge of the
country could hardly fail to have some importance in dealing with a
nobleman of that country. I find that the district he named is in the
extreme east of the country, just on the borders of three states,
Transylvania, Moldavia and Bukovina, in the midst of the Carpathian
mountains; one of the wildest and least known portions of Europe. I was
not able to light on any map or work giving the exact locality of the
Castle Dracula, as there are no maps of this country as yet to compare
with our own Ordnance Survey maps; but I found that Bistritz, the post
town named by Count Dracula, is a fairly well-known place. I shall enter
here some of my notes, as they may refresh my memory when I talk over my
travels with Mina.
In the population of Transylvania there are four distinct nationalities:
Saxons in the South, and mixed with them the Wallachs, who are the
descendants of the Dacians; Magyars in the West, and Szekelys in the
East and North. I am going among the latter, who claim to be descended
from Attila and the Huns. This may be so, for when the Magyars conquered
the country in the eleventh century they found the Huns settled in it. I
read that every known superstition in the world is gathered into the
horseshoe of the Carpathians, as if it were the centre of some sort of
imaginative whirlpool; if so my stay may be very interesting. (_Mem._, I
must ask the Count all about them.)
I did not sleep well, though my bed was comfortable enough, for I had
all sorts of queer dreams. There was a dog howling all night under my
window, which may have had something to do with it; or it may have been
the paprika, for I had to drink up all the water in my carafe, and was
still thirsty. Towards morning I slept and was wakened by the continuous
knocking at my door, so I guess I must have been sleeping soundly then.
I had for breakfast more paprika, and a sort of porridge of maize flour
which they said was "mamaliga," and egg-plant stuffed with forcemeat, a
very excellent dish, which they call "impletata." (_Mem._, get recipe
for this also.) I had to hurry breakfast, for the train started a little
before eight, or rather it ought to have done so, for after rushing to
the station at 7:30 I had to sit in the carriage for more than an hour
before we began to move. It seems to me that the further east you go the
more unpunctual are the trains. What ought they to be in China?
All day long we seemed to dawdle through a country which was full of
beauty of every kind. Sometimes we saw little towns or castles on the
top of steep hills such as we see in old missals; sometimes we ran by
rivers and streams which seemed from the wide stony margin on each side
of them to be subject to great floods. It takes a lot of water, and
running strong, to sweep the outside edge of a river clear. At every
station there were groups of people, sometimes crowds, and in all sorts
of attire. Some of them were just like the peasants at home or those I
saw coming through France and Germany, with short jackets and round hats
and home-made trousers; but others were very picturesque. The women
looked pretty, except when you got near them, but they were very clumsy
about the waist. They had all full white sleeves of some kind or other,
and most of them had big belts with a lot of strips of something
fluttering from them like the dresses in a ballet, but of course there
were petticoats under them. The strangest figures we saw were the
Slovaks, who were more barbarian than the rest, with their big cow-boy
hats, great baggy dirty-white trousers, white linen shirts, and enormous
heavy leather belts, nearly a foot wide, all studded over with brass
nails. They wore high boots, with their trousers tucked into them, and
had long black hair and heavy black moustaches. They are very
picturesque, but do not look prepossessing. On the stage they would be
set down at once as some old Oriental band of brigands. They are,
however, I am told, very harmless and rather wanting in natural
self-assertion.
It was on the dark side of twilight when we got to Bistritz, which is a
very interesting old place. Being practically on the frontier--for the
Borgo Pass leads from it into Bukovina--it has had a very stormy
existence, and it certainly shows marks of it. Fifty years ago a series
of great fires took place, which made terrible havoc on five separate
occasions. At the very beginning of the seventeenth century it underwent
a siege of three weeks and lost 13,000 people, the casualties of war
proper being assisted by famine and disease.
Count Dracula had directed me to go to the Golden Krone Hotel, which I
found, to my great delight, to be thoroughly old-fashioned, for of
course I wanted to see all I could of the ways of the country. I was
evidently expected, for when I got near the door I faced a
cheery-looking elderly woman in the usual peasant dress--white
undergarment with long double apron, front, and back, of coloured stuff
fitting almost too tight for modesty. When I came close she bowed and
said, "The Herr Englishman?" "Yes," I said, "Jonathan Harker." She
smiled, and gave some message to an elderly man in white shirt-sleeves,
who had followed her to the door. He went, but immediately returned with
a letter:"""
def main():
bsz = 4
gen_len = 250
torch.manual_seed(69)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
prompt = """I fucked her with my huge donut, when she seen my donut she went"""
prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
tokens = tokenizer.encode(prompt)
print("Prompt:")
for x in range(len(tokens)):
print(tokenizer.decode([tokens[x]]), end=" | ")
print("\n Generation:")
tokens = tokenizer.encode(long_prompt)
#print("Prompt:")
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
#print("\n Generation:")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = [tokens] * bsz
#tokens = torch.cat([tokens, tokens], dim=0)
tokens = torch.cat(tokens, dim=0)
t = time.perf_counter()
model = gptj.load_gpt_j().cuda().half().eval()
model = model.lm
ic(time.perf_counter() - t)
......@@ -307,18 +439,83 @@ def main():
}
ops_list = [ops] * bsz
timeit(lambda: generate(model.forward, tokens, gen_len, ops_list=ops_list), n=1, r=1)
tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
print(tokens_generated.shape)
#ic(prompt)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
for gen in tokens_generated:
print(str(gen))
print("===========================================================")
def same_cached_prompt_batched():
bsz = 2
gen_len = 1
torch.manual_seed(69)
prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
tokens = tokenizer.encode(long_prompt)
#print("Prompt:")
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
print("\n Generation:")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
rep_pen = {
"penalty": 3,
}
ops = {
"rep_pen": rep_pen,
"top_k": 50,
"temp": 0.8,
}
ops_list = [ops] * bsz
timeit(lambda: generate(model.forward, tokens, gen_len, ops_list=ops_list), n=5, r=1)
tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
#tokens_generated = generate_greedy(model.forward, tokens, gen_len)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape)
ic(prompt)
#ic(prompt)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
for gen in tokens_generated:
print(str(gen))
print("===========================================================")
#ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
#timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
#timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)
def no_cached_prompt_batched():
return
def single_gen():
bsz = 4
gen_len = 250
torch.manual_seed(69)
prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
tokens = tokenizer.encode(long_prompt)
#print("Prompt:")
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
print("\n Generation:")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = [tokens] * 4
tokens = torch.cat(tokens, dim=0)
rep_pen = {
"penalty": 3,
}
ops = {
"rep_pen": rep_pen,
"top_k": 50,
"temp": 0.8,
}
ops_list = [ops] * 4
timeit(lambda: generate(model.forward, tokens, gen_len, ops_list=ops_list), n=1, r=1)
tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
print(tokens_generated.shape)
#ic(prompt)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
for gen in tokens_generated:
print(str(gen))
print("===========================================================")
if __name__ == "__main__":
main()
\ No newline at end of file
#main()
same_cached_prompt_batched()
#single_gen()
\ No newline at end of file
from re import A
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch.utils import data
import math
import sys
from tqdm import tqdm
import time
import wandb
import numpy as np
from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil
from basedformer import gptj, optimizer, lm_utils
from basedformer.utils import *
import glob
from icecream import ic
from transformers import AutoTokenizer
from basedformer import sampling
def _init_weights(module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.parameter.Parameter):
module.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
cumsum = x.cumsum(dim = 1)
*x, x_pass = x.chunk(amt + 1, dim = -1)
*x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)
amts = 2 ** torch.arange(amt)
amts = amts.tolist()
shifts = []
denom = torch.arange(n, device = device)
for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
normed_shifted_x = shifted_chunk / (shifted_denom + eps)
shifts.append(normed_shifted_x)
return torch.cat((*shifts, x_pass), dim = -1)
def discounted_cumsum(t, gamma):
try:
from torch_discounted_cumsum import discounted_cumsum_left
except ImportError:
print('unable to import torch_discounted_cumsum - please run `pip install torch-discounted-cumsum`')
b, n, d = t.shape
t = rearrange(t, 'b n d -> (b d) n')
t = discounted_cumsum_left(t, gamma)
t = rearrange(t, '(b d) n -> b n d', b = b)
return t
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
class HyperNetworkGRU(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear1 = nn.Linear(embed_dim, embed_dim//8)
self.gru = nn.GRU(embed_dim//8, embed_dim // 8, num_layers=1, bidirectional=False, batch_first=True)
self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
self.activation = gelu_new
for module in self.modules():
_init_weights(module)
for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
for param in self.gru.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
def forward(self, x):
x = x.float()
x = self.linear1(x)
x = self.gru(x)[0]
x = self.ln_1(x)
x = self.linear2(x)
x = ck(self.activation, x)
return x.bfloat16()
class HyperNetwork(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
self.activation = gelu_new
self.num_shifts = ceil(log2(2048)) - 1
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
_init_weights(module)
for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, x):
x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = ck(self.activation, x)
x = self.linear2(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
class HyperNetworkSingle(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
self.activation = gelu_new
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
_init_weights(module)
for param in self.linear.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, x):
x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
def _attn(query, key, value, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = attn_weights / scale_attn
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class HyperNetworkLite(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config):
nn.Module.__init__(self)
self.head_dim = config.hidden_dim // config.n_head
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
self.latent_seq_size = 1000
self.latent_n_head = config.n_head
self.latent_head_dim = self.head_dim
self.k_embedding = torch.nn.parameter.Parameter(data=torch.zeros(self.latent_seq_size, config.hidden_dim), requires_grad=True)
self.v_embedding = torch.nn.parameter.Parameter(data=torch.zeros(self.latent_seq_size, config.hidden_dim), requires_grad=True)
for module in self.modules():
_init_weights(module)
self.k_embedding.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config.n_layer)))
self.v_embedding.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config.n_layer)))
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
def forward(self, x):
x = x.float()
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = x.view(B, S, self.n_head, self.head_dim).transpose(1, 2) #latent query
key = self.k_embedding.view(1, self.latent_seq_size, self.n_head, self.head_dim).transpose(1, 2).repeat(B, 1, 1, 1) #context key
value = self.v_embedding.view(1, self.latent_seq_size, self.n_head, self.head_dim).transpose(1, 2).repeat(B, 1, 1, 1) #context value
x = _attn(query, key, value, self.scale_attn)
x = x.transpose(1, 2).contiguous().view(B, S, H)
return x.bfloat16()
tokenizer = AutoTokenizer.from_pretrained('gpt2')
@torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None):
torch.manual_seed(69)
tokens = tokenizer.encode(prompt)
#print("Prompt:")
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
print("\n Generation:")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
rep_pen = {
"penalty": 3,
}
ops = {
"rep_pen": rep_pen,
"top_p": 0.8,
"temp": 0.8,
}
ops_list = [ops] * bsz
tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=hypernetwork)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
### send to wandb
columns = ["Prompt", "Generated Text"]
data = []
for gen in tokens_generated:
data.append([prompt, str(gen)])
wandb.log({"Generations": wandb.Table(data=data, columns=columns)})
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v7_infilling.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-wiki-hypernetworkslite-1000tokens",
"do_save": True,
"run_name": "gpt-j-6b-2e-4-hypernetworkslite-1000tokens",
"lr": 2e-4,
"end_lr": 2e-4,
"warmup_steps": 50,
"bs": 4,
"gas": 1,
"seed": 69,
"save_every": 300,
"eval_every": 300,
"amp": False,
"loss_scale": False,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = lm_utils.load_from_path("pretrained/gptj-6b").cuda().bfloat16()
for param in model.parameters():
param.requires_grad = False
for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
#hypernetwork = HyperNetworkSingle(model.config).cuda().float()
hypernetwork = HyperNetworkLite(model.config).cuda().float()
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(2)])
for param in hypernetwork.parameters():
param.requires_grad = True
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None
print(last_cp)
if last_cp:
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(), last_cp / "opt")
else:
opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = FbDataset(2049, train_config["data_path"])
if last_cp:
train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model.config})
if last_cp:
curr_step = opt.curr_step
else:
curr_step = 0
t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
labels = labels.cuda()
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
if train_config["loss_scale"]:
scaler.scale(gas_loss).backward()
else:
gas_loss.backward()
loss += gas_loss.item()
loss = loss / gas
if train_config["loss_scale"]:
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
if train_config["loss_scale"]:
opt.step(scaler=scaler)
else:
opt.step()
if train_config["loss_scale"]:
scaler.update()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log(
{
"train/loss": loss,
"train/tokens_per_sec": tokens_per_sec,
"train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/loss_scale": scaler.get_scale()
},
step=curr_step)
if train_config["do_save"]:
if curr_step % train_config["save_every"] == 0 and curr_step != 0:
save_folder = Path(train_config["save_path"]) / f"step_{curr_step}"
save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt")
print(f"Saved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0:
sample("<|endoftext|>", 150, 4, hypernetwork=hypernetwork)
curr_step += 1
\ No newline at end of file
......@@ -15,7 +15,10 @@ from math import log2, ceil
from basedformer import gptj, optimizer, lm_utils
from basedformer.utils import *
import glob
from transformers import AutoTokenizer
from basedformer import sampling
from icecream import ic
from termcolor import colored
def _init_weights(module):
if isinstance(module, nn.Linear):
......@@ -101,7 +104,7 @@ class HyperNetwork(nn.Module):
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
self.activation = gelu_new
self.activation = torch.nn.functional.gelu
self.num_shifts = ceil(log2(2048)) - 1
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
......@@ -147,15 +150,53 @@ class HyperNetworkSingle(nn.Module):
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
tokenizer = AutoTokenizer.from_pretrained('gpt2')
@torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None):
torch.seed()
tokens = tokenizer.encode(prompt)
#print("Prompt:")
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = [tokens] * bsz
tokens = torch.cat(tokens, dim=0)
rep_pen = {
"penalty": 3,
}
ops = {
"rep_pen": rep_pen,
"tfs": 0.8,
"temp": 0.8,
}
ops_list = [ops] * bsz
tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=hypernetwork, non_deterministic=True)
vanilla_tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=None)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
vanilla_tokens_generated = tokenizer.batch_decode(vanilla_tokens_generated.cpu().numpy())
### send to wandb
columns = ["Prompt", "Generated Text", "Vanilla Model"]
data = []
for x in range(len(tokens_generated)):
data.append([prompt, str(tokens_generated[x]), str(vanilla_tokens_generated[x])])
for gen in tokens_generated:
print(colored("==========================================================", "red"))
print(colored(gen, "green"))
print(colored("==========================================================", "red"))
wandb.log({"Generations": wandb.Table(data=data, columns=columns)})
# we need 250 batch size to train the small GPT.
train_config = {
#"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v7_infilling.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v7_infilling.map",
##"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-again",
"do_save": True,
"run_name": "gpt-j-6b-2e-4-infilling",
"run_name": "gpt-j-enwik9-6b-postln-bf16-2e-4-4bsz-every5layer",
"lr": 2e-4,
"end_lr": 2e-4,
"warmup_steps": 50,
......@@ -165,6 +206,7 @@ train_config = {
"save_every": 300,
"amp": False,
"loss_scale": False,
"eval_every": 100,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
......@@ -217,6 +259,7 @@ t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler()
sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
......@@ -273,5 +316,8 @@ for input_ids, labels in t:
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt")
print(f"Saved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0:
sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
curr_step += 1
\ No newline at end of file
......@@ -13,7 +13,7 @@ bash = False
config_obj = KubeConfig()
config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.A100_NVLINK, amount=1)
config_obj.set_gpu(gpu_name=GPU.A100_PCIE_80GB, amount=1)
config_obj.set_ram(24)
config_obj.set_cpu(4)
config_obj.dry_run(dry)
......@@ -36,6 +36,8 @@ if True:
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap icecream')
path.sh("pip3 install --editable .")
path.sh("pip3 install transformers")
path.sh("pip3 install termcolor")
with always_rerun():
if False:
#env1.sh('pip3 install transformers')
......
......@@ -9,23 +9,26 @@ from transformers import AutoTokenizer
from icecream import ic
import time
import sys
from termcolor import colored
def main():
#save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs4-2e-4-catchup"
save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling"
cp_list = sorted(os.listdir(save_path), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(save_path) / cp_list[-1] if len(cp_list) > 0 else None
save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs4-2e-4-catchup/step_1200"
#save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-again/step_1200"
#save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling"
#cp_list = sorted(os.listdir(save_path), key=lambda x: int(x.split("_")[-1]))
#last_cp = Path(save_path) / cp_list[-1] if len(cp_list) > 0 else None
last_cp = Path(save_path)
print(last_cp)
bsz = 1
gen_len = 400
gen_len = 1000
#torch.manual_seed(69)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
mask = "████████"
prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
prompt = """'''Kurumuz''' is the founder of tech company [["""
promptnomask = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window,{mask} holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
prompt = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved{mask}, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window, holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
tokens = tokenizer.encode(promptnomask)
#promptnomask = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window,{mask} holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
#prompt = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved{mask}, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window, holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
tokens = tokenizer.encode(prompt)
print(tokens)
print("Prompt:")
for x in range(len(tokens)):
......@@ -38,7 +41,7 @@ def main():
t = time.perf_counter()
model = lmu.load_from_path('pretrained/gptj-6b').cuda().bfloat16().eval()
hypernetwork = hypernet.HyperNetworkSingle(model.config).cuda().float()
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
#print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
ic(time.perf_counter() - t)
......@@ -53,19 +56,18 @@ def main():
"temp": 0.8,
}
ops_list = [ops] * bsz
#tokens_generated = sampling.generate(model.forward, tokens, gen_len, ops_list=ops_list, hypernetwork=hypernetwork)
tokens_generated = sampling.generate_greedy(model.forward, tokens, gen_len, hypernetwork=hypernetwork)
torch.manual_seed(69)
tokens_generated = sampling.generate(model.forward, tokens, gen_len, ops_list=ops_list, hypernetwork=hypernetwork, non_deterministic=False)
#tokens_generated = sampling.generate_greedy(model.forward, tokens, gen_len, hypernetwork=hypernetwork)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape)
tokens_generated[tokens_generated == 48585] = 35625
ic(prompt)
#print(tokens_generated.shape)
#tokens_generated[tokens_generated == 48585] = 35625
#ic(prompt)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
for gen in tokens_generated:
print(str(gen.split("*****")[0]))
print("++++++++++++")
print(str(gen.split("*****")[1]))
print("===========================================================")
print(colored("==========================================================", "red"))
print(colored(gen, "green"))
print(colored("==========================================================", "red"))
#ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
#timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
#timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment