Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c0925be
new pack strat
qgallouedec Jun 1, 2025
8eef518
sft
qgallouedec Jun 2, 2025
7e12779
improve implementation, deprecated pack_examples and test
qgallouedec Jun 2, 2025
22cad01
fix test
qgallouedec Jun 2, 2025
bfae377
use numba for ffd packing
thepowerfuldeez Jun 2, 2025
4a73543
update packing algorithm to support sequence_length
thepowerfuldeez Jun 2, 2025
d72c49b
Merge branch 'main' into ffd_pack
qgallouedec Jun 2, 2025
9c8bc00
fix signature columns and correct setting of return flash attention k…
thepowerfuldeez Jun 2, 2025
aae6b35
Merge branch 'ffd_pack' into packing_with_flash_attn_kwargs
thepowerfuldeez Jun 2, 2025
75196e1
Merge branch 'main' of github.com:huggingface/trl into packing_with_f…
thepowerfuldeez Jun 2, 2025
bb4951d
fix merge
thepowerfuldeez Jun 2, 2025
a533912
Empty commit
qgallouedec Jun 2, 2025
bf04c38
revert numba related changes
thepowerfuldeez Jun 3, 2025
c3d4076
Merge branch 'main' into packing_with_flash_attn_kwargs
thepowerfuldeez Jun 3, 2025
8b6d5a9
fix adding sequence lengths
thepowerfuldeez Jun 3, 2025
3448505
Merge branch 'packing_with_flash_attn_kwargs' of github.com:thepowerf…
thepowerfuldeez Jun 3, 2025
f994a38
Merge branch 'main' into packing_with_flash_attn_kwargs
thepowerfuldeez Jun 4, 2025
4b36271
resolve conflict
qgallouedec Jun 4, 2025
9da9c1a
Merge main
qgallouedec Jun 4, 2025
d2f5d93
Merge branch 'main' into packing_with_flash_attn_kwargs
qgallouedec Jun 4, 2025
15ed05a
position_ids in pack
qgallouedec Jun 5, 2025
b51a865
collate with position_ids
qgallouedec Jun 5, 2025
a4c39ef
Add padding-free option to DataCollatorForLanguageModeling and refact…
qgallouedec Jun 5, 2025
8cb93b2
signature columns must be position_ids
thepowerfuldeez Jun 5, 2025
69f21ff
Merge branch 'position_ids_in_pack' of github.com:huggingface/trl int…
thepowerfuldeez Jun 5, 2025
ac131a5
do not remove position_ids when using liger, add comment
thepowerfuldeez Jun 5, 2025
115137a
improve collator tests
qgallouedec Jun 6, 2025
30fed95
Drop DataCollatorWithFlattening
qgallouedec Jun 6, 2025
a24c0ee
clarify sft arg
qgallouedec Jun 6, 2025
b97159e
return_position_ids
qgallouedec Jun 6, 2025
9f4d9ee
return_position_ids
qgallouedec Jun 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 85 additions & 47 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Sequence
from typing import Any, Callable, Optional, TypeVar, Union

import numba
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -478,6 +479,67 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
return examples


