generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
⚡ Pack 300 times faster, truncate 100 times faster #3009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
60d54d1
Fast truncation and packing
mariosasko 987e050
Nit
mariosasko 1e027a1
Fix typo
mariosasko cc57863
Fix conflict
mariosasko 00cc012
Nit
mariosasko 8f4498f
Remove if
mariosasko 7e7854a
Handle sliced arrays
mariosasko b12c2cf
Merge branch 'main' of github.com:huggingface/trl into fast-pack-trun…
mariosasko 0f1b7bb
add an extra test
qgallouedec d9b33d4
Merge branch 'main' into fast-pack-truncate
qgallouedec d777a2d
Merge branch 'fast-pack-truncate' of https://github.com/mariosasko/tr…
qgallouedec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,13 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import functools | ||
from typing import Any, Callable, Optional, Sequence, TypeVar, Union | ||
|
||
import numpy as np | ||
import pyarrow as pa | ||
import pyarrow.compute as pc | ||
import pyarrow.types | ||
from datasets import Dataset, DatasetDict | ||
from transformers import PreTrainedTokenizerBase | ||
|
||
|
@@ -466,6 +471,132 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, | |
return examples | ||
|
||
|
||
def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dict[str, Any]] = None) -> DatasetType: | ||
r""" | ||
Pack sequences in a dataset into chunks of size `seq_length`. | ||
|
||
Args: | ||
dataset (`Dataset` or `DatasetDict`): | ||
Dataset to pack | ||
seq_length (`int`): | ||
Target sequence length to pack to. | ||
map_kwargs (`dict` or `None`, *optional*, defaults to `None`): | ||
Additional keyword arguments to pass to the dataset's map method when packing examples. | ||
|
||
Returns: | ||
`Dataset` or `DatasetDict`: The dataset with packed sequences. The number of examples may | ||
decrease as sequences are combined. | ||
|
||
Example: | ||
```python | ||
>>> from datasets import Dataset | ||
>>> examples = { | ||
... "input_ids": [[1, 2], [3, 4], [5, 6], [7]], | ||
... "attention_mask": [[1, 1], [0, 1], [1, 1], [1]], | ||
... } | ||
>>> dataset = Dataset.from_dict(examples) | ||
>>> packed_dataset = pack_dataset(dataset, seq_length=4) | ||
>>> packed_dataset[:] | ||
{'input_ids': [[1, 2, 3, 4], [5, 6, 7]], | ||
'attention_mask': [[1, 1, 0, 1], [1, 1, 1]]} | ||
``` | ||
""" | ||
if map_kwargs is None: | ||
map_kwargs = {} | ||
if isinstance(dataset, Dataset): | ||
# Fast packing with pyarrow | ||
def pack(examples): | ||
packed_columns = [] | ||
for column in examples.columns: | ||
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): | ||
if isinstance(column, pa.ChunkedArray): | ||
column = column.combine_chunks() | ||
offsets, values = column.offsets, column.values | ||
values = values[offsets[0].as_py() : offsets[-1].as_py()] | ||
num_elements = len(values) | ||
dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64 | ||
offsets = np.arange(0, num_elements, seq_length, dtype=dtype) | ||
offsets = np.concatenate((offsets, [num_elements])) | ||
column = type(column).from_arrays(offsets, values) | ||
packed_columns.append(column) | ||
return pa.Table.from_arrays(packed_columns, names=examples.column_names) | ||
|
||
dataset = dataset.with_format("arrow") | ||
dataset = dataset.map(pack, batched=True, **map_kwargs) | ||
dataset = dataset.with_format(None) | ||
else: | ||
dataset = dataset.map( | ||
functools.partial(pack_examples, seq_length=seq_length), | ||
batched=True, | ||
**map_kwargs, | ||
) | ||
return dataset | ||
|
||
|
||
def truncate_dataset( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comments for truncate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None | ||
) -> DatasetType: | ||
r""" | ||
Truncate sequences in a dataset to a specifed `max_length`. | ||
|
||
Args: | ||
dataset (`Dataset` or `DatasetDict`): | ||
Dataset to truncate. | ||
seq_length (`int`): | ||
Maximum sequence length to truncate to. | ||
map_kwargs (`dict` or `None`, *optional*, defaults to `None`): | ||
Additional keyword arguments to pass to the dataset's map method when truncating examples. | ||
|
||
Returns: | ||
`Dataset` or `DatasetDict`: The dataset with truncated sequences. | ||
|
||
Example: | ||
```python | ||
>>> from datasets import Dataset | ||
>>> examples = { | ||
... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], | ||
... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], | ||
... } | ||
>>> dataset = Dataset.from_dict(examples) | ||
>>> truncated_dataset = truncate_dataset(dataset, max_length=2) | ||
>>> truncated_dataset[:] | ||
{'input_ids': [[1, 2], [4, 5], [8]], | ||
'attention_mask': [[0, 1], [0, 0], [1]]} | ||
``` | ||
""" | ||
if map_kwargs is None: | ||
map_kwargs = {} | ||
if isinstance(dataset, Dataset): | ||
# Fast truncation with pyarrow | ||
def truncate(examples): | ||
truncated_columns = [] | ||
for column in examples.columns: | ||
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): | ||
column = pc.list_slice(column, 0, max_length) | ||
truncated_columns.append(column) | ||
return pa.Table.from_arrays(truncated_columns, names=examples.column_names) | ||
|
||
dataset = dataset.with_format("arrow") | ||
dataset = dataset.map(truncate, batched=True, **map_kwargs) | ||
dataset = dataset.with_format(None) | ||
else: | ||
|
||
def truncate(examples): | ||
truncated_examples = {} | ||
for key, column in examples.items(): | ||
if column and isinstance(column[0], list): | ||
column = [val[:max_length] for val in column] | ||
truncated_examples[key] = column | ||
return truncated_examples | ||
|
||
dataset = dataset.map( | ||
truncate, | ||
batched=True, | ||
**map_kwargs, | ||
) | ||
return dataset | ||
|
||
|
||
def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]: | ||
""" | ||
Convert a conversational dataset with fields `from` and `value` to ChatML format. | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it should also work for DatasetDict ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to make it work with DatasetDict here, we apply a potentially different preprocessing depending on the split, see
trl/trl/trainer/sft_trainer.py
Lines 185 to 200 in e3244d2