Skip to content

Enable Numba for FFD packing algorithm #3524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

thepowerfuldeez
Copy link
Contributor

What does this PR do?

Related to #3521 : enables Numba compilation for FFD packing.

Improves speed by 12.5%, tested on 10K, 100K, 1M sample datasets.

Code to test for speed
import timeit
import numpy as np
from datasets import Dataset
from trl.data_utils import pack_dataset


def create_test_dataset(num_samples: int, random_seed: int = 42) -> Dataset:
  """Create a test dataset with realistic sequence length distribution."""
  np.random.seed(random_seed)

  seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples).astype(int)
  seq_lengths = np.clip(seq_lengths, 10, 500)

  print("Dataset statistics:")
  print(f"  - Samples: {num_samples:,}")
  print(f"  - Mean sequence length: {np.mean(seq_lengths):.1f}")
  print(f"  - Median sequence length: {np.median(seq_lengths):.1f}")
  print(f"  - Max sequence length: {np.max(seq_lengths)}")
  print(f"  - Min sequence length: {np.min(seq_lengths)}")

  # Create dataset with multiple columns to test multi-column packing
  examples = {
      "input_ids": [list(range(length)) for length in seq_lengths],
      "attention_mask": [[1] * length for length in seq_lengths],
      "labels": [list(range(100, 100 + length)) for length in seq_lengths],
  }

  return Dataset.from_dict(examples)


def benchmark_packing_strategies(dataset: Dataset, seq_length: int = 256) -> dict:
  """Benchmark different packing strategies on a dataset."""
  results = {}

  print(f"\nBenchmarking packing strategies on {len(dataset):,} samples")
  print(f"Target sequence length: {seq_length}")
  print("=" * 60)

  # Test single-threaded FFD
  print("Testing baseline FFD...")
  n_tries = 3
  time_baseline = timeit.timeit(
      lambda: pack_dataset(dataset, seq_length, strategy="ffd"),
      number=n_tries,
  )
  results["baseline_ffd"] = time_baseline
  print(f"  Time: {time_baseline:.2f} seconds")
  print(f"  Throughput: {len(dataset) * n_tries / time_baseline:,.0f} samples/second")

  # Test Fixed strategy for comparison
  print("\nTesting Fixed strategy (baseline)...")
  time_fixed = timeit.timeit(
      lambda: pack_dataset(dataset, seq_length, strategy="fixed"), number=n_tries
  )
  results["fixed"] = time_fixed
  print(f"  Time: {time_fixed:.2f} seconds")
  print(f"  Throughput: {len(dataset) * n_tries / time_fixed:,.0f} samples/second")

  return results


def run_comprehensive_benchmarks():
  """Run comprehensive benchmarks on 10k, 100k and 1M datasets."""
  print("=" * 80)
  print("COMPREHENSIVE PARALLEL PACKING BENCHMARKS")
  print("=" * 80)

  # Test 10k dataset
  print("\n🔥 TESTING 10K DATASET")
  print("=" * 40)
  dataset_10k = create_test_dataset(10_000)
  results_10k = benchmark_packing_strategies(dataset_10k)

  # Test 100k dataset
  print("\n🔥 TESTING 100K DATASET")
  print("=" * 40)
  dataset_100k = create_test_dataset(100_000)
  results_100k = benchmark_packing_strategies(dataset_100k)

  # Test 1M dataset (if enough RAM)
  try:
      print("\n\n🚀 TESTING 1M DATASET")
      print("=" * 40)
      dataset_1m = create_test_dataset(1_000_000)
      results_1m = benchmark_packing_strategies(dataset_1m)

  except MemoryError:
      print("❌ Not enough memory for 1M dataset, skipping...")
      results_1m = None

  # Summary
  print("\n" + "=" * 80)
  print("📊 BENCHMARK SUMMARY")
  print("=" * 80)

  print("\n10K Dataset Results:")
  for strategy, time_taken in results_10k.items():
      throughput = 10_000 * 3 / time_taken
      print(f"  {strategy:20}: {time_taken:6.2f}s ({throughput:8,.0f} samples/sec)")

  print("\n100K Dataset Results:")
  for strategy, time_taken in results_100k.items():
      throughput = 100_000 * 3 / time_taken
      print(f"  {strategy:20}: {time_taken:6.2f}s ({throughput:8,.0f} samples/sec)")

  if results_1m:
      print("\n1M Dataset Results:")
      for strategy, time_taken in results_1m.items():
          throughput = 1_000_000 * 3 / time_taken
          print(
              f"  {strategy:20}: {time_taken:6.2f}s ({throughput:8,.0f} samples/sec)"
          )


if __name__ == "__main__":
  run_comprehensive_benchmarks()

📊 BENCHMARK SUMMARY (Before)

10K Dataset Results:
baseline_ffd : 0.67s ( 44,914 samples/sec)
fixed : 0.02s (1,668,828 samples/sec)

100K Dataset Results:
baseline_ffd : 6.72s ( 44,646 samples/sec)
fixed : 0.26s (1,149,577 samples/sec)

1M Dataset Results:
baseline_ffd : 67.98s ( 44,132 samples/sec)
fixed : 3.42s ( 876,439 samples/sec)

📊 BENCHMARK SUMMARY (After)

10K Dataset Results:
baseline_ffd : 0.60s ( 50,087 samples/sec)
fixed : 0.02s (1,718,645 samples/sec)

100K Dataset Results:
baseline_ffd : 5.98s ( 50,152 samples/sec)
fixed : 0.26s (1,140,803 samples/sec)

1M Dataset Results:
baseline_ffd : 60.42s ( 49,650 samples/sec)
fixed : 3.45s ( 870,240 samples/sec)

Speedup: 12.5%

Code to verify correctness
import argparse
from pathlib import Path
import sys
import timeit
import numpy as np
from datasets import Dataset, load_from_disk
from trl.data_utils import pack_dataset


def create_test_dataset(num_samples: int, random_seed: int = 42) -> Dataset:
  """Create a test dataset with realistic sequence length distribution."""
  np.random.seed(random_seed)

  seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples).astype(int)
  seq_lengths = np.clip(seq_lengths, 10, 500)

  print("Dataset statistics:")
  print(f"  - Samples: {num_samples:,}")
  print(f"  - Mean sequence length: {np.mean(seq_lengths):.1f}")
  print(f"  - Median sequence length: {np.median(seq_lengths):.1f}")
  print(f"  - Max sequence length: {np.max(seq_lengths)}")
  print(f"  - Min sequence length: {np.min(seq_lengths)}")

  # Create dataset with multiple columns to test multi-column packing
  examples = {
      "input_ids": [list(range(length)) for length in seq_lengths],
      "attention_mask": [[1] * length for length in seq_lengths],
      "labels": [list(range(100, 100 + length)) for length in seq_lengths],
  }

  return Dataset.from_dict(examples)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("--state", default="old")
  parser.add_argument("--seq_length", default=256, type=int)
  parser.add_argument("--num_samples", default=10000, type=int)
  args = parser.parse_args()

  if not Path("packed1").exists() and args.state == "old":
      dataset = create_test_dataset(args.num_samples)
      seq_length = args.seq_length
      packed = pack_dataset(dataset, seq_length, strategy="ffd")
      packed.save_to_disk("packed1")
      sys.exit()

  if not Path("packed2").exists() and args.state == "new":
      dataset = create_test_dataset(args.num_samples)
      seq_length = args.seq_length
      packed = pack_dataset(dataset, seq_length, strategy="ffd")
      packed.save_to_disk("packed2")
      sys.exit()

  dataset1 = load_from_disk("packed1")
  dataset2 = load_from_disk("packed2")
  for i in range(len(dataset1)):
      if dataset1[i] != dataset2[i]:
          for k in dataset1[i]:
              if dataset1[i][k] != dataset2[i][k]:
                  print(f"Mismatch at index {i} for key {k}")
                  print(dataset1[i][k], len(dataset1[i][k]))
                  print(dataset2[i][k], len(dataset2[i][k]))
          break

Usage:

# on old branch
python test_packing_compare.py --state old --seq_length 512 --num_samples 100000
# on this branch
python test_packing_compare.py --state new --seq_length 512 --num_samples 100000
# show difference
python test_packing_compare.py

Right now number of packed samples will be matched, however on this branch packing is more exact with +1 last sequence put into some of the bins.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@thepowerfuldeez
Copy link
Contributor Author

@qgallouedec

@qgallouedec
Copy link
Member

qgallouedec commented Jun 2, 2025

Nice! The only concern I have is that numba is not a dep of trl. And I don't think that adding it just for this really makes sense

@thepowerfuldeez
Copy link
Contributor Author

Nice! The only concern I have is that numba is not a dep of trl. And I don't think that adding it just for this really makes sense

hm, honestly I haven't installed it in fresh venv, let me check if it's a dependency of some lib which is part of trl's requirements

@thepowerfuldeez
Copy link
Contributor Author

Okay, so seems like numba is part of vllm dependencies, so if you install trl[vllm] you will get it installed as well.
Relatively safe to add to the main requirements, in that case, what do you think?

@qgallouedec qgallouedec deleted the branch huggingface:main June 2, 2025 20:15
@qgallouedec qgallouedec closed this Jun 2, 2025
@thepowerfuldeez
Copy link
Contributor Author

@qgallouedec needs to re-open with main branch now, thanks!

@qgallouedec qgallouedec reopened this Jun 2, 2025
@qgallouedec qgallouedec changed the base branch from ffd_pack to main June 2, 2025 22:24
@mariosasko mariosasko mentioned this pull request Jun 4, 2025
5 tasks
@mariosasko
Copy link
Contributor

Hi, I just opened a PR with a faster FFD implementation that doesn't introduce new dependencies and doesn't require warmup.

@qgallouedec
Copy link
Member

qgallouedec commented Jun 4, 2025

Relatively safe to add to the main requirements, in that case, what do you think?

It could be safe, but here's what I think: it would make sense if FFD were prohibitively slow to add a dependency, because it would be necessary to use it. Here, it's just to go from “very fast” to “very, very fast.” Thus, adding a dependency doesn't seem justified to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants