Commit a9eca288 authored by novelailab's avatar novelailab

it trains!

parent 9a167649
......@@ -131,3 +131,4 @@ dmypy.json
models
gptjconvert
j6b_vanilla
wandb
\ No newline at end of file
......@@ -61,21 +61,52 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
def rndinput(shape):
return torch.randint(0, 50256, shape).long().cuda()
def forward(model, x):
out = model.get_logits(x, act_ck=False)
@torch.no_grad()
def forward(model, x, hypernetwork=None):
out = model.get_logits(x, hypernetwork=hypernetwork, act_ck=True)
print(out.shape)
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
loss = torch.nn.CrossEntropyLoss()(out, out)
loss.backward()
model.zero_grad()
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
#print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
#loss = torch.nn.CrossEntropyLoss()(out, out)
#loss.backward()
#model.zero_grad()
#print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
class HyperNetwork(nn.Module):
def __init__(self, hidden_size, num_layers):
super().__init__()
embed_dim = hidden_size
self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
state = self.state_dict()
for k in state:
state[k] = state[k] * 1 / math.sqrt(2 * num_layers)
self.load_state_dict(state)
def forward(self, hidden_states):
hidden_states = self.linear(hidden_states)
hidden_states = hidden_states.mul(torch.sigmoid(hidden_states))
return hidden_states
def main():
model = init_1_3b().cuda().half()
shape = (1, 2048)
model = init_6b().cuda().half()
for param in model.parameters():
param.requires_grad = False
for param in model.vocab_embed.parameters():
param.requires_grad = True
for x in model.layers:
for param in x.ln_preattn.parameters():
param.requires_grad = True
hypernetwork = HyperNetwork(4096, 28).cuda().half()
hypernetwork.train()
shape = (1, 1)
#print(model(x).shape)
print("PyTorch Eager")
timeit(r=1, n=2, func=lambda: forward(model, rndinput(shape)), do_tqdm=False, first=False)
timeit(r=1, n=2, func=lambda: forward(model, rndinput(shape), hypernetwork), do_tqdm=False, first=False)
if __name__ == "__main__":
main()
{
"lr": 1.0e-4,
}
\ No newline at end of file
import torch
import math
@torch.jit.script
def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def gelu_slow(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def gelu_trace(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def gelu_involved(x):
return gelu_new(x)
#torch.jit.trace gelu
#code:
gelu_traced = torch.jit.trace(gelu_involved, torch.randn(1, 128, 128))
x = torch.rand(1, 128, 128)
assert torch.allclose(gelu_new(x), gelu_involved(x))
......@@ -11,18 +11,26 @@ def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
class BasedOptimizer:
def __init__(self, parameters, config, optimizer):
self.lr = config["lr"]
self.end_lr = config["end_lr"] if "end_lr" in config else self.lr
self.warmup_steps = config["warmup_steps"] if "warmup_steps" in config else 1
self.anneal_steps = config["anneal_steps"] if "anneal_steps" in config else 1
self.total_steps = config["total_steps"] if "total_steps" in config else None
self.weight_decay = config["weight_decay"] if "weight_decay" in config else 0
self.tokens = config["tokens"] if "tokens" in config else None
self.epochs = config["epochs"] if "epochs" in config else None
# tokens and epochs should not be here. calculate it somewhere else and find how many steps, then pass to the BasedOptimizer
self.beta1 = config["beta1"] if "beta1" in config else 0.9
self.beta2 = config["beta2"] if "beta2" in config else 0.95
self.eps = config["eps"] if "eps" in config else 1e-4
defaults = {
"lr": 6e-4,
"end_lr": 6e-4,
"warmup_steps": 1,
"anneal_steps": 1,
"total_steps": None,
"weight_decay": 0,
"tokens": None,
"epochs": None,
"beta1": 0.9,
"beta2": 0.95,
"eps": 1e-4,
}
for k, v in defaults.items():
setattr(self, k, v)
for k, v in config.items():
setattr(self, k, v)
self.max_lr = False
self.curr_step = 0
self.curr_lr = 0
......
......@@ -6,30 +6,17 @@ import torch
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None):
self.half_blocks = False
if block_size is not None and int(block_size) < 2048:
self.half_blocks = True
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, 2048))
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0]
if self.half_blocks:
self.samples *= 2
if not max_samples is None:
if max_samples is not None:
self.samples = min(self.samples, int(max_samples))
self.skip = 0
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
offset = 0
length = 2048
if self.half_blocks:
nth = _id // 2
offset = 1024 * (_id % 2)
length = 1024
data = torch.tensor(self.npz[nth][offset:offset+length].astype(np.int64))
return (data, data)
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
......
......@@ -72,7 +72,6 @@ class SplitCheckpoint(MutableMapping):
def get_logits(x, embedding):
return embedding(x)
@torch.jit.script
def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
......@@ -266,8 +265,8 @@ class GPTModel(nn.Module):
x = self.ln_final(x)
return x
def get_logits(self, x, act_ck=False):
x = self.forward(x, act_ck=act_ck)
def get_logits(self, x, hypernetwork=None, act_ck=False):
x = self.forward(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
......@@ -285,6 +284,22 @@ class GPTModel(nn.Module):
model = cls(**config)
return model
@classmethod
def neox_init(cls, config):
model = cls(**config)
modules = [*model.layers[:-1], model.vocab_embed, model.ln_final, model.lm_head]
init = small_init_method(config["hidden_dim"])
for module in modules:
for param in module.parameters():
init(param)
last_layer = model.layers[-1]
last_layer_init = wang_init_method(config["n_layer"], config["hidden_dim"])
for param in last_layer.parameters():
last_layer_init(param)
return model
def save(self, path):
try: os.mkdir(path)
except: pass
......@@ -297,6 +312,26 @@ class GPTModel(nn.Module):
# TODO: Do we want to have the LM head as a seperate Class? Or just a function? I think we might be better off with a function here and maybe
# also for the self attention, we can just write a function that gets fed in the q, k, v.
def wang_init_method(n_layers, dim):
std = 2 / n_layers / math.sqrt(dim)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
......
......@@ -7,7 +7,7 @@ dry = False
config_obj = KubeConfig()
config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.RTX_A5000, amount=1)
config_obj.set_gpu(gpu_name=GPU.A100_PCIE_40GB, amount=1)
config_obj.set_ram(16)
config_obj.set_cpu(4)
config_obj.dry_run(dry)
......@@ -23,7 +23,8 @@ env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-li
env1.sh('pip install einops numpy')
env1.sh('pip install tqdm')
env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
env1.sh('pip3 install einops==0.4.1')
env1.sh('pip3 install einops==0.4.1 pyyaml wandb')
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
with always_rerun():
print(f"Running {sys.argv[1]}")
path.sh(f'python3 {sys.argv[1]}')
\ No newline at end of file
......@@ -3,60 +3,62 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from lm_train import optimizer, utils
from torch.utils import data
from main import *
import yaml
import sys
from tqdm import tqdm
import time
import wandb
#Based Optimizer
class BasedOptimizer:
def __init__(self, model, config, optimizer):
self.min_lr = config["min_lr"] if "min_lr" in config else 1e-06
self.warmup_end = config["lr"] if "lr" in config else 5e-06
self.warmup_init = config["warmup_init"] if "warmup_init" in config else 0
self.warmup_steps = config["warmup_steps"] if "warmup_steps" in config else 1
self.total_steps = config["total_steps"] if "total_steps" in config else None
self.weight_decay = config["weight_decay"] if "weight_decay" in config else 0
self.start_step = config["start_step"] if "start_step" in config else 0
self.curr_step = self.start_step
self.curr_lr = 0
model_config = {
"n_layer": 12,
"n_head": 12,
"hidden_dim": 768,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
optim_func = optim.AdamW
# we need 250 batch size to train the small GPT.
train_config = {
"lr": 6e-4,
"end_lr": 6e-4,
"warmup_steps": 100,
"bs": 16,
"gas": 2,
"seed": 69,
}
bs = train_config["bs"]
gas = train_config["gas"]
model = GPTModel.neox_init(model_config).cuda().bfloat16()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
train_dataset = utils.FbDataset(2049, "sigurd_v5_2049.map")
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="basedformer-tests", name="sigurd_v5_2049")
self.optimizers = optim_func(model.parameters(), lr=self.warmup_init, weight_decay=self.weight_decay, betas=config["betas"], eps=config["eps"])
t = tqdm(train_loader)
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
labels = labels.cuda()
loss = 0
for x in range(train_config["gas"]):
logits = model.get_logits(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=None, act_ck=True)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :]
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
gas_loss.backward()
loss += gas_loss.item()
def get_current_lr(self):
cosine_lr = self.min_lr + 0.5 * (self.warmup_end - self.min_lr) * (1 + math.cos(math.pi * min(1.0, max(0, self.curr_step - self.warmup_steps) / (self.total_steps - self.warmup_steps))))
target_lr = self.warmup_end if self.curr_step < self.warmup_steps else cosine_lr
return inter(self.warmup_init, target_lr, max(0, self.curr_step - self.start_step) / max(1, self.warmup_steps))
return min(self.end_lr * (self.curr_step / self.warmup_steps), self.end_lr)
def backward(self, loss):
self.optimizers[0].backward(loss, update_master_grads=False)
#loss.backward()
def step(self, scaler=None):
self.curr_lr = self.get_current_lr()
for optimizer in self.optimizers:
for paramx in optimizer.param_groups:
paramx['lr'] = self.curr_lr
optimizer.update_master_grads()
if scaler:
for optimizer in self.optimizers:
scaler.step(optimizer)
else:
optimizer.step()
self.curr_step += 1
def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()
def print_info(self):
print(f"min_lr: {str(self.min_lr)}")
print(f"warmup_end: {str(self.warmup_end)}")
print(f"warmup_init: {str(self.warmup_init)}")
print(f"warmup_steps: {str(self.warmup_steps)}")
print(f"start_step: {str(self.start_step)}")
print(f"total_steps: {str(self.total_steps)}")
print(f"weight_decay: {str(self.weight_decay)}")
print(f"step: {str(self.curr_step)}")
print(f"curr_lr: {str(self.get_current_lr())}")
\ No newline at end of file
loss = loss / gas
opt.step()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * 2048
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr})
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment