Commit 9a167649 authored by novelailab's avatar novelailab

fast and good grad checkpoint, gelu torch script

parent a564d1bb
from main import *
import time
import gc
from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
precision = 'ns'
r_arr = np.empty([2, r]) # [0] = mean, [1] = std
if function:
func.__name__ = function.__name__
for i in tqdm(range(r)) if do_tqdm else range(r):
n_arr = np.empty(n)
for k in range(n):
start = perf_counter_ns()
func()
n_arr[k] = perf_counter_ns() - start
if not first:
# delete the first element from n_arr numpy array
n_arr = np.delete(n_arr, 0)
r_arr[0, i] = np.mean(n_arr)
r_arr[1, i] = np.std(n_arr)
best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
#check if best[0] bigger than 1ms in numpy
if best[0] < 1e3:
precision = 'ns'
elif best[0] >= 1e9:
best[0] = best[0] * 1e-9
best[1] = best[1] * 1e-9
precision = 's'
elif best[0] >= 1e6:
best[0] = best[0] * 1e-6
best[1] = best[1] * 1e-6
precision = 'ms'
elif best[0] >= 1e3:
precision = 'μs'
best[0] = best[0] * 1e-3
best[1] = best[1] * 1e-3
if not quiet:
if precision == 'ns':
print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
if precision == 'μs':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 'ms':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 's':
print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
def rndinput(shape):
return torch.randint(0, 50256, shape).long().cuda()
def forward(model, x):
out = model.get_logits(x, act_ck=False)
print(out.shape)
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
loss = torch.nn.CrossEntropyLoss()(out, out)
loss.backward()
model.zero_grad()
print("torch.cuda.memory_allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
def main():
model = init_1_3b().cuda().half()
shape = (1, 2048)
#print(model(x).shape)
print("PyTorch Eager")
timeit(r=1, n=2, func=lambda: forward(model, rndinput(shape)), do_tqdm=False, first=False)
if __name__ == "__main__":
main()
...@@ -60,15 +60,19 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True ...@@ -60,15 +60,19 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
def rndinput(shape): def rndinput(shape):
return torch.randint(0, 50256, shape).long().cuda() return torch.randint(0, 50256, shape).long().cuda()
with torch.no_grad(): @torch.no_grad()
model = init_6b().cuda().half() def main():
shape = (1, 1) model = init_1_3b().cuda().half()
shape = (1, 2048)
x = torch.zeros(shape).cuda().long() x = torch.zeros(shape).cuda().long()
print(model(x).shape) print(model(x).shape)
print("PyTorch Eager") print("PyTorch Eager")
timeit(r=1, n=1, func=lambda: model(x), do_tqdm=False, first=True) timeit(r=1, n=10, func=lambda: model(x), do_tqdm=False, first=False)
with torch.jit.fuser("fuser2"): with torch.jit.fuser("fuser2"):
module = torch.jit.trace(model, torch.zeros(shape).long().cuda()) module = torch.jit.trace(model, torch.zeros(shape).long().cuda())
torch.jit.optimize_for_inference(module) torch.jit.optimize_for_inference(module)
print("PyTorch JIT") print("PyTorch JIT")
timeit(r=1, n=1, func=lambda: module(rndinput((1, 1))), do_tqdm=False, first=True) timeit(r=1, n=10, func=lambda: module(rndinput(shape)), do_tqdm=False, first=False)
if __name__ == "__main__":
main()
from transformers import GPTNeoForCausalLM, AutoConfig
import torch
from lm_train.utils import *
import math
class GPT:
def __init__(self, model_dtype="bf16", model_device="cuda"):
self.config = self.get_config(model_dtype, model_device)
self.checkpoint = self.get_checkpoint()
self.model = None
return
def get_config(self, model_dtype="bf16", model_device="cuda"):
print("Using device:", model_device)
config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
config.num_layers = 28
config.attention_layers = ["global"] * config.num_layers
config.attention_types = [["global"], config.num_layers]
config.num_heads = 16
config.hidden_size = 256 * config.num_heads
config.vocab_size = 50400
config.rotary = True
config.rotary_dim = 64
config.jax = True
config.model_dtype = model_dtype
config.model_device = model_device
if model_dtype == "bf16":
config.full_bf16 = True
return config
def get_checkpoint(self):
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
from pathlib import Path
class Checkpoint(MutableMapping):
def __init__(self, chkpt_dir, device="cpu"):
self.device = device
self.chkpt_dir = Path(chkpt_dir)
self.checkpoint = torch.load(str(chkpt_dir / Path("m.pt")))
def __len__(self):
return len(self.checkpoint)
def __getitem__(self, key):
path = self.chkpt_dir / Path(self.checkpoint[key]).name
return torch.load(str(path), map_location=self.device)
def __setitem__(self, key, value):
return
def __delitem__(self, key, value):
return
def keys(self):
return self.checkpoint.keys()
def __iter__(self):
for key in self.checkpoint:
yield (key, self.__getitem__(key))
def __copy__(self):
return Checkpoint(self.chkpt_dir, device=self.device)
def copy(self):
return Checkpoint(self.chkpt_dir, device=self.device)
return Checkpoint
def load_model(self, model_path=None, model_name=None, config=None, checkpoint=None):
if config == None:
config = self.config
if checkpoint == None:
Checkpoint = self.checkpoint
if model_name != None:
model_path = self.assign_path(model_name)
print("Loading model from: " + model_path)
model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=self.config, state_dict=Checkpoint(model_path)))
self.model = model
return model
def assign_path(self, model_name):
if model_name == "gptj":
return "/home/xuser/models/j6b_ckpt_14001"
# Raise error if model name not recognized
else:
raise ValueError("Model name not recognized")
def init_model(self, config=None, method='wang'):
neox_init = True
if config == None:
config = self.config
model = no_init(lambda: GPTNeoForCausalLM(config))
if neox_init:
modules = [*model.transformer.h[:-1], model.transformer.wte, model.transformer.ln_f]
init = small_init_method(self.config.hidden_size)
for module in modules:
for param in module.parameters():
init(param)
last_layer = model.transformer.h[-1]
last_layer_init = wang_init_method(self.config.num_layers, self.config.hidden_size)
for param in last_layer.parameters():
last_layer_init(param)
self.model = model
return model
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embs=None,
):
if isinstance(self.model, GPTNeoForCausalLM):
outputs = self.model(input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, embs)
# outputs: dict(loss, logits, past_key_values, hidden_states, attentions)
return outputs
#def init_module()
def wang_init_method(n_layers, dim):
std = 2 / n_layers / math.sqrt(dim)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
from torch import optim
import numpy as np
#Based Optimizer
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr):
warmup_percent = np.clip(step, 0, warmup_steps) / warmup_steps
anneal_percent = np.clip(step - warmup_steps, 0, anneal_steps) / anneal_steps
#cosine schedule for annealing
return lr * warmup_percent - (lr - end_lr) * (1 - np.cos(np.pi * anneal_percent)) / 2
class BasedOptimizer:
def __init__(self, parameters, config, optimizer):
self.lr = config["lr"]
self.end_lr = config["end_lr"] if "end_lr" in config else self.lr
self.warmup_steps = config["warmup_steps"] if "warmup_steps" in config else 1
self.anneal_steps = config["anneal_steps"] if "anneal_steps" in config else 1
self.total_steps = config["total_steps"] if "total_steps" in config else None
self.weight_decay = config["weight_decay"] if "weight_decay" in config else 0
self.tokens = config["tokens"] if "tokens" in config else None
self.epochs = config["epochs"] if "epochs" in config else None
# tokens and epochs should not be here. calculate it somewhere else and find how many steps, then pass to the BasedOptimizer
self.beta1 = config["beta1"] if "beta1" in config else 0.9
self.beta2 = config["beta2"] if "beta2" in config else 0.95
self.eps = config["eps"] if "eps" in config else 1e-4
self.max_lr = False
self.curr_step = 0
self.curr_lr = 0
if optimizer == "adamw":
self.optimizer = optim.AdamW(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
def step(self):
self.optimizer.step()
self.curr_step = self.curr_step + 1
self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
if not self.max_lr:
if self.curr_lr == self.end_lr:
print("max lr reached.")
self.max_lr = True
for paramx in self.optimizer.param_groups:
paramx['lr'] = self.curr_lr
def zero_grad(self):
self.optimizer.zero_grad()
def print_info(self):
print(f"end_lr: {str(self.end_lr)}")
print(f"warmup_steps: {str(self.warmup_steps)}")
print(f"total_steps: {str(self.total_steps)}")
print(f"weight_decay: {str(self.weight_decay)}")
print(f"step: {str(self.curr_step)}")
if self.curr_step != 0:
print(f"curr_lr: {str(self.get_current_lr())}")
\ No newline at end of file
from torch.utils import data
from transformers.modeling_utils import no_init_weights
import numpy as np
import torch
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None):
self.half_blocks = False
if block_size is not None and int(block_size) < 2048:
self.half_blocks = True
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, 2048))
self.samples = self.npz.shape[0]
if self.half_blocks:
self.samples *= 2
if not max_samples is None:
self.samples = min(self.samples, int(max_samples))
self.skip = 0
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
offset = 0
length = 2048
if self.half_blocks:
nth = _id // 2
offset = 1024 * (_id % 2)
length = 1024
data = torch.tensor(self.npz[nth][offset:offset+length].astype(np.int64))
return (data, data)
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
with no_init_weights():
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
# Count the parameters of a given pytorch model.
def count_parameters(model, only_trainable=False):
return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)
\ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat from einops import rearrange, repeat
try: try:
from collections.abc import MutableMapping from collections.abc import MutableMapping
...@@ -71,6 +72,7 @@ class SplitCheckpoint(MutableMapping): ...@@ -71,6 +72,7 @@ class SplitCheckpoint(MutableMapping):
def get_logits(x, embedding): def get_logits(x, embedding):
return embedding(x) return embedding(x)
@torch.jit.script
def gelu_new(x): def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
...@@ -84,7 +86,7 @@ def fixed_pos_embedding(dim=None, seq_len=None, x=None): ...@@ -84,7 +86,7 @@ def fixed_pos_embedding(dim=None, seq_len=None, x=None):
def rotate_every_two(x): def rotate_every_two(x):
x1 = x[:, :, :, ::2] x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2] x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), axis=-1) x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)') return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0): def apply_rotary_pos_emb(x, sincos, offset=0):
...@@ -129,7 +131,7 @@ def _attn(query, key, value, causal_mask, masked_bias, ...@@ -129,7 +131,7 @@ def _attn(query, key, value, causal_mask, masked_bias,
if attention_mask is not None: if attention_mask is not None:
attn_weights = attn_weights + attention_mask attn_weights = attn_weights + attention_mask
attn_weights = nn.Softmax(dim=-1)(attn_weights) attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype) attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype) attn_output = torch.matmul(attn_weights, value).to(value.dtype)
...@@ -202,32 +204,41 @@ class SelfAttention(nn.Module): ...@@ -202,32 +204,41 @@ class SelfAttention(nn.Module):
return x return x
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim=768, hidden_dim=768*4, activation=nn.GELU, device="cuda", dtype=torch.float16): def __init__(self, dim=768, hidden_dim=768*4, activation=nn.GELU(), device="cuda", dtype=torch.float16):
super(FeedForward, self).__init__() super(FeedForward, self).__init__()
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype) self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype) self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation self.activation = activation
def forward(self, x): def forward(self, x, act_ck=False):
x = self.ff1(x) x = self.ff1(x)
x = self.activation(x) if act_ck:
ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x) x = self.ff2(x)
return x return x
class GPTLayer(nn.Module): class GPTLayer(nn.Module):
def __init__(self, attn=SelfAttention, ff=FeedForward, hidden_dim=768, n_head=4, eps=1e-6, activation=nn.GELU, device="cuda", dtype=torch.float16): def __init__(self, attn=SelfAttention, ff=FeedForward, hidden_dim=768, n_head=4, eps=1e-6, activation=nn.GELU(), device="cuda", dtype=torch.float16):
super(GPTLayer, self).__init__() super(GPTLayer, self).__init__()
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype) self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype) self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype) self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
def forward(self, x, hypernetwork): def forward(self, x, hypernetwork=None, act_ck=False):
residual = x residual = x
x = self.ln_preattn(x) if act_ck:
attn_out = self.attn(x) x = ck(self.ln_preattn, x)
ff_out = self.ff(x) attn_out = ck(self.attn, x)
x = residual + ff_out + attn_out
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
ff_out = self.ff(x, act_ck)
x = residual + attn_out + ff_out
if hypernetwork: if hypernetwork:
hyper_out = hypernetwork(x) hyper_out = hypernetwork(x)
x = x + hyper_out x = x + hyper_out
...@@ -248,16 +259,15 @@ class GPTModel(nn.Module): ...@@ -248,16 +259,15 @@ class GPTModel(nn.Module):
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet. #TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too. #TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.
def forward(self, x, hypernetwork=None): def forward(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x) x = self.vocab_embed(x)
for layer in self.layers: for layer in self.layers:
x = layer(x, hypernetwork) x = layer(x, hypernetwork, act_ck)
x = self.ln_final(x) x = self.ln_final(x)
return x return x
def get_logits(self, x): def get_logits(self, x, act_ck=False):
x = self.forward(x) x = self.forward(x, act_ck=act_ck)
x = self.lm_head(x) x = self.lm_head(x)
return x.float() return x.float()
......
...@@ -23,6 +23,7 @@ env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-li ...@@ -23,6 +23,7 @@ env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-li
env1.sh('pip install einops numpy') env1.sh('pip install einops numpy')
env1.sh('pip install tqdm') env1.sh('pip install tqdm')
env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo') env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
env1.sh('pip3 install einops==0.4.1')
with always_rerun(): with always_rerun():
print(f"Running {sys.argv[1]}") print(f"Running {sys.argv[1]}")
path.sh(f'python3 {sys.argv[1]}') path.sh(f'python3 {sys.argv[1]}')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment