Skip to content

Conversation

sanchit-gandhi
Copy link
Contributor

The current use of decoder_module assumes that encoder_hidden_states is the fourth positional argument of the decoder's call method. We see that this is indeed true of the two current Flax decoder models: FlaxGPT2LMHeadModel and FlaxBartForCausalLM. However, for other possible decoder models, such as the work-in-progress FlaxBertForCausalLM, there may be additional positional arguments (such as token_type_ids or head_mask) prior to encoder_hidden_states. To handle this more general case, we should not assume encoder_hidden_states is necessarily the fourth positional argument, and should instead pass it as a key-word argument.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 2, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

@sanchit-gandhi sanchit-gandhi merged commit 93b802c into huggingface:main May 2, 2022
@sanchit-gandhi sanchit-gandhi deleted the flax-speech-encoder-decoder branch May 2, 2022 11:06
stevhliu pushed a commit to stevhliu/transformers that referenced this pull request May 3, 2022
…#17036)

* [FlaxSpeechEncoderDecoder] Fix bug in `decoder_module`

* [FlaxEncoderDecoder] Fix bug in `decoder_module`
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…#17036)

* [FlaxSpeechEncoderDecoder] Fix bug in `decoder_module`

* [FlaxEncoderDecoder] Fix bug in `decoder_module`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants