Skip to content

Conversation

JohnGiorgi
Copy link
Contributor

@JohnGiorgi JohnGiorgi commented Mar 29, 2022

What does this PR do?

Certain Seq2Seq models (e.g. LED-based models such as PRIMERA) need to pass the global_attention_mask to model.generate() so that global attention is computed for particular tokens when decoding. This does not currently happen in Seq2SeqTrainer, but can easily be added by looking for global_attention_mask in the provided inputs, and adding them to gen_kwargs, much the same way as the regular attention_mask is currently handled. This PR does exactly that.

Other changes

  • Fixed a small typo in one of the comments in transformers/src/transformers/trainer_seq2seq.py.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sgugger, @patrickvonplaten

If global_attention_mask is found in the models inputs (used by certain
models, like LED) in the prediction_step method of Seq2SeqTrainer,
it is added to the gen_kwargs, which are passed to model.decode().
This allows us to properly set the global attention when decoding.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 29, 2022

The documentation is not available anymore as the PR was closed or merged.

@sgugger sgugger requested a review from patil-suraj March 29, 2022 22:01
@sgugger
Copy link
Collaborator

sgugger commented Mar 29, 2022

Not expert enough in generate to review this and @patrickvonplaten is on vacation, so waiting for @patil-suraj review :-)

Thanks a lot for your PR!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me - this will indeed enable generation for LED.

If @sgugger is ok with adding this somewhat model-specific line to the Trainer, the PR is good to go for me.

@sgugger sgugger merged commit b33ab4e into huggingface:main Apr 5, 2022
@caesar-one
Copy link
Contributor

caesar-one commented Apr 6, 2022

Hi @JohnGiorgi and @patrickvonplaten,

using model.generate(...) the model doesn't receive global_attention_mask anyway for me, I think it would be appropriate to change also the LEDForConditionalGeneration.prepare_inputs_for_generation(...) method (here) by adding the support for the global_attention_mask, with something like:

    def prepare_inputs_for_generation(
        self,
        decoder_input_ids,
        past=None,
        attention_mask=None,
        global_attention_mask=None,  ### ADDED ###
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past is used
        if past is not None:
            decoder_input_ids = decoder_input_ids[:, -1:]

        return {
            "input_ids": None,  # encoder_outputs is defined. input_ids not needed
            "encoder_outputs": encoder_outputs,
            "past_key_values": past,
            "decoder_input_ids": decoder_input_ids,
            "attention_mask": attention_mask,
            "global_attention_mask": global_attention_mask,  ### ADDED ###
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,  # change this to avoid caching (presumably for debugging)
        }

Just in case this is correct, should I open a new pull request for this? Thanks

@patrickvonplaten
Copy link
Contributor

Good point @caesar-one ! Yes, it would be nice if you could open a new PR for this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants