-
Notifications
You must be signed in to change notification settings - Fork 2.1k
📦 Packing with flash attn kwargs to avoid cross-contamination #3526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
c0925be
8eef518
7e12779
22cad01
bfae377
4a73543
d72c49b
9c8bc00
aae6b35
75196e1
bb4951d
a533912
bf04c38
c3d4076
8b6d5a9
3448505
f994a38
4b36271
9da9c1a
d2f5d93
15ed05a
b51a865
a4c39ef
8cb93b2
69f21ff
ac131a5
115137a
30fed95
a24c0ee
b97159e
9f4d9ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,6 +17,7 @@ | |||||||||||||
from collections.abc import Sequence | ||||||||||||||
from typing import Any, Callable, Optional, TypeVar, Union | ||||||||||||||
|
||||||||||||||
import numba | ||||||||||||||
import numpy as np | ||||||||||||||
import pyarrow as pa | ||||||||||||||
import pyarrow.compute as pc | ||||||||||||||
|
@@ -478,6 +479,67 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, | |||||||||||||
return examples | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@numba.njit(["(int32[:], int32)", "(int64[:], int64)"], cache=True) | ||||||||||||||
def _pack_sequences_ffd_core(seq_lens: np.ndarray, seq_length: int) -> tuple[np.ndarray, np.ndarray]: | ||||||||||||||
"""First Fit Decreasing bin packing algorithm. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
seq_lens: Array of sequence lengths to pack | ||||||||||||||
seq_length: Target sequence length for each bin (maximum capacity) | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
tuple of (bin_assignments, bin_sizes) where: | ||||||||||||||
- bin_assignments[i] is the bin index for sequence i | ||||||||||||||
""" | ||||||||||||||
n_sequences = len(seq_lens) | ||||||||||||||
sorted_indices = np.argsort(-seq_lens) | ||||||||||||||
|
||||||||||||||
# seq_idx -> bin_idx (initialize to -1) | ||||||||||||||
bin_assignments = np.full(n_sequences, -1, dtype=seq_lens.dtype) | ||||||||||||||
# bin_idx -> remaining space | ||||||||||||||
bin_remaining_space = np.empty(n_sequences, dtype=seq_lens.dtype) | ||||||||||||||
bin_count = 0 | ||||||||||||||
|
||||||||||||||
for seq_idx in sorted_indices: | ||||||||||||||
seq_len = seq_lens[seq_idx] | ||||||||||||||
|
||||||||||||||
# Find best‐fit bin in a single loop: | ||||||||||||||
best_bin_idx = -1 | ||||||||||||||
best_waste = seq_length + 1 | ||||||||||||||
|
||||||||||||||
for bin_idx in range(bin_count): | ||||||||||||||
remaining = bin_remaining_space[bin_idx] | ||||||||||||||
if remaining >= seq_len: | ||||||||||||||
waste = remaining - seq_len | ||||||||||||||
if waste < best_waste: | ||||||||||||||
best_waste = waste | ||||||||||||||
best_bin_idx = bin_idx | ||||||||||||||
if waste == 0: | ||||||||||||||
# perfect fit—no need to keep searching | ||||||||||||||
break | ||||||||||||||
|
||||||||||||||
if best_bin_idx >= 0: | ||||||||||||||
bin_assignments[seq_idx] = best_bin_idx | ||||||||||||||
bin_remaining_space[best_bin_idx] -= seq_len | ||||||||||||||
else: | ||||||||||||||
# create a new bin | ||||||||||||||
bin_assignments[seq_idx] = bin_count | ||||||||||||||
bin_remaining_space[bin_count] = seq_length - seq_len | ||||||||||||||
bin_count += 1 | ||||||||||||||
|
||||||||||||||
bin_sizes = seq_length - bin_remaining_space[:bin_count] | ||||||||||||||
return bin_assignments, bin_sizes | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
# Warm up the function at module import to avoid first-use compilation | ||||||||||||||
try: | ||||||||||||||
# Pre-warm with small arrays to trigger compilation | ||||||||||||||
_pack_sequences_ffd_core(np.array([10, 20, 30], dtype=np.int32), 50) | ||||||||||||||
_pack_sequences_ffd_core(np.array([10, 20, 30], dtype=np.int64), 50) | ||||||||||||||
except: | ||||||||||||||
pass # Ignore any compilation errors during import | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table: | ||||||||||||||
"""Pack sequences in a pyarrow Table using First Fit Decreasing strategy.""" | ||||||||||||||
packed_columns = [] | ||||||||||||||
|
@@ -498,66 +560,41 @@ def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table: | |||||||||||||
truncated_lens = np.minimum(seq_lens, seq_length) | ||||||||||||||
truncated_ends = starts + truncated_lens | ||||||||||||||
|
||||||||||||||
# Create sequences list with truncated values | ||||||||||||||
sequences = list(zip(truncated_lens, starts, truncated_ends)) | ||||||||||||||
|
||||||||||||||
# Sort by length (decreasing) for First Fit Decreasing | ||||||||||||||
sequences.sort(key=lambda x: x[0], reverse=True) | ||||||||||||||
|
||||||||||||||
# Optimized bin packing using a priority queue approach | ||||||||||||||
bins_by_remaining = defaultdict(list) # remaining_space -> [bin_indices] | ||||||||||||||
bins = [] # [(current_length, seq_indices)] | ||||||||||||||
|
||||||||||||||
for i, (seq_len, _start, _end) in enumerate(sequences): | ||||||||||||||
# Find bins with enough space using the dictionary | ||||||||||||||
placed = False | ||||||||||||||
for remaining in range(seq_len, seq_length + 1): | ||||||||||||||
if bins_by_remaining[remaining]: | ||||||||||||||
# Use the first available bin with this remaining space | ||||||||||||||
bin_idx = bins_by_remaining[remaining].pop() | ||||||||||||||
current_len, seq_indices = bins[bin_idx] | ||||||||||||||
|
||||||||||||||
# Update bin | ||||||||||||||
new_len = current_len + seq_len | ||||||||||||||
new_remaining = seq_length - new_len | ||||||||||||||
bins[bin_idx] = (new_len, seq_indices + [i]) | ||||||||||||||
|
||||||||||||||
# Update the remaining space mapping | ||||||||||||||
if new_remaining > 0: | ||||||||||||||
bins_by_remaining[new_remaining].append(bin_idx) | ||||||||||||||
|
||||||||||||||
placed = True | ||||||||||||||
break | ||||||||||||||
|
||||||||||||||
# If no bin fits, create new bin | ||||||||||||||
if not placed: | ||||||||||||||
bin_idx = len(bins) | ||||||||||||||
bins.append((seq_len, [i])) | ||||||||||||||
remaining = seq_length - seq_len | ||||||||||||||
if remaining > 0: | ||||||||||||||
bins_by_remaining[remaining].append(bin_idx) | ||||||||||||||
bin_assignments, bin_sizes = _pack_sequences_ffd_core(truncated_lens, seq_length) | ||||||||||||||
|
||||||||||||||
# Reconstruct packed values more efficiently | ||||||||||||||
values_numpy = values.to_numpy() | ||||||||||||||
packed_values = [] | ||||||||||||||
new_offsets = [0] | ||||||||||||||
|
||||||||||||||
for _, seq_indices in bins: | ||||||||||||||
packed_values: list[np.dtype] = [] | ||||||||||||||
new_offsets: list[int] = [0] | ||||||||||||||
sequence_lengths: list[list[int]] = [] | ||||||||||||||
|
||||||||||||||
# Group sequences by bin assignment and concatenate them | ||||||||||||||
for bin_idx in range(len(bin_sizes)): | ||||||||||||||
# Find all sequences assigned to this bin | ||||||||||||||
seq_indices = np.where(bin_assignments == bin_idx)[0] | ||||||||||||||
seq_lens_in_bin = [] | ||||||||||||||
for seq_idx in seq_indices: | ||||||||||||||
_, start, end = sequences[seq_idx] | ||||||||||||||
start = starts[seq_idx] | ||||||||||||||
end = truncated_ends[seq_idx] | ||||||||||||||
packed_values.extend(values_numpy[start:end]) | ||||||||||||||
seq_lens_in_bin.append(end - start) | ||||||||||||||
sequence_lengths.append(seq_lens_in_bin) | ||||||||||||||
new_offsets.append(len(packed_values)) | ||||||||||||||
|
||||||||||||||
dtype = offsets.type.to_pandas_dtype() | ||||||||||||||
new_offsets = np.array(new_offsets, dtype=dtype) | ||||||||||||||
packed_values = pa.array(packed_values, type=values.type) | ||||||||||||||
sequence_lengths = pa.array(sequence_lengths, type=pa.list_(pa.int32())) | ||||||||||||||
column = type(column).from_arrays(new_offsets, packed_values) | ||||||||||||||
packed_columns.append(column) | ||||||||||||||
return pa.Table.from_arrays(packed_columns, names=examples.column_names) | ||||||||||||||
packed_columns.append(column) | ||||||||||||||
packed_columns.append(sequence_lengths) | ||||||||||||||
else: | ||||||||||||||
packed_columns.append(column) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should work as well?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually it seems like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah yea makes sense, and I also think ordering matters right? we could just ensure that we added this once and in the end, I guess? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated |
||||||||||||||
return pa.Table.from_arrays(packed_columns, names=examples.column_names + ["sequence_length"]) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table: | ||||||||||||||
"""Pack sequences in a pyarrow Table using a wrapped strategy.""" | ||||||||||||||
"""Pack sequences in a pyarrow Table using fixed-length packing.""" | ||||||||||||||
packed_columns = [] | ||||||||||||||
for column in examples.columns: | ||||||||||||||
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): | ||||||||||||||
|
@@ -611,7 +648,8 @@ def pack_dataset( | |||||||||||||
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="ffd") | ||||||||||||||
>>> packed_dataset[:] | ||||||||||||||
{'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]], | ||||||||||||||
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]} | ||||||||||||||
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]] | ||||||||||||||
'sequence_length': [[3, 1], [3, 2]]} | ||||||||||||||
``` | ||||||||||||||
""" | ||||||||||||||
if map_kwargs is None: | ||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -294,7 +294,7 @@ def __init__( | |||||
if args.padding_free: | ||||||
if data_collator is not None: | ||||||
raise ValueError("Passing a custom data collator is not supported when using padding-free.") | ||||||
if args.packing: | ||||||
if args.packing and args.packing_strategy != "ffd": | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
actually padding_free is a different method (which is less relevant once once we have ffd with flash-attn). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean if we use packing with flash attn, we inherently already using padding free? If we pass position_ids |
||||||
warnings.warn( | ||||||
"You are passing `packing=True` and `padding_free=True` which is not recommended. Please refer " | ||||||
"to the documentation to understand why this is not recommended." | ||||||
|
@@ -314,7 +314,18 @@ def __init__( | |||||
"of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " | ||||||
"to at least 2." | ||||||
) | ||||||
data_collator = DataCollatorWithFlattening() | ||||||
if args.packing and model.config._attn_implementation != "flash_attention_2": | ||||||
warnings.warn( | ||||||
"You are using packing with padding-free training, but the attention implementation is not set to " | ||||||
"'flash_attention_2'. Packing flattens batches into a single sequence, and 'flash_attention_2' is " | ||||||
"the only known attention mechanism that reliably supports this. Using other implementations may " | ||||||
"lead to unexpected behavior. To ensure compatibility, set `attn_implementation='flash_attention_2'` " | ||||||
"in the model configuration." | ||||||
) | ||||||
data_collator = DataCollatorWithFlattening( | ||||||
return_flash_attn_kwargs=False, | ||||||
return_position_ids=True, | ||||||
) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest dedenting this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure what would happen if user do not specifies flash attention 2 implementation - right now we just won't use padding free with position ids. |
||||||
|
||||||
if args.completion_only_loss is None: | ||||||
first_example = next(iter(train_dataset)) | ||||||
|
@@ -659,15 +670,18 @@ def tokenize(example, processing_class, dataset_text_field, add_special_tokens): | |||||
raise ValueError("When packing is enabled, `max_length` can't be `None`.") | ||||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | ||||||
map_kwargs["desc"] = f"Packing {dataset_name} dataset" | ||||||
dataset = dataset.select_columns("input_ids") | ||||||
dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) | ||||||
dataset: Dataset = dataset.select_columns("input_ids") | ||||||
dataset: Dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) | ||||||
elif args.max_length is not None: | ||||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` | ||||||
map_kwargs["desc"] = f"Truncating {dataset_name} dataset" | ||||||
dataset = truncate_dataset(dataset, args.max_length, map_kwargs) | ||||||
dataset: Dataset = truncate_dataset(dataset, args.max_length, map_kwargs) | ||||||
# For Liger kernel, ensure only input_ids is present | ||||||
if args.use_liger_kernel: | ||||||
dataset = dataset.select_columns("input_ids") | ||||||
if "sequence_length" in dataset.column_names: | ||||||
dataset: Dataset = dataset.select_columns(["input_ids", "sequence_length"]) | ||||||
else: | ||||||
dataset: Dataset = dataset.select_columns("input_ids") | ||||||
|
||||||
return dataset | ||||||
|
||||||
|
@@ -678,6 +692,9 @@ def _set_signature_columns_if_needed(self): | |||||
# dataset. So we need to override the default signature columns to include "completion_mask" as well. | ||||||
if self._signature_columns is None: | ||||||
self._signature_columns = ["input_ids", "attention_mask", "completion_mask"] | ||||||
# For the packing case with FFD, we need to store sequence_length returned by the data collator with flattening | ||||||
if self.args.packing and self.args.packing_strategy == "ffd" and self.args.padding_free: | ||||||
self._signature_columns.append("sequence_length") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice! |
||||||
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | ||||||
""" | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mind reverting this change? so that we keep PRs separate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually that might not be that easy, your 2 PRs are quite intertwined right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, but I will think about reverting, sorry that I haven't done that before
this should be possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!