import torch
try:
    from collections.abc import MutableMapping
except ImportError:
    from collections import MutableMapping
from pathlib import Path
import os
import math

from torch.utils import data
import numpy as np
import torch
from tqdm import tqdm
import time
from simplejpeg import decode_jpeg
import mmap
from timeit import default_timer as timer
import pickle
import concurrent
from itertools import repeat

# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
    def __init__(self, block_size, map_file, max_samples=None, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        self.samples = self.npz.shape[0]
        if max_samples is not None:
            self.samples = min(self.samples, int(max_samples))
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedDataset(data.Dataset):
    def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        #might want to pad later
        self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
        #shard
        self.npz = self.npz[rank::world_size]
        self.samples = self.npz.shape[0]
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))

def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    
    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

# Count the parameters of a given pytorch model.

def count_parameters(model, only_trainable=False):
    return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)

def print_parameters(model, only_trainable=False):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)
    params = params / 1e6
    print(f"{params:.2f}M parameters")

SPLIT_WEIGHTS_NAME = "m.pt"
class SplitCheckpoint(MutableMapping):
    def __init__(self, name_or_path, device="cpu", subfolder=None):
        self.device = device
        localpath = Path(name_or_path)
        if subfolder is not None:
            localpath = localpath / subfolder
        if os.path.isfile(localpath):
            self.chkpt_dir = localpath.parent
            self.remote = False
        elif os.path.isfile(localpath / SPLIT_WEIGHTS_NAME):
            self.chkpt_dir = localpath
            self.checkpoint = torch.load(str(localpath / SPLIT_WEIGHTS_NAME))
            self.remote = False
        self.checkpoint = self._load(SPLIT_WEIGHTS_NAME, None)
    def _load(self, name, shape, **kwparams):
            path = str(self.chkpt_dir / name)
            return torch.load(path, **kwparams)
    def __len__(self):
        return len(self.checkpoint)
    def __getitem__(self, key):
        name = str(self.checkpoint[key])
        if type(name) is tuple:
            return self._load(name[0].split('/')[-1], name[1], map_location=self.device)
        else:
            return self._load(name.split('/')[-1], None, map_location=self.device)
    def __setitem__(self, key, value):
        return
    def __delitem__(self, key, value):
        return
    def keys(self):
        return self.checkpoint.keys()
    def __iter__(self):
        for key in self.checkpoint:
            yield (key, self.__getitem__(key))
    def __copy__(self):
        return SplitCheckpoint(self.chkpt_dir, device=self.device)
    def copy(self):
        return SplitCheckpoint(self.chkpt_dir, device=self.device)

def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True, cuda_blocking=True):
    precision = 'ns'
    r_arr = np.empty([2, r]) # [0] = mean, [1] = std
    if function:
        func.__name__ = function.__name__

    for i in tqdm(range(r)) if do_tqdm else range(r):
        n_arr = np.empty(n)
        for k in range(n):
            start = time.perf_counter_ns()
            if cuda_blocking:
                torch.cuda.synchronize()
            func()
            if cuda_blocking:
                torch.cuda.synchronize()
            n_arr[k] = time.perf_counter_ns() - start
        
        if not first:
            # delete the first element from n_arr numpy array
            n_arr = np.delete(n_arr, 0)

        r_arr[0, i] = np.mean(n_arr)
        r_arr[1, i] = np.std(n_arr)
    
    best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
    #check if best[0] bigger than 1ms in numpy
    if best[0] < 1e3:
        precision = 'ns'

    elif best[0] >= 1e9:
        best[0] = best[0] * 1e-9
        best[1] = best[1] * 1e-9
        precision = 's'

    elif best[0] >= 1e6:
        best[0] = best[0] * 1e-6
        best[1] = best[1] * 1e-6
        precision = 'ms'

    elif best[0] >= 1e3:
        precision = 'μs'
        best[0] = best[0] * 1e-3
        best[1] = best[1] * 1e-3

    if not quiet:
        if precision == 'ns':
            print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        if precision == 'μs':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 'ms':
            print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
        elif precision == 's':
            print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")

def gelu_new(x):
    return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class FIAReader():
    def __init__(self, dataset_path: str, metadata_path: str, transform=None, local_transform=None, 
                       skip=0, batch_size=8500, image_cnt=100000):
        self.skip = skip # not used for now
        self.threads = 16 # it seems 16 is the ideal thread count for this machine
        self.image_cnt = image_cnt # The image count to be read at each run of FIAReader[x] 
        self.batch_size = batch_size
        self.transform = transform
        self.local_transform = local_transform
        self.dataset_path = dataset_path
        with open(metadata_path, 'rb') as f:
            self.metadata = pickle.load(f)

    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, key):
        # Currently, we're just iterating over the dataset, decoding each JPEGs into a tensor, and doing nothing with a tensor
        # this code is currently only used for benchmarks. See the tensors object declaration below
        start_time = timer()
        keys = [*range(key, key+self.image_cnt)]
        for i in tqdm(range(self.image_cnt // self.batch_size)):
            start_val = self.metadata[key + (i * self.batch_size)]
            end_val = self.metadata[key + ((i + 1) * self.batch_size)]
            start_ptr = start_val[0]
            end_ptr = end_val[0] + end_val[1]
            # At this part, we're reading the file using mmap for all pictures at the current batch
            with open(self.dataset_path, mode="r", encoding="utf8") as file_obj:
                with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
                    mmap_obj.seek(start_ptr)
                    curr_mmap = mmap_obj.read(end_ptr - start_ptr)
            # We can use a with statement to ensure threads are cleaned up promptly
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
                # tensors object is not saved to anywhere due to memory constaints. 
                tensors = list(executor.map(self.read_from_metadata_key, repeat(curr_mmap), repeat(start_ptr), keys[i*self.batch_size:(i+1)*self.batch_size - 1]))
            mmap_obj.close()
        end_time = timer()
        print('image reading time: ', end_time - start_time)
        # The code below the return expression has not been tested yet
        return
        if self.local_transform:
            globo1_list = []
            globo2_list = []
            local_list = []
            for i, t in enumerate(tensors):
                globo1, globo2, local = self.local_transform(t.cuda())
                globo1_list.append(globo1)
                globo2_list.append(globo2)
                local_list.append(local)

            globo1 = torch.stack(globo1_list).cuda()
            globo2 = torch.stack(globo2_list).cuda()
            local = torch.cat(local_list, dim=0).cuda()
        
        if self.transform:
            globo1, globo2, local = self.transform(globo1, globo2, local)

        imagelist = []
        imagelist.append(globo1)
        imagelist.append(globo2)
        imagelist = [*imagelist, *local.split(self.image_cnt)]
        return imagelist


    def read_from_metadata_key(self, dataset_mmap, start_ptr, key):
        val = self.metadata[key]
        data = dataset_mmap[val[0]-start_ptr: val[0]+val[1]-start_ptr]
        #data = torch.frombuffer(data, dtype=torch.uint8)
        #data = torchvision_decode_jpeg(data, device="cpu")
        #data = np.frombuffer(data, dtype=np.uint8)
        data = decode_jpeg(data)
        data = torch.from_numpy(data).permute(2, 0, 1)
        return data