[Flax(Speech)EncoderDecoder] Fix bug in decoder_module
#17036
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The current use of
decoder_module
assumes thatencoder_hidden_states
is the fourth positional argument of the decoder's call method. We see that this is indeed true of the two current Flax decoder models:FlaxGPT2LMHeadModel
andFlaxBartForCausalLM
. However, for other possible decoder models, such as the work-in-progressFlaxBertForCausalLM
, there may be additional positional arguments (such astoken_type_ids
orhead_mask
) prior toencoder_hidden_states
. To handle this more general case, we should not assumeencoder_hidden_states
is necessarily the fourth positional argument, and should instead pass it as a key-word argument.