Skip to content

Commit 3335e22

Browse files
SahilCarterrhlky
andauthored
[FIX] Bug in FluxPosEmbed (#10115)
* Fix get_1d_rotary_pos_embed in embedding.py * Update embeddings.py --------- Co-authored-by: hlky <[email protected]>
1 parent 65ab105 commit 3335e22

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/diffusers/models/embeddings.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,12 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
959959
freqs_dtype = torch.float32 if is_mps else torch.float64
960960
for i in range(n_axes):
961961
cos, sin = get_1d_rotary_pos_embed(
962-
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
962+
self.axes_dim[i],
963+
pos[:, i],
964+
theta=self.theta,
965+
repeat_interleave_real=True,
966+
use_real=True,
967+
freqs_dtype=freqs_dtype,
963968
)
964969
cos_out.append(cos)
965970
sin_out.append(sin)

0 commit comments

Comments
 (0)