Commit fd387a42 authored by novelailab's avatar novelailab

update

parent 482a7bae
...@@ -31,7 +31,8 @@ def init(model_class, config): ...@@ -31,7 +31,8 @@ def init(model_class, config):
init_weights(model, config["n_layer"]) init_weights(model, config["n_layer"])
return model return model
def no_init(model_class, config): def no_init(config):
model_class = models.get_model(config["model_class"])
model = utils.no_init(lambda: model_class(config)) model = utils.no_init(lambda: model_class(config))
return model return model
......
...@@ -133,7 +133,7 @@ class SelfAttention(nn.Module): ...@@ -133,7 +133,7 @@ class SelfAttention(nn.Module):
x = self.out_proj(x) x = self.out_proj(x)
if cache: if cache:
return x, (key, value) return x, [key, value]
else: else:
return x, None return x, None
......
This diff is collapsed.
This diff is collapsed.
...@@ -15,7 +15,10 @@ from math import log2, ceil ...@@ -15,7 +15,10 @@ from math import log2, ceil
from basedformer import gptj, optimizer, lm_utils from basedformer import gptj, optimizer, lm_utils
from basedformer.utils import * from basedformer.utils import *
import glob import glob
from transformers import AutoTokenizer
from basedformer import sampling
from icecream import ic from icecream import ic
from termcolor import colored
def _init_weights(module): def _init_weights(module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
...@@ -101,7 +104,7 @@ class HyperNetwork(nn.Module): ...@@ -101,7 +104,7 @@ class HyperNetwork(nn.Module):
embed_dim = config["hidden_dim"] embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True) self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True) self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
self.activation = gelu_new self.activation = torch.nn.functional.gelu
self.num_shifts = ceil(log2(2048)) - 1 self.num_shifts = ceil(log2(2048)) - 1
#self.linear.weight.data.normal_(mean=0.0, std=0.02) #self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules(): for module in self.modules():
...@@ -147,15 +150,53 @@ class HyperNetworkSingle(nn.Module): ...@@ -147,15 +150,53 @@ class HyperNetworkSingle(nn.Module):
x = x.mul(torch.sigmoid(x)) x = x.mul(torch.sigmoid(x))
return x.bfloat16() return x.bfloat16()
tokenizer = AutoTokenizer.from_pretrained('gpt2')
@torch.no_grad()
def sample(prompt, n_tokens, bsz, hypernetwork=None):
torch.seed()
tokens = tokenizer.encode(prompt)
#print("Prompt:")
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = [tokens] * bsz
tokens = torch.cat(tokens, dim=0)
rep_pen = {
"penalty": 3,
}
ops = {
"rep_pen": rep_pen,
"tfs": 0.8,
"temp": 0.8,
}
ops_list = [ops] * bsz
tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=hypernetwork, non_deterministic=True)
vanilla_tokens_generated = sampling.generate(model.forward, tokens, n_tokens, ops_list=ops_list, hypernetwork=None)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
vanilla_tokens_generated = tokenizer.batch_decode(vanilla_tokens_generated.cpu().numpy())
### send to wandb
columns = ["Prompt", "Generated Text", "Vanilla Model"]
data = []
for x in range(len(tokens_generated)):
data.append([prompt, str(tokens_generated[x]), str(vanilla_tokens_generated[x])])
for gen in tokens_generated:
print(colored("==========================================================", "red"))
print(colored(gen, "green"))
print(colored("==========================================================", "red"))
wandb.log({"Generations": wandb.Table(data=data, columns=columns)})
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
train_config = { train_config = {
#"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map", "data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v7_infilling.map", #"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v7_infilling.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map", ##"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map", #"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling", "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-again",
"do_save": True, "do_save": True,
"run_name": "gpt-j-6b-2e-4-infilling", "run_name": "gpt-j-enwik9-6b-postln-bf16-2e-4-4bsz-every5layer",
"lr": 2e-4, "lr": 2e-4,
"end_lr": 2e-4, "end_lr": 2e-4,
"warmup_steps": 50, "warmup_steps": 50,
...@@ -165,6 +206,7 @@ train_config = { ...@@ -165,6 +206,7 @@ train_config = {
"save_every": 300, "save_every": 300,
"amp": False, "amp": False,
"loss_scale": False, "loss_scale": False,
"eval_every": 100,
} }
torch.manual_seed(train_config["seed"]) torch.manual_seed(train_config["seed"])
bs = train_config["bs"] bs = train_config["bs"]
...@@ -217,6 +259,7 @@ t = tqdm(train_loader, initial=curr_step) ...@@ -217,6 +259,7 @@ t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
for input_ids, labels in t: for input_ids, labels in t:
timex = time.perf_counter() timex = time.perf_counter()
input_ids = input_ids.cuda() input_ids = input_ids.cuda()
...@@ -273,5 +316,8 @@ for input_ids, labels in t: ...@@ -273,5 +316,8 @@ for input_ids, labels in t:
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt") torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt") opt.save(save_folder / "opt")
print(f"Saved model at step {curr_step}") print(f"Saved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0:
sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
curr_step += 1 curr_step += 1
\ No newline at end of file
...@@ -13,7 +13,7 @@ bash = False ...@@ -13,7 +13,7 @@ bash = False
config_obj = KubeConfig() config_obj = KubeConfig()
config_obj.set_name(name) config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.A100_NVLINK, amount=1) config_obj.set_gpu(gpu_name=GPU.A100_PCIE_80GB, amount=1)
config_obj.set_ram(24) config_obj.set_ram(24)
config_obj.set_cpu(4) config_obj.set_cpu(4)
config_obj.dry_run(dry) config_obj.dry_run(dry)
...@@ -36,6 +36,8 @@ if True: ...@@ -36,6 +36,8 @@ if True:
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4') env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap icecream') env1.sh('pip3 install dotmap icecream')
path.sh("pip3 install --editable .") path.sh("pip3 install --editable .")
path.sh("pip3 install transformers")
path.sh("pip3 install termcolor")
with always_rerun(): with always_rerun():
if False: if False:
#env1.sh('pip3 install transformers') #env1.sh('pip3 install transformers')
......
...@@ -9,23 +9,26 @@ from transformers import AutoTokenizer ...@@ -9,23 +9,26 @@ from transformers import AutoTokenizer
from icecream import ic from icecream import ic
import time import time
import sys import sys
from termcolor import colored
def main(): def main():
#save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs4-2e-4-catchup" save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs4-2e-4-catchup/step_1200"
save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling" #save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-again/step_1200"
cp_list = sorted(os.listdir(save_path), key=lambda x: int(x.split("_")[-1])) #save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling"
last_cp = Path(save_path) / cp_list[-1] if len(cp_list) > 0 else None #cp_list = sorted(os.listdir(save_path), key=lambda x: int(x.split("_")[-1]))
#last_cp = Path(save_path) / cp_list[-1] if len(cp_list) > 0 else None
last_cp = Path(save_path)
print(last_cp) print(last_cp)
bsz = 1 bsz = 1
gen_len = 400 gen_len = 1000
#torch.manual_seed(69) #torch.manual_seed(69)
tokenizer = AutoTokenizer.from_pretrained('gpt2') tokenizer = AutoTokenizer.from_pretrained('gpt2')
mask = "████████" mask = "████████"
prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats." prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
prompt = """'''Kurumuz''' is the founder of tech company [[""" prompt = """'''Kurumuz''' is the founder of tech company [["""
promptnomask = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window,{mask} holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}""" #promptnomask = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window,{mask} holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
prompt = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved{mask}, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window, holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}""" #prompt = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved{mask}, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window, holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
tokens = tokenizer.encode(promptnomask) tokens = tokenizer.encode(prompt)
print(tokens) print(tokens)
print("Prompt:") print("Prompt:")
for x in range(len(tokens)): for x in range(len(tokens)):
...@@ -38,7 +41,7 @@ def main(): ...@@ -38,7 +41,7 @@ def main():
t = time.perf_counter() t = time.perf_counter()
model = lmu.load_from_path('pretrained/gptj-6b').cuda().bfloat16().eval() model = lmu.load_from_path('pretrained/gptj-6b').cuda().bfloat16().eval()
hypernetwork = hypernet.HyperNetworkSingle(model.config).cuda().float() hypernetwork = hypernet.HyperNetworkSingle(model.config).cuda().float()
print("Loading from step {}".format(cp_list[-1].split("_")[-1])) #print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt")) hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
ic(time.perf_counter() - t) ic(time.perf_counter() - t)
...@@ -53,19 +56,18 @@ def main(): ...@@ -53,19 +56,18 @@ def main():
"temp": 0.8, "temp": 0.8,
} }
ops_list = [ops] * bsz ops_list = [ops] * bsz
torch.manual_seed(69)
#tokens_generated = sampling.generate(model.forward, tokens, gen_len, ops_list=ops_list, hypernetwork=hypernetwork) tokens_generated = sampling.generate(model.forward, tokens, gen_len, ops_list=ops_list, hypernetwork=hypernetwork, non_deterministic=False)
tokens_generated = sampling.generate_greedy(model.forward, tokens, gen_len, hypernetwork=hypernetwork) #tokens_generated = sampling.generate_greedy(model.forward, tokens, gen_len, hypernetwork=hypernetwork)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops) #tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape) #print(tokens_generated.shape)
tokens_generated[tokens_generated == 48585] = 35625 #tokens_generated[tokens_generated == 48585] = 35625
ic(prompt) #ic(prompt)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy()) tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
for gen in tokens_generated: for gen in tokens_generated:
print(str(gen.split("*****")[0])) print(colored("==========================================================", "red"))
print("++++++++++++") print(colored(gen, "green"))
print(str(gen.split("*****")[1])) print(colored("==========================================================", "red"))
print("===========================================================")
#ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy())) #ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
#timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30) #timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
#timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30) #timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment