from main import *
import time

from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
from basedformer.hypernet import * 
import sys
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
    precision = 'ns'
    r_arr = np.empty([2, r]) # [0] = mean, [1] = std
    if function:
        func.__name__ = function.__name__

    for i in tqdm(range(r)) if do_tqdm else range(r):
        n_arr = np.empty(n)
        for k in range(n):
            start = perf_counter_ns()
            func()
            n_arr[k] = perf_counter_ns() - start
        
        if not first:
            # delete the first element from n_arr numpy array
            n_arr = np.delete(n_arr, 0)

        r_arr[0, i] = np.mean(n_arr)
        r_arr[1, i] = np.std(n_arr)
    
    best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
    #check if best[0] bigger than 1ms in numpy
    if best[0] < 1e3:
        precision = 'ns'

    elif best[0] >= 1e9:
        print('b')
        best[0] = best[0] * 1e-9
        best[1] = best[1] * 1e-9
        precision = 's'

    elif best[0] >= 1e6:
        best[0] = best[0] * 1e-6
        best[1] = best[1] * 1e-6
        precision = 'ms'

    elif best[0] >= 1e3:
        precision = 'μs'
        best[0] = best[0] * 1e-3
        best[1] = best[1] * 1e-3

    if not quiet:
        if precision == 'ns':
            print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        if precision == 'μs':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 'ms':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 's':
            print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")


def test_thing(graph, input):
    torch.cuda.synchronize()
    static_input.copy_(input)
    graph.replay()
    torch.cuda.synchronize()

model_config = {
    "n_layer": 28,
    "n_head": 16,
    "hidden_dim": 4096,
    "vocab_dim": 50400,
    "eps": 1e-5,
    "activation": gelu_new,
    "Layer": GPTLayer
}

with torch.no_grad():
    model = init_6b().cuda().bfloat16()
    shape = (1, 2048)
    hypernet = HyperNetworkSingle(model_config).cuda()
    x = torch.zeros(shape).cuda().long()
    print(shape)
    print("PyTorch Eager")
    timeit(r=1, n=100, func=lambda: model(x, hypernetwork=None), do_tqdm=False, first=False)
    print("PyTorch Eager + Hypernet")
    timeit(r=1, n=100, func=lambda: model(x, hypernetwork=hypernet), do_tqdm=False, first=False)
    sys.exit(0)
    print("PyTorch CUDAGraph+JIT+NVFuser")
    with torch.jit.fuser("fuser2"):
        module = torch.jit.trace(model, torch.zeros(shape).long().cuda())
        torch.jit.optimize_for_inference(module)
    static_input = torch.randint(0, 50256, shape, device='cuda')
    fake_inputs = [torch.randint(0, 50256, shape, device="cuda") for _ in range(100)]
    real_inputs = [torch.randint(0, 50256, shape, device="cuda") for _ in range(100)]
    s = torch.cuda.Stream()
    s.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(s):
        for y in fake_inputs:
            static_output = module(y)

    torch.cuda.current_stream().wait_stream(s)

    g = torch.cuda.CUDAGraph()
    with torch.cuda.graph(g):
        static_output = module(static_input)

    timeit(func=lambda: test_thing(g, static_input), r=1, n=100, do_tqdm=False, first=False)


            
