Skip to content

Commit 336c241

Browse files
committed
update attention_mask padding
1 parent 23d94f8 commit 336c241

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

paddlenlp/transformers/tokenizer_utils_base.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3189,8 +3189,16 @@ def _pad(
31893189

31903190
if self.padding_side == "right":
31913191
if return_attention_mask:
3192-
3193-
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
3192+
if len(encoded_inputs["attention_mask"].shape) > 2:
3193+
# attention_mask shape [1,seq_len,seq_len]
3194+
encoded_inputs["attention_mask"] = np.pad(
3195+
encoded_inputs["attention_mask"],
3196+
pad_width=[(0, 0), (0, difference), (0, difference)],
3197+
mode="constant",
3198+
constant_values=0,
3199+
)
3200+
else:
3201+
encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
31943202
if "token_type_ids" in encoded_inputs:
31953203
encoded_inputs["token_type_ids"] = (
31963204
encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
@@ -3209,7 +3217,16 @@ def _pad(
32093217
encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
32103218
elif self.padding_side == "left":
32113219
if return_attention_mask:
3212-
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
3220+
if len(encoded_inputs["attention_mask"].shape) > 2:
3221+
# attention_mask shape [1,seq_len,seq_len]
3222+
encoded_inputs["attention_mask"] = np.pad(
3223+
encoded_inputs["attention_mask"],
3224+
pad_width=[(0, 0), (difference, 0), (difference, 0)],
3225+
mode="constant",
3226+
constant_values=0,
3227+
)
3228+
else:
3229+
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
32133230
if "token_type_ids" in encoded_inputs:
32143231
encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
32153232
"token_type_ids"

0 commit comments

Comments
 (0)