@@ -920,7 +920,7 @@ def __call__(
920
920
def _get_feat_extract_output_lengths (
921
921
self , input_lengths : Union [jnp .ndarray , int ], add_adapter : Optional [bool ] = None
922
922
):
923
- return self .module ._get_feat_extract_output_lengths (input_lengths )
923
+ return self .module ._get_feat_extract_output_lengths (input_lengths , add_adapter = add_adapter )
924
924
925
925
926
926
class FlaxWav2Vec2Module (nn .Module ):
@@ -956,15 +956,10 @@ def __call__(
956
956
957
957
# make sure that no loss is computed on padded inputs
958
958
if attention_mask is not None :
959
- # compute real output lengths according to convolution formula
960
- output_lengths = self ._get_feat_extract_output_lengths (attention_mask .sum (- 1 ).astype ("i4" ))
961
-
962
- attention_mask = jnp .zeros (extract_features .shape [:2 ], dtype = self .dtype )
963
-
964
- # these two operations makes sure that all values
965
- # before the output lengths indices are attended to
966
- attention_mask = attention_mask .at [jnp .arange (attention_mask .shape [0 ]), output_lengths - 1 ].set (1 )
967
- attention_mask = jnp .flip (jnp .flip (attention_mask , - 1 ).cumsum (- 1 ), - 1 ).astype ("bool" )
959
+ # compute reduced attention_mask corresponding to feature vectors
960
+ attention_mask = self ._get_feature_vector_attention_mask (
961
+ extract_features .shape [1 ], attention_mask , add_adapter = False
962
+ )
968
963
969
964
hidden_states , extract_features = self .feature_projection (extract_features , deterministic = deterministic )
970
965
if mask_time_indices is not None : # apply SpecAugment along time axis with given indices
@@ -1034,12 +1029,10 @@ def _get_feature_vector_attention_mask(
1034
1029
batch_size = attention_mask .shape [0 ]
1035
1030
1036
1031
attention_mask = jnp .zeros ((batch_size , feature_vector_length ), dtype = attention_mask .dtype )
1037
- # these two operations makes sure that all values before the output lengths idxs are attended to
1038
- idx = (jnp .arange (attention_mask .shape [0 ]), output_lengths - 1 )
1039
- attention_mask = attention_mask .at [idx ].set (1 )
1040
- attention_mask = jnp .flip (jnp .flip (attention_mask , axis = - 1 ).cumsum (axis = - 1 ), axis = - 1 )
1041
-
1042
- attention_mask = jnp .array (attention_mask , dtype = bool )
1032
+ # these two operations makes sure that all values
1033
+ # before the output lengths indices are attended to
1034
+ attention_mask = attention_mask .at [jnp .arange (attention_mask .shape [0 ]), output_lengths - 1 ].set (1 )
1035
+ attention_mask = jnp .flip (jnp .flip (attention_mask , - 1 ).cumsum (- 1 ), - 1 ).astype ("bool" )
1043
1036
return attention_mask
1044
1037
1045
1038
@@ -1286,11 +1279,15 @@ def __call__(
1286
1279
attentions = outputs .attentions ,
1287
1280
)
1288
1281
1289
- def _get_feat_extract_output_lengths (self , input_lengths : Union [jnp .ndarray , int ]):
1282
+ def _get_feat_extract_output_lengths (
1283
+ self , input_lengths : Union [jnp .ndarray , int ], add_adapter : Optional [bool ] = None
1284
+ ):
1290
1285
"""
1291
1286
Computes the output length of the convolutional layers
1292
1287
"""
1293
1288
1289
+ add_adapter = self .config .add_adapter if add_adapter is None else add_adapter
1290
+
1294
1291
def _conv_out_length (input_length , kernel_size , stride ):
1295
1292
# 1D convolutional layer output length formula taken
1296
1293
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
@@ -1299,6 +1296,10 @@ def _conv_out_length(input_length, kernel_size, stride):
1299
1296
for kernel_size , stride in zip (self .config .conv_kernel , self .config .conv_stride ):
1300
1297
input_lengths = _conv_out_length (input_lengths , kernel_size , stride )
1301
1298
1299
+ if add_adapter :
1300
+ for _ in range (self .config .num_adapter_layers ):
1301
+ input_lengths = _conv_out_length (input_lengths , 1 , self .config .adapter_stride )
1302
+
1302
1303
return input_lengths
1303
1304
1304
1305
0 commit comments