Commit 33b3166f authored by novelailab's avatar novelailab

should work

parent 424ba3ef
FROM nvidia/cuda:11.3.1-base-ubuntu20.04
ENV DEBIAN_FRONTEND=noninteractive
#Install core packages
RUN apt-get update && apt-get install -y \
libncurses5 python3 python3-pip curl git apt-utils ssh ca-certificates \
tmux nano vim sudo bash rsync htop wget unzip python3.8-venv \
python3-virtualenv python3-distutils python3-numpy tini && \
update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \
update-alternatives --install /usr/bin/pip pip /usr/bin/pip3 1 && \
pip3 install --no-cache-dir --upgrade pip
#Install Python deps
RUN pip3 install --no-cache-dir dotmap icecream sentry-sdk numpy fastapi "uvicorn[standard]" gunicorn
RUN pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
#Open ports
EXPOSE 8080
EXPOSE 80
EXPOSE 443
EXPOSE 4369
EXPOSE 5672
EXPOSE 25672
EXPOSE 15672
EXPOSE 15692
EXPOSE 50051
#Copy node src and run
WORKDIR /usr/src/app
COPY . .
CMD [ "gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:80" ]
\ No newline at end of file
...@@ -10,10 +10,13 @@ from dotmap import DotMap ...@@ -10,10 +10,13 @@ from dotmap import DotMap
from icecream import ic from icecream import ic
from sentry_sdk import capture_exception from sentry_sdk import capture_exception
from sentry_sdk.integrations.threading import ThreadingIntegration from sentry_sdk.integrations.threading import ThreadingIntegration
from hydra_node.models import StableDiffusionModel
def init_config_model(): def init_config_model():
config = DotMap() config = DotMap()
config.model_type = "GPT" config.dtype = os.getenv("DTYPE", "float16")
config.device = os.getenv("DEVICE", "cuda")
config.amp = os.getenv("AMP", False)
is_dev = "" is_dev = ""
environment = "production" environment = "production"
if os.environ['DEV'] == "True": if os.environ['DEV'] == "True":
...@@ -73,7 +76,7 @@ def init_config_model(): ...@@ -73,7 +76,7 @@ def init_config_model():
load_time = time.time() load_time = time.time()
try: try:
model = GPTModel(config) model = StableDiffusionModel(config)
except Exception as e: except Exception as e:
ic(e) ic(e)
capture_exception(e) capture_exception(e)
......
...@@ -6,34 +6,55 @@ from omegaconf import OmegaConf ...@@ -6,34 +6,55 @@ from omegaconf import OmegaConf
from dotmap import DotMap from dotmap import DotMap
import numpy as np import numpy as np
import base64 import base64
from torch import autocast
from einops import rearrange from einops import rearrange
from torchvision.utils import make_grid from torchvision.utils import make_grid
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
import time
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
class StableDiffusionModel(nn.Module): class StableDiffusionModel(nn.Module):
def __init__(self, config): def __init__(self, config):
nn.Module.__init__(self) nn.Module.__init__(self)
model, model_config = self.from_folder(config.model_path) self.config = config
self.model = model model, model_config = no_init(lambda: self.from_folder(config.model_path))
self.config = model_config if config.dtype == "float16":
typex = torch.float16
else:
typex = torch.float32
self.model = model.to(config.device).to(typex)
self.model_config = model_config
self.plms = PLMSSampler(model) self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model) self.ddim = DDIMSampler(model)
@staticmethod
def from_folder(self, folder): def from_folder(self, folder):
folder = Path(folder) folder = Path(folder)
model_config = OmegaConf.load(folder / "config.yaml") model_config = OmegaConf.load(folder / "config.yaml")
model = self.load_model_from_config(model_config, folder / "model.ckpt") model = self.load_model_from_config(model_config, folder / "model.ckpt")
return model, model_config return model, model_config
@staticmethod def load_model_from_config(self, config, ckpt, verbose=False):
def load_model_from_config(self, config, ckpt, verbose=True): self.config.logger.info(f"Loading model from {ckpt}")
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") self.config.logger.info(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
...@@ -50,6 +71,10 @@ class StableDiffusionModel(nn.Module): ...@@ -50,6 +71,10 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad() @torch.no_grad()
def sample(self, request): def sample(self, request):
request = DotMap(request) request = DotMap(request)
if request.seed:
torch.manual_seed(request.seed)
np.random.seed(request.seed)
if request.plms: if request.plms:
sampler = self.plms sampler = self.plms
else: else:
...@@ -76,6 +101,8 @@ class StableDiffusionModel(nn.Module): ...@@ -76,6 +101,8 @@ class StableDiffusionModel(nn.Module):
request.height // request.downsampling_factor, request.height // request.downsampling_factor,
request.width // request.downsampling_factor request.width // request.downsampling_factor
] ]
with torch.autocast("cuda", enabled=self.config.amp):
with self.model.ema_scope():
samples, _ = sampler.sample( samples, _ = sampler.sample(
S=request.steps, S=request.steps,
conditioning=prompt_condition, conditioning=prompt_condition,
......
from dotmap import DotMap
v1pp_defaults = {
'steps': 50,
'plms': True,
'fixed_code': False,
'ddim_eta': 0.0,
'height': 512,
'width': 512,
'latent_channels': 4,
'downsampling_factor': 8,
'scale': 7.0,
'dynamic_threshold': None,
'seed': None,
}
v1pp_forced_defaults = {
'latent_channels': 4,
'downsampling_factor': 8,
}
defaults = {
'v1pp': (v1pp_defaults, v1pp_forced_defaults),
}
def sanitize_input(request):
"""
Sanitize the input data and set defaults
"""
request = DotMap(request)
default = defaults[request.model]
default, forced_default = default
for k, v in default.items():
if k not in request:
request[k] = v
for k, v in forced_default.items():
request[k] = v
if request.width * request.height == 0:
return False, "width and height must be non-zero"
if request.width <= 0:
return False, "width must be positive"
if request.height <= 0:
return False, "height must be positive"
if request.steps <= 0:
return False, "steps must be positive"
if request.ddim_eta < 0:
return False, "ddim_eta shouldn't be negative"
if request.scale < 1.0:
return False, "scale should be at least 1.0"
if request.dynamic_threshold is not None and request.dynamic_threshold < 0:
return False, "dynamic_threshold shouldn't be negative"
if request.width * request.height >= 2048*512:
return False, "width and height must be less than 2048*512"
return True, request
\ No newline at end of file
from fastapi import FastAPI from fastapi import FastAPI, Request
from pydantic import BaseModel from pydantic import BaseModel
from fastapi.responses import HTMLResponse from fastapi.responses import HTMLResponse, PlainTextResponse
from sentry_sdk import capture_exception from sentry_sdk import capture_exception
from sentry_sdk import capture_message from sentry_sdk import capture_message
from sentry_sdk import start_transaction from sentry_sdk import start_transaction
from hydra_node.config import init_config_model from hydra_node.config import init_config_model
from typing import Optional from typing import Optional, List
import socket
from hydra_node.sanitize import sanitize_input
import uvicorn
from typing import Union
import time
import gc
import os
import signal
#Initialize model and config
model, config = init_config_model()
logger = config.logger
hostname = socket.gethostname()
sent_first_message = False
#Initialize fastapi
app = FastAPI() app = FastAPI()
@app.on_event("startup") @app.on_event("startup")
def startup_event(): def startup_event():
print("Startup") logger.info("FastAPI Started, serving")
#model, config = init_config()
@app.on_event("shutdown") @app.on_event("shutdown")
def shutdown_event(): def shutdown_event():
print('Shutdown') logger.info("FastAPI Shutdown, exiting")
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=PlainTextResponse)
def root(): def root():
return "OK" return "OK"
class GenerationRequest(BaseModel): class GenerationRequest(BaseModel):
prompt: str prompt: str
n_samples: int = 1 n_samples: int = 1
steps: int = None steps: int = 50
plms: bool = None plms: bool = True
fixed_code: bool = None fixed_code: bool = False
ddim_eta: float = None ddim_eta: float = 0.0
height: int = None height: int
width: int = None width: int
latent_channels: int = None latent_channels: int = 4
downsampling_factor: int = None downsampling_factor: int = 8
scale: float = None scale: float = 7.0
dynamic_threshold: float = None dynamic_threshold: float = None
make_grid: bool = False
n_rows: int = None
seed: int = None seed: int = None
class GenerationOutput(BaseModel): class GenerationOutput(BaseModel):
generation: str generation: List[str]
version: int = 1
class ErrorOutput(BaseModel):
error: str
@app.post('/generate', response_model=GenerationOutput) @app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
def generate(request: GenerationRequest): def generate(request: GenerationRequest):
request = request.dict()
print(request) print(request)
try:
output = sanitize_input(request)
if output[0]:
request = output[1]
else:
return {'error': output[1]}
images = model.sample(request)
return {"generations": images}
except Exception as e:
capture_exception(e)
logger.error(str(e))
e_s = str(e)
gc.collect()
if "CUDA out of memory" in e_s or \
"an illegal memory access" in e_s or "CUDA" in e_s:
logger.error("GPU error, committing seppuku.")
os.kill(os.getpid(), signal.SIGTERM)
return {"error": str(e)}
@app.middleware("http")
def handle_logging_and_errors(request: Request, call_next):
t = time.perf_counter()
response = call_next(request)
process_time = time.time() - t
response.headers["X-Process-Time"] = str(process_time)
logger.info(f"Request took {t:.3f} seconds")
if not sent_first_message:
f = open("/tmp/health_readiness", "w")
f.close()
sent_first_message = True
if os.environ['DEV'] == "False":
f = open("/tmp/healthy", "w")
f.close()
return response
return {"generation": "Hello World"} if __name__ == "__main__":
\ No newline at end of file uvicorn.run("main:app", host="127.0.0.1", port=8000, log_level="info")
\ No newline at end of file
export DTYPE="float16"
export MODEL="test"
export DEV="True"
export MODEL_PATH="/home/xuser/diffusionstorage/workspace/kuru/stablediff/v1pp-flatline-pruned"
export SENTRY_URL="https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448"
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:80
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment