Skip to content

Commit 9510b36

Browse files
authored
Merge branch 'main' into grpo_config_extend
2 parents 3e44f00 + 5a0cebc commit 9510b36

18 files changed

+395
-113
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: v0.9.7
3+
rev: v0.11.3
44
hooks:
55
- id: ruff
66
types_or: [ python, pyi ]

docs/source/grpo_trainer.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,13 @@ When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifi
115115
## Logged metrics
116116

117117
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
118-
- `mean_completion_length`: The average length of generated completions.
119-
- `min_completion_length`: The maximum length of generated completions.
120-
- `max_completion_length`: The minimun length of generated completions.
121-
- `mean_terminated_completion_length`: The average length of generated completions that terminate with EOS.
122-
- `min_terminated_completion_length`: The maximum length of generated completions that terminate with EOS.
123-
- `max_terminated_completion_length`: The minimun length of generated completions that terminate with EOS.
124-
- `max_terminated_completion_length`: The minimun length of generated completions that terminate with EOS.
125-
- `clipped_completions_ratio` : The ratio of trucated (clipped) completions.
118+
- `completions/mean_length`: The average length of generated completions.
119+
- `completions/min_length`: The minimun length of generated completions.
120+
- `completions/max_length`: The maximum length of generated completions.
121+
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
122+
- `completions/min_terminated_length`: The minimun length of generated completions that terminate with EOS.
123+
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
124+
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
126125
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
127126
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
128127
- `reward`: The overall average reward after applying reward weights.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[tool.ruff]
2-
target-version = "py37"
2+
target-version = "py39"
33
line-length = 119
44

55
[tool.ruff.lint]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
"deepspeed": ["deepspeed>=0.14.4"],
8282
"diffusers": ["diffusers>=0.18.0"],
8383
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
84-
"liger": ["liger-kernel>=0.5.5"],
84+
"liger": ["liger-kernel>=0.5.6"],
8585
"mergekit": ["mergekit>=0.0.5.1"],
8686
"peft": ["peft>=0.8.0"],
8787
"quantization": ["bitsandbytes"],

tests/slow/test_grpo_slow.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import gc
16+
import tempfile
17+
import unittest
18+
19+
import torch
20+
from accelerate.utils.memory import release_memory
21+
from datasets import load_dataset
22+
from parameterized import parameterized
23+
from transformers import AutoModelForCausalLM, AutoTokenizer
24+
from transformers.testing_utils import require_liger_kernel, require_torch_accelerator
25+
26+
from trl import GRPOConfig, GRPOTrainer
27+
28+
from .testing_constants import MODELS_TO_TEST
29+
30+
31+
@require_torch_accelerator
32+
class GRPOTrainerSlowTester(unittest.TestCase):
33+
def setUp(self):
34+
self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
35+
self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test")
36+
self.max_length = 128
37+
38+
def tearDown(self):
39+
gc.collect()
40+
torch.cuda.empty_cache()
41+
gc.collect()
42+
43+
@parameterized.expand(MODELS_TO_TEST)
44+
@require_liger_kernel
45+
def test_training_with_liger_grpo_loss(self, model_name):
46+
with tempfile.TemporaryDirectory() as tmp_dir:
47+
training_args = GRPOConfig(
48+
output_dir=tmp_dir,
49+
per_device_train_batch_size=3,
50+
num_generations=3,
51+
use_liger_loss=True,
52+
max_completion_length=self.max_length,
53+
report_to="none",
54+
logging_strategy="no",
55+
)
56+
57+
model = AutoModelForCausalLM.from_pretrained(model_name)
58+
tokenizer = AutoTokenizer.from_pretrained(model_name)
59+
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
60+
61+
trainer = GRPOTrainer(
62+
model=model,
63+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
64+
args=training_args,
65+
train_dataset=self.train_dataset,
66+
eval_dataset=self.eval_dataset,
67+
processing_class=tokenizer,
68+
)
69+
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
70+
71+
assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss)
72+
73+
previous_trainable_params = {n: param.clone() for n, param in model.named_parameters()}
74+
75+
trainer.train()
76+
77+
for n, param in previous_trainable_params.items():
78+
new_param = model.get_parameter(n)
79+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
80+
81+
release_memory(model, trainer)

tests/slow/test_sft_slow.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def test_sft_trainer_transformers(self, model_name, packing):
106106

107107
model = AutoModelForCausalLM.from_pretrained(model_name)
108108
tokenizer = AutoTokenizer.from_pretrained(model_name)
109-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
110109

