Skip to content

Commit aa111fd

Browse files
authored
Merge branch 'main' into main
2 parents 5d7bade + a0a5317 commit aa111fd

File tree

13 files changed

+298
-23
lines changed

13 files changed

+298
-23
lines changed

.github/workflows/tests_latest.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
steps:
1818
- name: Git checkout
1919
uses: actions/checkout@v4
20-
with: { ref: v0.15-release }
20+
with: { ref: v0.16-release }
2121
- name: Set up Python 3.12
2222
uses: actions/setup-python@v5
2323
with:

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,4 @@ checklink/cookies.txt
142142
# wandb files
143143
nbs/wandb/
144144
examples/notebooks/wandb/
145-
wandb/
145+
wandb/

CITATION.cff

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ keywords:
3131
- pytorch
3232
- transformers
3333
license: Apache-2.0
34-
version: 0.15
34+
version: 0.16

docs/source/data_utils.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,11 @@
3535
## pack_examples
3636

3737
[[autodoc]] pack_examples
38+
39+
## pack_dataset
40+
41+
[[autodoc]] pack_dataset
42+
43+
## truncate_dataset
44+
45+
[[autodoc]] truncate_dataset

docs/source/reducing_memory_usage.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Onl
136136
If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:
137137

138138
<hfoptions id="ds3_gather_for_generation">
139+
<hfoption id="GRPO">
140+
141+
```python
142+
from trl import GRPOConfig
143+
144+
training_args = GRPOConfig(..., ds3_gather_for_generation=False)
145+
```
146+
147+
</hfoption>
139148
<hfoption id="Online DPO">
140149

