Skip to content
Merged
8 changes: 8 additions & 0 deletions docs/source/data_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,11 @@
## pack_examples

[[autodoc]] pack_examples

## pack_dataset

[[autodoc]] pack_dataset

## truncate_dataset

[[autodoc]] truncate_dataset
50 changes: 49 additions & 1 deletion tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_dataset,
pack_examples,
truncate_dataset,
unpair_preference_dataset,
)

Expand Down Expand Up @@ -432,9 +434,55 @@ def test_pack_with_dataset(self):
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = dataset.map(pack_examples, batched=True, fn_kwargs={"seq_length": seq_length})
dataset = pack_dataset(dataset, seq_length)
self.assertEqual(dataset.to_dict(), expected_output)

def test_pack_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
seq_length = 3
expected_output = {
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length)
num_examples = len(examples[next(iter(examples))])
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)


class TestTruncateExamples(unittest.TestCase):
def test_truncate_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
self.assertEqual(dataset.to_dict(), expected_output)

def test_truncate_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
max_length = 2
expected_output = {
"input_ids": [[1, 2], [4, 5], [8]],
"attention_mask": [[0, 1], [0, 0], [1]],
}
dataset = truncate_dataset(dataset, max_length)
num_examples = len(examples[next(iter(examples))])
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)


class TestMaybeConvertToChatML(unittest.TestCase):
def test_with_conversations_key(self):
Expand Down
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
"maybe_convert_to_chatml",
"maybe_extract_prompt",
"maybe_unpair_preference_dataset",
"pack_dataset",
"pack_examples",
"truncate_dataset",
"unpair_preference_dataset",
],
"environment": ["TextEnvironment", "TextHistory"],
Expand Down Expand Up @@ -130,7 +132,9 @@
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_dataset,
pack_examples,
truncate_dataset,
unpair_preference_dataset,
)
from .environment import TextEnvironment, TextHistory
Expand Down
131 changes: 131 additions & 0 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Any, Callable, Optional, Sequence, TypeVar, Union

import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.types
from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizerBase

Expand Down Expand Up @@ -466,6 +471,132 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
return examples


def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dict[str, Any]] = None) -> DatasetType:
r"""
Pack sequences in a dataset into chunks of size `seq_length`.

Args:
dataset (`Dataset` or `DatasetDict`):
Dataset to pack
seq_length (`int`):
Target sequence length to pack to.
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
Additional keyword arguments to pass to the dataset's map method when packing examples.

Returns:
`Dataset` or `DatasetDict`: The dataset with packed sequences. The number of examples may
decrease as sequences are combined.

Example:
```python
>>> from datasets import Dataset
>>> examples = {
... "input_ids": [[1, 2], [3, 4], [5, 6], [7]],
... "attention_mask": [[1, 1], [0, 1], [1, 1], [1]],
... }
>>> dataset = Dataset.from_dict(examples)
>>> packed_dataset = pack_dataset(dataset, seq_length=4)
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 4], [5, 6, 7]],
'attention_mask': [[1, 1, 0, 1], [1, 1, 1]]}
```
"""
if map_kwargs is None:
map_kwargs = {}
if isinstance(dataset, Dataset):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it should also work for DatasetDict ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to make it work with DatasetDict here, we apply a potentially different preprocessing depending on the split, see

preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
if preprocess_dataset:
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
if eval_dataset is not None:
packing = args.packing if args.eval_packing is None else args.eval_packing
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(
eval_dataset, processing_class, args, packing, formatting_func, "eval"
)

# Fast packing with pyarrow
def pack(examples):
packed_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
offsets, values = column.offsets, column.values
values = values[offsets[0].as_py() : offsets[-1].as_py()]
num_elements = len(values)
dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64
offsets = np.arange(0, num_elements, seq_length, dtype=dtype)
offsets = np.concatenate((offsets, [num_elements]))
column = type(column).from_arrays(offsets, values)
packed_columns.append(column)
return pa.Table.from_arrays(packed_columns, names=examples.column_names)

dataset = dataset.with_format("arrow")
dataset = dataset.map(pack, batched=True, **map_kwargs)
dataset = dataset.with_format(None)
else:
dataset = dataset.map(
functools.partial(pack_examples, seq_length=seq_length),
batched=True,
**map_kwargs,
)
return dataset


def truncate_dataset(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments for truncate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyarrow.compute.list_slice reuses the .values buffers and only modifies the .offsets, so no unnecessary copies here.

dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None
) -> DatasetType:
r"""
Truncate sequences in a dataset to a specifed `max_length`.

Args:
dataset (`Dataset` or `DatasetDict`):
Dataset to truncate.
seq_length (`int`):
Maximum sequence length to truncate to.
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
Additional keyword arguments to pass to the dataset's map method when truncating examples.

Returns:
`Dataset` or `DatasetDict`: The dataset with truncated sequences.

Example:
```python
>>> from datasets import Dataset
>>> examples = {
... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
... }
>>> dataset = Dataset.from_dict(examples)
>>> truncated_dataset = truncate_dataset(dataset, max_length=2)
>>> truncated_dataset[:]
{'input_ids': [[1, 2], [4, 5], [8]],
'attention_mask': [[0, 1], [0, 0], [1]]}
```
"""
if map_kwargs is None:
map_kwargs = {}
if isinstance(dataset, Dataset):
# Fast truncation with pyarrow
def truncate(examples):
truncated_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, max_length)
truncated_columns.append(column)
return pa.Table.from_arrays(truncated_columns, names=examples.column_names)

dataset = dataset.with_format("arrow")
dataset = dataset.map(truncate, batched=True, **map_kwargs)
dataset = dataset.with_format(None)
else:

def truncate(examples):
truncated_examples = {}
for key, column in examples.items():
if column and isinstance(column[0], list):
column = [val[:max_length] for val in column]
truncated_examples[key] = column
return truncated_examples

dataset = dataset.map(
truncate,
batched=True,
**map_kwargs,
)
return dataset


def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
"""
Convert a conversational dataset with fields `from` and `value` to ChatML format.
Expand Down
23 changes: 9 additions & 14 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
from ..data_utils import (
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
pack_dataset,
truncate_dataset,
)
from .sft_config import SFTConfig
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16

Expand Down Expand Up @@ -470,22 +476,11 @@ def tokenize(example, processing_class, dataset_text_field):
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Packing {dataset_name} dataset"
dataset = dataset.select_columns("input_ids")
dataset = dataset.map(
pack_examples, batched=True, fn_kwargs={"seq_length": args.max_length}, **map_kwargs
)
dataset = pack_dataset(dataset, args.max_length, map_kwargs)
elif args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"

def truncate(example, max_length):
return {key: example[key][:max_length] for key in ["input_ids", "attention_mask"]}

dataset = dataset.map(
truncate,
fn_kwargs={"max_length": args.max_length},
**map_kwargs,
)

dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
# For Liger kernel, ensure only input_ids is present
if args.use_liger_kernel:
dataset = dataset.select_columns("input_ids")
Expand Down
Loading