Commit 5381405e authored by drhead's avatar drhead Committed by GitHub

re-derive sqrt alpha bar and sqrt one minus alphabar

This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent.
parent 78acdcf6
...@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): ...@@ -36,7 +36,7 @@ class CompVisTimestepsVDenoiser(torch.nn.Module):
self.inner_model = model self.inner_model = model
def predict_eps_from_z_and_v(self, x_t, t, v): def predict_eps_from_z_and_v(self, x_t, t, v):
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
def forward(self, input, timesteps, **kwargs): def forward(self, input, timesteps, **kwargs):
model_output = self.inner_model.apply_model(input, timesteps, **kwargs) model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment