from basedformer import gptj
from basedformer.utils import * 
from transformers import AutoTokenizer
from icecream import ic
import time
import sys

def print_top_k(logits, tokenizer, k):
    topk_ind = logits.topk(k)[1]
    for x in range(topk_ind.shape[0]):
        for y in range(topk_ind.shape[1]):
            print("\nToken " + str(y))
            for token in topk_ind[x, y, :].tolist():
                print(tokenizer.decode([token]), end=" | ")

def main():
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    prompt = """I fucked her with my huge donut, when she seen my donut she went"""
    tokens = tokenizer.encode(prompt)
    print("Prompt:")
    for x in range(len(tokens)):
        print(tokenizer.decode([tokens[x]]), end=" | ")
    print("\n Generation:")
    tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
    t = time.perf_counter()
    model = gptj.load_gpt_j().cuda().half().eval()
    model = model.lm
    
    ic(time.perf_counter() - t)

    with torch.no_grad():
        kv = None
        tokens_to_generate = 50
        in_tokens = tokens
        accum_tokens = []
        for x in range(tokens_to_generate):
            logits, kv = model(in_tokens, cache=True, kv=kv)
            in_tokens = logits[:, -1, :].topk(1)[1]
            #in_tokens = torch.cat([in_tokens, logits[:, -1, :].topk(1)[1]], dim=1)
            print(tokenizer.decode(in_tokens.squeeze(1).tolist()[-1]), end=" | ")

        #accum_tokens = torch.cat(accum_tokens, dim=1)
        #accum_tokens = accum_tokens.squeeze(0).tolist()
        #print("\n Final token list")
        #print(tokenizer.decode(accum_tokens))
if __name__ == "__main__":
    main()