Commit d57cfcec authored by novelailab's avatar novelailab

good benchmark. pepe sad

parent 7a721f81
...@@ -5,11 +5,6 @@ from time import perf_counter, perf_counter_ns ...@@ -5,11 +5,6 @@ from time import perf_counter, perf_counter_ns
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager from contextlib import contextmanager
from transformers import (
AutoModelForCausalLM,
GPTNeoForCausalLM,
AutoConfig,
)
#replicating timeit magic function of ipython #replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True): def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
precision = 'ns' precision = 'ns'
...@@ -64,13 +59,12 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True ...@@ -64,13 +59,12 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with torch.no_grad(): with torch.no_grad():
model = load_gpt_j().cuda().half().eval() model = load_gpt_j().cuda().half()
x = torch.zeros(1, 2048).cuda().long() x = torch.zeros(1, 2048).cuda().long()
our = model(x) print(model(x).shape)
print(our.shape) print("PyTorch Eager")
del model timeit(r=1, n=100, func=lambda: model(x), do_tqdm=False, first=False)
model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval() module = torch.jit.trace(model, torch.zeros((1, 2048)).long().cuda())
hf = model(x, output_hidden_states=True)["hidden_states"][-1] torch.jit.optimize_for_inference(module)
print(our[0, 2047, 1000:1020]) print("PyTorch JIT")
print(hf[0, 2047, 1000:1020]) timeit(r=1, n=100, func=lambda: module(x), do_tqdm=False, first=False)
print(hf.shape) \ No newline at end of file
\ No newline at end of file
from main import * from main import *
state_dict = SplitCheckpoint("/home/xuser/models/j6b_ckpt_14001", device="cpu") state_dict = SplitCheckpoint("j6b_vanilla", device="cpu")
# ORIGINAL # ORIGINAL
...@@ -62,4 +62,4 @@ def save(state_dict, path): ...@@ -62,4 +62,4 @@ def save(state_dict, path):
torch.save(x[1], f"{path}/b{i}.pt") torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt") torch.save(checkpoint, f"{path}/m.pt")
save(new_state_dict, "models/6b") save(new_state_dict, "models/6b_vanilla")
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment