Commit e89f1b53 authored by novelailab's avatar novelailab

per model sanitization

parent 13cc3ba3
...@@ -19,16 +19,27 @@ v1pp_forced_defaults = { ...@@ -19,16 +19,27 @@ v1pp_forced_defaults = {
'downsampling_factor': 8, 'downsampling_factor': 8,
} }
dalle_mini_defaults = {
'temp': 1.0,
'top_k': 256,
'scale': 16,
'grid_size': 4,
}
dalle_mini_forced_defaults = {
}
defaults = { defaults = {
'v1pp': (v1pp_defaults, v1pp_forced_defaults), 'stable-diffusion': (v1pp_defaults, v1pp_forced_defaults),
'dalle-mini': (dalle_mini_defaults, dalle_mini_forced_defaults),
} }
def sanitize_input(request): def sanitize_input(config, request):
""" """
Sanitize the input data and set defaults Sanitize the input data and set defaults
""" """
request = DotMap(request) request = DotMap(request)
default, forced_default = v1pp_defaults, v1pp_forced_defaults default, forced_default = defaults[config.model]
for k, v in default.items(): for k, v in default.items():
if k not in request: if k not in request:
request[k] = v request[k] = v
......
...@@ -65,6 +65,7 @@ class GenerationRequest(BaseModel): ...@@ -65,6 +65,7 @@ class GenerationRequest(BaseModel):
seed: int = None seed: int = None
temp: float = 1.0 temp: float = 1.0
top_k: int = 256 top_k: int = 256
grid_size: int = 4
class GenerationOutput(BaseModel): class GenerationOutput(BaseModel):
output: List[str] output: List[str]
...@@ -77,7 +78,7 @@ def generate(request: GenerationRequest): ...@@ -77,7 +78,7 @@ def generate(request: GenerationRequest):
t = time.perf_counter() t = time.perf_counter()
print(request) print(request)
try: try:
output = sanitize_input(request) output = sanitize_input(config, request)
if output[0]: if output[0]:
request = output[1] request = output[1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment