Skip to content

Enable Numba for FFD packing algorithm #3524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 70 additions & 41 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Sequence
from typing import Any, Callable, Optional, TypeVar, Union

import numba
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
Expand Down Expand Up @@ -478,6 +479,67 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
return examples


@numba.njit(["(int32[:], int32)", "(int64[:], int64)"], cache=True)
def _pack_sequences_ffd_core(seq_lens: np.ndarray, seq_length: int) -> tuple[np.ndarray, np.ndarray]:
"""First Fit Decreasing bin packing algorithm.

Args:
seq_lens: Array of sequence lengths to pack
seq_length: Target sequence length for each bin (maximum capacity)

Returns:
tuple of (bin_assignments, bin_sizes) where:
- bin_assignments[i] is the bin index for sequence i
"""
n_sequences = len(seq_lens)
sorted_indices = np.argsort(-seq_lens)

# seq_idx -> bin_idx (initialize to -1)
bin_assignments = np.full(n_sequences, -1, dtype=seq_lens.dtype)
# bin_idx -> remaining space
bin_remaining_space = np.empty(n_sequences, dtype=seq_lens.dtype)
bin_count = 0

for seq_idx in sorted_indices:
seq_len = seq_lens[seq_idx]

# Find best‐fit bin in a single loop:
best_bin_idx = -1
best_waste = seq_length + 1

for bin_idx in range(bin_count):
remaining = bin_remaining_space[bin_idx]
if remaining >= seq_len:
waste = remaining - seq_len
if waste < best_waste:
best_waste = waste
best_bin_idx = bin_idx
if waste == 0:
# perfect fit—no need to keep searching
break

if best_bin_idx >= 0:
bin_assignments[seq_idx] = best_bin_idx
bin_remaining_space[best_bin_idx] -= seq_len
else:
# create a new bin
bin_assignments[seq_idx] = bin_count
bin_remaining_space[bin_count] = seq_length - seq_len
bin_count += 1

bin_sizes = seq_length - bin_remaining_space[:bin_count]
return bin_assignments, bin_sizes


# Warm up the function at module import to avoid first-use compilation
try:
# Pre-warm with small arrays to trigger compilation
_pack_sequences_ffd_core(np.array([10, 20, 30], dtype=np.int32), 50)
_pack_sequences_ffd_core(np.array([10, 20, 30], dtype=np.int64), 50)
except:
pass # Ignore any compilation errors during import


def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:
"""Pack sequences in a pyarrow Table using First Fit Decreasing strategy."""
packed_columns = []
Expand All @@ -498,53 +560,20 @@ def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:
truncated_lens = np.minimum(seq_lens, seq_length)
truncated_ends = starts + truncated_lens

# Create sequences list with truncated values
sequences = list(zip(truncated_lens, starts, truncated_ends))

# Sort by length (decreasing) for First Fit Decreasing
sequences.sort(key=lambda x: x[0], reverse=True)

# Optimized bin packing using a priority queue approach
bins_by_remaining = defaultdict(list) # remaining_space -> [bin_indices]
bins = [] # [(current_length, seq_indices)]

for i, (seq_len, _start, _end) in enumerate(sequences):
# Find bins with enough space using the dictionary
placed = False
for remaining in range(seq_len, seq_length + 1):
if bins_by_remaining[remaining]:
# Use the first available bin with this remaining space
bin_idx = bins_by_remaining[remaining].pop()
current_len, seq_indices = bins[bin_idx]

# Update bin
new_len = current_len + seq_len
new_remaining = seq_length - new_len
bins[bin_idx] = (new_len, seq_indices + [i])

# Update the remaining space mapping
if new_remaining > 0:
bins_by_remaining[new_remaining].append(bin_idx)

placed = True
break

# If no bin fits, create new bin
if not placed:
bin_idx = len(bins)
bins.append((seq_len, [i]))
remaining = seq_length - seq_len
if remaining > 0:
bins_by_remaining[remaining].append(bin_idx)
bin_assignments, bin_sizes = _pack_sequences_ffd_core(truncated_lens, seq_length)

# Reconstruct packed values more efficiently
values_numpy = values.to_numpy()
packed_values = []
new_offsets = [0]

for _, seq_indices in bins:
# Group sequences by bin assignment and concatenate them
for bin_idx in range(len(bin_sizes)):
# Find all sequences assigned to this bin
seq_indices = np.where(bin_assignments == bin_idx)[0]
for seq_idx in seq_indices:
_, start, end = sequences[seq_idx]
start = starts[seq_idx]
end = truncated_ends[seq_idx]
packed_values.extend(values_numpy[start:end])
new_offsets.append(len(packed_values))

Expand All @@ -557,7 +586,7 @@ def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:


def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table:
"""Pack sequences in a pyarrow Table using a wrapped strategy."""
"""Pack sequences in a pyarrow Table using fixed-length packing."""
packed_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ class SFTConfig(TrainingArguments):
"`'wrapped'`."
},
)
packing_strategy: str = field(
default="ffd",
metadata={
"help": "Strategy for packing sequences. Can be either `'ffd'` (first-fit decreasing, default), or "
"`'fixed'`."
},
)
padding_free: bool = field(
default=False,
metadata={
Expand Down