import torch
try:
    from collections.abc import MutableMapping
except ImportError:
    from collections import MutableMapping
from pathlib import Path
import os


def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    
    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

SPLIT_WEIGHTS_NAME = "m.pt"
class SplitCheckpoint(MutableMapping):
    def __init__(self, name_or_path, device="cpu", subfolder=None):
        self.device = device
        localpath = Path(name_or_path)
        if subfolder is not None:
            localpath = localpath / subfolder
        if os.path.isfile(localpath):
            self.chkpt_dir = localpath.parent
            self.remote = False
        elif os.path.isfile(localpath / SPLIT_WEIGHTS_NAME):
            self.chkpt_dir = localpath
            self.checkpoint = torch.load(str(localpath / SPLIT_WEIGHTS_NAME))
            self.remote = False
        self.checkpoint = self._load(SPLIT_WEIGHTS_NAME, None)
    def _load(self, name, shape, **kwparams):
            path = str(self.chkpt_dir / name)
            return torch.load(path, **kwparams)
    def __len__(self):
        return len(self.checkpoint)
    def __getitem__(self, key):
        name = self.checkpoint[key]
        if type(name) is tuple:
            return self._load(name[0].split('/')[-1], name[1], map_location=self.device)
        else:
            return self._load(name.split('/')[-1], None, map_location=self.device)
    def __setitem__(self, key, value):
        return
    def __delitem__(self, key, value):
        return
    def keys(self):
        return self.checkpoint.keys()
    def __iter__(self):
        for key in self.checkpoint:
            yield (key, self.__getitem__(key))
    def __copy__(self):
        return SplitCheckpoint(self.chkpt_dir, device=self.device)
    def copy(self):
        return SplitCheckpoint(self.chkpt_dir, device=self.device)