-
Notifications
You must be signed in to change notification settings - Fork 2.1k
⚡ Pack 300 times faster, truncate 100 times faster #3009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Mario ! :) just added a few comments but overall lgtm
""" | ||
if map_kwargs is None: | ||
map_kwargs = {} | ||
if isinstance(dataset, Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it should also work for DatasetDict ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to make it work with DatasetDict here, we apply a potentially different preprocessing depending on the split, see
trl/trl/trainer/sft_trainer.py
Lines 185 to 200 in e3244d2
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) | |
if preprocess_dataset: | |
train_dataset = self._prepare_dataset( | |
train_dataset, processing_class, args, args.packing, formatting_func, "train" | |
) | |
if eval_dataset is not None: | |
packing = args.packing if args.eval_packing is None else args.eval_packing | |
if isinstance(eval_dataset, dict): | |
eval_dataset = { | |
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) | |
for key, dataset in eval_dataset.items() | |
} | |
else: | |
eval_dataset = self._prepare_dataset( | |
eval_dataset, processing_class, args, packing, formatting_func, "eval" | |
) |
trl/data_utils.py
Outdated
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): | ||
if isinstance(column, pa.ChunkedArray): | ||
column = column.combine_chunks() | ||
num_elements = len(column.values) | ||
dtype = column.offsets.type.to_pandas_dtype() # np.int32 or np.int64 | ||
offsets = np.arange(0, num_elements + 1, seq_length, dtype=dtype) | ||
if offsets[-1] != num_elements: | ||
offsets = np.concatenate([offsets, [num_elements]]) | ||
column = type(column).from_arrays(offsets, column.values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a pyarrow.compute
function you can use here instead to simplify this ?
The functions from pyarrow.compute
generally do a copy of the data, but since .combine_chunks()
also copies the data it might not affect performance that much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compute API is very limited, so I think this is the best we can do :)
The only way to avoid the copy would be by packing the individual .chunks
(without crossing the boundaries), but this would then make the operation dependent on the underlying chunking and lead to more padding afterward (when collating the batches), so I think the copy is justified here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
too bad pyarrow.compute
is missing something like that... anyways let's go with this approach then :) My only remaining question is whether this approach works with sliced arrays ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made the change to handle that, but I don't remember in what scenario we receive sliced arrays as input, so without a test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm ! I think it can happen when training on a small portion of the train set, like --dataset_train_split "train[:100]"
return dataset | ||
|
||
|
||
def truncate_dataset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comments for truncate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pyarrow.compute.list_slice
reuses the .values
buffers and only modifies the .offsets
, so no unnecessary copies here.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
Benchmark packingimport timeit
import numpy as np
from datasets import Dataset
from trl.data_utils import pack_examples, pack_dataset
# Create a larger dataset with sequence lengths following a gamma distribution
num_samples = 10_000
# Generate sequence lengths following a gamma distribution
seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples) # mean will be 100
seq_lengths = np.clip(seq_lengths, 10, None).astype(int) # Clip to [10, inf)
# Generate input sequences with random lengths based on gamma distribution
examples = {
"input_ids": [list(range(length)) for length in seq_lengths],
"attention_mask": [[1] * length for length in seq_lengths],
}
dataset = Dataset.from_dict(examples)
max_length = 128 # Set a fixed packing length
# Benchmark pack_dataset
time_pack_dataset = timeit.timeit(lambda: pack_dataset(dataset, max_length), number=10)
# Benchmark dataset.map with pack_examples
time_pack_examples = timeit.timeit(
lambda: dataset.map(pack_examples, batched=True, fn_kwargs={"seq_length": max_length}), number=10
)
print(f"pack_dataset time: {time_pack_dataset:.4f} seconds")
print(f"dataset.map(pack_examples) time: {time_pack_examples:.4f} seconds")
|
Benchmark truncateimport timeit
import numpy as np
from datasets import Dataset
from trl.data_utils import truncate_dataset
def truncate_examples(example, max_length):
return {key: example[key][:max_length] for key in ["input_ids", "attention_mask"]}
# Create a larger dataset with sequence lengths following a gamma distribution
num_samples = 10_000
# Generate sequence lengths following a gamma distribution
seq_lengths = np.random.gamma(shape=5, scale=20, size=num_samples) # mean will be 100
seq_lengths = np.clip(seq_lengths, 10, None).astype(int) # Clip to [10, inf)
# Generate input sequences with random lengths based on gamma distribution
examples = {
"input_ids": [list(range(length)) for length in seq_lengths],
"attention_mask": [[1] * length for length in seq_lengths],
}
dataset = Dataset.from_dict(examples)
max_length = 128 # Set a fixed truncation length
# Benchmark truncate_dataset
time_truncate_dataset = timeit.timeit(lambda: truncate_dataset(dataset, max_length), number=10)
# Benchmark dataset.map with truncate_examples
time_truncate_examples = timeit.timeit(
lambda: dataset.map(truncate_examples, batched=True, fn_kwargs={"max_length": max_length}), number=10
)
print(f"truncate_dataset time: {time_truncate_dataset:.4f} seconds")
print(f"dataset.map(truncate_examples) time: {time_truncate_examples:.4f} seconds")
print(f"Speedup: {time_truncate_examples / time_truncate_dataset:.2f}x")
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mariosasko @lhoestq 🔥🔥
…into pr/mariosasko/3009
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
What does this PR do?
Adds fast packing/truncation logic that operates directly on PyArrow arrays to avoid expensive Python-to-PyArrow and PyArrow-to-Python conversions. This makes these steps almost instantaneous regardless of the input dataset's size.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.