from lm_arch import utils
import math
import torch
from torch import nn
from lm_arch import gptj
import os

#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module):
    def __init__(self, config=None, lm=None):
        nn.Module.__init__(self)
        self.config = config
        self.lm = lm

    def init_weights(self):
        for module in self.lm.modules():
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=0.02)
                if module.bias is not None:
                    module.bias.data.zero_()

            elif isinstance(module, nn.Embedding):
                module.weight.data.normal_(mean=0.0, std=0.02)

            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)

            for name, p in module.named_parameters():
                if ("ff2" in name or "out_proj" in name) and "weight" in name:
                    p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.config.n_layer)))

    @classmethod
    def init(cls, config):
        lm = config.model_class(**config)
        model = cls(config, lm)
        model.init_weights()
        #make this modular later

        return model

    @classmethod
    def no_init(cls, config):
        lm = utils.no_init(lambda: config.model_class(**config))
        model = cls(config, lm)
        return model

    @classmethod
    def load(cls, model_class, config, path=None, state_dict=None, strict=False):
        # I am kinda sad that we will not have a load function in lm object itself.
        # might be better to add load functions to that as well but not sure.
        if path:
            state_dict = utils.SplitCheckpoint(path, device="cuda")

        lm = model_class(**config)
        model = cls(config, lm)
        model.lm.load_state_dict(state_dict, strict=strict)
        return model

    def save(self, path):
        if self.lm is None:
            print("No LM object to save. Please first init a model.")
            return

        try: os.mkdir(path)
        except: pass
        checkpoint = {}
        for i, x in enumerate(self.lm.state_dict().items()):
            checkpoint[x[0]] = f"{path}/b{i}.pt"
            torch.save(x[1], f"{path}/b{i}.pt")
        torch.save(checkpoint, f"{path}/m.pt")


def load_gpt_j(path="models/6b", state_dict=None):
    config = {
        "n_layer": 28,
        "n_head": 16,
        "hidden_dim": 4096,
        "vocab_dim": 50400,
        "eps": 1e-5,
        "activation": gptj.gelu_new,
        "Layer": gptj.GPTJLayer
    }
    model = BaseLM.load(gptj.GPTJModel, config, path, state_dict)
    return model
