import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from dotmap import DotMap
import math
from basedformer import models

class ConfigClass:
    def __init__(self, config):
        #set all the key and values in config to attributes of this class
        for key, value in config.items():
            setattr(self, key, value)

class BaseModel(nn.Module):
    def __init__(self, user_config, **kwargs):
        nn.Module.__init__(self)
        #configuration
        self.user_config = user_config
        self.config = self.configure_model()
        config = self.config
        #modeling
        self.n_layer = config.n_layer
        self.hidden_dim = config.hidden_dim
        self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
        self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
        self.layers = nn.ModuleList([])
        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
        self.total_params = sum(p.numel() for p in self.parameters())
        for i in range(config.n_layer):
            config.layer_idx = i
            self.layers.append(
            config.Layer(
                attn=config.SelfAttention,
                ff=config.FeedForward,
                config=config,
                )
            )

    def init_weights(self):
        n_layer = self.n_layer
        for module in self.modules():
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=0.02)
                if module.bias is not None:
                    module.bias.data.zero_()

            elif isinstance(module, nn.Embedding):
                module.weight.data.normal_(mean=0.0, std=0.02)

            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)

            for name, p in module.named_parameters():
                if ("ff2" in name or "out_proj" in name) and "weight" in name:
                    p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * n_layer)))

    def configure_model(self):
        full_config = {}
        if not hasattr(self, 'default_config'):
            raise ValueError("No default config found, add one for the model to function")
        
        #apply defaults
        for k, v in self.default_config.items():
            full_config[k] = v

        #apply user defined config if provided
        for k, v in self.user_config.items():
            full_config[k] = v

        #full_config = DotMap(full_config)
        full_config = ConfigClass(full_config)
        return full_config

    def forward_with_hidden_states(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
        x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
        x = self.lm_head(x)
        if target:
            logits = x.view(-1, logits.shape[-1])
            labels = target.view(-1)
            loss = F.cross_entropy(logits, labels)
        
        #clean this mess later
        if cache:
            if target:
                return loss, x.float(), kv
            else:
                return x.float(), kv
        else:
            if target:
                return loss, x.float()
            else:
                return x.float()

    def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
        hidden_states, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
        x = self.lm_head(hidden_states)
        if target:
            logits = x.view(-1, logits.shape[-1])
            labels = target.view(-1)
            loss = F.cross_entropy(logits, labels)
        
        #clean this mess later
        if cache:
            if target:
                return loss, x.float(), kv
            else:
                return x.float(), kv
        else:
            if target:
                return loss, x.float()
            else:
                return x.float(), hidden_states

    def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
        if kv is None:
            kv = [None] * self.n_layer

        kv_new = []
        x = self.vocab_embed(x)
        
        for layer_id, layer in enumerate(self.layers):
            x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
            kv_new.append(kvi)

        x = self.ln_final(x)
        if cache:
            return x, kv_new
        else:
            return x, None

    def get_embeds_ds(self, x, past_key_values=None, use_cache=True):
        if past_key_values is None:
            past_key_values = [None] * self.n_layer

        kv_new = []
        x = self.vocab_embed(x)

        for layer_id, layer in enumerate(self.layers):
            x = layer(x, layer_past=past_key_values[layer_id], use_cache=use_cache)
            kv_new.append(x[1])
            x = x[0]

        x = self.ln_final(x)
        if use_cache:
            return x, kv_new
        else:
            return x, None

    def forward_ds(self, x, past_key_values=None, use_cache=True):
        x, kv = self.get_embeds_ds(x, past_key_values=past_key_values, use_cache=use_cache)
        x = self.lm_head(x)
        return x, kv

    def convert_to_ds(self):
        convert_func = models.ds_strats.model_map[self.config.Layer]
        model = convert_func(self)
        return model




