Skip to content

Commit 352fc7a

Browse files
committed
fix sorted_indices cast
1 parent e227782 commit 352fc7a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

paddlenlp/generation/logits_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,8 @@ def TopPProcess(probs: paddle.Tensor, top_p: float, min_tokens_to_keep: int):
312312
sorted_indices = paddle.argsort(probs, descending=True)
313313
sorted_probs = paddle.sort(probs, descending=True)
314314

315-
sorted_probs = paddle.cast(sorted_probs, paddle.bfloat16)
316-
sorted_indices = paddle.cast(sorted_indices, paddle.int64)
315+
sorted_probs = paddle.cast(sorted_probs, paddle.bfloat16)
316+
317317
else:
318318
sorted_indices = paddle.argsort(probs, descending=True)
319319
sorted_probs = paddle.sort(probs, descending=True)

0 commit comments

Comments
 (0)