from basedformer import utils
from basedformer import models
import math
import torch
from torch import nn, distributed
import os
import json
from dataclasses import dataclass
from pathlib import Path

def init_weights(model, n_layer):
    for module in model.modules():
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()

        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)

        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

        for name, p in module.named_parameters():
            if ("ff2" in name or "out_proj" in name) and "weight" in name:
                p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * n_layer)))

def init(model_class, config):
    model = model_class(config)
    init_weights(model, config["n_layer"])
    return model

def no_init(config):
    model_class = models.get_model(config["model_class"])
    model = utils.no_init(lambda: model_class(config))
    return model

def serialize_config(config):
    serialized_dict = {
        "model_class": "gptj",
        "model_path": ".",
        'model_config': {
        'n_layer': config.n_layer,
        'n_head': config.n_head,
        'n_tokens': config.n_tokens,
        'hidden_dim': config.hidden_dim,
        'vocab_dim': config.vocab_dim,
        'eps': config.eps,
        }
    }
    return serialized_dict

def save(model, path, save_fp16=True):
    if distributed.is_initialized() and distributed.get_rank() != 0:
        return

    if save_fp16:
        model = model.half()

    path = Path(path)
    lm_path = path / "lm"
    #make folder
    lm_path.mkdir(parents=True, exist_ok=True)
    checkpoint = {}
    for i, x in enumerate(model.state_dict().items()):
        checkpoint[x[0]] = lm_path / f"b{i}.pt"
        torch.save(x[1], lm_path / f"b{i}.pt")
    torch.save(checkpoint, lm_path / "m.pt")

    #write model.config to config.json inside path
    with open(path / "config.json", "w") as f:
        json.dump(serialize_config(model.config), f)

def load_from_path(config_folder=None, strict=False):
    config_folder = Path(config_folder)
    config = _load_config_file(config_folder / "config.json")
    model_class = models.get_model(config["model_class"])
    model_path = config["model_path"]
    model_config = config["model_config"]

    if model_path == ".":
        # model_path is the config_folder directory.
        model_path = config_folder
    
    model_path = str(Path(model_path) / "lm")
    model = _load_dict_model(model_class, model_config, model_path, strict=strict)
    return model
    
def _load_dict_model(model_class, config, path=None, state_dict=None, strict=False):
    # I am kinda sad that we will not have a load function in lm object itself.
    # might be better to add load functions -- actually nope.
    if path:
        state_dict = utils.SplitCheckpoint(path, device="cuda")

    model= utils.no_init(lambda: model_class(config))
    model.load_state_dict(state_dict, strict=strict)
    return model

def _load_config_file(config_file):
    if not config_file.exists():
        raise FileNotFoundError(f"Config file not found at {config_file}")

    with open(config_file) as f:
        config = json.load(f)

    return config

    



