Commit d73b5fe0 authored by novelailab's avatar novelailab

seed fix

parent 898ffdc2
...@@ -110,6 +110,10 @@ class StableDiffusionModel(nn.Module): ...@@ -110,6 +110,10 @@ class StableDiffusionModel(nn.Module):
x_sample = np.ascontiguousarray(x_sample) x_sample = np.ascontiguousarray(x_sample)
images.append(x_sample) images.append(x_sample)
if request.seed is not None:
torch.seed()
np.random.seed()
return images return images
@torch.no_grad() @torch.no_grad()
...@@ -150,5 +154,11 @@ class DalleMiniModel(nn.Module): ...@@ -150,5 +154,11 @@ class DalleMiniModel(nn.Module):
images = images.to('cpu').numpy() images = images.to('cpu').numpy()
images = images.astype(np.uint8) images = images.astype(np.uint8)
images = np.ascontiguousarray(images) images = np.ascontiguousarray(images)
if request.seed is not None:
torch.seed()
np.random.seed()
return images return images
...@@ -77,7 +77,6 @@ class ErrorOutput(BaseModel): ...@@ -77,7 +77,6 @@ class ErrorOutput(BaseModel):
@app.post('/generate-stream') @app.post('/generate-stream')
def generate(request: GenerationRequest): def generate(request: GenerationRequest):
t = time.perf_counter() t = time.perf_counter()
print(request)
try: try:
output = sanitize_input(config, request) output = sanitize_input(config, request)
...@@ -118,10 +117,29 @@ def generate(request: GenerationRequest): ...@@ -118,10 +117,29 @@ def generate(request: GenerationRequest):
os.kill(mainpid, signal.SIGTERM) os.kill(mainpid, signal.SIGTERM)
return {"error": str(e)} return {"error": str(e)}
@app.post('/image-to-image')
def image_to_image(request: GenerationRequest):
#prompt is a base64 encoded image
try:
output = sanitize_input(config, request)
if output[0]:
request = output[1]
else:
return ErrorOutput(error=output[1])
image = base64.b64decode(request.prompt)
image = simplejpeg.decode_jpeg(image)
image = model.image_to_image(image, request)
image = simplejpeg.encode_jpeg(image, quality=95)
#get base64 of image
image = base64.b64encode(image).decode("ascii")
return GenerationOutput(output=[image])
@app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput]) @app.post('/generate', response_model=Union[GenerationOutput, ErrorOutput])
def generate(request: GenerationRequest): def generate(request: GenerationRequest):
t = time.perf_counter() t = time.perf_counter()
print(request)
try: try:
output = sanitize_input(config, request) output = sanitize_input(config, request)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment