Skip to content

Commit b67ead2

Browse files
sanchit-gandhielusenji
authored andcommitted
[FlaxSpeechEncoderDecoder] Fix input shape bug in weights init (huggingface#16728)
* [FlaxSpeechEncoderDecoder] Fix input shape bug in weights init * make style
1 parent e56de4f commit b67ead2

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,15 @@ def setup(self):
226226
else:
227227
self.enc_to_dec_proj = None
228228

229-
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
229+
def _get_feat_extract_output_lengths(
230+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
231+
):
230232
"""
231233
Computes the output length of the convolutional layers
232234
"""
233235

236+
add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter
237+
234238
def _conv_out_length(input_length, kernel_size, stride):
235239
# 1D convolutional layer output length formula taken
236240
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
@@ -239,6 +243,10 @@ def _conv_out_length(input_length, kernel_size, stride):
239243
for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride):
240244
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
241245

246+
if add_adapter:
247+
for _ in range(self.config.encoder.num_adapter_layers):
248+
input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride)
249+
242250
return input_lengths
243251

244252
def _get_encoder_module(self):
@@ -432,8 +440,10 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_
432440
)
433441
return unfreeze(init_variables["cache"])
434442

435-
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
436-
return self.module._get_feat_extract_output_lengths(input_lengths)
443+
def _get_feat_extract_output_lengths(
444+
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
445+
):
446+
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
437447

438448
@add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING)
439449
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC)

0 commit comments

Comments
 (0)