from basedformer.utils import * 
import basedformer.lm_utils as lmu
from fairseq.models.transformer_lm import TransformerLanguageModel
import time

import torch
from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
import torch.nn.functional as F
from transformers import GPTNeoForCausalLM
from icecream import ic

with torch.no_grad():
    model_dir = '/home/xuser/diffusionstorage/models/fairseq/converted/en_dense_lm_125m/'
    hf_model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(model_dir)).cuda().half().eval()
    print("Loaded hf model")
    path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/fairseq_125m"
    based_model = lmu.load_from_path(path).cuda().half().eval()
    print("Loaded based model")
    x = torch.randint(0, 51200, (1, 300)).cuda().long()

    assert torch.allclose(hf_model.transformer.wte(x), based_model.vocab_embed(x))
    hidden = hf_model.transformer.wte(x)
    for layer in range(len(based_model.layers)):
        ic(layer)
        residual = hidden
        #ln_preattn
        assert torch.allclose(hf_model.transformer.h[layer].ln_1(hidden), based_model.layers[layer].ln_preattn(hidden))
        hidden = hf_model.transformer.h[layer].ln_1(hidden)
        #attn
        ic(hf_model.transformer.h[layer].attn(hidden)[0].abs().mean())
        ic(based_model.layers[layer].attn(hidden)[0].abs().mean())
        ic((hf_model.transformer.h[layer].attn(hidden)[0] - based_model.layers[layer].attn(hidden)[0]).abs().mean())
        assert torch.allclose(hf_model.transformer.h[layer].attn(hidden)[0], based_model.layers[layer].attn(hidden)[0], rtol=1e-6)
        attn_out = hf_model.transformer.h[layer].attn(hidden)[0]
        hidden = residual + attn_out
        residual = hidden
        assert torch.allclose(hf_model.transformer.h[layer].ln_2(hidden), based_model.layers[layer].ln_postattn(hidden))
        hidden = hf_model.transformer.h[layer].ln_2(hidden)
        #ffn
        assert torch.allclose(hf_model.transformer.h[layer].mlp(hidden), based_model.layers[layer].ff(hidden))
        ff_out = hf_model.transformer.h[layer].mlp(hidden)
        hidden = residual + ff_out
        assert torch.allclose(hf_model.transformer.h[layer](hidden)[0], based_model.layers[layer](hidden)[0])

    ic(hf_model(x)["logits"].abs().mean())
    ic(based_model(x).abs().mean())
    assert torch.allclose(hf_model.transformer.ln_f(hidden), based_model.ln_final(hidden))
    hidden = hf_model.transformer.ln_f(hidden)
    assert torch.allclose(hf_model.transformer(x)["last_hidden_state"], based_model.get_embeds(x)[0])
    assert torch.allclose(hf_model(x)["logits"], based_model(x))