import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
    from collections.abc import MutableMapping
except ImportError:
    from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer.models import base_lm
from typing import Optional, Any

def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
    """Replace non-padding symbols with their position numbers.

    Position numbers begin at padding_idx+1. Padding symbols are ignored.
    """
    # The series of casts and type-conversions here are carefully
    # balanced to both work with ONNX export and XLA. In particular XLA
    # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
    # how to handle the dtype kwarg in cumsum.
    mask = tensor.ne(torch.tensor(50257, requires_grad=False)).int()
    return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx

class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.

    Padding symbols are ignored.
    """

    def __init__(self, embedding_dim, padding_idx, init_size=1024):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx if padding_idx is not None else 0
        self.weights = SinusoidalPositionalEmbedding.get_embedding(
            init_size, embedding_dim, padding_idx
        )
        self.onnx_trace = False
        self.register_buffer("_float_tensor",
                             torch.tensor(1.0, requires_grad=False).float())
        self.max_positions = int(1e5)
        # print(embedding_dim, padding_idx, init_size)

    def prepare_for_onnx_export_(self):
        self.onnx_trace = True

    @staticmethod
    def get_embedding(
            num_embeddings: int, embedding_dim: int,
            padding_idx: Optional[int] = None
    ):
        """Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
            1
        ) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
            num_embeddings, -1
        )
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    def forward(
            self,
            input,
            incremental_state: Optional[Any] = None,
            timestep: Optional[torch.Tensor] = None,
            positions: Optional[Any] = None,
            offset: Optional[int] = 0
    ):
        """Input is expected to be of size [bsz x seqlen]."""
        bspair = input.shape
        bsz, seq_len = bspair[0], bspair[1]
        max_pos = self.padding_idx + 1 + seq_len + offset
        # print("max_pos: " + str(max_pos))
        if self.weights is None or max_pos > self.weights.size(0):
            # print("recomputing embeddings")
            # recompute/expand embeddings if needed
            self.weights = SinusoidalPositionalEmbedding.get_embedding(
                max_pos, self.embedding_dim, self.padding_idx + offset
            )
        self.weights = self.weights.to(self._float_tensor)

        if incremental_state is not None:
            # positions is the same for every token when decoding a single step
            pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
            if self.onnx_trace:
                return (
                    self.weights.index_select(
                        index=self.padding_idx + pos + offset, dim=0)
                        .unsqueeze(1)
                        .repeat(bsz, 1, 1)
                )
            return self.weights[self.padding_idx + pos + offset, :].expand(bsz,
                                                                           1,
                                                                           -1)

        positions = make_positions(
            input, self.padding_idx + offset, onnx_trace=self.onnx_trace
        )
        if self.onnx_trace:
            flat_embeddings = self.weights.detach().index_select(0,
                                                                 positions.view(
                                                                     -1))
            embedding_shape = torch.cat(
                (bsz.view(1), seq_len.view(1),
                 torch.tensor([-1], dtype=torch.long))
            )
            embeddings = torch.onnx.operators.reshape_from_tensor_shape(
                flat_embeddings, embedding_shape
            )
            return embeddings
        return (
            self.weights.index_select(0, positions.view(-1))
                .view(bsz, seq_len, -1)
                .detach()
        )


def PositionalEmbedding(
        num_embeddings: int,
        embedding_dim: int,
        padding_idx: int,
):
    m = SinusoidalPositionalEmbedding(
        embedding_dim,
        padding_idx,
        init_size=num_embeddings + padding_idx + 1,
    )
    return m

def _attn(query, key, value, causal_mask, masked_bias,
            attention_mask=None, scale_attn=None, fp32_attn=True):

    if fp32_attn:
        attn_weights = torch.matmul(query.float(), key.transpose(-1, -2).float())
    else:
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

    attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
    attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = F.softmax(attn_weights, dim=-1)
    attn_weights = attn_weights.to(value.dtype)

    attn_output = torch.matmul(attn_weights, value).to(value.dtype)

    return attn_output

class SelfAttention(nn.Module):
    # Code copied from HF, might want to sanity check later.
    def __init__(self, config):
        nn.Module.__init__(self)
        self.config = config
        max_positions = 2049
        bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
            1, 1, max_positions, max_positions).bool()
        self.head_dim = config.hidden_dim // config.n_head
        self.rotary_dim = self.head_dim // 4
        self.hidden_dim = config.hidden_dim
        self.n_head = config.n_head
        device = config.device
        dtype = config.dtype

        self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
        self.register_buffer("bias", bias)
        self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
        attn_bias = True #fairseq has attn_bias
        self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        
    def forward(self, x, kv=None, cache=False):
        B, S, H = x.shape # batch, sequence, hidden_dim
        # split heads into: [batch, head, sequence, head_dim]
        query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
        key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
        value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)

        if kv:
            k, v = kv
            # cat key and value (get the whole sequence, other than the last added token all are cached),
            # so query can attend to it.
            torch.cat([k, key], dim=-2) # cat key
            torch.cat([v, value], dim=-2) # cat value
            
        query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
        causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]

        x = _attn(
            query, key, value, causal_mask, self.masked_bias, None, self.scale_attn, self.config.fp32_attn
        )

        x = x.transpose(1, 2).contiguous().view(B, S, H)
        x = self.out_proj(x)
        if cache:
            return x, (key, value)
        else:
            return x, None

class FeedForward(nn.Module):
    def __init__(self, config):
        nn.Module.__init__(self)
        self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim * 4, device=config.device, dtype=config.dtype)
        self.ff2 = nn.Linear(config.hidden_dim * 4, config.hidden_dim, device=config.device, dtype=config.dtype)
        self.activation = config.activation

    def forward(self, x, act_ck=False):
        x = self.ff1(x)
        if act_ck:
            x = ck(self.activation, x)
        else:
            x = self.activation(x)
        x = self.ff2(x)
        return x

class GPTFairLayer(nn.Module):
    def __init__(self, attn, ff, config):
        nn.Module.__init__(self)
        self.hidden_dim = config.hidden_dim
        self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
        self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
        self.ff = ff(config)
        self.attn = attn(config)
        self.tick = True

    def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, cache=False, kv=None):
        residual = x
        
        if act_ck:
            x = ck(self.ln_preattn, x)
            attn_out, kv = ck(self.attn, x, kv=kv, cache=cache)

        else:
            x = self.ln_preattn(x)
            attn_out, kv = self.attn(x, kv=kv, cache=cache)

        x = residual + attn_out
        residual = x
        x = self.ln_postattn(x)
        ff_out = self.ff(x, act_ck)
        x = residual + ff_out
            
        return x, kv

class GPTFairModel(base_lm.BaseModel):
    def __init__(self, user_config, **kwargs):
        self.default_config = {
            'n_layer': 6,
            'n_head': 8,
            'n_tokens': 2049,
            'hidden_dim': 512,
            'vocab_dim': 50400,
            'fp32_attn': True, #fairseq models are trained with fp32 attn
            'eps': 1e-5,
            'device': torch.device('cuda'),
            'dtype': torch.float16,
            'Layer': GPTFairLayer,
            'activation': F.gelu,
            'SelfAttention': SelfAttention,
            'FeedForward': FeedForward,
        }
        base_lm.BaseModel.__init__(self, user_config, **kwargs)
        # returns sinusoidal embeddings of shape: (1, n_tokens, 768)
        self.register_buffer("embed_scale", torch.sqrt(torch.tensor(self.config.hidden_dim, requires_grad=False)))
        self.pos_embed = PositionalEmbedding(self.config.n_tokens, self.config.hidden_dim, 1)
        self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False)
        #bias=False for fairseq models

    def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
        if kv is None:
            kv = [None] * self.n_layer
            past_length = 0

        else:
            past_length = kv[0][0].size(-2) #get sequence dim of key

        kv_new = []

        position_embeds = self.pos_embed(x, offset=past_length)
        input_embeds = self.vocab_embed(x) * self.embed_scale
        x = position_embeds + input_embeds

        for layer_id, layer in enumerate(self.layers):
            x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
            kv_new.append(kvi)

        x = self.ln_final(x)
        if cache:
            return x, kv_new
        else:
            return x, None