Commit 9223ef70 authored by novelailab's avatar novelailab

push

parent a1556cf7
...@@ -83,7 +83,7 @@ class SelfAttention(nn.Module): ...@@ -83,7 +83,7 @@ class SelfAttention(nn.Module):
offset = kv[0].shape[-2] offset = kv[0].shape[-2]
else: else:
offset = 0 offset = 0
if self.rotary_dim < self.head_dim: if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim] k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:] k_pass = key[:, :, :, self.rotary_dim:]
......
...@@ -10,6 +10,8 @@ import math ...@@ -10,6 +10,8 @@ import math
from torch.utils import data from torch.utils import data
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
import time
# Does this work with other block_sizes? doesn't seem to. # Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset): class FbDataset(data.Dataset):
...@@ -92,5 +94,57 @@ class SplitCheckpoint(MutableMapping): ...@@ -92,5 +94,57 @@ class SplitCheckpoint(MutableMapping):
def copy(self): def copy(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device) return SplitCheckpoint(self.chkpt_dir, device=self.device)
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True, cuda_blocking=False):
precision = 'ns'
r_arr = np.empty([2, r]) # [0] = mean, [1] = std
if function:
func.__name__ = function.__name__
for i in tqdm(range(r)) if do_tqdm else range(r):
n_arr = np.empty(n)
for k in range(n):
start = time.perf_counter_ns()
torch.cuda.synchronize()
func()
torch.cuda.synchronize()
n_arr[k] = time.perf_counter_ns() - start
if not first:
# delete the first element from n_arr numpy array
n_arr = np.delete(n_arr, 0)
r_arr[0, i] = np.mean(n_arr)
r_arr[1, i] = np.std(n_arr)
best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
#check if best[0] bigger than 1ms in numpy
if best[0] < 1e3:
precision = 'ns'
elif best[0] >= 1e9:
best[0] = best[0] * 1e-9
best[1] = best[1] * 1e-9
precision = 's'
elif best[0] >= 1e6:
best[0] = best[0] * 1e-6
best[1] = best[1] * 1e-6
precision = 'ms'
elif best[0] >= 1e3:
precision = 'μs'
best[0] = best[0] * 1e-3
best[1] = best[1] * 1e-3
if not quiet:
if precision == 'ns':
print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
if precision == 'μs':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 'ms':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 's':
print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
def gelu_new(x): def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
\ No newline at end of file
...@@ -13,7 +13,7 @@ bash = False ...@@ -13,7 +13,7 @@ bash = False
config_obj = KubeConfig() config_obj = KubeConfig()
config_obj.set_name(name) config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.A100_PCIE_40GB, amount=1) config_obj.set_gpu(gpu_name=GPU.RTX_A6000, amount=1)
config_obj.set_ram(16) config_obj.set_ram(16)
config_obj.set_cpu(4) config_obj.set_cpu(4)
config_obj.dry_run(dry) config_obj.dry_run(dry)
...@@ -32,6 +32,9 @@ env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/tra ...@@ -32,6 +32,9 @@ env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/tra
env1.sh('pip3 install einops==0.4.1 pyyaml wandb') env1.sh('pip3 install einops==0.4.1 pyyaml wandb')
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4') env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap icecream') env1.sh('pip3 install dotmap icecream')
path.sh("pip3 install --editable .")
#path.sh("pip3 uninstall torch")
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
with always_rerun(): with always_rerun():
if bash: if bash:
path.sh("bash") path.sh("bash")
......
import torch
import functorch
def example_func(x, config):
print(x)
print(config)
return nested_test(x)
def nested_test(x):
x = x * 2
return x
config = {0: "huh", 1: "fuck"}
x = torch.randn(2, 50400).cuda().float()
print(x)
vectorized = functorch.vmap(func=example_func, in_dims=(0, None), out_dims=0)
y = vectorized(x, config)
print(y)
print(y.shape)
\ No newline at end of file
import torch
import time
from basedformer.utils import *
x = torch.randn(1, 2048, 4096).cuda().float()
attn_weights = torch.matmul(x, x.transpose(-1, -2))
print(attn_weights.shape)
attn_weights = torch.randn(1, 1, 50400).cuda().float()
timeit(lambda: torch.log_softmax(attn_weights, dim=-1), n=100, cuda_blocking=True)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment