Skip to content

📦 Packing with flash attn kwargs to avoid cross-contamination #3526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

thepowerfuldeez
Copy link
Contributor

@thepowerfuldeez thepowerfuldeez commented Jun 2, 2025

Requires huggingface/transformers#38536
Requires #3521

Modifies packing so that flash attention kernel is aware of sequence boundaries, which leads to improved sparsity and quality. Works only for SFT.

Loss plots:
image

(Blue is the new version, brown is the old one)

Achieved by this training config:

training_args = SFTConfig(
    run_name=args.run_name,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    learning_rate=1e-5,
    weight_decay=1e-7,
    max_grad_norm=1.0,
    num_train_epochs=3,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    max_length=args.max_length,
    output_dir=args.output_dir,
    padding_free=True,
    eval_steps=0.05,
    logging_steps=0.01,
    save_steps=0.15,
    save_total_limit=3,
    save_strategy="steps",
    eval_strategy="steps",
    completion_only_loss=True,
    eos_token="<|im_end|>",
    include_tokens_per_second=True,
    use_liger_kernel=True,
    report_to=["wandb"],
    dataloader_num_workers=8,
    packing=True,
    eval_packing=False,
    packing_strategy="ffd",
)

@qgallouedec
Copy link
Member

Super cool! Can you change the base branch?

I think it doesn't requires transformers change if we play with signature columns, let check

@thepowerfuldeez
Copy link
Contributor Author

I missed the signature columns, so now it should work :) Updated plots, now the difference is visible.
I haven't found any data collators in trl, so I believe we would need to modify transformers in that case!

@thepowerfuldeez thepowerfuldeez changed the base branch from main to ffd_pack June 2, 2025 18:49
@qgallouedec qgallouedec deleted the branch huggingface:main June 2, 2025 20:15
@qgallouedec qgallouedec closed this Jun 2, 2025
@thepowerfuldeez
Copy link
Contributor Author

@qgallouedec needs to re-open with main branch now

@qgallouedec
Copy link
Member

Wait, it wasn't automatically rebased? That's what usually happens.🤨

@qgallouedec qgallouedec reopened this Jun 2, 2025
@qgallouedec qgallouedec changed the base branch from ffd_pack to main June 2, 2025 22:25
@@ -478,6 +479,67 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
return examples


@numba.njit(["(int32[:], int32)", "(int64[:], int64)"], cache=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind reverting this change? so that we keep PRs separate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually that might not be that easy, your 2 PRs are quite intertwined right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, but I will think about reverting, sorry that I haven't done that before
this should be possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 560 to 563
packed_columns.append(column)
packed_columns.append(sequence_lengths)
else:
packed_columns.append(column)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work as well?

Suggested change
packed_columns.append(column)
packed_columns.append(sequence_lengths)
else:
packed_columns.append(column)
packed_columns.append(sequence_lengths)
packed_columns.append(column)

Copy link
Member

@qgallouedec qgallouedec Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it seems like sequence_lengths can be appended more than once, if you've more than one col that matches if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): (which is almost always the case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yea makes sense, and I also think ordering matters right? we could just ensure that we added this once and in the end, I guess?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

Comment on lines 695 to 697
# For the packing case with FFD, we need to store sequence_length returned by the data collator with flattening
if self.args.packing and self.args.packing_strategy == "ffd" and self.args.padding_free:
self._signature_columns.append("sequence_length")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@qgallouedec
Copy link
Member

I seems like it doesn't work, I'm not sure why exactly. It seems easier to directly build position_ids from packing I think, I'm trying this approach in another branch

@@ -293,7 +293,7 @@ def __init__(
if args.padding_free:
if data_collator is not None:
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
if args.packing:
if args.packing and args.packing_strategy != "ffd":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if args.packing and args.packing_strategy != "ffd":
if args.packing:

actually padding_free is a different method (which is less relevant once once we have ffd with flash-attn).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean if we use packing with flash attn, we inherently already using padding free? If we pass position_ids
Probably we could change warning to say something like "when FFD packing is enabled, model accepts position_ids which works the same way as padding-free, so specifying padding_free makes no effect"

Comment on lines 316 to 327
if args.packing and model.config._attn_implementation != "flash_attention_2":
warnings.warn(
"You are using packing with padding-free training, but the attention implementation is not set to "
"'flash_attention_2'. Packing flattens batches into a single sequence, and 'flash_attention_2' is "
"the only known attention mechanism that reliably supports this. Using other implementations may "
"lead to unexpected behavior. To ensure compatibility, set `attn_implementation='flash_attention_2'` "
"in the model configuration."
)
data_collator = DataCollatorWithFlattening(
return_flash_attn_kwargs=False,
return_position_ids=True,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest dedenting this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what would happen if user do not specifies flash attention 2 implementation - right now we just won't use padding free with position ids.
If we de-dent - we enable DataCollatorWithFlattening for any cases?

@thepowerfuldeez
Copy link
Contributor Author

Ah, I see the problem. After the recent Attention refactoring, all models still accept position_ids as an argument.
And even if we pass position_ids to model.forward, it doesn't go to FA2 directly.

After the investigation, here's the attention flow

  1. Inside the model class we set attention interface as
    attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  2. for FA2 it calls flash_attention_forward inside transformers.integrations.flash_attention
  3. flash_attention_forward do not accept position_ids directly, only through kwargs
  4. flash_attention_forward calls _flash_attention_forward inside transformers.modelling_flash_attention_utils
  5. If position_ids is provided here and attention_mask is not provided, we will generate padding-free masks: link

The problem, however, is that for some models attention_mask is always not None.
For example, for Qwen3, we pass attention_mask depending on attention type (full / sliding) and it's always casual attention mask.

I have a hack locally for going into flash_attn_varlen branch inside flash_attention_forward if we have position ids (so that position ids has priority over attention_mask).

@thepowerfuldeez
Copy link
Contributor Author

thepowerfuldeez commented Jun 5, 2025

In addition, adding sequence length is incorrect after resolving conflict @qgallouedec
I got

ArrowInvalid: Column 1 named sequence_length expected length 436 but got length 874

Reverting helped

@qgallouedec
Copy link
Member

For reference: thepowerfuldeez#1

@thepowerfuldeez
Copy link
Contributor Author

False alarm. I just tested training with aforementioned branch + transformers main, and we are calling flash_attention_varlen using position_ids. Nice and elegant solution by preparing position_ids during packing + flatten using existing data collator.

However, this approach uses more GPU memory, which is surprising

@qgallouedec
Copy link
Member

Experiments

Packing vs no packing

3.7 x speedup

Screenshot 2025-06-05 at 20 11 49 Screenshot 2025-06-05 at 20 11 35
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

dataset = load_dataset("trl-lib/tldr", split="train[:2000]")

def concat(example):
    return{"text": example["prompt"]+ example["completion"]}
dataset = dataset.map(concat, batched=True, remove_columns=dataset.column_names)

packing = True
trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    args=SFTConfig(
        model_init_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"},
        packing=packing,
        run_name="packing" if packing else "no_packing",
        gradient_accumulation_steps=1 if packing else 6,  # to get close num_tokens/opt steps 
    ),
    train_dataset=dataset,
)
trainer.train()

Flash attn kwargs vs no flash attn kwargs

The curves are very close, which seems normal to me.

The advantage here is that you avoid contamination between samples, which is not trivial to observe.

Screenshot 2025-06-05 at 20 17 42

For this one, I just tweaked the code to remove position_ids.

@qgallouedec
Copy link
Member

I'm realising that using position_ids when the attention implementation is not flash-attn hurts the result quite a lot. I'm adding a way use position_ids or not

@qgallouedec
Copy link
Member

Now it looks good! (purple is the annoying setting where position_ids are passed, even when flash-attn is disable: not the case anymore after 9f4d9ee)

Screenshot 2025-06-05 at 20 48 06

@qgallouedec qgallouedec changed the title Packing with flash attn kwargs 📦 Packing with flash attn kwargs to avoid cross-contamination Jun 6, 2025
@qgallouedec
Copy link
Member

Thanks a lot @thepowerfuldeez, I'll merge now, but feel free to make additional comments

@jiosephlee
Copy link

jiosephlee commented Jul 5, 2025

Hi @qgallouedec @thepowerfuldeez!

I just had a clarification on how packing is implemented with max_length. If there is a datum that can fit in only partially, will that be forced in with truncation, or will that belong to the next sequence? I know the non-packing max_length behavior is truncation. If it doesn't truncate, then how does packing handle whole datums that are longer than the set max_length?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants