Skip to content

Conversation

stancld
Copy link
Contributor

@stancld stancld commented Apr 14, 2022

What does this PR do?

Fixes #16681

This PR adds PyTorch and Flax implementation of the LongT5 model. (TensorFlow implementation is omitted for this PR as it requires another round of reviews. However, I'm willing to work on the TF side in another PR as well)

This PR adds LongT5 model according to the original Google's paper LongT5: Efficient Text-To-Text Transformer for Long Sequences.

PyTorch implementation

  • Local Attention
  • Transient-Global Attention

Flax implementation

  • Local Attention
  • Transient-Global Attention

t5x - HF equivalence

Model equivalence is investigated in my repo here.

  • Local Attention (looks promising right now)
  • Transient-Global Attention (it looks like there's a problem with the calculation of side_position_bias)

Other features

  • Compatibility with a standard T5 model checkpoints

Original checkpoints converted to the HF format can be temporarily found on the HF hub:

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests for the Local model?
  • Did you write any new necessary tests for the TGlobal model?
  • Did you update the results of slow/tooslow tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@patrickvonplaten @patil-suraj

More information will be added

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 15, 2022

The documentation is not available anymore as the PR was closed or merged.

@stancld stancld changed the title [WIP] Add LongT5 model Add LongT5 model Apr 19, 2022
@stancld
Copy link
Contributor Author

stancld commented Apr 19, 2022

Hello @patrickvonplaten and @patil-suraj. I will need to go through the code a few times again to polish it. However, as this is the very first time for me to add the new model into transformers, I'd like to kindly ask you if you can provide me with preliminary feedback just to know if there's anything missing or so :]

As indicated in the PR description, there's a glitch regarding the calculation side_position_bias. I'll try to think about this part more.

Q1: Afaik, LongT5 uses the same tokenizer as T5. Is it, therefore, right not to add any new tokenizer?
Q2: Is it okay to have a single model class both for a model with local attention and transient-global attention, or is it prefered to split this into separate classes?

@stancld stancld marked this pull request as ready for review April 19, 2022 18:26
@patrickvonplaten
Copy link
Contributor

Sorry for being a bit late here - answering tomorrow!

@PhungVanDuy
Copy link
Contributor

@stancld Do you have any plans for a LongT5 release? I'm really looking forward to being able to replace the LED model with LongT5. Thank you so much for your effort.

@PhungVanDuy
Copy link
Contributor

@stancld I tried to use this PR for seq2seq training, I got this bug, can you check this ?

My code:

    tokenizer = AutoTokenizer.from_pretrained("t5-large")
    model = AutoModelForSeq2SeqLM.from_pretrained("Stancld/LongT5-Local-Base")

    rouge = load_metric("rouge")
    bleu = load_metric("bleu")

    train_dataset = Seq2SeqDataset("../data/train.pkl", tokenizer)
    val_dataset = Seq2SeqDataset("../data/val.pkl", tokenizer)

    # instantiate trainer
    trainer = Seq2SeqTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        compute_metrics=compute_metrics,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
    )

This error that I got:

Traceback (most recent call last):
  File "seq2seq_train.py", line 105, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.8/site-packages/transformers-4.19.0.dev0-py3.8.egg/transformers/trainer.py", line 1428, in train
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/lib/python3.8/site-packages/transformers-4.19.0.dev0-py3.8.egg/transformers/trainer.py", line 2019, in training_step
    loss = self.compute_loss(model, inputs)
  File "/opt/conda/lib/python3.8/site-packages/transformers-4.19.0.dev0-py3.8.egg/transformers/trainer.py", line 2051, in compute_loss
    outputs = model(**inputs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 797, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers-4.19.0.dev0-py3.8.egg/transformers/models/longt5/modeling_longt5.py", line 2285, in forward
    encoder_outputs = self.encoder(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers-4.19.0.dev0-py3.8.egg/transformers/models/longt5/modeling_longt5.py", line 1636, in forward
    extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
  File "/opt/conda/lib/python3.8/site-packages/transformers-4.19.0.dev0-py3.8.egg/transformers/models/longt5/modeling_longt5.py", line 204, in _get_local_attention_mask
    local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
  File "/opt/conda/lib/python3.8/site-packages/transformers-4.19.0.dev0-py3.8.egg/transformers/models/longt5/modeling_longt5.py", line 189, in _mask_local_attention_mask
    return torch.logical_and(local_attention_mask, locality_mask)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@stancld
Copy link
Contributor Author

stancld commented Apr 21, 2022

@PhungVanDuy - Thanks for the pointer! I haven't tested my code on a GPU before. Should work fine with the new commit :]

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @stancld ! Didn't thoroughly reviewed it, but from a first look it looks good, and nothing seems to be missing in terms of implementation.

Q1: Afaik, LongT5 uses the same tokenizer as T5. Is it, therefore, right not to add any new tokenizer?

Yes, in this case there is no need to add a new tokenizer. There are quite a few models in the lib that share tokenizers for ex: GP2, GPTNeo/GPTNeoX, GPT-j etc.

Q2: Is it okay to have a single model class both for a model with local attention and transient-global attention, or is it prefered to split this into separate classes?

IMO it's not necessary split this into several classes, the attention type can be specified in the config.

@marksverdhei
Copy link

marksverdhei commented Apr 21, 2022

Is this model supposed to be able to load any T5 checkpoint using from_pretrained()? If not from pre-trained, does this PR provide any other ways to do it?
Although I might not have fully understood how to use the model yet, I wanted to check out if it was supported,
but it seems that it is not able to load any of the attention weights:

model = LongT5ForConditionalGeneration.from_pretrained("allenai/unifiedqa-t5-base")

You are using a model of type t5 to instantiate a model of type longt5. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at allenai/unifiedqa-t5-base were not used when initializing LongT5ForConditionalGeneration: ['encoder.block.0.layer.0.SelfAttention.v.weight', 'encoder.block.2.layer.0.SelfAttention.k.weight', 'encoder.block.2.layer.0.SelfAttention.v.weight', 'encoder.block.4.layer.0.SelfAttention.o.weight', 'encoder.block.11.layer.0.SelfAttention.o.weight', 'encoder.block.3.layer.0.SelfAttention.k.weight', 'encoder.block.10.layer.0.SelfAttention.o.weight', 'encoder.block.4.layer.0.SelfAttention.v.weight', 'encoder.block.8.layer.0.SelfAttention.k.weight',
...

@stancld
Copy link
Contributor Author

stancld commented Apr 21, 2022

Is this model supposed to be able to load any T5 checkpoint using from_pretrained()? If not from pre-trained, does this PR provide any other ways to do it? Although I might not have fully understood how to use the model yet, I wanted to check out if it was supported, but it seems that it is not able to load any of the attention weights:

model = LongT5ForConditionalGeneration.from_pretrained("allenai/unifiedqa-t5-base")

You are using a model of type t5 to instantiate a model of type longt5. This is not supported for all configurations of models and can yield errors.
Some weights of the model checkpoint at allenai/unifiedqa-t5-base were not used when initializing LongT5ForConditionalGeneration: ['encoder.block.0.layer.0.SelfAttention.v.weight', 'encoder.block.2.layer.0.SelfAttention.k.weight', 'encoder.block.2.layer.0.SelfAttention.v.weight', 'encoder.block.4.layer.0.SelfAttention.o.weight', 'encoder.block.11.layer.0.SelfAttention.o.weight', 'encoder.block.3.layer.0.SelfAttention.k.weight', 'encoder.block.10.layer.0.SelfAttention.o.weight', 'encoder.block.4.layer.0.SelfAttention.v.weight', 'encoder.block.8.layer.0.SelfAttention.k.weight',
...

It's a good point to ensure compatibility with T5 checkpoints. I have some ideas in my mind on how to make this possible, but maybe let's wait for some code review first. But thanks a lot for pointing this out! :]

@PhungVanDuy
Copy link
Contributor

@stancld Thank for you quick fix, I also found some same error with Large model, I also fixed it, but I guess still have few problem with model, when I tried to train a seq2seq model with same code base above, I got this logs.

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.0009765625                                                                                                                                                 [55/1886]
  0%|                                                                                                                                                                                                 | 10/338508 [00:32<281:22:30,  2.99s/it]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.000244140625
  0%|                                                                                                                                                                                                 | 11/338508 [00:35<279:30:47,  2.97s/it]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 3.0517578125e-05
  0%|                                                                                                                                                                                                 | 12/338508 [00:38<287:14:53,  3.05s/it]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 3.814697265625e-06
  0%|                                                                                                                                                                                                 | 13/338508 [00:41<286:40:39,  3.05s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 9.5367431640625e-07
  0%|                                                                                                                                                                                                 | 14/338508 [00:44<283:07:13,  3.01s/it]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1.1920928955078125e-07
  0%|                                                                                                                                                                                                 | 15/338508 [00:47<281:09:21,  2.99s/it]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 5.960464477539063e-08
  0%|                                                                                                                                                                                                 | 16/338508 [00:50<279:44:32,  2.98s/it]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 3.725290298461914e-09
  0%|                                                                                                                                                                                                 | 17/338508 [00:53<280:01:09,  2.98s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4.656612873077393e-10
  0%|                                                                                                                                                                                                 | 18/338508 [00:56<279:24:12,  2.97s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 5.820766091346741e-11
  0%|                                                                                                                                                                                                 | 19/338508 [00:59<281:12:31,  2.99s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 3.637978807091713e-12
{'loss': 4615.4496, 'learning_rate': 9.99940917201366e-05, 'epoch': 0.0}
  0%|                                                                                                                                                                                                 | 20/338508 [01:02<285:15:29,  3.03s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1.8189894035458565e-12
  0%|                                                                                                                                                                                                 | 21/338508 [01:05<284:58:54,  3.03s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4.547473508864641e-13
  0%|                                                                                                                                                                                                 | 22/338508 [01:08<282:52:19,  3.01s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1.1368683772161603e-13
  0%|                                                                                                                                                                                                 | 23/338508 [01:11<281:55:16,  3.00s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1.4210854715202004e-14
  0%|                                                                                                                                                                                                 | 24/338508 [01:14<281:53:37,  3.00s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1.7763568394002505e-15
  0%|                                                                                                                                                                                                 | 25/338508 [01:21<396:36:35,  4.22s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2.220446049250313e-16
  0%|                                                                                                                                                                                                 | 26/338508 [01:24<361:24:43,  3.84s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 5.551115123125783e-17
  0%|                                                                                                                                                                                                 | 27/338508 [01:27<336:14:37,  3.58s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 6.938893903907228e-18
  0%|                                                                                                                                                                                                 | 28/338508 [01:30<319:43:38,  3.40s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8.673617379884035e-19
  0%|                                                                                                                                                                                                 | 29/338508 [01:33<306:31:46,  3.26s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1.0842021724855044e-19
  0%|                                                                                                                                                                                                 | 30/338508 [01:36<299:47:18,  3.19s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 5.421010862427522e-20
  0%|                                                                                                                                                                                                 | 31/338508 [01:39<293:40:56,  3.12s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2.710505431213761e-20
  0%|                                                                                                                                                                                                 | 32/338508 [01:42<288:47:18,  3.07s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1.3552527156068805e-20
  0%|                                                                                                                                                                                                 | 33/338508 [01:45<288:19:34,  3.07s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1.6940658945086007e-21
  0%|                                                                                                                                                                                                 | 34/338508 [01:48<288:35:55,  3.07s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2.117582368135751e-22
  0%|                                                                                                                                                                                                 | 35/338508 [01:51<291:35:38,  3.10s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 5.293955920339377e-23
  0%|                                                                                                                                                                                                 | 36/338508 [01:54<286:56:54,  3.05s/it]
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 6.617444900424222e-24
  0%|                                                                                                                                                                                                 | 37/338508 [01:57<289:17:22,  3.08s/it```

@patrickvonplaten
Copy link
Contributor

Hey @stancld, @patrickvonplaten ! Finally, was able to find and fix the subtle bugs. Here the changes I made.

In both local and TGlobal attention

  • set masked value to -1e10 to match the encoder output.

The logits will still match even with -1e4 but the encoder_outputs won't. Since the encoder can handle really large text I think it's important that the encouder_outputs match, if someone wants to use only the encoder. wdyt @patrickvonplaten ?

In Local attention

  • Fixed computing relative_position to match the encoder outputs and original implementation.
  • In the ported model, the lm_head weights didn't match. To verify the outputs we need to set the correct weights for lm_head

In TGlobal attention

  • Fix global_segment_ids:
    In the original flax codebase the global_segment_ids are always either 1 or 0.
    The global_segment_ids are set 0 for orphan tokens and padded tokens and the rest are set to 1. Instead of 0 to _sequence_block_ids_max.
  • fix global_block_ids.
    The global_block_ids are not computed correctly when seq_length >= 16384 and attention_mask is passed with 0's in it. This change explicitly sets the padded position to -1 to match the original implementation.

If you want to verify:

  • Follow the instructions in @stancld's repo.
  • Set the activation_dtype in LongT5 gin configs to float32 to compare the outputs with torch.

https://github.com/google/flaxformer/blob/826c45c9cc14cee0f906b7c4b6d041f08f8ece5d/flaxformer/t5x/configs/longt5/architectures/longt5_1_1_transient_global_flaxformer.gin#L37

https://github.com/google/flaxformer/blob/main/flaxformer/t5x/configs/longt5/architectures/longt5_1_1_flaxformer.gin#L37

import gin
import jax.numpy as jnp
import numpy as np
import torch

import t5x

from transformers import AutoModelForSeq2SeqLM, FlaxAutoModelForSeq2SeqLM

# modify this path according to your setup
home = "/home/suraj_huggingface_co/longt5-debug/longt5-eval"

config_file = f"{home}/flaxformer/t5x/configs/longt5/models/longt5_1_1_transient_global_base.gin"
checkpoint_dir = f"{home}/google-checkpoints/LongT5-TGlobal-Base"
hf_model_path = "Stancld/LongT5-TGlobal-Base"


# Parse config file
with open(config_file) as bindings:
    gin.parse_config(bindings)
gin.finalize()

# Get model
model_config_ref = gin.query_parameter("%MODEL")
model = model_config_ref.scoped_configurable_fn()


# Load checkpoint
t5x_checkpoint = t5x.checkpoints.load_t5x_checkpoint(checkpoint_dir)

pt_model = AutoModelForSeq2SeqLM.from_pretrained(hf_model_path)

# for local attention model set the correct weights for `lm_head`
# pt_model.lm_head.weight.data = torch.from_numpy(t5x_checkpoint["target"]["decoder"]["logits_dense"]["kernel"].T)


enc_seq_length = 2048
seq_length = 10

enc_shape = [2, enc_seq_length]
shape = [2, seq_length]

encoder_input_tokens = np.ones(enc_shape, dtype=np.int32)
decoder_input_tokens = np.ones(shape, dtype=np.int32)
decoder_target_tokens = np.ones(shape, dtype=np.int32)
attention_mask =  np.ones(enc_shape, dtype=np.int32)

# # add some zeros as padding tokens
import random
mask_idx = random.randrange(10, enc_seq_length)
encoder_input_tokens[0, mask_idx:] = 0
attention_mask[0, mask_idx:] = 0

mask_idx = random.randrange(10, enc_seq_length)
encoder_input_tokens[1, mask_idx:] = 0
attention_mask[1, mask_idx:] = 0

decoder_input_tokens[:, seq_length-2:] = 0
decoder_target_tokens[:, seq_length-2:] = 0


# Run forward pass
print("~~~~~~~~~~ FlaxForrmer ~~~~~~~~~~~~")
t5x_logits, mod_vars = model.module.apply(
    {"params": t5x_checkpoint["target"]},
    encoder_input_tokens=encoder_input_tokens,
    decoder_input_tokens=decoder_input_tokens,
    decoder_target_tokens=decoder_target_tokens,
    enable_dropout=False,
    mutable='intermediates'
)


print("~~~~~~~~~ HF PyTorch ~~~~~~~~~~~~~")
with torch.no_grad():
    pt_output = pt_model(
        # encoder_outputs=(torch.from_numpy(encoder_output).float(),),
        input_ids=torch.from_numpy(encoder_input_tokens).long(),
        attention_mask=torch.from_numpy(attention_mask).long(),
        decoder_input_ids=torch.from_numpy(decoder_target_tokens).long(),
        output_hidden_states = True,
        output_attentions = True,
    )

# print(pt_output.shape)
print("~~~~~~~~~~~~~~~~~~~~~~")


# verify if `logits` match
np.allclose(pt_output.logits.numpy()[:, :-mask_idx, ...], t5x_logits[:, :-mask_idx, ...], atol=1e-3)

Let me know, if you try it find some issues with it. Now going to check the flax implementation. Once that's done will ping you for review @patrickvonplaten :)

That looks good to me!

@patrickvonplaten
Copy link
Contributor

@stancld looks like the pipeline failures are not flaky and related to this PR - do you need help with solving them? They can be a bit tricky!

@patil-suraj
Copy link
Contributor

Will take a look at pipeline failures.

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @stancld , this looks good for merge now. Just left some nits

  • we should update the path of the model to google org.
  • Add example in doc using the finetuned checkpoint
  • update summarization test
  • remove the translation tests.

@patrickvonplaten could you also take a final look here ?

@stancld stancld mentioned this pull request Jun 13, 2022
4 tasks
@patrickvonplaten patrickvonplaten merged commit a72f1c9 into huggingface:main Jun 13, 2022
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jun 16, 2022
* Initial commit

* Make some fixes

* Make PT model full forward pass

* Drop TF & Flax implementation, fix copies etc

* Add Flax model and update some corresponding stuff

* Drop some TF things

* Update config and flax local attn

* Add encoder_attention_type to config

* .

* Update docs

* Do some cleansing

* Fix some issues -> make style; add some docs

* Fix position_bias + mask addition + Update tests

* Fix repo consistency

* Fix model consistency by removing flax operation over attn_mask

* [WIP] Add PT TGlobal LongT5

* .

* [WIP] Add flax tglobal model

* [WIP] Update flax model to use the right attention type in the encoder

* Fix flax tglobal model forward pass

* Make the use of global_relative_attention_bias

* Add test suites for TGlobal model

* Fix minor bugs, clean code

* Fix pt-flax equivalence though not convinced with correctness

* Fix LocalAttn implementation to match the original impl. + update READMEs

* Few updates

* Update: [Flax] improve large model init and loading huggingface#16148

* Add ckpt conversion script accoring to huggingface#16853 + handle torch device placement

* Minor updates to conversion script.

* Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM

* gpu support + dtype fix

* Apply some suggestions from code review

Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>

* * Remove (de)parallelize stuff
* Edit shape comments
* Update README.md
* make fix-copies

* Remove caching logic for local & tglobal attention

* Apply another batch of suggestions from code review

* Add missing checkpoints
* Format converting scripts
* Drop (de)parallelize links from longT5 mdx

* Fix converting script + revert config file change

* Revert "Remove caching logic for local & tglobal attention"

This reverts commit 2a61982.

* Stash caching logic in Flax model

* Make side relative bias used always

* Drop caching logic in PT model

* Return side bias as it was

* Drop all remaining model parallel logic

* Remove clamp statements

* Move test files to the proper place

* Update docs with new version of hf-doc-builder

* Fix test imports

* Make some minor improvements

* Add missing checkpoints to docs
* Make TGlobal model compatible with torch.onnx.export
* Replace some np.ndarray with jnp.ndarray

* Fix TGlobal for ONNX conversion + update docs

* fix _make_global_fixed_block_ids and masked neg  value

* update flax model

* style and quality

* fix imports

* remove load_tf_weights_in_longt5 from init and fix copies

* add slow test for TGlobal model

* typo fix

* Drop obsolete is_parallelizable and one warning

* Update __init__ files to fix repo-consistency

* fix pipeline test

* Fix some device placements

* [wip]: Update tests -- need to generate summaries to update expected_summary

* Fix quality

* Update LongT5 model card

* Update (slow) summarization tests

* make style

* rename checkpoitns

* finish

* fix flax tests

Co-authored-by: phungvanduy <[email protected]>
Co-authored-by: Sylvain Gugger <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: patil-suraj <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LongT5: Efficient Text-To-Text Transformer for Long Sequences