Skip to content

Commit bd2d9d0

Browse files
authored
Fix matryoshka norm loss (#9773)
* [Trainer] update sequence parallel (#9757) * update emb doc * update register_sequence_parallel_allreduce_hooks * update fuse_sequence_parallel_allreduce * fix matryoshka
1 parent 2e62501 commit bd2d9d0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddlenlp/transformers/contrastive_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ def forward(self, q_reps, p_reps):
5252
if len(self.embedding_matryoshka_dims) > 0:
5353
loss = 0.0
5454
for dim in self.embedding_matryoshka_dims:
55-
reduced_q_reps = q_reps[:, :dim]
55+
reduced_q_reps = q_reps[:, :dim].astype("float32")
5656
reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1)
5757

58-
reduced_p_reps = p_reps[:, :dim]
58+
reduced_p_reps = p_reps[:, :dim].astype("float32")
5959
reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1)
6060

6161
dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps)

0 commit comments

Comments
 (0)