Skip to content

Commit a6b37f9

Browse files
author
gongel
committed
fix windows bug
1 parent bb88615 commit a6b37f9

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

paddlenlp/transformers/pegasus/modeling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
4242
if pad_token_id is None:
4343
raise ValueError("self.model.config.pad_token_id has to be defined.")
4444

45-
shifted_input_ids = paddle.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
45+
shifted_input_ids = paddle.where(
46+
shifted_input_ids == -100, paddle.full_like(shifted_input_ids, pad_token_id), shifted_input_ids
47+
)
4648
return shifted_input_ids
4749

4850

0 commit comments

Comments
 (0)