Commit d22bf3c7 authored by kurumuz's avatar kurumuz

img2img k samplers

parent 847474c9
...@@ -182,7 +182,9 @@ class StableDiffusionModel(nn.Module): ...@@ -182,7 +182,9 @@ class StableDiffusionModel(nn.Module):
np.random.seed(request.seed) np.random.seed(request.seed)
if request.image is not None: if request.image is not None:
request.sampler = "ddim_img2img" #enforce ddim for now #request.sampler = "ddim_img2img" #enforce ddim for now
if request.sampler == "plms":
request.sampler = "k_lms"
self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False) self.ddim.make_schedule(ddim_num_steps=request.steps, ddim_eta=request.ddim_eta, verbose=False)
start_code = encode_image(request.image, self.model.first_stage_model).to(self.device) start_code = encode_image(request.image, self.model.first_stage_model).to(self.device)
start_code = self.model.get_first_stage_encoding(start_code) start_code = self.model.get_first_stage_encoding(start_code)
...@@ -244,7 +246,14 @@ class StableDiffusionModel(nn.Module): ...@@ -244,7 +246,14 @@ class StableDiffusionModel(nn.Module):
elif sampler == "k-diffusion": elif sampler == "k-diffusion":
with self.model.ema_scope(): with self.model.ema_scope():
sigmas = self.k_model.get_sigmas(request.steps) sigmas = self.k_model.get_sigmas(request.steps)
start_code = start_code * sigmas[0] if request.image is not None:
noise = torch.randn_like(start_code) * sigmas[request.steps - t_enc - 1]
start_code = start_code + noise
sigmas = sigmas[request.steps - t_enc - 1:]
else:
start_code = start_code * sigmas[0]
extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale} extra_args = {'cond': prompt_condition, 'uncond': uc, 'cond_scale': request.scale}
samples = self.sampler_map[request.sampler](self.k_model, start_code, sigmas, extra_args=extra_args) samples = self.sampler_map[request.sampler](self.k_model, start_code, sigmas, extra_args=extra_args)
......
...@@ -2,7 +2,7 @@ export DTYPE="float32" ...@@ -2,7 +2,7 @@ export DTYPE="float32"
export AMP="1" export AMP="1"
export MODEL="stable-diffusion" export MODEL="stable-diffusion"
export DEV="True" export DEV="True"
export MODEL_PATH="/home/xuser/nvme1/workspace/aero/stable/sdfinetune/checkpoints/kuru30k" export MODEL_PATH="/home/xuser/nvme1/stableckpt/v13"
export TRANSFORMERS_CACHE="/home/xuser/nvme1/transformer_cache" export TRANSFORMERS_CACHE="/home/xuser/nvme1/transformer_cache"
export SENTRY_URL="https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448" export SENTRY_URL="https://49ca8adcf4444f82a10eae1b3fd4182f@o846434.ingest.sentry.io/6612448"
gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315 gunicorn main:app --workers 1 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:4315
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment