import torch.nn as nn
import torch.nn.functional as F
from basedformer import utils
from dotmap import DotMap
from pathlib import Path
import torch
import json
        
class PretrainedModel(nn.Module):
    def __init__(self, **kwargs):
        nn.Module.__init__(self)
        self.config = None

    @classmethod
    def no_init(cls, config):
        model = utils.no_init(lambda: cls(config))
        return model
    
    @classmethod
    def init(cls, config):
        model = cls(config)
        if hasattr(model, 'init_weights'):
            model.init_weights()
        else:
            raise ValueError("No init_weights found, add one for init to function")
        return model
    
    def save(self, path, save_as=torch.float16):
        original_dtype = model.dtype
        model = self
        if save_as:
            model = model.to(save_as)

        path = Path(path)
        model_path = path / "model"
        #make folder
        model_path.mkdir(parents=True, exist_ok=True)
        checkpoint = {}
        for i, x in enumerate(model.state_dict().items()):
            checkpoint[x[0]] = model_path / f"b{i}.pt"
            torch.save(x[1], model_path / f"b{i}.pt")
        torch.save(checkpoint, model_path / "m.pt")

        #write model.config to config.json inside path
        #with open(path / "config.json", "w") as f:
        #    json.dump(serialize_config(model.config), f)
        
        if save_as:
            model = model.to(original_dtype)