Skip to content

Commit b8d23a1

Browse files
authored
fix npu nn.Pad2D() (#5167)
1 parent 86c8848 commit b8d23a1

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

paddlenlp/transformers/generation_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,14 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
409409
if convert_dtype(attention_mask.dtype) == "bool":
410410
attention_mask = paddle.cast(attention_mask, "int64")
411411
if len(attention_mask.shape) == 4:
412-
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(attention_mask)
413-
attention_mask = nn.Pad2D([0, 1, 0, 0], value=-1e4)(attention_mask)
412+
cur_device = paddle.get_device()
413+
if cur_device.split(":")[0] == "npu":
414+
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(attention_mask)
415+
attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
416+
else:
417+
attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(attention_mask)
418+
attention_mask = nn.Pad2D([0, 1, 0, 0], value=-1e4)(attention_mask)
419+
414420
dtype = convert_dtype(attention_mask.dtype)
415421
if "int" in dtype:
416422
attention_mask[:, :, -1, -1] = 1

0 commit comments

Comments
 (0)