import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from dotmap import DotMap
        
class BaseModel(nn.Module):
    def __init__(self, user_config, **kwargs):
        nn.Module.__init__(self)
        #configuration
        self.user_config = user_config
        self.config = self.configure_model()
        config = self.config
        #modeling
        self.n_layer = config.n_layer
        self.hidden_dim = config.hidden_dim
        self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
        self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
        self.layers = nn.ModuleList([])
        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
        self.total_params = sum(p.numel() for p in self.parameters())
        for i in range(config.n_layer):
            config.layer_idx = i
            self.layers.append(
            config.Layer(
                attn=config.SelfAttention,
                ff=config.FeedForward,
                config=config,
                )
            )
    
    def configure_model(self):
        full_config = {}
        if not hasattr(self, 'default_config'):
            raise ValueError("No default config found, add one for the model to function")
        
        #apply defaults
        for k, v in self.default_config.items():
            full_config[k] = v

        #apply user defined config if provided
        for k, v in self.user_config.items():
            full_config[k] = v

        full_config = DotMap(full_config)
        return full_config

    def forward_with_hidden_states(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
        x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
        x = self.lm_head(x)
        if target:
            logits = x.view(-1, logits.shape[-1])
            labels = target.view(-1)
            loss = F.cross_entropy(logits, labels)
        
        #clean this mess later
        if cache:
            if target:
                return loss, x.float(), kv
            else:
                return x.float(), kv
        else:
            if target:
                return loss, x.float()
            else:
                return x.float()

    def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
        hidden_states, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
        x = self.lm_head(hidden_states)
        if target:
            logits = x.view(-1, logits.shape[-1])
            labels = target.view(-1)
            loss = F.cross_entropy(logits, labels)
        
        #clean this mess later
        if cache:
            if target:
                return loss, x.float(), kv
            else:
                return x.float(), kv
        else:
            if target:
                return loss, x.float()
            else:
                return x.float(), hidden_states

    def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
        if kv is None:
            kv = [None] * self.n_layer

        kv_new = []
        x = self.vocab_embed(x)
        
        for layer_id, layer in enumerate(self.layers):
            x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
            kv_new.append(kvi)

        x = self.ln_final(x)
        if cache:
            return x, kv_new
        else:
            return x, None