Commit 75dac706 authored by novelailab's avatar novelailab

fix device and model

parent 74d503df
......@@ -69,6 +69,7 @@ class StableDiffusionModel(nn.Module):
else:
typex = torch.float32
self.model = model.to(config.device).to(typex)
self.device = config.device
self.model_config = model_config
self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model)
......
......@@ -89,7 +89,7 @@ def generate(request: GenerationRequest):
if request.advanced:
if request.n_samples > 1:
return ErrorOutput(error="advanced mode does not support n_samples > 1")
images = model.sample_two_stages(request)
else:
images = model.sample(request)
......@@ -137,7 +137,14 @@ def generate_advanced(request: GenerationRequest):
else:
return ErrorOutput(error=output[1])
images = model.sample_two_stages(request)
if request.advanced:
if request.n_samples > 1:
return ErrorOutput(error="advanced mode does not support n_samples > 1")
images = model.sample_two_stages(request)
else:
images = model.sample(request)
images_encoded = []
for x in range(len(images)):
image = simplejpeg.encode_jpeg(images[x], quality=95)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment