Commit 3a215def authored by drhead's avatar drhead Committed by GitHub

vectorize kl-optimal sigma calculation

Co-authored-by: default avatarmamei16 <marcel.1710@live.de>
parent 83266205
...@@ -34,9 +34,8 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device): ...@@ -34,9 +34,8 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
def kl_optimal(n, sigma_min, sigma_max, device): def kl_optimal(n, sigma_min, sigma_max, device):
alpha_min = torch.arctan(torch.tensor(sigma_min, device=device)) alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))
alpha_max = torch.arctan(torch.tensor(sigma_max, device=device)) alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))
sigmas = torch.empty((n+1,), device=device) step_indices = torch.arange(n + 1, device=device)
for i in range(n+1): sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
sigmas[i] = torch.tan((i/n) * alpha_min + (1.0-i/n) * alpha_max)
return sigmas return sigmas
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment