Skip to content

⚡ Pack 300 times faster, truncate 100 times faster #3009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 22, 2025

Conversation

mariosasko
Copy link
Contributor

@mariosasko mariosasko commented Mar 4, 2025

What does this PR do?

Adds fast packing/truncation logic that operates directly on PyArrow arrays to avoid expensive Python-to-PyArrow and PyArrow-to-Python conversions. This makes these steps almost instantaneous regardless of the input dataset's size.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Mario ! :) just added a few comments but overall lgtm

"""
if map_kwargs is None:
map_kwargs = {}
if isinstance(dataset, Dataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it should also work for DatasetDict ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to make it work with DatasetDict here, we apply a potentially different preprocessing depending on the split, see

preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
if preprocess_dataset:
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
if eval_dataset is not None:
packing = args.packing if args.eval_packing is None else args.eval_packing
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(
eval_dataset, processing_class, args, packing, formatting_func, "eval"
)

Comment on lines 511 to 519
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
num_elements = len(column.values)
dtype = column.offsets.type.to_pandas_dtype() # np.int32 or np.int64
offsets = np.arange(0, num_elements + 1, seq_length, dtype=dtype)
if offsets[-1] != num_elements:
offsets = np.concatenate([offsets, [num_elements]])
column = type(column).from_arrays(offsets, column.values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a pyarrow.compute function you can use here instead to simplify this ?

The functions from pyarrow.compute generally do a copy of the data, but since .combine_chunks() also copies the data it might not affect performance that much

Copy link
Contributor Author

@mariosasko mariosasko Mar 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compute API is very limited, so I think this is the best we can do :)

The only way to avoid the copy would be by packing the individual .chunks (without crossing the boundaries), but this would then make the operation dependent on the underlying chunking and lead to more padding afterward (when collating the batches), so I think the copy is justified here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too bad pyarrow.compute is missing something like that... anyways let's go with this approach then :) My only remaining question is whether this approach works with sliced arrays ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change to handle that, but I don't remember in what scenario we receive sliced arrays as input, so without a test case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm ! I think it can happen when training on a small portion of the train set, like --dataset_train_split "train[:100]"

return dataset


def truncate_dataset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments for truncate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyarrow.compute.list_slice reuses the .values buffers and only modifies the .offsets, so no unnecessary copies here.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@qgallouedec
Copy link
Member

qgallouedec commented Mar 22, 2025

Benchmark packing

import timeit
import numpy as np
from datasets import Dataset
from trl.data_utils import pack_examples, pack_dataset

# Create a larger dataset with sequence lengths following a gamma distribution
num_samples = 10_000

# Generate sequence lengths following a gamma distribution
seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples) # mean will be 100
seq_lengths = np.clip(seq_lengths, 10, None).astype(int)  # Clip to [10, inf)

# Generate input sequences with random lengths based on gamma distribution
examples = {
    "input_ids": [list(range(length)) for length in seq_lengths],
    "attention_mask": [[1] * length for length in seq_lengths],
}

dataset = Dataset.from_dict(examples)
max_length = 128  # Set a fixed packing length

# Benchmark pack_dataset
time_pack_dataset = timeit.timeit(lambda: pack_dataset(dataset, max_length), number=10)

# Benchmark dataset.map with pack_examples
time_pack_examples = timeit.timeit(
    lambda: dataset.map(pack_examples, batched=True, fn_kwargs={"seq_length": max_length}), number=10
)

print(f"pack_dataset time: {time_pack_dataset:.4f} seconds")
print(f"dataset.map(pack_examples) time: {time_pack_examples:.4f} seconds")
pack_dataset time: 0.0667 seconds
dataset.map(pack_examples) time: 19.3734 seconds
Speedup: 290.46x

@qgallouedec
Copy link
Member

qgallouedec commented Mar 22, 2025

Benchmark truncate

import timeit
import numpy as np
from datasets import Dataset
from trl.data_utils import truncate_dataset


def truncate_examples(example, max_length):
    return {key: example[key][:max_length] for key in ["input_ids", "attention_mask"]}


# Create a larger dataset with sequence lengths following a gamma distribution
num_samples = 10_000

# Generate sequence lengths following a gamma distribution
seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples)  # mean will be 100
seq_lengths = np.clip(seq_lengths, 10, None).astype(int)  # Clip to [10, inf)

# Generate input sequences with random lengths based on gamma distribution
examples = {
    "input_ids": [list(range(length)) for length in seq_lengths],
    "attention_mask": [[1] * length for length in seq_lengths],
}

dataset = Dataset.from_dict(examples)
max_length = 128  # Set a fixed truncation length

# Benchmark truncate_dataset
time_truncate_dataset = timeit.timeit(lambda: truncate_dataset(dataset, max_length), number=10)

# Benchmark dataset.map with truncate_examples
time_truncate_examples = timeit.timeit(
    lambda: dataset.map(truncate_examples, batched=True, fn_kwargs={"max_length": max_length}), number=10
)

print(f"truncate_dataset time: {time_truncate_dataset:.4f} seconds")
print(f"dataset.map(truncate_examples) time: {time_truncate_examples:.4f} seconds")
print(f"Speedup: {time_truncate_examples / time_truncate_dataset:.2f}x")
truncate_dataset time: 0.0611 seconds
dataset.map(truncate_examples) time: 6.3807 seconds
Speedup: 104.47x

@qgallouedec qgallouedec changed the title Fast packing and truncation ⚡ Pack 300 times faster, truncate 100 times faster Mar 22, 2025
Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mariosasko @lhoestq 🔥🔥

@qgallouedec qgallouedec merged commit 7511aa4 into huggingface:main Mar 22, 2025
7 of 13 checks passed
@mariosasko mariosasko deleted the fast-pack-truncate branch March 25, 2025 00:01
kashif pushed a commit to kashif/trl that referenced this pull request Mar 28, 2025
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants