
from basedformer import lm_utils as lmu
from basedformer.models import hypernet
from basedformer import sampling
import os
from pathlib import Path
from basedformer.utils import * 
from transformers import AutoTokenizer
from icecream import ic
import time
import sys
from termcolor import colored

def main():
    save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs4-2e-4-catchup/step_1200"
    #save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-again/step_1200"
    #save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling"
    #cp_list = sorted(os.listdir(save_path), key=lambda x: int(x.split("_")[-1]))
    #last_cp = Path(save_path) / cp_list[-1] if len(cp_list) > 0 else None
    last_cp = Path(save_path)
    print(last_cp)
    bsz = 1
    gen_len = 1000
    #torch.manual_seed(69)
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    mask = "████████"
    prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
    #prompt = """'''Kurumuz''' is the founder of tech company [["""
    #promptnomask = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window,{mask} holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
    #prompt = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved{mask}, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window, holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
    tokens = tokenizer.encode(prompt)
    print(tokens)
    print("Prompt:")
    for x in range(len(tokens)):
        print(tokenizer.decode([tokens[x]]), end=" | ")
    print("\n Generation:")
    tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
    tokens = [tokens] * bsz
    #tokens = torch.cat([tokens, tokens], dim=0)
    tokens = torch.cat(tokens, dim=0)
    t = time.perf_counter()
    model = lmu.load_from_path('pretrained/fairseq_125m').cuda().bfloat16().eval()
    #hypernetwork = hypernet.HyperNetworkSingle(model.config).cuda().float()
    #print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
    #hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
    
    ic(time.perf_counter() - t)

    rep_pen = {
        "penalty": 5,
    }

    ops = {
        "rep_pen": rep_pen,
        "tfs": 0.86,
        "temp": 0.8,
    }
    ops_list = [ops] * bsz
    torch.manual_seed(69)
    tokens_generated = sampling.generate(model.forward, tokens, gen_len, ops_list=ops_list, hypernetwork=None, non_deterministic=True)
    #tokens_generated = sampling.generate_greedy(model.forward, tokens, gen_len, hypernetwork=hypernetwork)
    #tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
    #print(tokens_generated.shape)
    #tokens_generated[tokens_generated == 48585] = 35625
    #ic(prompt)
    tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
    for gen in tokens_generated:
        print(colored("==========================================================", "red"))
        print(colored(gen, "green"))
        print(colored("==========================================================", "red"))
    #ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
    #timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
    #timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)

if __name__ == "__main__":
    main()