from basedformer.models import gptj
from basedformer.utils import * 
from basedformer import lm_utils
from transformers import AutoTokenizer
from icecream import ic
import time
import sys

# TODO: Write a streamer for the sampler so we can decouple tokens_to_generate over the batch as well
# a lot more work as you need to schedule the forwards. Then we need a batcher, to look over a queue
# and take in the next batch items, without waiting for too long and selecting requests with sequence lengths and if possible
# generation lengths close.

# TODO: make the padding work to generate (need to take the logit before the padding starts instead of the last logit.)

def print_top_k(logits, tokenizer, k):
    topk_ind = logits.topk(k)[1]
    for x in range(topk_ind.shape[0]):
        for y in range(topk_ind.shape[1]):
            print("\nToken " + str(y))
            for token in topk_ind[x, y, :].tolist():
                print(tokenizer.decode([token]), end=" | ")

def apply_top_k(logits, k):
    # filter the logits that are not in the top-k to -inf
    # keep top_k_ind and filter the rest
    top_k_values = logits.topk(k)[0]
    remove_mask = logits < top_k_values[:, -1].unsqueeze(-1)
    logits[remove_mask == True] = -float("inf")
    return logits

def apply_top_p(logits, p):
    logits = torch.softmax(logits, dim=-1)
    sorted, indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(sorted, dim=-1)
    mask_tensor = cumulative_probs > p
    # Shift the indices to the right to keep also the first token above the threshold
    mask_tensor[..., 1:] = mask_tensor[..., :-1].clone()
    mask_tensor[..., 0] = 0
    mask_tensor = mask_tensor.scatter(dim=-1, index=indices, src=mask_tensor)
    logits[mask_tensor == True] = -float("inf")
    return logits

def apply_tfs(logits, tfs):
    logits = torch.softmax(logits, dim=-1)
    sorted, indices = torch.sort(logits, descending=True)
    d = sorted
    d = d[:, 1:] - d[:, :-1]
    d = d[:, 1:] - d[:, :-1]
    d = d.abs()
    d = d / d.sum(dim=-1).view(1, -1).T
    cumulative_probs = torch.cumsum(d, dim=-1)
    mask_tensor = torch.empty(indices.shape).cuda()
    mask_tensor[:, 1:-1] = (cumulative_probs > tfs)[:, :]
    # Always remove last token
    mask_tensor[:, -1:] = True

    # Always keep the first token
    mask_tensor[:, 0] = False
    mask_tensor = mask_tensor.scatter(dim=-1, index=indices, src=mask_tensor)
    logits[mask_tensor == True] = -float("inf")

    return logits

def apply_typical(logits, mass=0.9):
    scores = logits
    normalized = torch.nn.functional.log_softmax(scores, dim=-1)
    p = torch.exp(normalized)
    ent = -(normalized * p).nansum(-1, keepdim=True)

    # shift and sort
    shifted_scores = torch.abs((-normalized) - ent)
    sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
    sorted_logits = scores.gather(-1, sorted_indices)
    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

    # Remove tokens with cumulative mass above the threshold
    last_ind = (cumulative_probs < mass).sum(dim=1)
    last_ind[last_ind < 0] = 0
    sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

    scores = scores.masked_fill(indices_to_remove, -float("inf"))
    return scores

def apply_temp(logits, temperature):
    logits = logits / temperature
    return logits