@numba.njit(["(int32[:], int32)", "(int64[:], int64)"], cache=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind reverting this change? so that we keep PRs separate

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually that might not be that easy, your 2 PRs are quite intertwined right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, but I will think about reverting, sorry that I haven't done that before
this should be possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

def _pack_sequences_ffd_core(seq_lens: np.ndarray, seq_length: int) -> tuple[np.ndarray, np.ndarray]:
"""First Fit Decreasing bin packing algorithm.

Args:
seq_lens: Array of sequence lengths to pack
seq_length: Target sequence length for each bin (maximum capacity)

Returns:
tuple of (bin_assignments, bin_sizes) where:
- bin_assignments[i] is the bin index for sequence i
"""
n_sequences = len(seq_lens)
sorted_indices = np.argsort(-seq_lens)

# seq_idx -> bin_idx (initialize to -1)
bin_assignments = np.full(n_sequences, -1, dtype=seq_lens.dtype)
# bin_idx -> remaining space
bin_remaining_space = np.empty(n_sequences, dtype=seq_lens.dtype)
bin_count = 0

for seq_idx in sorted_indices:
seq_len = seq_lens[seq_idx]

# Find best‐fit bin in a single loop:
best_bin_idx = -1
best_waste = seq_length + 1

for bin_idx in range(bin_count):
remaining = bin_remaining_space[bin_idx]
if remaining >= seq_len:
waste = remaining - seq_len
if waste < best_waste:
best_waste = waste
best_bin_idx = bin_idx
if waste == 0:
# perfect fit—no need to keep searching
break

if best_bin_idx >= 0:
bin_assignments[seq_idx] = best_bin_idx
bin_remaining_space[best_bin_idx] -= seq_len
else:
# create a new bin
bin_assignments[seq_idx] = bin_count
bin_remaining_space[bin_count] = seq_length - seq_len
bin_count += 1

bin_sizes = seq_length - bin_remaining_space[:bin_count]
return bin_assignments, bin_sizes


# Warm up the function at module import to avoid first-use compilation
try:
# Pre-warm with small arrays to trigger compilation
_pack_sequences_ffd_core(np.array([10, 20, 30], dtype=np.int32), 50)
_pack_sequences_ffd_core(np.array([10, 20, 30], dtype=np.int64), 50)
except:
pass # Ignore any compilation errors during import


def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:
"""Pack sequences in a pyarrow Table using First Fit Decreasing strategy."""
packed_columns = []
Expand All @@ -498,66 +560,41 @@ def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:
truncated_lens = np.minimum(seq_lens, seq_length)
truncated_ends = starts + truncated_lens

# Create sequences list with truncated values
sequences = list(zip(truncated_lens, starts, truncated_ends))

# Sort by length (decreasing) for First Fit Decreasing
sequences.sort(key=lambda x: x[0], reverse=True)

# Optimized bin packing using a priority queue approach
bins_by_remaining = defaultdict(list) # remaining_space -> [bin_indices]
bins = [] # [(current_length, seq_indices)]

for i, (seq_len, _start, _end) in enumerate(sequences):
# Find bins with enough space using the dictionary
placed = False
for remaining in range(seq_len, seq_length + 1):
if bins_by_remaining[remaining]:
# Use the first available bin with this remaining space
bin_idx = bins_by_remaining[remaining].pop()
current_len, seq_indices = bins[bin_idx]

# Update bin
new_len = current_len + seq_len
new_remaining = seq_length - new_len
bins[bin_idx] = (new_len, seq_indices + [i])

# Update the remaining space mapping
if new_remaining > 0:
bins_by_remaining[new_remaining].append(bin_idx)

placed = True
break

# If no bin fits, create new bin
if not placed:
bin_idx = len(bins)
bins.append((seq_len, [i]))
remaining = seq_length - seq_len
if remaining > 0:
bins_by_remaining[remaining].append(bin_idx)
bin_assignments, bin_sizes = _pack_sequences_ffd_core(truncated_lens, seq_length)

# Reconstruct packed values more efficiently
values_numpy = values.to_numpy()
packed_values = []
new_offsets = [0]

for _, seq_indices in bins:
packed_values: list[np.dtype] = []
new_offsets: list[int] = [0]
sequence_lengths: list[list[int]] = []

# Group sequences by bin assignment and concatenate them
for bin_idx in range(len(bin_sizes)):
# Find all sequences assigned to this bin
seq_indices = np.where(bin_assignments == bin_idx)[0]
seq_lens_in_bin = []
for seq_idx in seq_indices:
_, start, end = sequences[seq_idx]
start = starts[seq_idx]
end = truncated_ends[seq_idx]
packed_values.extend(values_numpy[start:end])
seq_lens_in_bin.append(end - start)
sequence_lengths.append(seq_lens_in_bin)
new_offsets.append(len(packed_values))

dtype = offsets.type.to_pandas_dtype()
new_offsets = np.array(new_offsets, dtype=dtype)
packed_values = pa.array(packed_values, type=values.type)
sequence_lengths = pa.array(sequence_lengths, type=pa.list_(pa.int32()))
column = type(column).from_arrays(new_offsets, packed_values)
packed_columns.append(column)
return pa.Table.from_arrays(packed_columns, names=examples.column_names)
packed_columns.append(column)
packed_columns.append(sequence_lengths)
else:
packed_columns.append(column)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work as well?

Suggested change
packed_columns.append(column)
packed_columns.append(sequence_lengths)
else:
packed_columns.append(column)
packed_columns.append(sequence_lengths)
packed_columns.append(column)

Copy link
Member

@qgallouedec qgallouedec Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually it seems like sequence_lengths can be appended more than once, if you've more than one col that matches if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): (which is almost always the case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yea makes sense, and I also think ordering matters right? we could just ensure that we added this once and in the end, I guess?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

return pa.Table.from_arrays(packed_columns, names=examples.column_names + ["sequence_length"])


def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table:
"""Pack sequences in a pyarrow Table using a wrapped strategy."""
"""Pack sequences in a pyarrow Table using fixed-length packing."""
packed_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
Expand Down Expand Up @@ -611,7 +648,8 @@ def pack_dataset(
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="ffd")
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]],
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]}
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]
'sequence_length': [[3, 1], [3, 2]]}
```
"""
if map_kwargs is None:
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ class SFTConfig(TrainingArguments):
"`'wrapped'`."
},
)
packing_strategy: str = field(
default="ffd",
metadata={
"help": "Strategy for packing sequences. Can be either `'ffd'` (first-fit decreasing, default), or "
"`'fixed'`."
},
)
padding_free: bool = field(
default=False,
metadata={
Expand Down
29 changes: 23 additions & 6 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def __init__(
if args.padding_free:
if data_collator is not None:
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
if args.packing:
if args.packing and args.packing_strategy != "ffd":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if args.packing and args.packing_strategy != "ffd":
if args.packing:

actually padding_free is a different method (which is less relevant once once we have ffd with flash-attn).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean if we use packing with flash attn, we inherently already using padding free? If we pass position_ids
Probably we could change warning to say something like "when FFD packing is enabled, model accepts position_ids which works the same way as padding-free, so specifying padding_free makes no effect"

warnings.warn(
"You are passing `packing=True` and `padding_free=True` which is not recommended. Please refer "
"to the documentation to understand why this is not recommended."
Expand All @@ -314,7 +314,18 @@ def __init__(
"of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size "
"to at least 2."
)
data_collator = DataCollatorWithFlattening()
if args.packing and model.config._attn_implementation != "flash_attention_2":
warnings.warn(
"You are using packing with padding-free training, but the attention implementation is not set to "
"'flash_attention_2'. Packing flattens batches into a single sequence, and 'flash_attention_2' is "
"the only known attention mechanism that reliably supports this. Using other implementations may "
"lead to unexpected behavior. To ensure compatibility, set `attn_implementation='flash_attention_2'` "
"in the model configuration."
)
data_collator = DataCollatorWithFlattening(
return_flash_attn_kwargs=False,
return_position_ids=True,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest dedenting this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what would happen if user do not specifies flash attention 2 implementation - right now we just won't use padding free with position ids.
If we de-dent - we enable DataCollatorWithFlattening for any cases?


if args.completion_only_loss is None:
first_example = next(iter(train_dataset))
Expand Down Expand Up @@ -659,15 +670,18 @@ def tokenize(example, processing_class, dataset_text_field, add_special_tokens):
raise ValueError("When packing is enabled, `max_length` can't be `None`.")
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Packing {dataset_name} dataset"
dataset = dataset.select_columns("input_ids")
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
dataset: Dataset = dataset.select_columns("input_ids")
dataset: Dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs)
elif args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
dataset: Dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
# For Liger kernel, ensure only input_ids is present
if args.use_liger_kernel:
dataset = dataset.select_columns("input_ids")
if "sequence_length" in dataset.column_names:
dataset: Dataset = dataset.select_columns(["input_ids", "sequence_length"])
else:
dataset: Dataset = dataset.select_columns("input_ids")

return dataset

Expand All @@ -678,6 +692,9 @@ def _set_signature_columns_if_needed(self):
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
if self._signature_columns is None:
self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]
# For the packing case with FFD, we need to store sequence_length returned by the data collator with flattening
if self.args.packing and self.args.packing_strategy == "ffd" and self.args.padding_free:
self._signature_columns.append("sequence_length")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Expand Down
Loading