-
Notifications
You must be signed in to change notification settings - Fork 2.1k
📉 FFD packing #3521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
📉 FFD packing #3521
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
Awesome work!
I believe there are two things that still need to implemented in future PRs to make packing comparable to non-packing:
- Correct Sequence IDs, currently the sequence IDs are created on the fly in the modelling code, so we do not account for the actual position of the packed sequences. This could be precomputed as passed to the model.
- 4D attention masks, many transformers implementations now support 4D attention masks in their model signature. This would mitigate potential issues from cross-attention between packed sequences.
To provide more details on the attn masks. Flash attn, for example, requires cumulative_seqlens
which correspond to the start and end indices of each sequences in the (unflattened) batch, this can bepassed as flash_attn_kwargs
Links to example in the Qwen modelling code, for reference:
I could take a look on passing kwargs to flash attn kernel + making packing more efficient today! Great PR, long awaited. |
This PR introduces a new packing strategy, FFD (First Fit Decreasing).
Advantages:
Drawbacks:
max_length
discardedBenchmark
Speed
Time to pack a dataset containing 100k rows (hardly correlated to
max_length
)So it's way slower (~30 times) but still very reasonable
Code used
Padding tokens efficiency
I compared the number of padding tokens (the fewer the better) that we ended up with for different datasets and different sequence lengths.
Code used