from torch import nn
import torch
import os
import math
from lm_arch.utils import *

# Can access and change every module from here, as both Layer class and ff and attn classes are passed from GPTModel.
class GPTModel(nn.Module):
    def __init__(self, hidden_dim=512, n_layer=12, n_head=4, vocab_dim=50400, eps=1e-4, activation=nn.GELU(), Layer=None, SelfAttention=None, FeedForward=None, device="cuda", dtype=torch.float16):
        super(GPTModel, self).__init__()
        self.hidden_dim = hidden_dim
        self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
        self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
        self.layers = nn.ModuleList([])
        self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
        for _ in range(n_layer):
            self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
            #TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
            #TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.

    def forward(self, x, hypernetwork=None, act_ck=False):
        x = self.vocab_embed(x)
        for layer in self.layers:
            x = layer(x, hypernetwork, act_ck)
        x = self.ln_final(x)
        return x

    def get_logits(self, x, hypernetwork=None, act_ck=False):
        x = self.forward(x, hypernetwork=hypernetwork, act_ck=act_ck)
        x = self.lm_head(x)
        return x.float()
    
    @classmethod
    def load(cls, config, path=None, state_dict=None):
        if path:
            state_dict = SplitCheckpoint(path, device="cuda")

        model = no_init(lambda: cls(**config))
        model.load_state_dict(state_dict, strict=False)
        return model

    @classmethod
    def init(cls, config):
        model = cls(**config)
        return model

    @classmethod
    def neox_init(cls, config):
        model = cls(**config)
        modules = [*model.layers[:-1], model.vocab_embed, model.ln_final, model.lm_head]
        init = small_init_method(config["hidden_dim"])
        for module in modules:
            for param in module.parameters():
                init(param)
                
        last_layer = model.layers[-1]
        last_layer_init = wang_init_method(config["n_layer"], config["hidden_dim"])
        for param in last_layer.parameters():
            last_layer_init(param)

        return model

    @classmethod
    def simple_init(cls, config):
        model = cls(**config)
        state = model.state_dict()
        for k in state:
            state[k] = state[k] / math.sqrt(2 * config["n_layer"])
        model.load_state_dict(state)

        return model

    def save(self, path):
        try: os.mkdir(path)
        except: pass
        checkpoint = {}
        for i, x in enumerate(self.state_dict().items()):
            checkpoint[x[0]] = f"{path}/b{i}.pt"
            torch.save(x[1], f"{path}/b{i}.pt")
        torch.save(checkpoint, f"{path}/m.pt")

def wang_init_method(n_layers, dim):
    std = 2 / n_layers / math.sqrt(dim)

    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_

# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
    """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving 
    the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
    std = math.sqrt(2 / (5 * dim))

    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_