111110
trainer = SFTTrainer(
112111
model,
@@ -141,7 +140,6 @@ def test_sft_trainer_peft(self, model_name, packing):
141140

142141
model = AutoModelForCausalLM.from_pretrained(model_name)
143142
tokenizer = AutoTokenizer.from_pretrained(model_name)
144-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
145143

146144
trainer = SFTTrainer(
147145
model,
@@ -178,7 +176,6 @@ def test_sft_trainer_transformers_mp(self, model_name, packing):
178176

179177
model = AutoModelForCausalLM.from_pretrained(model_name)
180178
tokenizer = AutoTokenizer.from_pretrained(model_name)
181-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
182179

183180
trainer = SFTTrainer(
184181
model,
@@ -214,7 +211,6 @@ def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_chec
214211

215212
model = AutoModelForCausalLM.from_pretrained(model_name)
216213
tokenizer = AutoTokenizer.from_pretrained(model_name)
217-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
218214

219215
trainer = SFTTrainer(
220216
model,
@@ -251,7 +247,6 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient
251247

252248
model = AutoModelForCausalLM.from_pretrained(model_name)
253249
tokenizer = AutoTokenizer.from_pretrained(model_name)
254-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
255250

256251
trainer = SFTTrainer(
257252
model,
@@ -295,7 +290,6 @@ def test_sft_trainer_transformers_mp_gc_device_map(
295290

296291
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
297292
tokenizer = AutoTokenizer.from_pretrained(model_name)
298-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
299293

300294
trainer = SFTTrainer(
301295
model,
@@ -335,7 +329,6 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
335329

336330
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
337331
tokenizer = AutoTokenizer.from_pretrained(model_name)
338-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
339332

340333
trainer = SFTTrainer(
341334
model,
@@ -381,7 +374,6 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
381374

382375
if tokenizer.chat_template is None:
383376
model, tokenizer = setup_chat_format(model, tokenizer)
384-
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
385377

386378
trainer = SFTTrainer(
387379
model,

tests/test_sft_trainer.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from trl import SFTConfig, SFTTrainer
3535
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
36+
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
3637

3738

3839
def formatting_prompts_func(example):
@@ -59,14 +60,41 @@ def formatting_prompts_func_batched(example):
5960
from PIL import Image as PILImage
6061

6162

63+
class TestDataCollatorForLanguageModeling(unittest.TestCase):
64+
def test_collate_padding(self):
65+
collator = DataCollatorForLanguageModeling(pad_token_id=0)
66+
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
67+
output = collator(examples)
68+
69+
expected_input_ids = torch.tensor([[1, 2, 3], [4, 5, 0]])
70+
expected_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
71+
expected_labels = torch.tensor([[1, 2, 3], [4, 5, -100]])
72+
73+
self.assertEqual(output["input_ids"].tolist(), expected_input_ids.tolist())
74+
self.assertEqual(output["attention_mask"].tolist(), expected_attention_mask.tolist())
75+
self.assertEqual(output["labels"].tolist(), expected_labels.tolist())
76+
77+
def test_collate_no_padding(self):
78+
collator = DataCollatorForLanguageModeling(pad_token_id=0)
79+
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5, 6]}]
80+
output = collator(examples)
81+
82+
expected_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
83+
expected_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]])
84+
expected_labels = torch.tensor([[1, 2, 3], [4, 5, 6]])
85+
86+
self.assertEqual(output["input_ids"].tolist(), expected_input_ids.tolist())
87+
self.assertEqual(output["attention_mask"].tolist(), expected_attention_mask.tolist())
88+
self.assertEqual(output["labels"].tolist(), expected_labels.tolist())
89+
90+
6291
class SFTTrainerTester(unittest.TestCase):
6392
r""" """
6493

6594
def setUp(self):
6695
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
6796
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
6897
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
69-
self.tokenizer.pad_token = self.tokenizer.eos_token
7098
self.dummy_dataset = Dataset.from_dict(
7199
{
72100
"question": [

trl/data_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
import functools
16-
from typing import Any, Callable, Optional, Sequence, TypeVar, Union
16+
from collections.abc import Sequence
17+
from typing import Any, Callable, Optional, TypeVar, Union
1718

1819
import numpy as np
1920
import pyarrow as pa

trl/extras/profiling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import contextlib
1616
import functools
1717
import time
18-
from typing import Generator
18+
from collections.abc import Generator
1919

2020
from transformers import Trainer, is_wandb_available
2121

trl/extras/vllm_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def generate(
140140
min_p: float = 0.0,
141141
max_tokens: int = 16,
142142
guided_decoding_regex: Optional[str] = None,
143-
) -> list[list[str]]:
143+
) -> list[list[int]]:
144144
"""
145145
Generates model completions for the provided prompts.
146146

0 commit comments

Comments
 (0)