We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c7395fb commit 3579d3dCopy full SHA for 3579d3d
paddlenlp/transformers/generation_utils.py
@@ -409,8 +409,14 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
409
if convert_dtype(attention_mask.dtype) == "bool":
410
attention_mask = paddle.cast(attention_mask, "int64")
411
if len(attention_mask.shape) == 4:
412
- attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(attention_mask)
413
- attention_mask = nn.Pad2D([0, 1, 0, 0], value=-1e4)(attention_mask)
+ cur_device = paddle.get_device()
+ if cur_device.split(":")[0] == "npu":
414
+ attention_mask = nn.Pad2D([0, 0, 0, 1], mode="constant")(attention_mask)
415
+ attention_mask = nn.Pad2D([0, 1, 0, 0], value=0)(attention_mask)
416
+ else:
417
+ attention_mask = nn.Pad2D([0, 0, 0, 1], mode="replicate")(attention_mask)
418
+ attention_mask = nn.Pad2D([0, 1, 0, 0], value=-1e4)(attention_mask)
419
+
420
dtype = convert_dtype(attention_mask.dtype)
421
if "int" in dtype:
422
attention_mask[:, :, -1, -1] = 1
0 commit comments