from transformers import GPTNeoForCausalLM, AutoConfig
import torch
from lm_train.utils import *
import math

class GPT:
    def __init__(self, model_dtype="bf16", model_device="cuda"):
        self.config = self.get_config(model_dtype, model_device)
        self.checkpoint = self.get_checkpoint()
        self.model = None
        return
    
    def get_config(self, model_dtype="bf16", model_device="cuda"):
        print("Using device:", model_device)
        config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
        config.num_layers = 28
        config.attention_layers = ["global"] * config.num_layers
        config.attention_types = [["global"], config.num_layers]
        config.num_heads = 16
        config.hidden_size = 256 * config.num_heads
        config.vocab_size = 50400
        config.rotary = True
        config.rotary_dim = 64
        config.jax = True
        config.model_dtype = model_dtype
        config.model_device = model_device
        if model_dtype == "bf16":
            config.full_bf16 = True
        return config

    def get_checkpoint(self):
        try:
            from collections.abc import MutableMapping
        except ImportError:
            from collections import MutableMapping
        from pathlib import Path

        class Checkpoint(MutableMapping):
            def __init__(self, chkpt_dir, device="cpu"):
                self.device = device
                self.chkpt_dir = Path(chkpt_dir)
                self.checkpoint = torch.load(str(chkpt_dir / Path("m.pt")))
            def __len__(self):
                return len(self.checkpoint)
            def __getitem__(self, key):
                path = self.chkpt_dir / Path(self.checkpoint[key]).name
                return torch.load(str(path), map_location=self.device)
            def __setitem__(self, key, value):
                return
            def __delitem__(self, key, value):
                return
            def keys(self):
                return self.checkpoint.keys()
            def __iter__(self):
                for key in self.checkpoint:
                    yield (key, self.__getitem__(key))
            def __copy__(self):
                return Checkpoint(self.chkpt_dir, device=self.device)
            def copy(self):
                return Checkpoint(self.chkpt_dir, device=self.device)
        
        return Checkpoint
        
    def load_model(self, model_path=None, model_name=None, config=None, checkpoint=None):
        if config == None:
            config = self.config

        if checkpoint == None:
            Checkpoint = self.checkpoint

        if model_name != None:
            model_path = self.assign_path(model_name)
            print("Loading model from: " + model_path)

        model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=self.config, state_dict=Checkpoint(model_path)))
        self.model = model
        return model

    def assign_path(self, model_name):
        if model_name == "gptj":
            return "/home/xuser/models/j6b_ckpt_14001"

        # Raise error if model name not recognized
        else:
            raise ValueError("Model name not recognized")

    def init_model(self, config=None, method='wang'):
        neox_init = True
        if config == None:
            config = self.config

        model = no_init(lambda: GPTNeoForCausalLM(config))
        if neox_init:
            modules = [*model.transformer.h[:-1], model.transformer.wte, model.transformer.ln_f]
            init = small_init_method(self.config.hidden_size)
            for module in modules:
                for param in module.parameters():
                    init(param)
                    
            last_layer = model.transformer.h[-1]
            last_layer_init = wang_init_method(self.config.num_layers, self.config.hidden_size)
            for param in last_layer.parameters():
                last_layer_init(param)

        self.model = model
        return model

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        embs=None,
    ):
        if isinstance(self.model, GPTNeoForCausalLM):
            outputs = self.model(input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, embs)
            # outputs: dict(loss, logits, past_key_values, hidden_states, attentions)
        return outputs
    
#def init_module()

def wang_init_method(n_layers, dim):
    std = 2 / n_layers / math.sqrt(dim)

    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_

# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
    """Fills the input Tensor with values according to the method described in Transformers without Tears: Improving 
    the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
    std = math.sqrt(2 / (5 * dim))

    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_


