Commit c7f34c27 authored by kurumuz's avatar kurumuz

add sanitization for new settings

parent 20a8dbeb
...@@ -3,7 +3,8 @@ import math ...@@ -3,7 +3,8 @@ import math
v1pp_defaults = { v1pp_defaults = {
'steps': 50, 'steps': 50,
'plms': True, 'sampler': "plms",
'image': None,
'fixed_code': False, 'fixed_code': False,
'ddim_eta': 0.0, 'ddim_eta': 0.0,
'height': 512, 'height': 512,
...@@ -35,6 +36,17 @@ defaults = { ...@@ -35,6 +36,17 @@ defaults = {
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults), 'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
} }
samplers = [
"plms",
"ddim",
"k_euler",
"k_euler_ancestral",
"k_heun",
"k_dpm_2",
"k_dpm_2_ancestral",
"k_lms"
]
def closest_multiple(num, mult): def closest_multiple(num, mult):
num_int = int(num) num_int = int(num)
floor = math.floor(num_int / mult) * mult floor = math.floor(num_int / mult) * mult
...@@ -76,6 +88,9 @@ def sanitize_stable_diffusion(request): ...@@ -76,6 +88,9 @@ def sanitize_stable_diffusion(request):
request.width = closest_multiple(request.width // 2, 64) request.width = closest_multiple(request.width // 2, 64)
request.height = closest_multiple(request.height // 2, 64) request.height = closest_multiple(request.height // 2, 64)
if request.sampler not in samplers:
return False, "sampler should be one of {}".format(samplers)
return True, request return True, request
def sanitize_dalle_mini(request): def sanitize_dalle_mini(request):
......
...@@ -52,9 +52,10 @@ def root(): ...@@ -52,9 +52,10 @@ def root():
class GenerationRequest(BaseModel): class GenerationRequest(BaseModel):
prompt: str prompt: str
image: str = None
n_samples: int = 1 n_samples: int = 1
steps: int = 50 steps: int = 50
plms: bool = True sampler: str = "plms"
fixed_code: bool = False fixed_code: bool = False
ddim_eta: float = 0.0 ddim_eta: float = 0.0
height: int = 512 height: int = 512
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment