import numpy as np
import torch
import mmap
import pickle
import concurrent
from torch.utils import data
from simplejpeg import decode_jpeg
import simplejpeg
import pickle
from pathlib import Path

# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
    def __init__(self, block_size, map_file, max_samples=None, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        self.samples = self.npz.shape[0]
        if max_samples is not None:
            self.samples = min(self.samples, int(max_samples))
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedDataset(data.Dataset):
    def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        #might want to pad later
        self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
        #shard
        self.npz = self.npz[rank::world_size]
        self.samples = self.npz.shape[0]
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedImageDataset(data.Dataset):
    def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None,
        outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):

        self.skip = skip
        self.threads = threads
        self.bsz = bsz
        #for one by one transforms because images can't be batched
        self.inner_transform = inner_transform
        #for batched transforms after images become batchable
        self.outer_transform = outer_transform
        self.dataset_path = dataset_path
        self.world_size = world_size
        self.local_rank = local_rank
        self.global_rank = global_rank
        self.device = device
        with open(index_path, 'rb') as f:
            self.index = pickle.load(f)

        if metadata_path:
            with open(metadata_path, 'rb') as f:
                self.metadata = pickle.load(f)

        with open(self.dataset_path, mode="r") as file_obj:
            self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)

        #make so metadata is shardable by world_size(num_gpus)
        #and batch_size
        self.index = self.index[:len(self.index) - (len(self.index) % (bsz * world_size))]

        #shard the dataset according to the rank
        self.index = self.index[global_rank::world_size]
        #override possible gil locks by making the index map an nparray
        self.index = np.array(self.index)
        self.ids = self.index.transpose(1, 0)[2]
        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)

    def __len__(self):
        return len(self.index) // self.bsz
    
    def __getitem__(self, key):
        key = self.skip + key
        keys = [*range(key, key+self.bsz)]
        tensors = self.executor.map(self.read_from_index_key, keys)
        tensors = list(tensors)
        #make sure these operations are fast!
        ids = [t[1] for t in tensors]
        tensors = torch.stack([t[0] for t in tensors])
        if self.device == "cuda":
            tensors = tensors.to(self.local_rank)

        tensors = tensors.permute(0, 3, 1, 2).float()
        #####################################
        if self.outer_transform:
            tensors = self.outer_transform(tensors)
            
        return tensors, ids

    def read_from_index_key(self, key):
        offset, size, id = self.index[key]
        data = self.mmap[offset:offset+size]
        data = decode_jpeg(data)
        data = torch.from_numpy(data)#.permute(2, 0, 1)
        if self.inner_transform:
            data = self.inner_transform(data)

        return data, id

    def read_from_id(self, id):
        #to be used standalone
        offset, size, _ = self.index[self.ids == id][0]
        data = self.mmap[offset:offset+size]
        data = decode_jpeg(data)
        return data

    def get_metadata(self, id):
        return self.metadata[id]

class ImageDatasetBuilder():
    def __init__(self, folder_path, name, dataset=True, index=True, metadata=False, threads=None):
        self.folder_path = Path(folder_path)
        self.dataset_name = name + ".ds"
        self.index_name = name + ".index"
        self.metadata_name = name + ".metadata"
        self.dataset_path = self.folder_path / self.dataset_name
        self.index_path = self.folder_path / self.index_name
        self.metadata_path = self.folder_path / self.metadata_name
        self.open_dataset = dataset
        self.open_index = index
        self.open_metadata = metadata
        self.dataset = None
        self.index = None
        self.metadata = None
        self.threads = threads

    @property
    def is_open(self):
        self.dataset is not None or self.index is not None or self.metadata is not None

    @property
    def is_close(self):
        self.dataset is None or self.index is None

    @property
    def biggest_id(self):
        try:
            return np.max(self.np_index[:, 2])
        except:
            return -1
            
    @property
    def biggest_item(self):
        try:
            return np.max(self.np_index[:, 1])
        except:
            return -1
    @property
    def total_ids(self):
        try:
            return len(self.np_index)
        except:
            return -1

    @property
    def np_index(self):
        return np.array(self.index)

    def build(self):
        #be careful with not nuking the files if they exist
        if self.is_open:
            raise Exception("Dataset already built")
        
        self.folder_path.mkdir(parents=True, exist_ok=True)
        if self.open_dataset:
            dataset = open(self.dataset_path, mode="ab+")
            dataset.flush()

        if self.open_index:
            self.index = []

        if self.open_metadata:
            self.metadata = {}

    def open(self, overwrite=False):
        if overwrite is False and self.is_open:
            raise Exception("A dataset is already open! If you wish to continue set overwrite to True.")

        if overwrite is True:
            self.close(silent=True)
            self.flush_index(silent=True)
            self.flush_metadata(silent=True)
            print("Dataset closed and flushed.")
        
        if self.open_dataset and self.dataset_path.is_file():
            self.dataset = open(self.dataset_path, mode="ab+")
        else:
            raise Exception("Dataset file not found at {}".format(self.dataset_path))
            
        if self.open_index and self.index_path.is_file():
            with open(self.index_path, 'rb') as f:
                self.index = pickle.load(f)
        else:
            raise Exception("Index file not found at {}".format(self.index_path))
        
        if self.open_metadata and self.metadata_path.is_file():
            with open(self.metadata_path, 'rb') as f:
                self.metadata = pickle.load(f)
        else:
            raise Exception("Metadata file not found at {}".format(self.metadata_path))

    def operate(self, operation, batch, identities, metadata=None):
        executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
        futures = executor.map(operation, batch)
        futures = list(futures)
            
        for data, identity in zip(futures, identities):
            self.write(data, identity)
    
    def encode_op(self, data):
        if simplejpeg.is_jpeg(data):
            pass
        else:
            data = simplejpeg.encode_jpeg(data, quality=91)

        return data

    def write(self, data, identity, metadata=None, flush=False):
        if self.is_close:
            raise Exception("Dataset not built")

        self.dataset.write(data)
        self.index.append([self.dataset.tell(), len(data), identity])
        if self.metadata and metadata:
            self.metadata[identity] = metadata

        if flush:
            self.flush()

    def write_metadata(self, id, metadata):
        self.metadata[id] = metadata

    def flush_index(self, silent=False):
        if not self.index and not silent:
            print("Warning: Index not built, couldn't flush")
            return

        with open(self.index_path, 'wb') as f:
            pickle.dump(self.index, f)
    
    def flush_metadata(self, silent=False):
        if not self.metadata and not silent:
            print("Warning: Metadata not built, couldn't flush")
            return

        with open(self.metadata_path, 'wb') as f:
            pickle.dump(self.metadata, f)

    def flush(self, silent=False):
        if not self.dataset and not silent:
            print("Warning: Dataset not built, couldn't flush")
            return

        self.dataset.flush()

    def close(self, silent=False):
        if not self.dataset and not silent:
            print("Warning: Dataset not built, couldn't flush")
            return

        #close the dataset filehandle and dump the pickle index
        self.flush()
        self.dataset.close()