from typing import KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
    from collections.abc import MutableMapping
except ImportError:
    from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer import lm_base

def fixed_pos_embedding(dim=None, seq_len=None, x=None):
    if x is None:
        x = torch.empty(0)
    inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
    sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
    return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)

def rotate_every_two(x):
    x1 = x[:, :, :, ::2]
    x2 = x[:, :, :, 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return rearrange(x, '... d j -> ... (d j)')

def apply_rotary_pos_emb(x, sincos, offset=0):
    sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
    return (x * cos) + (rotate_every_two(x) * sin)

def _attn(query, key, value, causal_mask, masked_bias,
            attention_mask=None, scale_attn=None):

    attn_weights = torch.matmul(query, key.transpose(-1, -2))
    attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
    attn_weights = attn_weights / scale_attn

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = F.softmax(attn_weights, dim=-1)
    attn_weights = attn_weights.to(value.dtype)

    attn_output = torch.matmul(attn_weights, value).to(value.dtype)

    return attn_output

class SelfAttention(nn.Module):
    # Code copied from HF, might want to sanity check later.
    def __init__(self, hidden_dim, n_head, device, dtype):
        nn.Module.__init__(self)
        max_positions = 2049
        bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
            1, 1, max_positions, max_positions).bool()
        self.head_dim = hidden_dim // n_head
        self.rotary_dim = self.head_dim // 4
        self.hidden_dim = hidden_dim
        self.n_head = n_head
        self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
        self.register_buffer("bias", bias)
        self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
        attn_bias = False
        self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
        self.register_buffer("sin", sin)
        self.register_buffer("cos", cos)

    def forward(self, x, kv=None, cache=False):
        B, S, H = x.shape # batch, sequence, hidden_dim
        # split heads into: [batch, head, sequence, head_dim]
        # transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
        query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
        key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
        value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)

        if kv:
            offset = kv[0].shape[-2]
        else:
            offset = 0

        if self.rotary_dim < self.head_dim:
            k_rot = key[:, :, :, :self.rotary_dim]
            k_pass = key[:, :, :, self.rotary_dim:]

            q_rot = query[:, :, :, :self.rotary_dim]
            q_pass = query[:, :, :, self.rotary_dim:]

            k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
            q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)

            key = torch.cat([k_rot, k_pass], dim=-1)
            query = torch.cat([q_rot, q_pass], dim=-1)

        else: 
            key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
            query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)

        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        
        if kv:
            k, v = kv
            # cat key and value (get the whole sequence, other than the last added token all are cached),
            # so query can attend to it.
            key = torch.cat([k, key], dim=-2) # cat key
            value = torch.cat([v, value], dim=-2) # cat value

        query_length, key_length = query.size(-2), key.size(-2)
        #causal mask with generation in mind
        causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]

        x = _attn(
            query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
        )

        x = x.transpose(1, 2).contiguous().view(B, S, H)
        x = self.out_proj(x)

        if cache:
            return x, (key, value)
        else:
            return x, None

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, activation, device, dtype):
        nn.Module.__init__(self)
        self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
        self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
        self.activation = activation

    def forward(self, x, act_ck=False):
        x = self.ff1(x)
        if act_ck:
            x = ck(self.activation, x)
        else:
            x = self.activation(x)
        x = self.ff2(x)
        return x

class GPTJLayer(nn.Module):
    def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
        nn.Module.__init__(self)
        self.hidden_dim = hidden_dim
        self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
        self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
        self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
        self.tick = True

    def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
        residual = x
        if act_ck:
            x = ck(self.ln_preattn, x)
            attn_out, kv = ck(self.attn, x, kv=kv, cache=cache)

        else:
            x = self.ln_preattn(x)
            attn_out, kv = self.attn(x, kv=kv, cache=cache)

        if hypernetwork:
            if diff_hypernets:
                if interleaving_layers and layer_id % every_n == 0:
                    if self.tick:
                        hyper_out = hypernetwork[0](x)
                        self.tick = False
                    else:
                        hyper_out = hypernetwork[1](x)
                        self.tick = True

                elif layer_id % every_n == 0:
                    hyper_out = hypernetwork[(layer_id // every_n) - 1](x)

            else:
                if layer_id % every_n == 0:
                    hyper_out = hypernetwork(x)

        ff_out = self.ff(x, act_ck)
        #order of addition matters, i had no idea... fixed a bug here.
        x = attn_out + ff_out + residual
        #x = residual + attn_out + ff_out -> doesn't match.
        if hypernetwork and layer_id % every_n == 0:
            x = x + hyper_out
            
        return x, kv

class GPTJModel(nn.Module):
    def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTJLayer, device="cuda", dtype=torch.float16, **kwargs):
        nn.Module.__init__(self)
        self.n_layer = n_layer
        self.hidden_dim = hidden_dim
        self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
        self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
        self.layers = nn.ModuleList([])
        self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
        for _ in range(n_layer):
            self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))

    def forward(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
        x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
        x = self.lm_head(x)
        if kv:
            return x.float(), kv
        else:
            return x.float()

    def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
        if kv is None:
            kv = [None] * self.n_layer

        kv_new = []
        x = self.vocab_embed(x)
        
        for layer_id, layer in enumerate(self.layers):
            x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
            kv_new.append(kvi)

        x = self.ln_final(x)
        if cache:
            return x, kv_new
        else:
            return x, None

class GPTJBaseLM(lm_base.BaseLM):
    def __init__(self, config=None, lm=None):
        nn.Module.__init__(self)
        lm_base.BaseLM.__init__(self, config, lm)
        self.model_class=GPTJModel

def load_gpt_j(path="models/6b", state_dict=None):
    config = {
        "n_layer": 28,
        "n_head": 16,
        "hidden_dim": 4096,
        "vocab_dim": 50400,
        "eps": 1e-5
    }
    model = GPTJBaseLM.load(config, path, state_dict)
    return model