def rep_pen(input_ids, scores, penalty, m=3.33, penalize_last=250,
            alpha_frequency=None, alpha_presence=None, whitelist=None,
            ):

    scores = torch.log_softmax(scores, dim=-1)
    penalty = 1.0 if penalty < 1.0 else penalty
    raw_penalty = penalty
    penalize_last = None
    if not m is None and not penalize_last is None and penalize_last >= 1:
        penalty = (torch.arange(penalize_last)/(penalize_last - 1)) * 2. - 1
        penalty = (m * penalty) / (1 + torch.abs(penalty) * (m - 1))
        penalty = 1 + ((penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
        penalize_last = penalize_last

    alpha_enable = alpha_frequency is not None or alpha_presence is not None
    whitelist = None
    whitelist_list = None
    if whitelist is not None:
        whitelist_list = whitelist

    ##########

    if whitelist is None and whitelist_list is not None:
        whitelist_list = list(filter(lambda x: x >= 0 and x < scores.shape[1], whitelist_list))
        if len(whitelist_list) > 0:
            whitelist = torch.tensor(whitelist_list).long().sort()[0]
            whitelist = whitelist.to(input_ids.device)

    if whitelist is not None:
        unpenalized = scores.gather(1, whitelist.view(1, -1))

    if raw_penalty > 1.0:
        if not penalize_last is None:
            penality_len = min(input_ids.shape[1], penalize_last)
            input_ids = input_ids[:, -penality_len:]
        score = torch.gather(scores, 1, input_ids)

        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
        if not penalize_last is None:
            penalty = penalty.type(score.dtype).to(score.device)
            score = torch.where(score < 0, score * penalty[:, -penality_len:], score / penalty[:, -penality_len:])
        else:
            score = torch.where(score < 0, score * penalty, score / penalty)

        scores.scatter_(1, input_ids, score)

    if alpha_enable:
        c = torch.zeros(scores.shape).long().to(input_ids.device)
        # unique only returns counts for first item in batch, so manually iterate
        for i in range(input_ids.shape[0]):
            if penalize_last is not None:
                token_input_ids, counts = torch.unique(input_ids[i,-penalize_last:], sorted=True, return_counts=True, dim=-1)
            else:
                token_input_ids, counts = torch.unique(input_ids[i], sorted=True, return_counts=True, dim=-1)
            c[i].scatter_(0, token_input_ids, counts)
        if alpha_frequency:
            scores -= c * alpha_frequency
        if alpha_presence:
            scores[c > 0] -= alpha_presence

    if whitelist is not None:
        scores.scatter_(1, whitelist.view(1, -1), unpenalized)

    return scores

def func_multinomial(x):
    torch.manual_seed(69)
    return torch.multinomial(x, 1)

@torch.no_grad()
def generate_greedy(forward, prompt_tokens, tokens_to_generate=50, hypernetwork=None):
    in_tokens = prompt_tokens
    padding_token = 50256
    generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device)
    kv = None
    for i in range(tokens_to_generate):
        logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
        logits = logits[:, -1, :] #get the last token in the seq
        # get the token before the padding_token in the seq        
        logits = logits.argmax(dim=-1).unsqueeze(-1)

        generated = torch.cat([generated, logits], dim=-1)
        in_tokens = logits

    return generated

@torch.no_grad()
def generate(forward, prompt_tokens, tokens_to_generate=50, ds=False, ops_list=[{"temp": 0.9}], hypernetwork=None, non_deterministic=False, fully_deterministic=False):
    in_tokens = prompt_tokens
    context = prompt_tokens
    generated = torch.zeros(len(ops_list), 0, dtype=torch.long).to(in_tokens.device)
    kv = None
    if non_deterministic:
        torch.seed()
    #soft_required = ["top_k", "top_p"]
    op_map = {
        "top_k": apply_top_k,
        "top_p": apply_top_p,
        "typical": apply_typical,
        "temp": apply_temp,
        "tfs": apply_tfs,
        "rep_pen": rep_pen,
    }

    for _ in range(tokens_to_generate):
        if ds:
            logits, kv = forward(in_tokens, past_key_values=kv, use_cache=True)
        else:
            logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
        logits = logits[:, -1, :] #get the last token in the seq
        logits = torch.log_softmax(logits, dim=-1)
        #if kv[0][0].shape[0] == 1 and (kv[0][0].shape[0] != len(ops_list)):
        #    t = time.perf_counter()
        #    for layer in kv:
        #        for i in range(len(layer)):
        #            layer[i] = layer[i].repeat(len(ops_list), 1, 1, 1)
        #    ic("replicated kv")
        #    ic(time.perf_counter() - t)

        #if logits.shape[0] == 1 and (logits.shape[0] != len(ops_list)):
        #    logits = logits.repeat(len(ops_list), 1)
        #    ic("replicated logits")
        #if context.shape[0] == 1 and (context.shape[0] != len(ops_list)):
        #    context = context.repeat(len(ops_list), 1)
        #    ic("replicated context")
        #can save one softmax here by not applying softmax for the first op,
        #need to take the softmax out of the necessary functions though
        batch = []
        for i, ops in enumerate(ops_list):
            item = logits[i, ...].unsqueeze(0)
            ctx = context[i, ...].unsqueeze(0)
            for op, value in ops.items():
                if op == "rep_pen":
                    item = op_map[op](ctx, item, **value)

                else:
                    item = op_map[op](item, value)
                
            batch.append(item)

        logits = torch.cat(batch, dim=0)
        logits = torch.softmax(logits, dim=-1)

        #fully_deterministic makes it deterministic across the batch
        if fully_deterministic:
            logits = logits.split(1, dim=0)
            logit_list = []
            for logit in logits:
                torch.manual_seed(69)
                logit_list.append(torch.multinomial(logit, 1))

            logits = torch.cat(logit_list, dim=0)

        else:
            logits = torch.multinomial(logits, 1)

        generated = torch.cat([generated, logits], dim=-1)
        context = torch.cat([context, logits], dim=-1)
        in_tokens = logits

    return generated

def generate_real_batched(forward, prompt_tokens, tokens_to_generate=50, ops_list={"temp": 0.9}, hypernetwork=None):
    with torch.no_grad():
        in_tokens = prompt_tokens
        kv = None
        fully_deterministic = False
        tokens_generated = []
        op_map = {
            "top_k": apply_top_k,
            "top_p": apply_top_p,
            "typical": apply_typical,
            "temp": apply_temp,
            "tfs": apply_tfs
        }

        for _ in range(tokens_to_generate):
            logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
            logits = logits[:, -1, :] #get the last token in the seq
            logits = torch.log_softmax(logits, dim=-1)

            for op, value in ops_list.items():
                logits = op_map[op](logits, value).float()
    
            logits = torch.softmax(logits, dim=-1).float()
            
            if fully_deterministic:
                logits = logits.split(1, dim=0)
                logit_list = []
                for logit in logits:
                    torch.manual_seed(69)
                    logit_list.append(torch.multinomial(logit, 1))

                logits = torch.cat(logit_list, dim=0)

            else:
                torch.manual_seed(69)
                logits = torch.multinomial(logits, 1)

            in_tokens = logits
            tokens_generated.append(logits)

        tokens_generated = torch.cat(tokens_generated, dim=-1)
        return tokens_generated


if __name__ == "__main__":
    #model = lm_utils.load_from_path("/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gptj-6b").half().cuda().eval()
    params = {
        "model_class": "gptj",
        'n_layer': 44,
        'n_head': 64,
        'n_tokens': 2048,
        'hidden_dim': 6144,
    }
    model = lm_utils.no_init(params).half().cuda().eval()
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    print_parameters(model)
    long_prompt = """_3 May. Bistritz._--Left Munich at 8:35 P. M., on 1st May, arriving at
    Vienna early next morning; should have arrived at 6:46, but train was an
    hour late. Buda-Pesth seems a wonderful place, from the glimpse which I
    got of it from the train and the little I could walk through the
    streets. I feared to go very far from the station, as we had arrived
    late and would start as near the correct time as possible. The
    impression I had was that we were leaving the West and entering the
    East; the most western of splendid bridges over the Danube, which is
    here of noble width and depth, took us among the traditions of Turkish
    rule.

    We left in pretty good time, and came after nightfall to Klausenburgh.
    Here I stopped for the night at the Hotel Royale. I had for dinner, or
    rather supper, a chicken done up some way with red pepper, which was
    very good but thirsty. (_Mem._, get recipe for Mina.) I asked the
    waiter, and he said it was called "paprika hendl," and that, as it was a
    national dish, I should be able to get it anywhere along the
    Carpathians. I found my smattering of German very useful here; indeed, I
    don't know how I should be able to get on without it.

    Having had some time at my disposal when in London, I had visited the
    British Museum, and made search among the books and maps in the library
    regarding Transylvania; it had struck me that some foreknowledge of the
    country could hardly fail to have some importance in dealing with a
    nobleman of that country. I find that the district he named is in the
    extreme east of the country, just on the borders of three states,
    Transylvania, Moldavia and Bukovina, in the midst of the Carpathian
    mountains; one of the wildest and least known portions of Europe. I was
    not able to light on any map or work giving the exact locality of the
    Castle Dracula, as there are no maps of this country as yet to compare
    with our own Ordnance Survey maps; but I found that Bistritz, the post
    town named by Count Dracula, is a fairly well-known place. I shall enter
    here some of my notes, as they may refresh my memory when I talk over my
    travels with Mina.

    In the population of Transylvania there are four distinct nationalities:
    Saxons in the South, and mixed with them the Wallachs, who are the
    descendants of the Dacians; Magyars in the West, and Szekelys in the
    East and North. I am going among the latter, who claim to be descended
    from Attila and the Huns. This may be so, for when the Magyars conquered
    the country in the eleventh century they found the Huns settled in it. I
    read that every known superstition in the world is gathered into the
    horseshoe of the Carpathians, as if it were the centre of some sort of
    imaginative whirlpool; if so my stay may be very interesting. (_Mem._, I
    must ask the Count all about them.)

    I did not sleep well, though my bed was comfortable enough, for I had
    all sorts of queer dreams. There was a dog howling all night under my
    window, which may have had something to do with it; or it may have been
    the paprika, for I had to drink up all the water in my carafe, and was
    still thirsty. Towards morning I slept and was wakened by the continuous
    knocking at my door, so I guess I must have been sleeping soundly then.
    I had for breakfast more paprika, and a sort of porridge of maize flour
    which they said was "mamaliga," and egg-plant stuffed with forcemeat, a
    very excellent dish, which they call "impletata." (_Mem._, get recipe
    for this also.) I had to hurry breakfast, for the train started a little
    before eight, or rather it ought to have done so, for after rushing to
    the station at 7:30 I had to sit in the carriage for more than an hour
    before we began to move. It seems to me that the further east you go the
    more unpunctual are the trains. What ought they to be in China?

    All day long we seemed to dawdle through a country which was full of
    beauty of every kind. Sometimes we saw little towns or castles on the
    top of steep hills such as we see in old missals; sometimes we ran by
    rivers and streams which seemed from the wide stony margin on each side
    of them to be subject to great floods. It takes a lot of water, and
    running strong, to sweep the outside edge of a river clear. At every
    station there were groups of people, sometimes crowds, and in all sorts
    of attire. Some of them were just like the peasants at home or those I
    saw coming through France and Germany, with short jackets and round hats
    and home-made trousers; but others were very picturesque. The women
    looked pretty, except when you got near them, but they were very clumsy
    about the waist. They had all full white sleeves of some kind or other,
    and most of them had big belts with a lot of strips of something
    fluttering from them like the dresses in a ballet, but of course there
    were petticoats under them. The strangest figures we saw were the
    Slovaks, who were more barbarian than the rest, with their big cow-boy
    hats, great baggy dirty-white trousers, white linen shirts, and enormous
    heavy leather belts, nearly a foot wide, all studded over with brass
    nails. They wore high boots, with their trousers tucked into them, and
    had long black hair and heavy black moustaches. They are very
    picturesque, but do not look prepossessing. On the stage they would be
    set down at once as some old Oriental band of brigands. They are,
    however, I am told, very harmless and rather wanting in natural
    self-assertion.

    It was on the dark side of twilight when we got to Bistritz, which is a
    very interesting old place. Being practically on the frontier--for the
    Borgo Pass leads from it into Bukovina--it has had a very stormy
    existence, and it certainly shows marks of it. Fifty years ago a series
    of great fires took place, which made terrible havoc on five separate
    occasions. At the very beginning of the seventeenth century it underwent
    a siege of three weeks and lost 13,000 people, the casualties of war
    proper being assisted by famine and disease.

    Count Dracula had directed me to go to the Golden Krone Hotel, which I
    found, to my great delight, to be thoroughly old-fashioned, for of
    course I wanted to see all I could of the ways of the country. I was
    evidently expected, for when I got near the door I faced a
    cheery-looking elderly woman in the usual peasant dress--white
    undergarment with long double apron, front, and back, of coloured stuff
    fitting almost too tight for modesty. When I came close she bowed and
    said, "The Herr Englishman?" "Yes," I said, "Jonathan Harker." She
    smiled, and gave some message to an elderly man in white shirt-sleeves,
    who had followed her to the door. He went, but immediately returned with
    a letter:"""
def main():
    bsz = 4
    gen_len = 250
    torch.manual_seed(69)
    prompt = """I fucked her with my huge donut, when she seen my donut she went"""
    prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
    tokens = tokenizer.encode(long_prompt)

    #print("Prompt:")
    #for x in range(len(tokens)):
    #    print(tokenizer.decode([tokens[x]]), end=" | ")
    #print("\n Generation:")
    tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
    tokens = [tokens] * bsz
    #tokens = torch.cat([tokens, tokens], dim=0)
    tokens = torch.cat(tokens, dim=0)
    t = time.perf_counter()
    
    ic(time.perf_counter() - t)

    rep_pen = {
        "penalty": 3,
    }

    ops = {
        "rep_pen": rep_pen,
        "top_k": 50,
        "temp": 0.8,
    }
    ops_list = [ops] * bsz

    timeit(lambda: generate(model.forward, tokens, gen_len, ops_list=ops_list), n=1, r=1)
    tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
    print(tokens_generated.shape)
    #ic(prompt)
    tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
    for gen in tokens_generated:
        print(str(gen))
        print("===========================================================")

def same_cached_prompt_batched():
    bsz = 2
    gen_len = 1
    torch.manual_seed(69)
    prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
    tokens = tokenizer.encode(long_prompt)
    #print("Prompt:")
    #for x in range(len(tokens)):
    #    print(tokenizer.decode([tokens[x]]), end=" | ")
    print("\n Generation:")
    tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()

    rep_pen = {
        "penalty": 3,
    }

    ops = {
        "rep_pen": rep_pen,
        "top_k": 50,
        "temp": 0.8,
    }
    ops_list = [ops] * bsz

    timeit(lambda: generate(model.forward, tokens, gen_len, ops_list=ops_list), n=5, r=1)
    tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
    print(tokens_generated.shape)
    #ic(prompt)
    tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
    for gen in tokens_generated:
        print(str(gen))
        print("===========================================================")

def no_cached_prompt_batched():
    return
def single_gen():
    bsz = 4
    gen_len = 250
    torch.manual_seed(69)
    prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
    tokens = tokenizer.encode(long_prompt)
    #print("Prompt:")
    #for x in range(len(tokens)):
    #    print(tokenizer.decode([tokens[x]]), end=" | ")
    print("\n Generation:")
    tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
    tokens = [tokens] * 4
    tokens = torch.cat(tokens, dim=0)

    rep_pen = {
        "penalty": 3,
    }

    ops = {
        "rep_pen": rep_pen,
        "top_k": 50,
        "temp": 0.8,
    }
    ops_list = [ops] * 4

    timeit(lambda: generate(model.forward, tokens, gen_len, ops_list=ops_list), n=1, r=1)
    tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
    print(tokens_generated.shape)
    #ic(prompt)
    tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
    for gen in tokens_generated:
        print(str(gen))
        print("===========================================================")
if __name__ == "__main__":
    #main()
    same_cached_prompt_batched()
    #single_gen()