from basedformer import lm_base
from basedformer.utils import * 
import time

import torch
from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM,
    GPTNeoForCausalLM,
    AutoConfig,
)
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
    precision = 'ns'
    r_arr = np.empty([2, r]) # [0] = mean, [1] = std
    if function:
        func.__name__ = function.__name__

    for i in tqdm(range(r)) if do_tqdm else range(r):
        n_arr = np.empty(n)
        for k in range(n):
            start = perf_counter_ns()
            func()
            n_arr[k] = perf_counter_ns() - start
        
        if not first:
            # delete the first element from n_arr numpy array
            n_arr = np.delete(n_arr, 0)

        r_arr[0, i] = np.mean(n_arr)
        r_arr[1, i] = np.std(n_arr)
    
    best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
    #check if best[0] bigger than 1ms in numpy
    if best[0] < 1e3:
        precision = 'ns'

    elif best[0] >= 1e9:
        print('b')
        best[0] = best[0] * 1e-9
        best[1] = best[1] * 1e-9
        precision = 's'

    elif best[0] >= 1e6:
        best[0] = best[0] * 1e-6
        best[1] = best[1] * 1e-6
        precision = 'ms'

    elif best[0] >= 1e3:
        precision = 'μs'
        best[0] = best[0] * 1e-3
        best[1] = best[1] * 1e-3

    if not quiet:
        if precision == 'ns':
            print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        if precision == 'μs':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 'ms':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 's':
            print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")


with torch.no_grad():
    based_model = lm_base.load_gpt_j().cuda().half().eval()
    based_model = based_model.lm
    print("Loaded based model")
    hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval()
    print("Loaded hf model")
    x = torch.randint(0, 50256, (1, 2048)).cuda().long()

    assert torch.allclose(hf_model.transformer.wte(x), based_model.vocab_embed(x))
    hidden = hf_model.transformer.wte(x)
    for layer in range(28):
        assert torch.allclose(hf_model.transformer.h[layer].ln_1(hidden), based_model.layers[layer].ln_preattn(hidden))
        hidden = hf_model.transformer.h[layer].ln_1(hidden)
        assert torch.allclose(hf_model.transformer.h[layer].mlp(hidden), based_model.layers[layer].ff(hidden))
        hidden = hf_model.transformer.h[layer].mlp(hidden)
        assert torch.allclose(hf_model.transformer.h[layer].attn(hidden)[0], based_model.layers[layer].attn(hidden))
        hidden = hf_model.transformer.h[layer].attn(hidden)[0]
        assert torch.allclose(hf_model.transformer.h[layer](hidden)[0], based_model.layers[layer](hidden))

    assert torch.allclose(hf_model.transformer.ln_f(hidden), based_model.ln_final(hidden))
    hidden = hf_model.transformer.ln_f(hidden)
    assert torch.allclose(hf_model.transformer(x)["last_hidden_state"], based_model.get_embeds(x))
    assert torch.allclose(hf_model(x)["logits"], based_model(x))