Commit cc02ad48 authored by Wes Brown's avatar Wes Brown

Some cleanup, device agnosticism.

parent 91470a5a
...@@ -4,11 +4,8 @@ import mmap ...@@ -4,11 +4,8 @@ import mmap
import pickle import pickle
import concurrent import concurrent
from torch.utils import data from torch.utils import data
from simplejpeg import decode_jpeg
import simplejpeg
import pickle import pickle
from pathlib import Path from pathlib import Path
from PIL import Image
from tqdm import tqdm from tqdm import tqdm
from concurrent.futures import as_completed from concurrent.futures import as_completed
import requests import requests
...@@ -54,6 +51,9 @@ class ShardedDataset(data.Dataset): ...@@ -54,6 +51,9 @@ class ShardedDataset(data.Dataset):
class ShardedImageDataset(data.Dataset): class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None, def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"): outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
from simplejpeg import decode_jpeg
import simplejpeg
from PIL import Image
self.skip = skip self.skip = skip
self.threads = threads self.threads = threads
......
...@@ -87,11 +87,13 @@ def load_from_path(config_folder=None, strict=False): ...@@ -87,11 +87,13 @@ def load_from_path(config_folder=None, strict=False):
model = _load_dict_model(model_class, model_config, model_path, strict=strict) model = _load_dict_model(model_class, model_config, model_path, strict=strict)
return model return model
def _load_dict_model(model_class, config, path=None, state_dict=None, strict=False): def _load_dict_model(model_class, config, path=None, state_dict=None,
strict=False, device="cuda"):
# I am kinda sad that we will not have a load function in lm object itself. # I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope. # might be better to add load functions -- actually nope.
if path: if path:
state_dict = utils.SplitCheckpoint(path, device="cuda") state_dict = utils.SplitCheckpoint(path, device=device)
state_dict.device = device
model= utils.no_init(lambda: model_class(config)) model= utils.no_init(lambda: model_class(config))
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
......
...@@ -20,6 +20,12 @@ from basedformer import sampling ...@@ -20,6 +20,12 @@ from basedformer import sampling
from icecream import ic from icecream import ic
from termcolor import colored from termcolor import colored
gpu = "cuda"
amp = torch.cuda.amp
if gpu != "cuda":
amp = torch.amp
scaler = torch.cuda.amp.GradScaler()
def _init_weights(module): def _init_weights(module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02) module.weight.data.normal_(mean=0.0, std=0.02)
...@@ -158,7 +164,7 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None): ...@@ -158,7 +164,7 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None):
#print("Prompt:") #print("Prompt:")
#for x in range(len(tokens)): #for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ") # print(tokenizer.decode([tokens[x]]), end=" | ")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda() tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu)
tokens = [tokens] * bsz tokens = [tokens] * bsz
tokens = torch.cat(tokens, dim=0) tokens = torch.cat(tokens, dim=0)
...@@ -190,9 +196,9 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None): ...@@ -190,9 +196,9 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None):
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
train_config = { train_config = {
"data_path": "/home/xuser/nvme1/dataset/enwik9-gpt2-2049.map", "data_path": "dataset/enwik9-gpt2-2049.map",
"save_path": "/home/xuser/models/enwik9-sigurdv4-hypernet2", "save_path": "models/enwik9-sigurdv4-hypernet2",
"lm_path": "/home/xuser/nvme1/pretrained/sigurdv4", "lm_path": "pretrained/sigurdv4",
"optimizer": "adamw", "optimizer": "adamw",
"masked_softmax_fusion": False, "masked_softmax_fusion": False,
"do_save": True, "do_save": True,
...@@ -214,8 +220,8 @@ gas = train_config["gas"] ...@@ -214,8 +220,8 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True) Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = lm_utils.load_from_path("/home/xuser/nvme1/pretrained/sigurdv4").cuda().bfloat16() model = lm_utils.load_from_path("pretrained/sigurdv4").to(gpu).bfloat16()
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -223,9 +229,7 @@ for name, p in model.named_parameters(): ...@@ -223,9 +229,7 @@ for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name): if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True p.requires_grad = True
hypernetwork = HyperNetworkSingle(model.config).cuda().float() hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(2)])
for param in hypernetwork.parameters(): for param in hypernetwork.parameters():
param.requires_grad = True param.requires_grad = True
...@@ -257,17 +261,17 @@ else: ...@@ -257,17 +261,17 @@ else:
t = tqdm(train_loader, initial=curr_step) t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler()
#sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork) #sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
for input_ids, labels in t: for input_ids, labels in t:
timex = time.perf_counter() timex = time.perf_counter()
input_ids = input_ids.cuda() input_ids = input_ids.to(gpu)
labels = labels.cuda() labels = labels.to(gpu)
loss = 0 loss = 0
for x in range(train_config["gas"]): for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16): with amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits, _ = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=True) logits, _ = model(input_ids[x*bs:(x+1)*bs, :].to(gpu), hypernetwork=hypernetwork, act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0])) #print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1]) logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous() gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
...@@ -317,6 +321,7 @@ for input_ids, labels in t: ...@@ -317,6 +321,7 @@ for input_ids, labels in t:
print(f"Saved model at step {curr_step}") print(f"Saved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0 and curr_step != 0: if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
print("")
sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork) sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
curr_step += 1 curr_step += 1
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment