Commit 767c1eed authored by novelailab's avatar novelailab

model almost working

parent 548f5aaa
import torch
import torch.nn as nn
from pathlib import Path
from omegaconf import OmegaConf
from dotmap import DotMap
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
class StableDiffusionModel(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
model, model_config = self.from_folder(config.model_path)
self.model = model
self.config = model_config
self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model)
@staticmethod
def from_folder(self, folder):
folder = Path(folder)
model_config = OmegaConf.load(folder / "config.yaml")
model = self.load_model_from_config(model_config, folder / "model.ckpt")
return model, model_config
@staticmethod
def load_model_from_config(self, config, ckpt, verbose=True):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.eval()
return model
@torch.no_grad()
def sample(self, request):
request = DotMap(request)
...@@ -5,15 +5,16 @@ from sentry_sdk import capture_exception ...@@ -5,15 +5,16 @@ from sentry_sdk import capture_exception
from sentry_sdk import capture_message from sentry_sdk import capture_message
from sentry_sdk import start_transaction from sentry_sdk import start_transaction
from hydra_node.config import init_config_model from hydra_node.config import init_config_model
from typing import Optional
app = FastAPI() app = FastAPI()
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): def startup_event():
print("Startup") print("Startup")
#model, config = init_config() #model, config = init_config()
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event(): def shutdown_event():
print('Shutdown') print('Shutdown')
@app.get("/", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse)
...@@ -23,7 +24,16 @@ def root(): ...@@ -23,7 +24,16 @@ def root():
class GenerationRequest(BaseModel): class GenerationRequest(BaseModel):
prompt: str prompt: str
n_samples: int = 1 n_samples: int = 1
steps: int = 50 steps: int = None
plms: bool = None
fixed_code: bool = None
ddim_eta: float = None
height: int = None
width: int = None
latent_channels: int = None
downsampling_factor: int = None
scale: float = None
seed: int = None
class GenerationOutput(BaseModel): class GenerationOutput(BaseModel):
generation: str generation: str
...@@ -31,4 +41,7 @@ class GenerationOutput(BaseModel): ...@@ -31,4 +41,7 @@ class GenerationOutput(BaseModel):
@app.post('/generate', response_model=GenerationOutput) @app.post('/generate', response_model=GenerationOutput)
def generate(request: GenerationRequest): def generate(request: GenerationRequest):
request = request.dict()
print(request)
return {"generation": "Hello World"} return {"generation": "Hello World"}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment