import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from basedformer.models import base_image
import einops

def _attn(query, key, value, attention_mask=None, scale_attn=None):
    attn_weights = torch.matmul(query, key.transpose(-1, -2))
    attn_weights = attn_weights / scale_attn

    if attention_mask is not None:
        attn_weights = attn_weights + attention_mask

    attn_weights = F.softmax(attn_weights, dim=-1)
    attn_weights = attn_weights.to(value.dtype)

    attn_output = torch.matmul(attn_weights, value).to(value.dtype)

    return attn_output

class SelfAttention(nn.Module):
    # Code copied from HF, might want to sanity check later.
    def __init__(self, config):
        nn.Module.__init__(self)
        self.head_dim = config.hidden_dim // config.n_head
        self.rotary_dim = self.head_dim // 4
        self.hidden_dim = config.hidden_dim
        self.n_head = config.n_head
        device = config.device
        dtype = config.dtype

        self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
        attn_bias = False
        self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
        
    def forward(self, x, kv=None, cache=False):
        B, S, H = x.shape # batch, sequence, hidden_dim
        # split heads into: [batch, head, sequence, head_dim]
        query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
        key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
        value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)

        if kv:
            k, v = kv
            # cat key and value (get the whole sequence, other than the last added token all are cached),
            # so query can attend to it.
            torch.cat([k, key], dim=-2) # cat key
            torch.cat([v, value], dim=-2) # cat value
            
        x = _attn(
            query, key, value, None, self.scale_attn
        )

        x = x.transpose(1, 2).contiguous().view(B, S, H)
        x = self.out_proj(x)
        if cache:
            return x, (key, value)
        else:
            return x, None

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
        self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
        self.activation = config.activation

    def forward(self, x):
        x = self.ff1(x)
        x = self.activation(x)
        x = self.ff2(x)
        return x

class ViTEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_dim
        self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
        self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
        self.ff = FeedForward(config)
        self.attn = SelfAttention(config)
    
    def forward(self, x):
        residual = x
        print(x.shape)
        x = self.ln_preattn(x)
        x = self.attn(x)[0]
        x = residual + x
        residual = x
        x = self.ln_postattn(x)
        x = self.ff(x)
        return x + residual


class ViTEmbeds(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        p_size = config.patch_size
        channels = config.channels
        dim = config.hidden_dim
        num_patches = (config.image_size[1] // p_size) * (config.image_size[0] // p_size)
        
        self.lin_emb = nn.Linear((p_size ** 2) * channels, dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.pos = nn.Parameter(torch.zeros(1, num_patches + 1, dim))

    def forward(self, x: torch.Tensor):
        embed = self.lin_emb(x)
        batch_size = x.size()[0]
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embed = torch.cat((cls_tokens, embed), dim=1)
        return embed + self.pos

class VisionTransformer(base_image.BaseVisionModel):
    def __init__(self):
        self.default_config = {
            'n_layer': 12,
            'n_head': 8,
            'channels': 3,
            'patch_size': 16,
            'hidden_dim': 768,
            'n_classes' : 1000,
            'activation': torch.nn.GELU(),
            'image_size': (224, 224),
            'eps': 1e-5,
            'device': torch.device('cuda'),
            'dtype': torch.float16,
        }
        super().__init__(self.default_config)
        self.embed = ViTEmbeds(self.config)
        self.encoder_layers = nn.ModuleList()
        for _ in range(self.config.n_layer):
            self.encoder_layers.append(ViTEncoder(self.config))
        self.mlp_head = nn.Linear(self.config.hidden_dim, self.config.n_classes)

    def forward(self, x):
        p_size = self.config.patch_size
        patches = einops.rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=p_size, s2=p_size)
        patches = self.embed(patches)
        for encoder in self.encoder_layers:
            patches = encoder(patches)
        return self.mlp_head(patches)
        

