Skip to content

Commit a960406

Browse files
[FlaxWav2Vec2Model] Fix bug in attention mask (#16725)
* [FlaxWav2Vec2Model] Fix bug in attention mask * more fixes * add (Flax)SpeechEncoderDecoderModel PT-FX cross-test
1 parent 6adefba commit a960406

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ def __call__(
920920
def _get_feat_extract_output_lengths(
921921
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
922922
):
923-
return self.module._get_feat_extract_output_lengths(input_lengths)
923+
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
924924

925925

926926
class FlaxWav2Vec2Module(nn.Module):
@@ -956,15 +956,10 @@ def __call__(
956956

957957
# make sure that no loss is computed on padded inputs
958958
if attention_mask is not None:
959-
# compute real output lengths according to convolution formula
960-
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))
961-
962-
attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype)
963-
964-
# these two operations makes sure that all values
965-
# before the output lengths indices are attended to
966-
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
967-
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
959+
# compute reduced attention_mask corresponding to feature vectors
960+
attention_mask = self._get_feature_vector_attention_mask(
961+
extract_features.shape[1], attention_mask, add_adapter=False
962+
)
968963

969964
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
970965
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
@@ -1034,12 +1029,10 @@ def _get_feature_vector_attention_mask(
10341029
batch_size = attention_mask.shape[0]
10351030

10361031
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
1037-
# these two operations makes sure that all values before the output lengths idxs are attended to
1038-
idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1)
1039-
attention_mask = attention_mask.at[idx].set(1)
1040-
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)
1041-
1042-
attention_mask = jnp.array(attention_mask, dtype=bool)
1032+
# these two operations makes sure that all values
1033+
# before the output lengths indices are attended to
1034+
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
1035+
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
10431036
return attention_mask
10441037

10451038

@@ -1286,11 +1279,15 @@ def __call__(
12861279
attentions=outputs.attentions,
12871280
)
12881281

1289-
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
1282+
def _get_feat_extract_output_lengths(
1283+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
1284+
):
12901285
"""
12911286
Computes the output length of the convolutional layers
12921287
"""
12931288

1289+
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
1290+
12941291
def _conv_out_length(input_length, kernel_size, stride):
12951292
# 1D convolutional layer output length formula taken
12961293
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
@@ -1299,6 +1296,10 @@ def _conv_out_length(input_length, kernel_size, stride):
12991296
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
13001297
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
13011298

1299+
if add_adapter:
1300+
for _ in range(self.config.num_adapter_layers):
1301+
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
1302+
13021303
return input_lengths
13031304

13041305

tests/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,12 @@ def test_pt_flax_equivalence(self):
539539
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
540540
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
541541

542+
# check `add_adapter` works as expected
543+
config.add_adapter = True
544+
self.assertTrue(config.add_adapter)
545+
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
546+
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
547+
542548
@slow
543549
def test_real_model_save_load_from_pretrained(self):
544550
model_2 = self.get_pretrained_model()

0 commit comments

Comments
 (0)