-
Notifications
You must be signed in to change notification settings - Fork 30.3k
Add LongT5
model
#16792
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LongT5
model
#16792
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Hello @patrickvonplaten and @patil-suraj. I will need to go through the code a few times again to polish it. However, as this is the very first time for me to add the new model into As indicated in the PR description, there's a glitch regarding the calculation Q1: Afaik, |
Sorry for being a bit late here - answering tomorrow! |
@stancld Do you have any plans for a LongT5 release? I'm really looking forward to being able to replace the LED model with LongT5. Thank you so much for your effort. |
@stancld I tried to use this PR for seq2seq training, I got this bug, can you check this ? My code: tokenizer = AutoTokenizer.from_pretrained("t5-large")
model = AutoModelForSeq2SeqLM.from_pretrained("Stancld/LongT5-Local-Base")
rouge = load_metric("rouge")
bleu = load_metric("bleu")
train_dataset = Seq2SeqDataset("../data/train.pkl", tokenizer)
val_dataset = Seq2SeqDataset("../data/val.pkl", tokenizer)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=val_dataset,
) This error that I got:
|
@PhungVanDuy - Thanks for the pointer! I haven't tested my code on a GPU before. Should work fine with the new commit :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @stancld ! Didn't thoroughly reviewed it, but from a first look it looks good, and nothing seems to be missing in terms of implementation.
Q1: Afaik, LongT5 uses the same tokenizer as T5. Is it, therefore, right not to add any new tokenizer?
Yes, in this case there is no need to add a new tokenizer. There are quite a few models in the lib that share tokenizers for ex: GP2, GPTNeo/GPTNeoX, GPT-j etc.
Q2: Is it okay to have a single model class both for a model with local attention and transient-global attention, or is it prefered to split this into separate classes?
IMO it's not necessary split this into several classes, the attention type can be specified in the config.
Is this model supposed to be able to load any T5 checkpoint using
|
It's a good point to ensure compatibility with |
@stancld Thank for you quick fix, I also found some same error with Large model, I also fixed it, but I guess still have few problem with model, when I tried to train a seq2seq model with same code base above, I got this logs.
|
That looks good to me! |
@stancld looks like the pipeline failures are not flaky and related to this PR - do you need help with solving them? They can be a bit tricky! |
Will take a look at pipeline failures. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @stancld , this looks good for merge now. Just left some nits
- we should update the path of the model to
google
org. - Add example in doc using the finetuned checkpoint
- update summarization test
- remove the translation tests.
@patrickvonplaten could you also take a final look here ?
* Initial commit * Make some fixes * Make PT model full forward pass * Drop TF & Flax implementation, fix copies etc * Add Flax model and update some corresponding stuff * Drop some TF things * Update config and flax local attn * Add encoder_attention_type to config * . * Update docs * Do some cleansing * Fix some issues -> make style; add some docs * Fix position_bias + mask addition + Update tests * Fix repo consistency * Fix model consistency by removing flax operation over attn_mask * [WIP] Add PT TGlobal LongT5 * . * [WIP] Add flax tglobal model * [WIP] Update flax model to use the right attention type in the encoder * Fix flax tglobal model forward pass * Make the use of global_relative_attention_bias * Add test suites for TGlobal model * Fix minor bugs, clean code * Fix pt-flax equivalence though not convinced with correctness * Fix LocalAttn implementation to match the original impl. + update READMEs * Few updates * Update: [Flax] improve large model init and loading huggingface#16148 * Add ckpt conversion script accoring to huggingface#16853 + handle torch device placement * Minor updates to conversion script. * Typo: AutoModelForSeq2SeqLM -> FlaxAutoModelForSeq2SeqLM * gpu support + dtype fix * Apply some suggestions from code review Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> * * Remove (de)parallelize stuff * Edit shape comments * Update README.md * make fix-copies * Remove caching logic for local & tglobal attention * Apply another batch of suggestions from code review * Add missing checkpoints * Format converting scripts * Drop (de)parallelize links from longT5 mdx * Fix converting script + revert config file change * Revert "Remove caching logic for local & tglobal attention" This reverts commit 2a61982. * Stash caching logic in Flax model * Make side relative bias used always * Drop caching logic in PT model * Return side bias as it was * Drop all remaining model parallel logic * Remove clamp statements * Move test files to the proper place * Update docs with new version of hf-doc-builder * Fix test imports * Make some minor improvements * Add missing checkpoints to docs * Make TGlobal model compatible with torch.onnx.export * Replace some np.ndarray with jnp.ndarray * Fix TGlobal for ONNX conversion + update docs * fix _make_global_fixed_block_ids and masked neg value * update flax model * style and quality * fix imports * remove load_tf_weights_in_longt5 from init and fix copies * add slow test for TGlobal model * typo fix * Drop obsolete is_parallelizable and one warning * Update __init__ files to fix repo-consistency * fix pipeline test * Fix some device placements * [wip]: Update tests -- need to generate summaries to update expected_summary * Fix quality * Update LongT5 model card * Update (slow) summarization tests * make style * rename checkpoitns * finish * fix flax tests Co-authored-by: phungvanduy <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: patil-suraj <[email protected]>
What does this PR do?
Fixes #16681
This PR adds
PyTorch
andFlax
implementation of theLongT5
model. (TensorFlow implementation is omitted for this PR as it requires another round of reviews. However, I'm willing to work on the TF side in another PR as well)This PR adds
LongT5
model according to the original Google's paper LongT5: Efficient Text-To-Text Transformer for Long Sequences.PyTorch implementation
Flax implementation
t5x - HF equivalence
Model equivalence is investigated in my repo here.
side_position_bias
)Other features
Original checkpoints converted to the HF format can be temporarily found on the HF hub:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@patrickvonplaten @patil-suraj
More information will be added