Skip to content
Merged
8 changes: 8 additions & 0 deletions docs/source/data_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@
## pack_examples

[[autodoc]] pack_examples

## pack_dataset

[[autodoc]] pack_dataset

## truncate_dataset

[[autodoc]] truncate_dataset
86 changes: 83 additions & 3 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_dataset,
pack_examples,
truncate_dataset,
unpair_preference_dataset,
)

Expand Down Expand Up @@ -395,7 +397,7 @@ def test_maybe_extract_prompt_standard_already_explicit(self):


class TestPackExamples(unittest.TestCase):
def test_pack_examples_larger_chunks(self):
def test_larger_chunks(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
Expand All @@ -408,7 +410,7 @@ def test_pack_examples_larger_chunks(self):
result = pack_examples(examples, seq_length)
self.assertEqual(result, expected_output)

def test_pack_examples_smaller_chunks(self):
def test_smaller_chunks(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
Expand All @@ -421,7 +423,7 @@ def test_pack_examples_smaller_chunks(self):
result = pack_examples(examples, seq_length)
self.assertEqual(result, expected_output)

def test_pack_with_dataset(self):
def test_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
Expand All @@ -436,6 +438,84 @@ def test_pack_with_dataset(self):
self.assertEqual(dataset.to_dict(), expected_output)


class TestPackDataset(unittest.TestCase):
def test_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length)
self.assertEqual(dataset.to_dict(), expected_output)

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length)
num_examples = len(examples[next(iter(examples))])
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)


class TestTruncateExamples(unittest.TestCase):
def test_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
self.assertEqual(dataset.to_dict(), expected_output)

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
num_examples = len(examples[next(iter(examples))])
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)

def test_with_extra_column(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
"my_column": ["a", "b", "c"],
}
dataset = Dataset.from_dict(examples)
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
"my_column": ["a", "b", "c"],
}
dataset = truncate_dataset(dataset, max_length)
self.assertEqual(dataset.to_dict(), expected_output)


class TestMaybeConvertToChatML(unittest.TestCase):
def test_with_conversations_key(self):
# Particular case where the key is "conversations": we rename it to "messages"
Expand Down
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
"maybe_convert_to_chatml",
"maybe_extract_prompt",
"maybe_unpair_preference_dataset",
"pack_dataset",
"pack_examples",
"truncate_dataset",
"unpair_preference_dataset",
],
"environment": ["TextEnvironment", "TextHistory"],
Expand Down Expand Up @@ -130,7 +132,9 @@
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_dataset,
pack_examples,
truncate_dataset,
unpair_preference_dataset,
)
from .environment import TextEnvironment, TextHistory
Expand Down
131 changes: 131 additions & 0 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Callable, Optional, Sequence, TypeVar, Union

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.types
from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -466,6 +471,132 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
return examples


def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dict[str, Any]] = None) -> DatasetType:
r"""
Pack sequences in a dataset into chunks of size `seq_length`.

Args:
dataset (`Dataset` or `DatasetDict`):
Dataset to pack
seq_length (`int`):
Target sequence length to pack to.
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
Additional keyword arguments to pass to the dataset's map method when packing examples.

Returns:
`Dataset` or `DatasetDict`: The dataset with packed sequences. The number of examples may
decrease as sequences are combined.

Example:
```python
>>> from datasets import Dataset
>>> examples = {
... "input_ids": [[1, 2], [3, 4], [5, 6], [7]],
... "attention_mask": [[1, 1], [0, 1], [1, 1], [1]],
... }
>>> dataset = Dataset.from_dict(examples)
>>> packed_dataset = pack_dataset(dataset, seq_length=4)
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 4], [5, 6, 7]],
'attention_mask': [[1, 1, 0, 1], [1, 1, 1]]}
```
"""
if map_kwargs is None:
map_kwargs = {}
if isinstance(dataset, Dataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it should also work for DatasetDict ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to make it work with DatasetDict here, we apply a potentially different preprocessing depending on the split, see

preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
if preprocess_dataset:
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
if eval_dataset is not None:
packing = args.packing if args.eval_packing is None else args.eval_packing
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(
eval_dataset, processing_class, args, packing, formatting_func, "eval"
)

# Fast packing with pyarrow
def pack(examples):
packed_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
offsets, values = column.offsets, column.values
values = values[offsets[0].as_py() : offsets[-1].as_py()]
num_elements = len(values)
dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64
offsets = np.arange(0, num_elements, seq_length, dtype=dtype)
offsets = np.concatenate((offsets, [num_elements]))
column = type(column).from_arrays(offsets, values)
packed_columns.append(column)
return pa.Table.from_arrays(packed_columns, names=examples.column_names)

dataset = dataset.with_format("arrow")
dataset = dataset.map(pack, batched=True, **map_kwargs)
dataset = dataset.with_format(None)
else:
dataset = dataset.map(
functools.partial(pack_examples, seq_length=seq_length),
batched=True,
**map_kwargs,
)
return dataset


def truncate_dataset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments for truncate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyarrow.compute.list_slice reuses the .values buffers and only modifies the .offsets, so no unnecessary copies here.

dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None
) -> DatasetType:
r"""
Truncate sequences in a dataset to a specifed `max_length`.

Args:
dataset (`Dataset` or `DatasetDict`):
Dataset to truncate.
seq_length (`int`):
Maximum sequence length to truncate to.
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
Additional keyword arguments to pass to the dataset's map method when truncating examples.

Returns:
`Dataset` or `DatasetDict`: The dataset with truncated sequences.

Example:
```python
>>> from datasets import Dataset
>>> examples = {
... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
... }
>>> dataset = Dataset.from_dict(examples)
>>> truncated_dataset = truncate_dataset(dataset, max_length=2)
>>> truncated_dataset[:]
{'input_ids': [[1, 2], [4, 5], [8]],
'attention_mask': [[0, 1], [0, 0], [1]]}
```
"""
if map_kwargs is None:
map_kwargs = {}
if isinstance(dataset, Dataset):
# Fast truncation with pyarrow
def truncate(examples):
truncated_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, max_length)
truncated_columns.append(column)
return pa.Table.from_arrays(truncated_columns, names=examples.column_names)

dataset = dataset.with_format("arrow")
dataset = dataset.map(truncate, batched=True, **map_kwargs)
dataset = dataset.with_format(None)
else:

def truncate(examples):
truncated_examples = {}
for key, column in examples.items():
if column and isinstance(column[0], list):
column = [val[:max_length] for val in column]
truncated_examples[key] = column
return truncated_examples

dataset = dataset.map(
truncate,
batched=True,
**map_kwargs,
)
return dataset


def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
"""
Convert a conversational dataset with fields `from` and `value` to ChatML format.
Expand Down
23 changes: 9 additions & 14 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
from ..data_utils import (
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
pack_dataset,
truncate_dataset,
)
from .sft_config import SFTConfig
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16

Expand Down Expand Up @@ -470,22 +476,11 @@ def tokenize(example, processing_class, dataset_text_field):
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Packing {dataset_name} dataset"
dataset = dataset.select_columns("input_ids")
dataset = dataset.map(
pack_examples, batched=True, fn_kwargs={"seq_length": args.max_length}, **map_kwargs
)
dataset = pack_dataset(dataset, args.max_length, map_kwargs)
elif args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"

def truncate(example, max_length):
return {key: example[key][:max_length] for key in ["input_ids", "attention_mask"]}

dataset = dataset.map(
truncate,
fn_kwargs={"max_length": args.max_length},
**map_kwargs,
)

dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
# For Liger kernel, ensure only input_ids is present
if args.use_liger_kernel:
dataset = dataset.select_columns("input_ids")
Expand Down
Loading