@@ -226,11 +226,15 @@ def setup(self):
226
226
else :
227
227
self .enc_to_dec_proj = None
228
228
229
- def _get_feat_extract_output_lengths (self , input_lengths : Union [jnp .ndarray , int ]):
229
+ def _get_feat_extract_output_lengths (
230
+ self , input_lengths : Union [jnp .ndarray , int ], add_adapter : Optional [bool ] = None
231
+ ):
230
232
"""
231
233
Computes the output length of the convolutional layers
232
234
"""
233
235
236
+ add_adapter = self .config .encoder .add_adapter if add_adapter is None else add_adapter
237
+
234
238
def _conv_out_length (input_length , kernel_size , stride ):
235
239
# 1D convolutional layer output length formula taken
236
240
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
@@ -239,6 +243,10 @@ def _conv_out_length(input_length, kernel_size, stride):
239
243
for kernel_size , stride in zip (self .config .encoder .conv_kernel , self .config .encoder .conv_stride ):
240
244
input_lengths = _conv_out_length (input_lengths , kernel_size , stride )
241
245
246
+ if add_adapter :
247
+ for _ in range (self .config .encoder .num_adapter_layers ):
248
+ input_lengths = _conv_out_length (input_lengths , 1 , self .config .encoder .adapter_stride )
249
+
242
250
return input_lengths
243
251
244
252
def _get_encoder_module (self ):
@@ -432,8 +440,10 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_
432
440
)
433
441
return unfreeze (init_variables ["cache" ])
434
442
435
- def _get_feat_extract_output_lengths (self , input_lengths : Union [jnp .ndarray , int ]):
436
- return self .module ._get_feat_extract_output_lengths (input_lengths )
443
+ def _get_feat_extract_output_lengths (
444
+ self , input_lengths : Union [jnp .ndarray , int ], add_adapter : Optional [bool ] = None
445
+ ):
446
+ return self .module ._get_feat_extract_output_lengths (input_lengths , add_adapter = add_adapter )
437
447
438
448
@add_start_docstrings (SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING )
439
449
@replace_return_docstrings (output_type = FlaxBaseModelOutput , config_class = _CONFIG_FOR_DOC )
0 commit comments