from basedformer import models, utils
import torch
config = {
    "n_layer": 28,
    "n_head": 16,
    "hidden_dim": 4096,
}

config = {
    "n_layer": 40,
    "n_head": 40,
    "hidden_dim": 5120,
}

config_q = {**config, "q_only":True}
#init param matched GPT
gpt = models.fairseq.GPTFairModel(config).cuda().half()
utils.print_parameters(gpt)
bsz = 3
cached_seq = 1000
y = torch.randint(0, 50256, (bsz, cached_seq)).long().cuda()
x = torch.randint(0, 50256, (bsz, 1)).long().cuda()
cache_f = torch.rand(bsz, config["n_head"], cached_seq, config["hidden_dim"]//config["n_head"]).cuda().half()
cache_f = (cache_f, cache_f)
cache_f = [cache_f for _ in range(config["n_layer"])]
print(len(cache_f))
print(cache_f[0][1].shape)
######
cache_q = torch.rand(bsz, 1, cached_seq, config["hidden_dim"]//config["n_head"]).cuda().half()
cache_q = (cache_q, cache_q)
cache_q = [cache_q for _ in range(config["n_layer"])]
print(cache_q[0][0].shape)
with torch.no_grad():
    #print("Initial Context GPT:")
    #utils.timeit(func=lambda: gpt(y), r=10, n=10)
    out = gpt(y, cache=True)
    print(out[1][0][0].shape)
    print("GPT")
    utils.timeit(func=lambda: gpt(x, kv=cache_f), r=10, n=10)


'''
del gpt
#init param matched Q-Only
gpt_q = models.gptj.GPTJModel(config_q).cuda().half()
utils.print_parameters(gpt_q)

with torch.no_grad():
    #print("Initial Context GPT-Q:")
    #utils.timeit(func=lambda: gpt_q(y), r=10, n=10)
    out_q = gpt_q(y, cache=True)
    print("GPT-Q:")
    utils.timeit(func=lambda: gpt_q(x, kv=cache_q), r=10, n=10)
'''