Skip to content

Commit 43e7431

Browse files
drheadruchej
authored andcommitted
re-derive sqrt alpha bar and sqrt one minus alphabar
This is the only place these values are ever referenced outside of training code so this change is very justifiable and more consistent.
1 parent 483c27d commit 43e7431

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

modules/sd_samplers_timesteps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, model, *args, **kwargs):
3636
self.inner_model = model
3737

3838
def predict_eps_from_z_and_v(self, x_t, t, v):
39-
return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t
39+
return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t
4040

4141
def forward(self, input, timesteps, **kwargs):
4242
model_output = self.inner_model.apply_model(input, timesteps, **kwargs)

0 commit comments

Comments
 (0)