import numpy as np
import torch
import mmap
import concurrent
from torch.utils import data
import pickle
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import as_completed
import requests
import hashlib
import io
import os
from simplejpeg import decode_jpeg
import simplejpeg
from PIL import Image

# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
    def __init__(self, block_size, map_file, max_samples=None, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        self.samples = self.npz.shape[0]
        if max_samples is not None:
            self.samples = min(self.samples, int(max_samples))
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedDataset(data.Dataset):
    def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
        self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
        #might want to pad later
        self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
        #shard
        self.npz = self.npz[rank::world_size]
        self.samples = self.npz.shape[0]
        self.skip = skip

    def __len__(self):
        return self.samples

    def __getitem__(self, _id):
        nth = _id + self.skip
        data = torch.tensor(self.npz[nth].astype(np.int64))
        return (data[:-1], data[1:])

class ShardedImageDataset(data.Dataset):
    def __init__(self, dataset_path: str, name:str, index_path:str=None, shuffle=False, metadata_path=None, threads=None, inner_transform=None,
        outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
        self.skip = skip
        self.threads = threads
        self.bsz = bsz
        #for one by one transforms because images can't be batched
        self.inner_transform = inner_transform
        #for batched transforms after images become batchable
        self.outer_transform = outer_transform
        self.dataset_path = Path(dataset_path)
        if index_path is None:
            self.index_path = self.dataset_path / f"{name}.index"
        else:
            self.index_pth = Path(index_path)
            
        self.pointer_path = self.dataset_path / f"{name}.pointer"
        self.dataset_path = self.dataset_path / f"{name}.ds"
        self.world_size = world_size
        self.local_rank = local_rank
        self.global_rank = global_rank
        self.device = device
        with open(self.index_path, 'rb') as f:
            self.index = pickle.load(f)

        if metadata_path:
            with open(metadata_path, 'rb') as f:
                self.metadata = pickle.load(f)

        with open(self.dataset_path, mode="r") as file_obj:
            self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)

        #precompute pointer lookup dict for faster random read
        if not self.pointer_path.is_file():
            self.pointer_lookup = {}
            for t in tqdm(self.index):
                offset, length, id = t
                self.pointer_lookup[id] = (offset, length)

            with open(self.pointer_path, 'wb') as f:
                pickle.dump(self.pointer_lookup, f)

        else:
            with open(self.pointer_path, 'rb') as f:
                self.pointer_lookup = pickle.load(f)
        #make so metadata is shardable by world_size(num_gpus)
        #and batch_size
        self.original_index = self.index
        self.shard(shuffle=shuffle)
        #override possible gil locks by making the index map an nparray
        self.index = np.array(self.index)
        self.ids = self.index.transpose(1, 0)[2]

        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)

    def __len__(self):
        return len(self.index) // self.bsz

    def shard(self, shuffle=False, epoch=1, seed=69):
        #get numpy random state
        state = np.random.get_state()
        #set numpy seed
        np.random.seed(seed)
        #use this function to shuffle every new epoch as well.
        self.index = self.original_index
        if shuffle:
            #repeat index n_epoch times
            self.index = np.repeat(self.index, epoch, axis=0)
            #shuffle the index
            self.index = np.random.permutation(self.index)

        self.index = self.index[:len(self.index) - (len(self.index) % (self.bsz * self.world_size))]
        self.index = self.index[self.global_rank::self.world_size]
        #reset numpy random state
        np.random.set_state(state)
    
    def __getitem__(self, key):
        key = self.skip + key
        keys = [*range(key, key+self.bsz)]
        tensors = self.executor.map(self.read_from_index_key, keys)
        tensors = list(tensors)
        #make sure these operations are fast!
        ids = [t[1] for t in tensors]
        tensors = torch.stack([t[0] for t in tensors])
        if self.device == "cuda":
            tensors = tensors.to(self.local_rank)

        tensors = tensors.float()#permute#(0, 3, 1, 2).float() / 255.0
        tensors = tensors / 127.5 - 1
        #####################################
        if self.outer_transform:
            tensors = self.outer_transform(tensors)
            
        return tensors, ids

    def read_from_index_key(self, key):
        offset, size, id = self.index[key]
        data = self.mmap[offset:offset+size]
        data = decode_jpeg(data)
        if self.inner_transform:
            data = self.inner_transform(data)

        data = torch.from_numpy(data)#.permute(2, 0, 1)
        return data, id

    def read_from_id(self, id, decode=True):
        offset, size = self.pointer_lookup[id]
        data = self.mmap[offset:offset+size]
        if decode:
            data = decode_jpeg(data)
        return data

    def get_metadata(self, id):
        return self.metadata[id]

class CPUTransforms():
    def __init__(self, threads=None):
        self.threads=None

    @staticmethod
    def scale(data, res, pil=True):
        #scale can be an int or a tuple(h, w)
        #if it's int preserve aspect ratio
        #use opencv2
        #data.shape = (h, w, c)
        h, w = data.shape[:2]
        #w, h = data.size

        if isinstance(res, int):
            if h > w:
                #get the scale needed to make the width match the target
                scale = res / w
                hw = (res, int(h*scale))

            elif h == w:
                hw = (res, res)
            
            else:
                #get the scale needed to make the height match the target
                scale = res / h
                hw = (int(w*scale), res)

        if pil:
            data = Image.fromarray(data)
            data = data.resize(hw, Image.LANCZOS)
            data = np.asarray(data)
        else:
            data = cv2.resize(data, hw, interpolation=cv2.INTER_AREA)
        return data
    
    @staticmethod
    def centercrop(data, res: int):
        h_offset = (data.shape[0] - res) // 2
        w_offset = (data.shape[1] - res) // 2
        data = data[h_offset:h_offset+res, w_offset:w_offset+res]
        return data

    @staticmethod
    def cast_to_rgb(data, pil=False):
        if len(data.shape) < 3:
            data = np.expand_dims(data, axis=2)
            data = np.repeat(data, 3, axis=2)
            return data
        if data.shape[2] == 1:
            data = np.repeat(data, 3, axis=2)
            return data
        if data.shape[2] == 3:
            return data
        if data.shape[2] == 4:
            #Alpha blending, remove alpha channel and blend in with white
            png = Image.fromarray(data) # ->Fails here because image is uint16??

            background = Image.new('RGBA', png.size, (255,255,255))
            alpha_composite = Image.alpha_composite(background, png)
            data = np.asarray(alpha_composite)
            '''
            data = data.astype(np.float32)
            data = data / 255.0
            alpha = data[:,:,[3,3,3]]
            data  = data[:,:,:3]
            ones = np.ones_like(data)
            data = (data * alpha) + (ones * (1-alpha))
            data = data * 255.0
            data = np.clip(data, 0, 255)
            data = data.astype(np.uint8)
            '''
            return data
        else:
            return data

    @staticmethod
    def randomcrop(data, res):
        h, w = data.shape[:2]
        if h - res > 0:
            h_offset = np.random.randint(0, h - res)
        else:
            h_offset = 0

        if w - res > 0:
            w_offset = np.random.randint(0, w - res)
        else:
            w_offset = 0
        data = data[h_offset:h_offset+res, w_offset:w_offset+res]
        return data

class ImageDatasetBuilder():
    def __init__(self, folder_path, name, dataset=True, index=True, metadata=False, threads=None, block_size=4096, align_fs_blocks=True):
        self.folder_path = Path(folder_path)
        self.dataset_name = name + ".ds"
        self.index_name = name + ".index"
        self.metadata_name = name + ".metadata"
        self.index_name_temp = name + ".temp.index"
        self.metadata_name_temp = name + ".temp.metadata"
        self.dataset_path = self.folder_path / self.dataset_name
        self.index_path = self.folder_path / self.index_name
        self.metadata_path = self.folder_path / self.metadata_name
        self.index_path_temp = self.folder_path / self.index_name_temp
        self.metadata_path_temp = self.folder_path / self.metadata_name_temp
        self.open_dataset = dataset
        self.open_index = index
        self.open_metadata = metadata
        self.dataset = None
        self.index = None
        self.metadata = None
        self.threads = threads
        self.block_size = block_size
        self.align_fs_blocks = align_fs_blocks

    @property
    def is_open(self):
        self.dataset is not None or self.index is not None or self.metadata is not None

    @property
    def is_close(self):
        self.dataset is None or self.index is None

    @property
    def biggest_id(self):
        try:
            return np.max(self.np_index[:, 2])
        except:
            return -1
            
    @property
    def biggest_item(self):
        try:
            return np.max(self.np_index[:, 1])
        except:
            return -1
    @property
    def total_ids(self):
        try:
            return len(self.np_index)
        except:
            return -1

    @property
    def np_index(self):
        return np.array(self.index)

    def build(self):
        #be careful with not nuking the files if they exist
        if self.is_open:
            raise Exception("Dataset already built")
        
        self.folder_path.mkdir(parents=True, exist_ok=True)
        if self.open_dataset:
            self.dataset = open(self.dataset_path, mode="ab+")
            self.dataset.flush()

        if self.open_index:
            self.index = []

        if self.open_metadata:
            self.metadata = {}

    def open(self, overwrite=False):
        if overwrite is False and self.is_open:
            raise Exception("A dataset is already open! If you wish to continue set overwrite to True.")

        if overwrite is True:
            self.close(silent=True)
            self.flush_index(silent=True)
            self.flush_metadata(silent=True)
            print("Dataset closed and flushed.")
        
        if self.open_dataset and self.dataset_path.is_file():
            self.dataset = open(self.dataset_path, mode="ab+")
        else:
            raise Exception("Dataset file not found at {}".format(self.dataset_path))
            
        if self.open_index and self.index_path.is_file():
            with open(self.index_path, 'rb') as f:
                self.index = pickle.load(f)
        else:
            raise Exception("Index file not found at {}".format(self.index_path))
        
        if self.open_metadata and self.metadata_path.is_file():
            with open(self.metadata_path, 'rb') as f:
                self.metadata = pickle.load(f)
        else:
            raise Exception("Metadata file not found at {}".format(self.metadata_path))

    def operate(self, operation, batch, identities, metadata=None, executor=concurrent.futures.ThreadPoolExecutor, use_tqdm=False, **kwargs):
        executor = executor(max_workers=self.threads)
        futures = executor.map(operation, batch)
        if use_tqdm:
            futures = tqdm(futures, total=len(batch), leave=False)
        futures = list(futures)

        for data, identity in zip(futures, identities):
            self.write(data, identity)
    
    def encode_op(self, data):
        if simplejpeg.is_jpeg(data):
            try:
                simplejpeg.decode(data)
            except:
                return None
        else:
            data = Image.open(io.BytesIO(data))
            data = np.asarray(data)
            data = simplejpeg.encode_jpeg(data, quality=91)

        return data

    def url_op(self, url, md5):
        result = requests.get(url)
        for _ in range(5):
            if result.status_code == 200:
                break
        
        if result.status_code != 200:
            return None

        data = result.content
        saved_md5 = hashlib.md5(data)
        if saved_md5 != md5:
            return None
        data = self.encode_op(data)
        return data

    def write(self, data, identity, metadata=None, flush=False):
        if self.is_close:
            raise Exception("Dataset not built")

        if data == None:
            return

        data_ptr = self.dataset.tell()
        data_len = len(data)
        self.index.append([data_ptr, data_len, identity])
        self.dataset.write(data)

        # block align
        if self.align_fs_blocks:
            remainder = (data_ptr + data_len) % self.block_size
            if remainder != 0:
                self.dataset.write(bytearray(self.block_size - remainder))

        if self.metadata and metadata:
            self.metadata[identity] = metadata

        if flush:
            self.flush()

    def write_metadata(self, id, metadata):
        self.metadata[id] = metadata

    def flush_index(self, silent=False):
        if not self.index and not silent:
            print("Warning: Index not built, couldn't flush")
            return

        with open(self.index_path_temp, 'wb') as f:
            pickle.dump(self.index, f)

        try:
            os.remove(self.index_path)
        except: pass
        os.rename(self.index_path_temp, self.index_path)
    
    def flush_metadata(self, silent=False):
        if not self.metadata and not silent:
            print("Warning: Metadata not built, couldn't flush")
            return

        with open(self.metadata_path_temp, 'wb') as f:
            pickle.dump(self.metadata, f)

        try:
            os.remove(self.metadata_path)
        except: pass
        os.rename(self.metadata_path_temp, self.metadata_path)

    def flush(self, silent=False):
        if not self.dataset and not silent:
            print("Warning: Dataset not built, couldn't flush")
            return

        self.dataset.flush()

    def close(self, silent=False):
        if not self.dataset and not silent:
            print("Warning: Dataset not built, couldn't flush")
            return

        #close the dataset filehandle and dump the pickle index
        self.flush()
        self.dataset.close()