Commit ea7e46e9 authored by gd1551's avatar gd1551 Committed by GitHub

feat: add strength to parameters (sd)

parent 75dac706
...@@ -218,7 +218,7 @@ class StableDiffusionModel(nn.Module): ...@@ -218,7 +218,7 @@ class StableDiffusionModel(nn.Module):
with self.model.ema_scope(): with self.model.ema_scope():
init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim)) init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(x_samples_ddim))
sampler.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False) sampler.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
t_enc = int(0.69 * request.steps) t_enc = int(request.strength * request.steps)
print("init latent shape:") print("init latent shape:")
print(init_latent.shape) print(init_latent.shape)
......
...@@ -59,6 +59,9 @@ def sanitize_stable_diffusion(request): ...@@ -59,6 +59,9 @@ def sanitize_stable_diffusion(request):
if request.width * request.height >= 1024*1025: if request.width * request.height >= 1024*1025:
return False, "width and height must be less than 1024*1025" return False, "width and height must be less than 1024*1025"
if request.strength < 0.0 or request.strength > 1.0:
return False, "strength should be more than 0.0 and less than 1.0"
return True, request return True, request
def sanitize_dalle_mini(request): def sanitize_dalle_mini(request):
......
...@@ -68,6 +68,7 @@ class GenerationRequest(BaseModel): ...@@ -68,6 +68,7 @@ class GenerationRequest(BaseModel):
top_k: int = 256 top_k: int = 256
grid_size: int = 4 grid_size: int = 4
advanced: bool = False advanced: bool = False
strength: bool = 0.69
class GenerationOutput(BaseModel): class GenerationOutput(BaseModel):
output: List[str] output: List[str]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment