Commit 424ba3ef authored by novelailab's avatar novelailab

generate should work

parent 767c1eed
from re import S
import torch import torch
import torch.nn as nn import torch.nn as nn
from pathlib import Path from pathlib import Path
from omegaconf import OmegaConf from omegaconf import OmegaConf
from dotmap import DotMap from dotmap import DotMap
import numpy as np
import base64
from einops import rearrange
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
...@@ -45,11 +50,55 @@ class StableDiffusionModel(nn.Module): ...@@ -45,11 +50,55 @@ class StableDiffusionModel(nn.Module):
@torch.no_grad() @torch.no_grad()
def sample(self, request): def sample(self, request):
request = DotMap(request) request = DotMap(request)
if request.plms:
sampler = self.plms
else:
sampler = self.ddim
start_code = None
if request.fixed_code:
start_code = torch.randn([
request.n_samples,
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor,
], device=self.device)
prompt = [request.prompt] * request.n_samples
prompt_condition = self.model.get_learned_conditioning(prompt)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
shape = [
request.latent_channels,
request.height // request.downsampling_factor,
request.width // request.downsampling_factor
]
samples, _ = sampler.sample(
S=request.steps,
conditioning=prompt_condition,
batch_size=request.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=request.scale,
unconditional_conditioning=uc,
eta=request.ddim_eta,
dynamic_threshold=request.dynamic_threshold,
x_T=start_code,
)
x_samples_ddim = self.model.decode_first_stage(samples)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
images = []
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.tobytes()
#get base64 of x_sample
x_sample = str(base64.b64encode(x_sample))
base_count += 1
images.append(x_sample)
return images
\ No newline at end of file
...@@ -33,6 +33,9 @@ class GenerationRequest(BaseModel): ...@@ -33,6 +33,9 @@ class GenerationRequest(BaseModel):
latent_channels: int = None latent_channels: int = None
downsampling_factor: int = None downsampling_factor: int = None
scale: float = None scale: float = None
dynamic_threshold: float = None
make_grid: bool = False
n_rows: int = None
seed: int = None seed: int = None
class GenerationOutput(BaseModel): class GenerationOutput(BaseModel):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment