-
Notifications
You must be signed in to change notification settings - Fork 2.1k
📦 Packing with flash attn kwargs to avoid cross-contamination #3526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
📦 Packing with flash attn kwargs to avoid cross-contamination #3526
Conversation
Super cool! Can you change the base branch? I think it doesn't requires transformers change if we play with signature columns, let check |
I missed the signature columns, so now it should work :) Updated plots, now the difference is visible. |
@qgallouedec needs to re-open with main branch now |
Wait, it wasn't automatically rebased? That's what usually happens.🤨 |
trl/data_utils.py
Outdated
@@ -478,6 +479,67 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, | |||
return examples | |||
|
|||
|
|||
@numba.njit(["(int32[:], int32)", "(int64[:], int64)"], cache=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mind reverting this change? so that we keep PRs separate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually that might not be that easy, your 2 PRs are quite intertwined right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, but I will think about reverting, sorry that I haven't done that before
this should be possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
trl/data_utils.py
Outdated
packed_columns.append(column) | ||
packed_columns.append(sequence_lengths) | ||
else: | ||
packed_columns.append(column) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work as well?
packed_columns.append(column) | |
packed_columns.append(sequence_lengths) | |
else: | |
packed_columns.append(column) | |
packed_columns.append(sequence_lengths) | |
packed_columns.append(column) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually it seems like sequence_lengths
can be appended more than once, if you've more than one col that matches if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
(which is almost always the case)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yea makes sense, and I also think ordering matters right? we could just ensure that we added this once and in the end, I guess?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
trl/trainer/sft_trainer.py
Outdated
# For the packing case with FFD, we need to store sequence_length returned by the data collator with flattening | ||
if self.args.packing and self.args.packing_strategy == "ffd" and self.args.padding_free: | ||
self._signature_columns.append("sequence_length") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
I seems like it doesn't work, I'm not sure why exactly. It seems easier to directly build |
trl/trainer/sft_trainer.py
Outdated
@@ -293,7 +293,7 @@ def __init__( | |||
if args.padding_free: | |||
if data_collator is not None: | |||
raise ValueError("Passing a custom data collator is not supported when using padding-free.") | |||
if args.packing: | |||
if args.packing and args.packing_strategy != "ffd": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if args.packing and args.packing_strategy != "ffd": | |
if args.packing: |
actually padding_free is a different method (which is less relevant once once we have ffd with flash-attn).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean if we use packing with flash attn, we inherently already using padding free? If we pass position_ids
Probably we could change warning to say something like "when FFD packing is enabled, model accepts position_ids which works the same way as padding-free, so specifying padding_free makes no effect"
trl/trainer/sft_trainer.py
Outdated
if args.packing and model.config._attn_implementation != "flash_attention_2": | ||
warnings.warn( | ||
"You are using packing with padding-free training, but the attention implementation is not set to " | ||
"'flash_attention_2'. Packing flattens batches into a single sequence, and 'flash_attention_2' is " | ||
"the only known attention mechanism that reliably supports this. Using other implementations may " | ||
"lead to unexpected behavior. To ensure compatibility, set `attn_implementation='flash_attention_2'` " | ||
"in the model configuration." | ||
) | ||
data_collator = DataCollatorWithFlattening( | ||
return_flash_attn_kwargs=False, | ||
return_position_ids=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest dedenting this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure what would happen if user do not specifies flash attention 2 implementation - right now we just won't use padding free with position ids.
If we de-dent - we enable DataCollatorWithFlattening for any cases?
Ah, I see the problem. After the recent Attention refactoring, all models still accept position_ids as an argument. After the investigation, here's the attention flow
The problem, however, is that for some models attention_mask is always not None. I have a hack locally for going into flash_attn_varlen branch inside |
In addition, adding sequence length is incorrect after resolving conflict @qgallouedec
Reverting helped |
For reference: thepowerfuldeez#1 |
False alarm. I just tested training with aforementioned branch + transformers main, and we are calling flash_attention_varlen using position_ids. Nice and elegant solution by preparing position_ids during packing + flatten using existing data collator. However, this approach uses more GPU memory, which is surprising |
I'm realising that using position_ids when the attention implementation is not flash-attn hurts the result quite a lot. I'm adding a way use position_ids or not |
Now it looks good! (purple is the annoying setting where position_ids are passed, even when flash-attn is disable: not the case anymore after 9f4d9ee) ![]() |
Thanks a lot @thepowerfuldeez, I'll merge now, but feel free to make additional comments |
Hi @qgallouedec @thepowerfuldeez! I just had a clarification on how packing is implemented with max_length. If there is a datum that can fit in only partially, will that be forced in with truncation, or will that belong to the next sequence? I know the non-packing max_length behavior is truncation. If it doesn't truncate, then how does packing handle whole datums that are longer than the set max_length? |
Requires huggingface/transformers#38536
Requires #3521
Modifies packing so that flash attention kernel is aware of sequence boundaries, which leads to improved sparsity and quality. Works only for SFT.
Loss plots:

(Blue is the new version, brown is the old one)
Achieved by this training config: