from re import S
import torch
import torch.nn as nn
from pathlib import Path
from omegaconf import OmegaConf
from dotmap import DotMap
import numpy as np
import base64
from torch import autocast
from einops import rearrange
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
import time

def no_init(loading_code):
    def dummy(self):
        return
    
    modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
    original = {}
    for mod in modules:
        original[mod] = mod.reset_parameters
        mod.reset_parameters = dummy
    
    result = loading_code()
    for mod in modules:
        mod.reset_parameters = original[mod]
    
    return result

class StableDiffusionModel(nn.Module):
    def __init__(self, config):
        nn.Module.__init__(self)
        self.config = config
        model, model_config = no_init(lambda: self.from_folder(config.model_path))
        if config.dtype == "float16":
            typex = torch.float16
        else:
            typex = torch.float32
        self.model = model.to(config.device).to(typex)
        self.model_config = model_config
        self.plms = PLMSSampler(model)
        self.ddim = DDIMSampler(model)

    def from_folder(self, folder):
        folder = Path(folder)
        model_config = OmegaConf.load(folder / "config.yaml")
        model = self.load_model_from_config(model_config, folder / "model.ckpt")
        return model, model_config

    def load_model_from_config(self, config, ckpt, verbose=False):
        self.config.logger.info(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt, map_location="cpu")
        if "global_step" in pl_sd:
            self.config.logger.info(f"Global Step: {pl_sd['global_step']}")
        sd = pl_sd["state_dict"]
        model = instantiate_from_config(config.model)
        m, u = model.load_state_dict(sd, strict=False)
        if len(m) > 0 and verbose:
            print("missing keys:")
            print(m)
        if len(u) > 0 and verbose:
            print("unexpected keys:")
            print(u)

        model.eval()
        return model
    
    @torch.no_grad()
    def sample(self, request):
        request = DotMap(request)
        if request.seed:
            torch.manual_seed(request.seed)
            np.random.seed(request.seed)

        if request.plms:
            sampler = self.plms
        else:
            sampler = self.ddim

        start_code = None
        if request.fixed_code:
            start_code = torch.randn([
                request.n_samples,
                request.latent_channels,
                request.height // request.downsampling_factor,
                request.width // request.downsampling_factor,
                ], device=self.device)
        
        prompt = [request.prompt] * request.n_samples
        prompt_condition = self.model.get_learned_conditioning(prompt)

        uc = None
        if request.scale != 1.0:
            uc = self.model.get_learned_conditioning(request.n_samples * [""])

        shape = [
            request.latent_channels,
            request.height // request.downsampling_factor,
            request.width // request.downsampling_factor
        ]
        with torch.autocast("cuda", enabled=self.config.amp):
            with self.model.ema_scope():
                samples, _ = sampler.sample(
                    S=request.steps,
                    conditioning=prompt_condition,
                    batch_size=request.n_samples,
                    shape=shape,
                    verbose=False,
                    unconditional_guidance_scale=request.scale,
                    unconditional_conditioning=uc,
                    eta=request.ddim_eta,
                    dynamic_threshold=request.dynamic_threshold,
                    x_T=start_code,
                )

        x_samples_ddim = self.model.decode_first_stage(samples)
        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

        images = []
        for x_sample in x_samples_ddim:
            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
            images.append(x_sample) 

        return images