Skip to content

Commit 3579d3d

Browse files
committed
fix npu nn.Pad2D()
1 parent c7395fb commit 3579d3d

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

paddlenlp/transformers/generation_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,14 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
409409
if convert_dtype(attention_mask.dtype) == "bool":
410410
attention_mask = paddle.cast(attention_mask, "int64")
411411
if len(attention_mask.shape) == 4:
412-
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(attention_mask)
413-
attention_mask = nn.Pad2D([0, 1, 0, 0], value=-1e4)(attention_mask)
412+
cur_device = paddle.get_device()
413+
if cur_device.split(":")[0] == "npu":
414+
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(attention_mask)
415+
attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
416+
else:
417+
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(attention_mask)
418+
attention_mask = nn.Pad2D([0, 1, 0, 0], value=-1e4)(attention_mask)
419+
414420
dtype = convert_dtype(attention_mask.dtype)
415421
if "int" in dtype:
416422
attention_mask[:, :, -1, -1] = 1

0 commit comments

Comments
 (0)