from basedformer import models, utils
import torch
config = {
    "n_layer": 6,
    "n_head": 8,
    "hidden_dim": 4096,
}
#init param matched GPT
gpt = models.gptj.GPTJModel(config).cuda().float()
utils.print_parameters(gpt)

#init param matched LSTM
lstm = torch.nn.LSTM(batch_first=True, input_size=4096, hidden_size=4096, num_layers=12).cuda().float()
utils.print_parameters(lstm)

x = torch.randint(0, 50256, (1, 1)).long().cuda()
y = torch.rand(1, 1, 4096).cuda().float()
with torch.no_grad():
    print("GPT:")
    utils.timeit(func=lambda: gpt(x), r=10, n=10)
    print("LSTM:")
    utils.timeit(func=lambda: lstm(y), r=10, n=10)