Commit 33b3166f authored by novelailab's avatar novelailab

should work

parent 424ba3ef
FROM nvidia/cuda:11.3.1-base-ubuntu20.04
ENV DEBIAN_FRONTEND=noninteractive
#Install core packages
RUN apt-get update && apt-get install -y \
libncurses5 python3 python3-pip curl git apt-utils ssh ca-certificates \
tmux nano vim sudo bash rsync htop wget unzip python3.8-venv \
python3-virtualenv python3-distutils python3-numpy tini && \
update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \
update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1 && \
pip3 install --no-cache-dir --upgrade pip
#Install Python deps
RUN pip3 install --no-cache-dir dotmap icecream sentry-sdk numpy fastapi "uvicorn[standard]" gunicorn
RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
#Open ports
EXPOSE 8080
EXPOSE 80
EXPOSE 443
EXPOSE 4369
EXPOSE 5672
EXPOSE 25672
EXPOSE 15672
EXPOSE 15692
EXPOSE 50051
#Copy node src and run
WORKDIR /usr/src/app
COPY . .
CMD [ "gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:80" ]
\ No newline at end of file
......@@ -10,10 +10,13 @@ from dotmap import DotMap
from icecream import ic
from sentry_sdk import capture_exception
from sentry_sdk.integrations.threading import ThreadingIntegration
from hydra_node.models import StableDiffusionModel
def init_config_model():
config = DotMap()
config.model_type = "GPT"
config.dtype = os.getenv("DTYPE", "float16")
config.device = os.getenv("DEVICE", "cuda")
config.amp = os.getenv("AMP", False)
is_dev = ""
environment = "production"
if os.environ['DEV'] == "True":
......@@ -73,7 +76,7 @@ def init_config_model():
load_time = time.time()
try:
model = GPTModel(config)
model = StableDiffusionModel(config)
except Exception as e:
ic(e)
capture_exception(e)
......
......@@ -6,34 +6,55 @@ from omegaconf import OmegaConf
from dotmap import DotMap
import numpy as np
import base64
from torch import autocast
from einops import rearrange
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
import time
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
class StableDiffusionModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
model, model_config = self.from_folder(config.model_path)
self.model = model
self.config = model_config
self.config = config
model, model_config = no_init(lambda: self.from_folder(config.model_path))
if config.dtype == "float16":
typex = torch.float16
else:
typex = torch.float32
self.model = model.to(config.device).to(typex)
self.model_config = model_config
self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model)
@staticmethod
def from_folder(self, folder):
folder = Path(folder)
model_config = OmegaConf.load(folder / "config.yaml")
model = self.load_model_from_config(model_config, folder / "model.ckpt")
return model, model_config
@staticmethod
def load_model_from_config(self, config, ckpt, verbose=True):
print(f"Loading model from {ckpt}")
def load_model_from_config(self, config, ckpt, verbose=False):
self.config.logger.info(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
self.config.logger.info(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
......@@ -50,6 +71,10 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad()
def sample(self, request):
request = DotMap(request)
if request.seed:
torch.manual_seed(request.seed)
np.random.seed(request.seed)
if request.plms:
sampler = self.plms
else:
......@@ -76,18 +101,20 @@ class StableDiffusionModel(nn.Module):
request.height // request.downsampling_factor,
request.width // request.downsampling_factor
]
samples, _ = sampler.sample(
S=request.steps,
conditioning=prompt_condition,
batch_size=request.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
)
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
samples, _ = sampler.sample(
S=request.steps,
conditioning=prompt_condition,
batch_size=request.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
......@@ -100,5 +127,5 @@ class StableDiffusionModel(nn.Module):
x_sample = str(base64.b64encode(x_sample))
base_count += 1
images.append(x_sample)
return images
\ No newline at end of file
from dotmap import DotMap
v1pp_defaults = {
'steps': 50,
'plms': True,
'fixed_code': False,
'ddim_eta': 0.0,
'height': 512,
'width': 512,
'latent_channels': 4,
'downsampling_factor': 8,
'scale': 7.0,
'dynamic_threshold': None,
'seed': None,
}
v1pp_forced_defaults = {
'latent_channels': 4,
'downsampling_factor': 8,
}
defaults = {
'v1pp': (v1pp_defaults, v1pp_forced_defaults),
}
def sanitize_input(request):
"""
Sanitize the input data and set defaults
"""
request = DotMap(request)
default = defaults[request.model]
default, forced_default = default
for k, v in default.items():
if k not in request:
request[k] = v
for k, v in forced_default.items():
request[k] = v
if request.width * request.height == 0:
return False, "width and height must be non-zero"
if request.width <= 0:
return False, "width must be positive"
if request.height <= 0:
return False, "height must be positive"
if request.steps <= 0:
return False, "steps must be positive"
if request.ddim_eta < 0:
return False, "ddim_eta shouldn't be negative"
if request.scale < 1.0:
return False, "scale should be at least 1.0"
if request.dynamic_threshold is not None and request.dynamic_threshold < 0:
return False, "dynamic_threshold shouldn't be negative"
if request.width * request.height >= 2048*512:
return False, "width and height must be less than 2048*512"
return True, request
\ No newline at end of file
from fastapi import FastAPI
from fastapi import FastAPI, Request
from pydantic import BaseModel
from fastapi.responses import HTMLResponse
from fastapi.responses import HTMLResponse, PlainTextResponse
from sentry_sdk import capture_exception
from sentry_sdk import capture_message
from sentry_sdk import start_transaction
from hydra_node.config import init_config_model
from typing import Optional
from typing import Optional, List
import socket
from hydra_node.sanitize import sanitize_input
import uvicorn
from typing import Union
import time
import gc
import os
import signal
#Initialize model and config
model, config = init_config_model()
logger = config.logger
hostname = socket.gethostname()
sent_first_message = False
#Initialize fastapi
app = FastAPI()
@app.on_event("startup")
def startup_event():
print("Startup")
#model, config = init_config()
logger.info("FastAPI Started, serving")
@app.on_event("shutdown")
def shutdown_event():
print('Shutdown')
logger.info("FastAPI Shutdown, exiting")
@app.get("/", response_class=HTMLResponse)
@app.get("/", response_class=PlainTextResponse)
def root():
return "OK"
class GenerationRequest(BaseModel):
prompt: str
n_samples: int = 1
steps: int = None
plms: bool = None
fixed_code: bool = None
ddim_eta: float = None
height: int = None
width: int = None
latent_channels: int = None
downsampling_factor: int = None
scale: float = None
steps: int = 50
plms: bool = True
fixed_code: bool = False
ddim_eta: float = 0.0
height: int
width: int
latent_channels: int = 4
downsampling_factor: int = 8
scale: float = 7.0
dynamic_threshold: float = None
make_grid: bool = False
n_rows: int = None
seed: int = None
class GenerationOutput(BaseModel):
generation: str
version: int = 1
generation: List[str]
@app.post('/generate', response_model=GenerationOutput)
class ErrorOutput(BaseModel):
error: str
@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
def generate(request: GenerationRequest):
request = request.dict()
print(request)
try:
output = sanitize_input(request)
if output[0]:
request = output[1]
else:
return {'error': output[1]}
images = model.sample(request)
return {"generations": images}
except Exception as e:
capture_exception(e)
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
logger.error("GPU error, committing seppuku.")
os.kill(os.getpid(), signal.SIGTERM)
return {"error": str(e)}
@app.middleware("http")
def handle_logging_and_errors(request: Request, call_next):
t = time.perf_counter()
response = call_next(request)
process_time = time.time() - t
response.headers["X-Process-Time"] = str(process_time)
logger.info(f"Request took {t:.3f} seconds")
if not sent_first_message:
f = open("/tmp/health_readiness", "w")
f.close()
sent_first_message = True
if os.environ['DEV'] == "False":
f = open("/tmp/healthy", "w")
f.close()
return response
return {"generation": "Hello World"}
\ No newline at end of file
if __name__ == "__main__":
uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")
\ No newline at end of file
export DTYPE="float16"
export MODEL="test"
export DEV="True"
export MODEL_PATH="/home/xuser/diffusionstorage/workspace/kuru/stablediff/v1pp-flatline-pruned"
export SENTRY_URL="https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448"
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:80
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment