import torch
try:
    from collections.abc import MutableMapping
except ImportError:
    from collections import MutableMapping
from pathlib import Path
import os
import math

from torch.utils import data
import numpy as np
import torch
from tqdm import tqdm
import time
from simplejpeg import decode_jpeg
import mmap
from timeit import default_timer as timer
import pickle
import concurrent
from itertools import repeat

# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
    def __init__(self, block_size, map_file, max_samples=None, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        self.samples = self.npz.shape[0]
        if max_samples is not None:
            self.samples = min(self.samples, int(max_samples))
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedDataset(data.Dataset):
    def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        #might want to pad later
        self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
        #shard
        self.npz = self.npz[rank::world_size]
        self.samples = self.npz.shape[0]
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedImageDataset(data.Dataset):
    def __init__(self, dataset_path: str, metadata_path: str, threads=None, skip=0, bsz=256, world_size=1, rank=0):
        self.skip = skip
        self.threads = threads
        self.bsz = bsz
        self.dataset_path = dataset_path
        self.world_size = world_size
        self.rank = rank
        with open(metadata_path, 'rb') as f:
            self.metadata = pickle.load(f)

        with open(self.dataset_path, mode="r", encoding="utf8") as file_obj:
            self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)

        #make so metadata is shardable by world_size(num_gpus)
        #and batch_size
        self.metadata = self.metadata[:len(self.metadata) - (len(self.metadata) % (bsz * world_size))]

        #shard the dataset according to the rank
        self.metadata = self.metadata[rank::world_size]
        #override possible gil locks by making the metadata map an nparray
        self.metadata = np.array(self.metadata)
        #getting the threadpoolexecutor to __init__ instead of __getitem__
        #made it 10x faster lol
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)

    def __len__(self):
        return len(self.metadata) // (self.bsz * self.world_size)
    
    def __getitem__(self, key):
        key = self.skip + key
        keys = [*range(key, key+self.bsz)]
        tensors = self.executor.map(self.read_from_metadata_key, keys)
        return tensors

    def read_from_metadata_key(self, key):
        offset, size, d_id = self.metadata[key]
        data = self.mmap[offset:offset+size]
        data = decode_jpeg(data)
        data = torch.from_numpy(data).permute(2, 0, 1)
        return data

# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))

def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    
    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

# Count the parameters of a given pytorch model.

def count_parameters(model, only_trainable=False):
    return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)

def print_parameters(model, only_trainable=False):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)
    params = params / 1e6
    print(f"{params:.2f}M parameters")

SPLIT_WEIGHTS_NAME = "m.pt"
class SplitCheckpoint(MutableMapping):
    def __init__(self, name_or_path, device="cpu", subfolder=None):
        self.device = device
        localpath = Path(name_or_path)
        if subfolder is not None:
            localpath = localpath / subfolder
        if os.path.isfile(localpath):
            self.chkpt_dir = localpath.parent
            self.remote = False
        elif os.path.isfile(localpath / SPLIT_WEIGHTS_NAME):
            self.chkpt_dir = localpath
            self.checkpoint = torch.load(str(localpath / SPLIT_WEIGHTS_NAME))
            self.remote = False
        self.checkpoint = self._load(SPLIT_WEIGHTS_NAME, None)
    def _load(self, name, shape, **kwparams):
            path = str(self.chkpt_dir / name)
            return torch.load(path, **kwparams)
    def __len__(self):
        return len(self.checkpoint)
    def __getitem__(self, key):
        name = str(self.checkpoint[key])
        if type(name) is tuple:
            return self._load(name[0].split('/')[-1], name[1], map_location=self.device)
        else:
            return self._load(name.split('/')[-1], None, map_location=self.device)
    def __setitem__(self, key, value):
        return
    def __delitem__(self, key, value):
        return
    def keys(self):
        return self.checkpoint.keys()
    def __iter__(self):
        for key in self.checkpoint:
            yield (key, self.__getitem__(key))
    def __copy__(self):
        return SplitCheckpoint(self.chkpt_dir, device=self.device)
    def copy(self):
        return SplitCheckpoint(self.chkpt_dir, device=self.device)

def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True, cuda_blocking=True):
    precision = 'ns'
    r_arr = np.empty([2, r]) # [0] = mean, [1] = std
    if function:
        func.__name__ = function.__name__

    for i in tqdm(range(r)) if do_tqdm else range(r):
        n_arr = np.empty(n)
        for k in range(n):
            start = time.perf_counter_ns()
            if cuda_blocking:
                torch.cuda.synchronize()
            func()
            if cuda_blocking:
                torch.cuda.synchronize()
            n_arr[k] = time.perf_counter_ns() - start
        
        if not first:
            # delete the first element from n_arr numpy array
            n_arr = np.delete(n_arr, 0)

        r_arr[0, i] = np.mean(n_arr)
        r_arr[1, i] = np.std(n_arr)
    
    best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
    #check if best[0] bigger than 1ms in numpy
    if best[0] < 1e3:
        precision = 'ns'

    elif best[0] >= 1e9:
        best[0] = best[0] * 1e-9
        best[1] = best[1] * 1e-9
        precision = 's'

    elif best[0] >= 1e6:
        best[0] = best[0] * 1e-6
        best[1] = best[1] * 1e-6
        precision = 'ms'

    elif best[0] >= 1e3:
        precision = 'μs'
        best[0] = best[0] * 1e-3
        best[1] = best[1] * 1e-3

    if not quiet:
        if precision == 'ns':
            print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        if precision == 'μs':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 'ms':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 's':
            print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")

def gelu_new(x):
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))