from fastapi import FastAPI, Request
from pydantic import BaseModel
from fastapi.responses import HTMLResponse, PlainTextResponse, Response
from fastapi.middleware.cors import CORSMiddleware
from sentry_sdk import capture_exception
from sentry_sdk import capture_message
from sentry_sdk import start_transaction
from hydra_node.config import init_config_model
from typing import Optional, List
import socket
from hydra_node.sanitize import sanitize_input
import uvicorn
from typing import Union
import time
import gc
import os
import io
import signal
import simplejpeg
import base64
import traceback
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import json

#Initialize model and config
model, config, model_hash = init_config_model()
logger = config.logger
config.mainpid = int(open("gunicorn.pid", "r").read())
mainpid = config.mainpid
hostname = socket.gethostname()
sent_first_message = False

#Initialize fastapi
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=False,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.on_event("startup")
def startup_event():
    logger.info("FastAPI Started, serving")

@app.on_event("shutdown")
def shutdown_event():
    logger.info("FastAPI Shutdown, exiting")

@app.get("/", response_class=PlainTextResponse)
def root():
    return "OK"

class GenerationRequest(BaseModel):
    prompt: str
    image: str = None
    n_samples: int = 1
    steps: int = 50
    sampler: str = "plms"
    fixed_code: bool = False
    ddim_eta: float = 0.0
    height: int = 512
    width: int = 512
    latent_channels: int = 4
    downsampling_factor: int = 8
    scale: float = 7.0
    dynamic_threshold: float = None
    seed: int = None
    temp: float = 1.0
    top_k: int = 256
    grid_size: int = 4
    advanced: bool = False
    stage_two_seed: int = None
    strength: float = 0.69
    noise: float = 0.667
    mitigate: bool = False
    module: str = None

class GenerationOutput(BaseModel):
    output: List[str]

class ErrorOutput(BaseModel):
    error: str

@app.post('/generate-stream')
def generate(request: GenerationRequest):
    t = time.perf_counter()
    try:
        output = sanitize_input(config, request)
        
        if output[0]:
            request = output[1]
        else:
            return ErrorOutput(error=output[1])
        
        if request.advanced:
            if request.n_samples > 1:
                return ErrorOutput(error="advanced mode does not support n_samples > 1")

            images = model.sample_two_stages(request)
        else:
            images = model.sample(request)

        seed = request.seed

        images_encoded = []
        for x in range(len(images)):
            if seed is not None:
                request.seed = seed
                seed += 1
            comment = json.dumps({"steps":request.steps,"sampler":request.sampler,"seed":request.seed,"strength":request.strength,"noise":request.noise,"scale":request.scale})
            metadata = PngInfo()
            metadata.add_text("Title", "AI generated image")
            metadata.add_text("Description", request.prompt)
            metadata.add_text("Software", "NovelAI")
            metadata.add_text("Source", "Stable Diffusion "+model_hash)
            metadata.add_text("Comment", comment)
            #image = simplejpeg.encode_jpeg(images[x], quality=95) 
            image = Image.fromarray(images[x])
            #save pillow image with bytesIO
            output = io.BytesIO()
            image.save(output, format='PNG', pnginfo=metadata)
            image = output.getvalue()
            #get base64 of image
            image = base64.b64encode(image).decode("ascii")
            images_encoded.append(image)

        del images

        process_time = time.perf_counter() - t
        logger.info(f"Request took {process_time:0.3f} seconds")
        data = ""
        ptr = 0
        for x in images_encoded:
            ptr += 1
            data += ("event: newImage\nid: {}\ndata:{}\n\n").format(ptr, x)
        return Response(content=data, media_type="text/event-stream")
        #return GenerationOutput(output=images)

    except Exception as e:
        traceback.print_exc()
        capture_exception(e)
        logger.error(str(e))
        e_s = str(e)
        gc.collect()
        if "CUDA out of memory" in e_s or \
                "an illegal memory access" in e_s or "CUDA" in e_s:
            logger.error("GPU error, committing seppuku.")
            os.kill(mainpid, signal.SIGTERM)
        return {"error": str(e)}

'''
@app.post('/image-to-image')
def image_to_image(request: GenerationRequest):
    #prompt is a base64 encoded image
    try:
        output = sanitize_input(config, request)
        
        if output[0]:
            request = output[1]
        else:
            return ErrorOutput(error=output[1])
        
        image = base64.b64decode(request.prompt)
        image = simplejpeg.decode_jpeg(image)
        image = model.image_to_image(image, request)
        image = simplejpeg.encode_jpeg(image, quality=95) 
        #get base64 of image
        image = base64.b64encode(image).decode("ascii")
        return GenerationOutput(output=[image])
'''

@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
def generate(request: GenerationRequest):
    t = time.perf_counter()
    try:
        output = sanitize_input(config, request)
        
        if output[0]:
            request = output[1]
        else:
            return ErrorOutput(error=output[1])
        
        images = model.sample(request)
        images_encoded = []
        for x in range(len(images)):
            #image = simplejpeg.encode_jpeg(images[x], quality=95) 
            image = Image.fromarray(images[x])
            #save pillow image with bytesIO
            output = io.BytesIO()
            image.save(output, format='PNG')
            image = output.getvalue()
            #get base64 of image
            image = base64.b64encode(image).decode("ascii")
            images_encoded.append(image)

        del images

        process_time = time.perf_counter() - t
        logger.info(f"Request took {process_time:0.3f} seconds")
        return GenerationOutput(output=images_encoded)

    except Exception as e:
        traceback.print_exc()
        capture_exception(e)
        logger.error(str(e))
        e_s = str(e)
        gc.collect()
        if "CUDA out of memory" in e_s or \
                "an illegal memory access" in e_s or "CUDA" in e_s:
            logger.error("GPU error, committing seppuku.")
            os.kill(mainpid, signal.SIGTERM)
        return {"error": str(e)}

if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=80, log_level="info")