import numpy as np
import torch
import mmap
import pickle
import concurrent
from torch.utils import data
from simplejpeg import decode_jpeg
import simplejpeg
import pickle
from pathlib import Path

# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
    def __init__(self, block_size, map_file, max_samples=None, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        self.samples = self.npz.shape[0]
        if max_samples is not None:
            self.samples = min(self.samples, int(max_samples))
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedDataset(data.Dataset):
    def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        #might want to pad later
        self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
        #shard
        self.npz = self.npz[rank::world_size]
        self.samples = self.npz.shape[0]
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedImageDataset(data.Dataset):
    def __init__(self, dataset_path: str, metadata_path: str, threads=None, inner_transform=None,
        outer_transform=None, skip=0, bsz=256, world_size=1, rank=0):

        self.skip = skip
        self.threads = threads
        self.bsz = bsz
        #for one by one transforms because images can't be batched
        self.inner_transform = inner_transform
        #for batched transforms after images become batchable
        self.outer_transform = outer_transform
        self.dataset_path = dataset_path
        self.world_size = world_size
        self.rank = rank
        with open(metadata_path, 'rb') as f:
            self.metadata = pickle.load(f)

        with open(self.dataset_path, mode="r") as file_obj:
            self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)

        #make so metadata is shardable by world_size(num_gpus)
        #and batch_size
        self.metadata = self.metadata[:len(self.metadata) - (len(self.metadata) % (bsz * world_size))]

        #shard the dataset according to the rank
        self.metadata = self.metadata[rank::world_size]
        #override possible gil locks by making the metadata map an nparray
        self.metadata = np.array(self.metadata)
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)

    def __len__(self):
        return len(self.metadata) // self.bsz
    
    def __getitem__(self, key):
        key = self.skip + key
        keys = [*range(key, key+self.bsz)]
        tensors = self.executor.map(self.read_from_metadata_key, keys)
        tensors = list(tensors)
        #make sure these operations are fast!
        ids = [t[1] for t in tensors]
        tensors = torch.stack([t[0] for t in tensors])
        tensors = tensors.permute(0, 3, 1, 2).float()
        #####################################
        if self.outer_transform:
            tensors = self.outer_transform(tensors)
            
        return tensors, ids

    def read_from_metadata_key(self, key):
        offset, size, id = self.metadata[key]
        data = self.mmap[offset:offset+size]
        data = decode_jpeg(data)
        data = torch.from_numpy(data)#.permute(2, 0, 1)
        if self.inner_transform:
            data = self.inner_transform(data)

        return data, id

class ImageDatasetBuilder():
    def __init__(self, folder_path, name, threads=None):
        self.folder_path = Path(folder_path)
        self.dataset_name = name + ".ds"
        self.index_name = name + ".index"
        self.dataset_path = self.folder_path / self.dataset_name
        self.index_path = self.folder_path / self.index_name
        self.dataset = None
        self.index = None
        self.threads = threads

    @property
    def is_open(self):
        self.dataset is not None or self.index is not None

    @property
    def is_close(self):
        self.dataset is None or self.index is None

    def build(self):
        #be careful with not nuking the files if they exist
        if self.is_open:
            raise Exception("Dataset already built")
        
        self.folder_path.mkdir(parents=True, exist_ok=True)
        dataset = open(self.dataset_path, mode="ab+")
        dataset.flush()
        self.index = []

    def open(self, overwrite=False):
        if overwrite is False and self.is_open:
            raise Exception("A dataset is already open! If you wish to continue set overwrite to True.")

        if overwrite is True and self.is_open:
            self.close()
            print("Dataset closed and flushed.")

        if not self.dataset_path.is_file() or not self.index_path.is_file():
            raise Exception("Dataset files not found")

        self.dataset = open(self.dataset_path, mode="ab+")
        with open(self.index_name, 'rb') as f:
            self.index = pickle.load(f)

    def operate(self, operation, data_batch, identities):
        executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
        futures = executor.map(operation, data_batch)
        futures = list(futures)
            
        for data, identity in zip(futures, identities):
            self.write(data, identity)
    
    def encode_op(self, data):
        if simplejpeg.is_jpeg(data):
            pass
        else:
            data = simplejpeg.encode_jpeg(data, quality=91)

        return data

    def write(self, data, identity, flush=False):
        if self.is_close:
            raise Exception("Dataset not built")

        self.dataset.write(data)
        self.index.append([self.dataset.tell(), len(data), identity])
        if flush:
            self.dataset.flush()

    def flush_index(self):
        if self.is_close:
            raise Exception("Dataset not built")

        with open(self.index_name, 'wb') as f:
            pickle.dump(self.index, f)

    def flush(self):
        if self.is_close:
            raise Exception("Dataset not built")

        self.dataset.flush()

    def close(self):
        if self.is_close:
            raise Exception("Dataset not built")
        #close the dataset filehandle and dump the pickle index
        self.flush()
        self.dataset.close()
        self.flush_index()