141150
```python

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from setuptools import find_packages, setup
7070

7171

72-
__version__ = "0.16.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
72+
__version__ = "0.17.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
7373

7474
REQUIRED_PKGS = [
7575
"accelerate>=0.34.0",

tests/test_data_utils.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
maybe_convert_to_chatml,
2828
maybe_extract_prompt,
2929
maybe_unpair_preference_dataset,
30+
pack_dataset,
3031
pack_examples,
32+
truncate_dataset,
3133
unpair_preference_dataset,
3234
)
3335

@@ -395,7 +397,7 @@ def test_maybe_extract_prompt_standard_already_explicit(self):
395397

396398

397399
class TestPackExamples(unittest.TestCase):
398-
def test_pack_examples_larger_chunks(self):
400+
def test_larger_chunks(self):
399401
examples = {
400402
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
401403
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
@@ -408,7 +410,7 @@ def test_pack_examples_larger_chunks(self):
408410
result = pack_examples(examples, seq_length)
409411
self.assertEqual(result, expected_output)
410412

411-
def test_pack_examples_smaller_chunks(self):
413+
def test_smaller_chunks(self):
412414
examples = {
413415
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
414416
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
@@ -421,7 +423,7 @@ def test_pack_examples_smaller_chunks(self):
421423
result = pack_examples(examples, seq_length)
422424
self.assertEqual(result, expected_output)
423425

424-
def test_pack_with_dataset(self):
426+
def test_with_dataset(self):
425427
examples = {
426428
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
427429
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
@@ -436,6 +438,84 @@ def test_pack_with_dataset(self):
436438
self.assertEqual(dataset.to_dict(), expected_output)
437439

438440

441+
class TestPackDataset(unittest.TestCase):
442+
def test_with_dataset(self):
443+
examples = {
444+
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
445+
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
446+
}
447+
dataset = Dataset.from_dict(examples)
448+
seq_length = 3
449+
expected_output = {
450+
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
451+
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
452+
}
453+
dataset = pack_dataset(dataset, seq_length)
454+
self.assertEqual(dataset.to_dict(), expected_output)
455+
456+
def test_with_iterable_dataset(self):
457+
examples = {
458+
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
459+
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
460+
}
461+
dataset = Dataset.from_dict(examples).to_iterable_dataset()
462+
seq_length = 3
463+
expected_output = {
464+
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
465+
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
466+
}
467+
dataset = pack_dataset(dataset, seq_length)
468+
num_examples = len(examples[next(iter(examples))])
469+
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
470+
471+
472+
class TestTruncateExamples(unittest.TestCase):
473+
def test_with_dataset(self):
474+
examples = {
475+
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
476+
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
477+
}
478+
dataset = Dataset.from_dict(examples)
479+
max_length = 2
480+
expected_output = {
481+
"input_ids": [[1, 2], [4, 5], [8]],
482+
"attention_mask": [[0, 1], [0, 0], [1]],
483+
}
484+
dataset = truncate_dataset(dataset, max_length)
485+
self.assertEqual(dataset.to_dict(), expected_output)
486+
487+
def test_with_iterable_dataset(self):
488+
examples = {
489+
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
490+
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
491+
}
492+
dataset = Dataset.from_dict(examples).to_iterable_dataset()
493+
max_length = 2
494+
expected_output = {
495+
"input_ids": [[1, 2], [4, 5], [8]],
496+
"attention_mask": [[0, 1], [0, 0], [1]],
497+
}
498+
dataset = truncate_dataset(dataset, max_length)
499+
num_examples = len(examples[next(iter(examples))])
500+
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)
501+
502+
def test_with_extra_column(self):
503+
examples = {
504+
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
505+
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
506+
"my_column": ["a", "b", "c"],
507+
}
508+
dataset = Dataset.from_dict(examples)
509+
max_length = 2
510+
expected_output = {
511+
"input_ids": [[1, 2], [4, 5], [8]],
512+
"attention_mask": [[0, 1], [0, 0], [1]],
513+
"my_column": ["a", "b", "c"],
514+
}
515+
dataset = truncate_dataset(dataset, max_length)
516+
self.assertEqual(dataset.to_dict(), expected_output)
517+
518+
439519
class TestMaybeConvertToChatML(unittest.TestCase):
440520
def test_with_conversations_key(self):
441521
# Particular case where the key is "conversations": we rename it to "messages"

tests/test_grpo_trainer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,3 +914,34 @@ def test_training_vllm_with_additional_generation_kwargs(self):
914914
for n, param in previous_trainable_params.items():
915915
new_param = trainer.model.get_parameter(n)
916916
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
917+
918+
def test_training_no_scale_rewards(self):
919+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
920+
921+
with tempfile.TemporaryDirectory() as tmp_dir:
922+
training_args = GRPOConfig(
923+
output_dir=tmp_dir,
924+
learning_rate=0.1, # increase the learning rate to speed up the test
925+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
926+
num_generations=3, # reduce the number of generations to reduce memory usage
927+
max_completion_length=32, # reduce the completion length to reduce memory usage
928+
scale_rewards=False,
929+
report_to="none",
930+
)
931+
trainer = GRPOTrainer(
932+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
933+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
934+
args=training_args,
935+
train_dataset=dataset,
936+
)
937+
938+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
939+
940+
trainer.train()
941+
942+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
943+
944+
# Check that the params have changed
945+
for n, param in previous_trainable_params.items():
946+
new_param = trainer.model.get_parameter(n)
947+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

trl/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "0.16.0.dev0"
15+
__version__ = "0.17.0.dev0"
1616

1717
from typing import TYPE_CHECKING
1818

@@ -29,7 +29,9 @@
2929
"maybe_convert_to_chatml",
3030
"maybe_extract_prompt",
3131
"maybe_unpair_preference_dataset",
32+
"pack_dataset",
3233
"pack_examples",
34+
"truncate_dataset",
3335
"unpair_preference_dataset",
3436
],
3537
"environment": ["TextEnvironment", "TextHistory"],
@@ -130,7 +132,9 @@
130132
maybe_convert_to_chatml,
131133
maybe_extract_prompt,
132134
maybe_unpair_preference_dataset,
135+
pack_dataset,
133136
pack_examples,
137+
truncate_dataset,
134138
unpair_preference_dataset,
135139
)
136140
from .environment import TextEnvironment, TextHistory

trl/data_utils.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
1516
from typing import Any, Callable, Optional, Sequence, TypeVar, Union
1617

18+
import numpy as np
19+
import pyarrow as pa
20+
import pyarrow.compute as pc
21+
import pyarrow.types
1722
from datasets import Dataset, DatasetDict
1823
from transformers import PreTrainedTokenizerBase
1924

@@ -466,6 +471,132 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
466471
return examples
467472

468473

474+
def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dict[str, Any]] = None) -> DatasetType:
475+
r"""
476+
Pack sequences in a dataset into chunks of size `seq_length`.
477+
478+
Args:
479+
dataset (`Dataset` or `DatasetDict`):
480+
Dataset to pack
481+
seq_length (`int`):
482+
Target sequence length to pack to.
483+
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
484+
Additional keyword arguments to pass to the dataset's map method when packing examples.
485+
486+
Returns:
487+
`Dataset` or `DatasetDict`: The dataset with packed sequences. The number of examples may
488+
decrease as sequences are combined.
489+
490+
Example:
491+
```python
492+
>>> from datasets import Dataset
493+
>>> examples = {
494+
... "input_ids": [[1, 2], [3, 4], [5, 6], [7]],
495+
... "attention_mask": [[1, 1], [0, 1], [1, 1], [1]],
496+
... }
497+
>>> dataset = Dataset.from_dict(examples)
498+
>>> packed_dataset = pack_dataset(dataset, seq_length=4)
499+
>>> packed_dataset[:]
500+
{'input_ids': [[1, 2, 3, 4], [5, 6, 7]],
501+
'attention_mask': [[1, 1, 0, 1], [1, 1, 1]]}
502+
```
503+
"""
504+
if map_kwargs is None:
505+
map_kwargs = {}
506+
if isinstance(dataset, Dataset):
507+
# Fast packing with pyarrow
508+
def pack(examples):
509+
packed_columns = []
510+
for column in examples.columns:
511+
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
512+
if isinstance(column, pa.ChunkedArray):
513+
column = column.combine_chunks()
514+
offsets, values = column.offsets, column.values
515+
values = values[offsets[0].as_py() : offsets[-1].as_py()]
516+
num_elements = len(values)
517+
dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64
518+
offsets = np.arange(0, num_elements, seq_length, dtype=dtype)
519+
offsets = np.concatenate((offsets, [num_elements]))
520+
column = type(column).from_arrays(offsets, values)
521+
packed_columns.append(column)
522+
return pa.Table.from_arrays(packed_columns, names=examples.column_names)
523+
524+
dataset = dataset.with_format("arrow")
525+
dataset = dataset.map(pack, batched=True, **map_kwargs)
526+
dataset = dataset.with_format(None)
527+
else:
528+
dataset = dataset.map(
529+
functools.partial(pack_examples, seq_length=seq_length),
530+
batched=True,
531+
**map_kwargs,
532+
)
533+
return dataset
534+
535+
536+
def truncate_dataset(
537+
dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None
538+
) -> DatasetType:
539+
r"""
540+
Truncate sequences in a dataset to a specifed `max_length`.
541+
542+
Args:
543+
dataset (`Dataset` or `DatasetDict`):
544+
Dataset to truncate.
545+
seq_length (`int`):
546+
Maximum sequence length to truncate to.
547+
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
548+
Additional keyword arguments to pass to the dataset's map method when truncating examples.
549+
550+
Returns:
551+
`Dataset` or `DatasetDict`: The dataset with truncated sequences.
552+
553+
Example:
554+
```python
555+
>>> from datasets import Dataset
556+
>>> examples = {
557+
... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
558+
... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
559+
... }
560+
>>> dataset = Dataset.from_dict(examples)
561+
>>> truncated_dataset = truncate_dataset(dataset, max_length=2)
562+
>>> truncated_dataset[:]
563+
{'input_ids': [[1, 2], [4, 5], [8]],
564+
'attention_mask': [[0, 1], [0, 0], [1]]}
565+
```
566+
"""
567+
if map_kwargs is None:
568+
map_kwargs = {}
569+
if isinstance(dataset, Dataset):
570+
# Fast truncation with pyarrow
571+
def truncate(examples):
572+
truncated_columns = []
573+
for column in examples.columns:
574+
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
575+
column = pc.list_slice(column, 0, max_length)
576+
truncated_columns.append(column)
577+
return pa.Table.from_arrays(truncated_columns, names=examples.column_names)
578+
579+
dataset = dataset.with_format("arrow")
580+
dataset = dataset.map(truncate, batched=True, **map_kwargs)
581+
dataset = dataset.with_format(None)
582+
else:
583+
584+
def truncate(examples):
585+
truncated_examples = {}
586+
for key, column in examples.items():
587+
if column and isinstance(column[0], list):
588+
column = [val[:max_length] for val in column]
589+
truncated_examples[key] = column
590+
return truncated_examples
591+
592+
dataset = dataset.map(
593+
truncate,
594+
batched=True,
595+
**map_kwargs,
596+
)
597+
return dataset
598+
599+
469600
def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
470601
"""
471602
Convert a conversational dataset with fields `from` and `value` to ChatML format.

0 commit comments

Comments
 (0)