Commit 25eeeaa6 authored by drhead's avatar drhead Committed by GitHub

Allow refiner to be triggered by model timestep instead of sampling

parent 09d2e588
...@@ -156,7 +156,16 @@ replace_torchsde_browinan() ...@@ -156,7 +156,16 @@ replace_torchsde_browinan()
def apply_refiner(cfg_denoiser): def apply_refiner(cfg_denoiser):
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps if opts.refiner_switch_by_sample_steps:
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
else:
# torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch
try:
timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas - torch.max(sigma)))
except AttributeError: # for samplers that dont use sigmas (DDIM) sigma is actually the timestep
timestep = torch.max(sigma).to(dtype=int)
completed_ratio = (999 - timestep) / 1000
refiner_switch_at = cfg_denoiser.p.refiner_switch_at refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment