from re import S
import torch
import torch.nn as nn
from pathlib import Path
from omegaconf import OmegaConf
from dotmap import DotMap
import numpy as np
import base64
from torch import autocast
from einops import rearrange, repeat
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.modules.attention import CrossAttention, HyperLogic
import time
from PIL import Image
import k_diffusion as K

def pil_upscale(image, scale=1):
    device = image.device
    dtype = image.dtype
    image = Image.fromarray((image.cpu().permute(1,2,0).numpy().astype(np.float32) * 255.).astype(np.uint8))
    if scale > 1:
        image = image.resize((int(image.width * scale), int(image.height * scale)), resample=Image.LANCZOS)
    image = np.array(image)
    image = image.astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    image = 2.*image - 1.
    image = repeat(image, '1 ... -> b ...', b=1)
    return image.to(device)

def fix_batch(tensor, bs):
    return torch.stack([tensor.squeeze(0)]*bs, dim=0)

# mix conditioning vectors for prompts
# @aero
def prompt_mixing(model, prompt_body, batch_size):
    if "|" in prompt_body:
        prompt_parts = prompt_body.split("|")
        prompt_total_power = 0
        prompt_sum = None
        for prompt_part in prompt_parts:
            prompt_power = 1
            if ":" in prompt_part:
                prompt_sub_parts = prompt_part.split(":")
                try:
                    prompt_power = float(prompt_sub_parts[1])
                    prompt_part = prompt_sub_parts[0]
                except:
                    print("Error parsing prompt power! Assuming 1")
            prompt_vector = model.get_learned_conditioning([prompt_part])
            if prompt_sum is None:
                prompt_sum = prompt_vector * prompt_power
            else:
                prompt_sum = prompt_sum + (prompt_vector * prompt_power)
            prompt_total_power = prompt_total_power + prompt_power
        return fix_batch(prompt_sum / prompt_total_power, batch_size)
    else:
        return fix_batch(model.get_learned_conditioning([prompt_body]), batch_size)

def sample_start_noise(seed, C, H, W, f, device="cuda"):
    if seed:
        torch.manual_seed(seed)
        np.random.seed(seed)

    noise = torch.randn([C, (H) // f, (W) // f], device=device).unsqueeze(0)
    return noise

@torch.no_grad()
#@torch.autocast("cuda", enabled=True, dtype=torch.float16)
def encode_image(image, model):
    if isinstance(image, Image.Image):
        image = np.asarray(image)
        image = torch.from_numpy(image).clone()
    
    if isinstance(image, np.ndarray):
        image = torch.from_numpy(image)

    #gets image as numpy array and returns as tensor
    def preprocess_vqgan(x):
        x = x / 255.0
        x = 2.*x - 1.
        return x

    image = image.permute(2, 0, 1).unsqueeze(0).float().cuda()
    image = preprocess_vqgan(image)
    image = model.encode(image).sample()

    return image

@torch.no_grad()
def decode_image(image, model):
    def custom_to_pil(x):
        x = x.detach().float().cpu()
        x = torch.clamp(x, -1., 1.)
        x = (x + 1.)/2.
        x = x.permute(1,2,0).numpy()
        x = (255*x).astype(np.uint8)
        x = Image.fromarray(x)
        if not x.mode == "RGB":
            x = x.convert("RGB")
        return x
    
    image = model.decode(image)
    image = image.squeeze(0)
    image = custom_to_pil(image)
    return image

class VectorAdjustPrior(nn.Module):
    def __init__(self, hidden_size, inter_dim=64):
        super().__init__()
        self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True)
        self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True)

    def forward(self, z):
        b, s = z.shape[0:2]
        x1 = torch.mean(z, dim=1).repeat(s, 1)
        x2 = z.reshape(b*s, -1)
        x = torch.cat((x1, x2), dim=1)
        x = self.vector_proj(x)
        x = torch.cat((x2, x), dim=1)
        x = self.out_proj(x)
        x = x.reshape(b, s, -1)
        return x

    @classmethod
    def load_model(cls, model_path, hidden_size=768, inter_dim=64):
        model = cls(hidden_size=hidden_size, inter_dim=inter_dim)
        model.load_state_dict(torch.load(model_path)["state_dict"])
        return model

class StableInterface(nn.Module):
    def __init__(self, model, thresholder = None):
        super().__init__()
        self.inner_model = model
        self.sigma_to_t = model.sigma_to_t
        self.thresholder = thresholder
        self.get_sigmas = model.get_sigmas

    @torch.no_grad()
    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_two = torch.cat([x] * 2)
        sigma_two = torch.cat([sigma] * 2)
        cond_full = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_two, sigma_two, cond=cond_full).chunk(2)
        x_0 = uncond + (cond - uncond) * cond_scale
        if self.thresholder is not None:
            x_0 = self.thresholder(x_0)

        return x_0

class StableDiffusionModel(nn.Module):
    def __init__(self, config):
        nn.Module.__init__(self)
        self.config = config
        self.premodules = None
        model, model_config = self.from_folder(config.model_path)
        if config.dtype == "float16":
            typex = torch.float16
        else:
            typex = torch.float32
        self.model = model.to(config.device).to(typex)
        self.k_model = K.external.CompVisDenoiser(model)
        self.k_model = StableInterface(self.k_model)
        self.device = config.device
        self.model_config = model_config
        self.plms = PLMSSampler(model)
        self.ddim = DDIMSampler(model)
        self.sampler_map = {
            'plms': self.plms.sample,
            'ddim': self.ddim.sample,
            'k_euler': K.sampling.sample_euler,
            'k_euler_ancestral': K.sampling.sample_euler_ancestral,
            'k_heun': K.sampling.sample_heun,
            'k_dpm_2': K.sampling.sample_dpm_2,
            'k_dpm_2_ancestral': K.sampling.sample_dpm_2_ancestral,
            'k_lms': K.sampling.sample_lms,
        }
        if config.prior_path:
            self.prior = VectorAdjustPrior.load_model(config.prior_path).to(self.device)

    def from_folder(self, folder):
        folder = Path(folder)
        model_config = OmegaConf.load(folder / "config.yaml")
        if (folder / "pruned.ckpt").is_file():
            model_path = folder / "pruned.ckpt"
        else:
            model_path = folder / "model.ckpt"
        model = self.load_model_from_config(model_config, model_path)
        return model, model_config

    def load_model_from_config(self, config, ckpt, verbose=False):
        print(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt, map_location="cpu")
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
        sd = pl_sd["state_dict"]
        model = instantiate_from_config(config.model)
        m, u = model.load_state_dict(sd, strict=False)
        if len(m) > 0 and verbose:
            print("missing keys:")
            print(m)
        if len(u) > 0 and verbose:
            print("unexpected keys:")
            print(u)

        model.eval()
        return model
    
    @torch.no_grad()
    @torch.autocast("cuda", enabled=True, dtype=torch.float16)
    def sample(self, request):
        if request.module is not None:
            module = self.premodules[request.module]
            CrossAttention.set_hypernetwork(module)

        if request.seed is not None:
            torch.manual_seed(request.seed)
            np.random.seed(request.seed)

        if request.image is not None:
            #request.sampler = "ddim_img2img" #enforce ddim for now
            if request.sampler == "plms":
                request.sampler = "k_lms"
            if request.sampler == "ddim":
                request.sampler = "k_lms"

            self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
            start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
            start_code = self.model.get_first_stage_encoding(start_code)
            start_code = torch.repeat_interleave(start_code, request.n_samples, dim=0)

            main_noise = []
            start_noise = []
            for seed in range(request.seed, request.seed+request.n_samples):
                main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
                start_noise.append(sample_start_noise(None, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))

            main_noise = torch.cat(main_noise, dim=0)
            start_noise = torch.cat(start_noise, dim=0)

            start_code = start_code + (start_noise * request.noise)
            t_enc = int(request.strength * request.steps)

        if request.sampler.startswith("k_"):
            sampler = "k-diffusion"
        
        elif request.sampler == 'ddim_img2img':
            sampler = 'img2img'

        else:
            sampler = "normal"

        if request.image is None:
            main_noise = []
            for seed in range(request.seed, request.seed+request.n_samples):
                main_noise.append(sample_start_noise(seed, request.latent_channels, request.height, request.width, request.downsampling_factor, self.device))
            
            main_noise = torch.cat(main_noise, dim=0)
            start_code = main_noise
        
        prompt = [request.prompt] * request.n_samples
        prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
        if hasattr(self, "prior") and request.mitigate:
            prompt_condition = self.prior(prompt_condition)

        uc = None
        if request.scale != 1.0:
            uc = self.model.get_learned_conditioning(request.n_samples * [""])

        shape = [
            request.latent_channels,
            request.height // request.downsampling_factor,
            request.width // request.downsampling_factor
        ]
        if sampler == "normal":
            with self.model.ema_scope():
                samples, _ = self.sampler_map[request.sampler](
                    S=request.steps,
                    conditioning=prompt_condition,
                    batch_size=request.n_samples,
                    shape=shape,
                    verbose=False,
                    unconditional_guidance_scale=request.scale,
                    unconditional_conditioning=uc,
                    eta=request.ddim_eta,
                    dynamic_threshold=request.dynamic_threshold,
                    x_T=start_code,
                )

        elif sampler == "k-diffusion":
            with self.model.ema_scope():
                sigmas = self.k_model.get_sigmas(request.steps)
                if request.image is not None:
                    noise = main_noise * sigmas[request.steps - t_enc - 1]
                    start_code = start_code + noise
                    sigmas = sigmas[request.steps - t_enc - 1:]

                else:
                    start_code = start_code * sigmas[0]

                extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
                samples = self.sampler_map[request.sampler](self.k_model, start_code, sigmas, extra_args=extra_args)

        x_samples_ddim = self.model.decode_first_stage(samples)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

        images = []
        for x_sample in x_samples_ddim:
            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
            x_sample = x_sample.astype(np.uint8)
            x_sample = np.ascontiguousarray(x_sample)
            images.append(x_sample) 

        if request.seed is not None:
            torch.seed()
            np.random.seed()

        return images

    @torch.no_grad()
    def sample_two_stages(self, request):
        request = DotMap(request)
        if request.seed is not None:
            torch.manual_seed(request.seed)
            np.random.seed(request.seed)

        if request.plms:
            sampler = self.plms
        else:
            sampler = self.ddim

        start_code = None
        if request.fixed_code:
            start_code = torch.randn([
                request.n_samples,
                request.latent_channels,
                request.height // request.downsampling_factor,
                request.width // request.downsampling_factor,
                ], device=self.device)
        
        prompt = [request.prompt] * request.n_samples
        prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)

        uc = None
        if request.scale != 1.0:
            uc = self.model.get_learned_conditioning(request.n_samples * [""])

        shape = [
            request.latent_channels,
            request.height // request.downsampling_factor,
            request.width // request.downsampling_factor
        ]
        with torch.autocast("cuda", enabled=self.config.amp):
            with self.model.ema_scope():
                samples, _ = sampler.sample(
                    S=request.steps,
                    conditioning=prompt_condition,
                    batch_size=request.n_samples,
                    shape=shape,
                    verbose=False,
                    unconditional_guidance_scale=request.scale,
                    unconditional_conditioning=uc,
                    eta=request.ddim_eta,
                    dynamic_threshold=request.dynamic_threshold,
                    x_T=start_code,
                )

        x_samples_ddim = self.model.decode_first_stage(samples)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).squeeze(0)
        x_samples_ddim = pil_upscale(x_samples_ddim, scale=2)

        if request.stage_two_seed is not None:
            torch.manual_seed(request.stage_two_seed)
            np.random.seed(request.stage_two_seed)

        with torch.autocast("cuda", enabled=self.config.amp):
            with self.model.ema_scope():
                init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
                self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
                t_enc = int(request.strength * request.steps)

                print("init latent shape:")
                print(init_latent.shape)

                init_latent = init_latent + (torch.randn_like(init_latent) * request.noise)

                uc = None
                if request.scale != 1.0:
                    uc = self.model.get_learned_conditioning(request.n_samples * [""])
                
                prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)

                # encode (scaled latent)
                start_code_terped=None
                z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
                # decode it
                samples = self.ddim.decode(z_enc, prompt_condition, t_enc, unconditional_guidance_scale=request.scale,
                                        unconditional_conditioning=uc,)

                x_samples_ddim = self.model.decode_first_stage(samples)
                x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

        images = []
        for x_sample in x_samples_ddim:
            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
            x_sample = x_sample.astype(np.uint8)
            x_sample = np.ascontiguousarray(x_sample)
            images.append(x_sample) 

        if request.seed is not None:
            torch.seed()
            np.random.seed()

        return images

    @torch.no_grad()
    def sample_from_image(self, request):
        return

class DalleMiniModel(nn.Module):
    def __init__(self, config):
        nn.Module.__init__(self)
        from min_dalle import MinDalle

        self.config = config
        self.model = MinDalle(
            models_root=config.model_path,
            dtype=torch.float16,
            device='cuda',
            is_mega=True, 
            is_reusable=True
        )

    @torch.no_grad()
    def sample(self, request):
        if request.seed is not None:
            seed = request.seed
        else:
            seed = -1

        images = self.model.generate_images(
            text=request.prompt,
            seed=seed,
            grid_size=request.grid_size,
            is_seamless=False,
            temperature=request.temp,
            top_k=request.top_k,
            supercondition_factor=request.scale,
            is_verbose=False
        )
        images = images.to('cpu').numpy()
        images = images.astype(np.uint8)
        images = np.ascontiguousarray(images)

        if request.seed is not None:
            torch.seed()
            np.random.seed()
            
        return images


