from basedformer import utils
from basedformer import models
import math
import torch
from torch import nn
import os
import json
from dataclasses import dataclass
from pathlib import Path

def init_weights(model, n_layer):
    for module in model.modules():
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()

        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)

        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

        for name, p in module.named_parameters():
            if ("ff2" in name or "out_proj" in name) and "weight" in name:
                p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * n_layer)))

def init(model_class, config):
    model = model_class(config)
    init_weights(model, config["n_layer"])
    return model

def no_init(model_class, config):
    model = utils.no_init(lambda: model_class(config))
    return model

def save(model, path):
    try: os.mkdir(path)
    except: pass
    checkpoint = {}
    for i, x in enumerate(model.state_dict().items()):
        checkpoint[x[0]] = f"{path}/b{i}.pt"
        torch.save(x[1], f"{path}/b{i}.pt")
    torch.save(checkpoint, f"{path}/m.pt")

def load_from_path(config_folder=None, strict=False):
    config_folder = Path(config_folder)
    config = _load_config_file(config_folder / "config.json")
    model_class = models.get_model(config["model_class"])
    model_path = config["model_path"]
    model_config = config["model_config"]
    print(model_config)

    if model_path == ".":
        # model_path is the config_folder directory.
        model_path = config_folder
    
    model_path = Path(model_path) / "lm"
    model = _load_dict_model(model_class, model_config, model_path, strict=strict)
    return model
    
def _load_dict_model(model_class, config, path=None, state_dict=None, strict=False):
    # I am kinda sad that we will not have a load function in lm object itself.
    # might be better to add load functions -- actually nope.
    if path:
        state_dict = utils.SplitCheckpoint(path, device="cuda")

    model= utils.no_init(lambda: model_class(config))
    model.load_state_dict(state_dict, strict=strict)
    return model

def _load_config_file(config_file):
    if not config_file.exists():
        raise FileNotFoundError(f"Config file not found at {config_file}")

    with open(config_file) as f:
        config = json.load(f)

    return config

    



