from fastapi import FastAPI, Request
from pydantic import BaseModel
from fastapi.responses import HTMLResponse, PlainTextResponse
from sentry_sdk import capture_exception
from sentry_sdk import capture_message
from sentry_sdk import start_transaction
from hydra_node.config import init_config_model
from typing import Optional, List
import socket
from hydra_node.sanitize import sanitize_input
import uvicorn
from typing import Union
import time
import gc
import os
import signal

#Initialize model and config
model, config = init_config_model()
logger = config.logger
#config.mainpid = open("app.pid", "r").read()

hostname = socket.gethostname()
sent_first_message = False

#Initialize fastapi
app = FastAPI()

@app.on_event("startup")
def startup_event():
    logger.info("FastAPI Started, serving")

@app.on_event("shutdown")
def shutdown_event():
    logger.info("FastAPI Shutdown, exiting")

@app.get("/", response_class=PlainTextResponse)
def root():
    return "OK"

class GenerationRequest(BaseModel):
    prompt: str
    n_samples: int = 1
    steps: int = 50
    plms: bool = True
    fixed_code: bool = False
    ddim_eta: float = 0.0
    height: int
    width: int
    latent_channels: int = 4
    downsampling_factor: int = 8
    scale: float = 7.0
    dynamic_threshold: float = None
    seed: int = None

class GenerationOutput(BaseModel):
    generation: List[str]

class ErrorOutput(BaseModel):
    error: str

@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
def generate(request: GenerationRequest):
    
    print(request)
    try:
        output = sanitize_input(request)
        
        if output[0]:
            request = output[1]
        else:
            return {'error': output[1]}
        
        images = model.sample(request)
        return {"generations": images}

    except Exception as e:
        capture_exception(e)
        logger.error(str(e))
        e_s = str(e)
        gc.collect()
        if "CUDA out of memory" in e_s or \
                "an illegal memory access" in e_s or "CUDA" in e_s:
            logger.error("GPU error, committing seppuku.")
            os.kill(os.getpid(), signal.SIGTERM)
        return {"error": str(e)}

@app.middleware("http")
async def handle_logging_and_errors(request: Request, call_next):
    t = time.perf_counter()
    response = await call_next(request)
    process_time = time.perf_counter() - t
    response.headers["X-Process-Time"] = str(process_time)
    logger.info(f"Request took {t:0.3f} seconds")

    f = open("/tmp/health_readiness", "w")
    f.close()

    if os.environ['DEV'] == "False":
        f = open("/tmp/healthy", "w")
        f.close()

    return response

if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=80, log_level="info")