Commit 76f08949 authored by novelailab's avatar novelailab

encode as jpeg with simplejpeg

parent 899da5fe
...@@ -20,6 +20,7 @@ RUN git clone https://github.com/NovelAI/stable-diffusion ...@@ -20,6 +20,7 @@ RUN git clone https://github.com/NovelAI/stable-diffusion
RUN pip3 install -e stable-diffusion/. RUN pip3 install -e stable-diffusion/.
RUN pip3 install pytorch_lightning RUN pip3 install pytorch_lightning
RUN pip3 install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers RUN pip3 install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
RUN pip3 install simplejpeg
#Open ports #Open ports
EXPOSE 8080 EXPOSE 8080
......
...@@ -122,9 +122,6 @@ class StableDiffusionModel(nn.Module): ...@@ -122,9 +122,6 @@ class StableDiffusionModel(nn.Module):
images = [] images = []
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.tobytes()
#get base64 of x_sample
x_sample = str(base64.b64encode(x_sample))
images.append(x_sample) images.append(x_sample)
return images return images
\ No newline at end of file
...@@ -14,6 +14,8 @@ import time ...@@ -14,6 +14,8 @@ import time
import gc import gc
import os import os
import signal import signal
import simplejpeg
import base64
#Initialize model and config #Initialize model and config
model, config = init_config_model() model, config = init_config_model()
...@@ -72,6 +74,15 @@ def generate(request: GenerationRequest): ...@@ -72,6 +74,15 @@ def generate(request: GenerationRequest):
return ErrorOutput(error=output[1]) return ErrorOutput(error=output[1])
images = model.sample(request) images = model.sample(request)
for x in range(images):
image = simplejpeg.encode_jpeg(images[x], quality=95)
image = image.tobytes()
#get base64 of image
image = str(base64.b64encode(image))
#remove b' from base64
image = image[2:-1]
image[x] = image
process_time = time.perf_counter() - t process_time = time.perf_counter() - t
logger.info(f"Request took {process_time:0.3f} seconds") logger.info(f"Request took {process_time:0.3f} seconds")
return GenerationOutput(generation=images) return GenerationOutput(generation=images)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment