@@ -3189,8 +3189,16 @@ def _pad(
31893189
31903190 if self .padding_side == "right" :
31913191 if return_attention_mask :
3192-
3193- encoded_inputs ["attention_mask" ] = encoded_inputs ["attention_mask" ] + [0 ] * difference
3192+ if len (encoded_inputs ["attention_mask" ].shape ) > 2 :
3193+ # attention_mask shape [1,seq_len,seq_len]
3194+ encoded_inputs ["attention_mask" ] = np .pad (
3195+ encoded_inputs ["attention_mask" ],
3196+ pad_width = [(0 , 0 ), (0 , difference ), (0 , difference )],
3197+ mode = "constant" ,
3198+ constant_values = 0 ,
3199+ )
3200+ else :
3201+ encoded_inputs ["attention_mask" ] = encoded_inputs ["attention_mask" ] + [0 ] * difference
31943202 if "token_type_ids" in encoded_inputs :
31953203 encoded_inputs ["token_type_ids" ] = (
31963204 encoded_inputs ["token_type_ids" ] + [self .pad_token_type_id ] * difference
@@ -3209,7 +3217,16 @@ def _pad(
32093217 encoded_inputs [self .model_input_names [0 ]] = required_input + [self .pad_token_id ] * difference
32103218 elif self .padding_side == "left" :
32113219 if return_attention_mask :
3212- encoded_inputs ["attention_mask" ] = [0 ] * difference + encoded_inputs ["attention_mask" ]
3220+ if len (encoded_inputs ["attention_mask" ].shape ) > 2 :
3221+ # attention_mask shape [1,seq_len,seq_len]
3222+ encoded_inputs ["attention_mask" ] = np .pad (
3223+ encoded_inputs ["attention_mask" ],
3224+ pad_width = [(0 , 0 ), (difference , 0 ), (difference , 0 )],
3225+ mode = "constant" ,
3226+ constant_values = 0 ,
3227+ )
3228+ else :
3229+ encoded_inputs ["attention_mask" ] = [0 ] * difference + encoded_inputs ["attention_mask" ]
32133230 if "token_type_ids" in encoded_inputs :
32143231 encoded_inputs ["token_type_ids" ] = [self .pad_token_type_id ] * difference + encoded_inputs [
32153232 "token_type_ids"
0 commit comments