Commit d41ec0fd authored by kurumuz's avatar kurumuz

make base64 images work

parent 4249d339
...@@ -99,12 +99,6 @@ def decode_image(image, model): ...@@ -99,12 +99,6 @@ def decode_image(image, model):
image = custom_to_pil(image) image = custom_to_pil(image)
return image return image
def sanitize_image(image):
#Open image with PIL and get rid of alpha channel, scale to given res with center crop
image = Image.open(image)
image = image.convert('RGB')
return image
class StableInterface(nn.Module): class StableInterface(nn.Module):
def __init__(self, model, thresholder = None): def __init__(self, model, thresholder = None):
super().__init__() super().__init__()
...@@ -186,9 +180,7 @@ class StableDiffusionModel(nn.Module): ...@@ -186,9 +180,7 @@ class StableDiffusionModel(nn.Module):
if request.image is not None: if request.image is not None:
request.sampler = "ddim_img2img" #enforce ddim for now request.sampler = "ddim_img2img" #enforce ddim for now
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False) self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
image = sanitize_image(request.image) start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
start_code = encode_image(image, self.model.first_stage_model).to(self.device)
start_code = self.model.get_first_stage_encoding(start_code) start_code = self.model.get_first_stage_encoding(start_code)
print(start_code.shape) print(start_code.shape)
start_code = start_code + (torch.randn_like(start_code) * request.noise) start_code = start_code + (torch.randn_like(start_code) * request.noise)
......
from dotmap import DotMap from dotmap import DotMap
import math import math
from io import BytesIO
v1pp_defaults = { v1pp_defaults = {
'steps': 50, 'steps': 50,
...@@ -91,6 +92,24 @@ def sanitize_stable_diffusion(request): ...@@ -91,6 +92,24 @@ def sanitize_stable_diffusion(request):
if request.sampler not in samplers: if request.sampler not in samplers:
return False, "sampler should be one of {}".format(samplers) return False, "sampler should be one of {}".format(samplers)
if request.image is not None:
#decode from base64
request.image = request.image.decode('base64')
#check if image is valid
try:
from PIL import Image
image = Image.open(BytesIO(request.image))
image.verify()
except Exception as e:
return False, "image is not valid"
#image is valid, load it again
image = Image.open(BytesIO(request.image))
image = image.convert('RGB')
image = image.resize((request.width, request.height), resample=Image.Resampling.LANCZOS)
request.image = image
return True, request return True, request
def sanitize_dalle_mini(request): def sanitize_dalle_mini(request):
...@@ -111,6 +130,6 @@ def sanitize_input(config, request): ...@@ -111,6 +130,6 @@ def sanitize_input(config, request):
if config.model_name == 'stable-diffusion': if config.model_name == 'stable-diffusion':
return sanitize_stable_diffusion(request) return sanitize_stable_diffusion(request)
elif config.model_name == 'dalle-mini': elif config.model_name == 'dalle-mini':
return sanitize_dalle_mini(request) return sanitize_dalle_mini(request)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment