Commit 75dac706 authored by novelailab's avatar novelailab

fix device and model

parent 74d503df
...@@ -69,6 +69,7 @@ class StableDiffusionModel(nn.Module): ...@@ -69,6 +69,7 @@ class StableDiffusionModel(nn.Module):
else: else:
typex = torch.float32 typex = torch.float32
self.model = model.to(config.device).to(typex) self.model = model.to(config.device).to(typex)
self.device = config.device
self.model_config = model_config self.model_config = model_config
self.plms = PLMSSampler(model) self.plms = PLMSSampler(model)
self.ddim = DDIMSampler(model) self.ddim = DDIMSampler(model)
......
...@@ -89,7 +89,7 @@ def generate(request: GenerationRequest): ...@@ -89,7 +89,7 @@ def generate(request: GenerationRequest):
if request.advanced: if request.advanced:
if request.n_samples > 1: if request.n_samples > 1:
return ErrorOutput(error="advanced mode does not support n_samples > 1") return ErrorOutput(error="advanced mode does not support n_samples > 1")
images = model.sample_two_stages(request) images = model.sample_two_stages(request)
else: else:
images = model.sample(request) images = model.sample(request)
...@@ -137,7 +137,14 @@ def generate_advanced(request: GenerationRequest): ...@@ -137,7 +137,14 @@ def generate_advanced(request: GenerationRequest):
else: else:
return ErrorOutput(error=output[1]) return ErrorOutput(error=output[1])
images = model.sample_two_stages(request) if request.advanced:
if request.n_samples > 1:
return ErrorOutput(error="advanced mode does not support n_samples > 1")
images = model.sample_two_stages(request)
else:
images = model.sample(request)
images_encoded = [] images_encoded = []
for x in range(len(images)): for x in range(len(images)):
image = simplejpeg.encode_jpeg(images[x], quality=95) image = simplejpeg.encode_jpeg(images[x], quality=95)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment