Skip to content

Commit 9bc6206

Browse files
Move PRMTrainer to trl.experimental.prm (huggingface#4483)
Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent f7ac974 commit 9bc6206

File tree

12 files changed

+463
-407
lines changed

12 files changed

+463
-407
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@
6464
title: GRPO
6565
- local: kto_trainer
6666
title: KTO
67-
- local: prm_trainer
68-
title: PRM
6967
- local: reward_trainer
7068
title: Reward
7169
- local: rloo_trainer
@@ -119,6 +117,8 @@
119117
title: PAPO
120118
- local: ppo_trainer
121119
title: PPO
120+
- local: prm_trainer
121+
title: PRM
122122
- local: xpo_trainer
123123
title: XPO
124124
- local: openenv

docs/source/dataset_formats.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ Choosing the right dataset type depends on the task you are working on and the s
391391
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
392392
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
393393
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
394-
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
395394
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
396395
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
397396
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
@@ -401,6 +400,7 @@ Choosing the right dataset type depends on the task you are working on and the s
401400
| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) |
402401
| [`experimental.orpo.ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
403402
| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling |
403+
| [`experimental.prm.PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
404404
| [`experimental.xpo.XPOTrainer`] | [Prompt-only](#prompt-only) |
405405

406406
## Using any dataset with TRL: preprocessing and conversion

docs/source/example_overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
6262
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`experimental.orpo.ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
6363
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
6464
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
65-
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
65+
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`experimental.prm.PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
6666
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train an Outcome Reward Model (ORM) on your own dataset. |
6767
| [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. |
6868
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model. |

docs/source/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
3131

3232
### Reward modeling
3333

34-
- [`PRMTrainer`]
3534
- [`RewardTrainer`]
35+
- [`experimental.prm.PRMTrainer`] 🧪
3636

3737
</div>
3838
<div style="flex: 1; min-width: 0;">

docs/source/prm_trainer.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
[![model badge](https://img.shields.io/badge/All_models-PRM-blue)](https://huggingface.co/models?other=prm,trl)
44

5+
> [!TIP]
6+
> PRMTrainer has been moved to `trl.experimental.prm.PRMTrainer`. The `trl.trainer` version is deprecated and will be removed in TRL 0.29.0. Please update your imports to use `trl.experimental.prm.PRMTrainer` instead. See [issue #4467](https://github.com/huggingface/trl/issues/4467) for more information.
7+
58
> [!WARNING]
69
> PRM Trainer is an experimental API which is subject to change at any time.
710
@@ -31,7 +34,7 @@ Below is the script to train the model:
3134
```python
3235
# train_prm.py
3336
from datasets import load_dataset
34-
from trl import PRMConfig, PRMTrainer
37+
from trl.experimental.prm import PRMConfig, PRMTrainer
3538
from transformers import AutoModelForTokenClassification, AutoTokenizer
3639

3740
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
@@ -112,11 +115,11 @@ accelerate launch examples/scripts/prm.py \
112115

113116
## PRMTrainer
114117

115-
[[autodoc]] PRMTrainer
118+
[[autodoc]] experimental.prm.PRMTrainer
116119
- train
117120
- save_model
118121
- push_to_hub
119122

120123
## PRMConfig
121124

122-
[[autodoc]] PRMConfig
125+
[[autodoc]] experimental.prm.PRMConfig

examples/scripts/prm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,12 @@
5858

5959
from trl import (
6060
ModelConfig,
61-
PRMConfig,
62-
PRMTrainer,
6361
ScriptArguments,
6462
get_kbit_device_map,
6563
get_peft_config,
6664
get_quantization_config,
6765
)
66+
from trl.experimental.prm import PRMConfig, PRMTrainer
6867

6968

7069
logger = logging.get_logger(__name__)

tests/test_prm_trainer.py renamed to tests/experimental/test_prm_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase
2121
from transformers.utils import is_peft_available
2222

23-
from trl import PRMConfig, PRMTrainer
23+
from trl.experimental.prm import PRMConfig, PRMTrainer
2424

25-
from .testing_utils import TrlTestCase, require_peft
25+
from ..testing_utils import TrlTestCase, require_peft
2626

2727

2828
if is_peft_available():

trl/experimental/prm/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .prm_config import PRMConfig
16+
from .prm_trainer import PRMTrainer
17+
18+
19+
__all__ = ["PRMConfig", "PRMTrainer"]

trl/experimental/prm/prm_config.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass, field
16+
17+
from transformers import TrainingArguments
18+
19+
20+
@dataclass
21+
class PRMConfig(TrainingArguments):
22+
r"""
23+
Configuration class for the [`experimental.prm.PRMTrainer`].
24+
25+
This class includes only the parameters that are specific to PRM training. For a full list of training arguments,
26+
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
27+
differ from those in [`~transformers.TrainingArguments`].
28+
29+
Using [`~transformers.HfArgumentParser`] we can turn this class into
30+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
31+
command line.
32+
33+
Parameters:
34+
max_length (`int` or `None`, *optional*, defaults to `1024`):
35+
Maximum length of the sequences (prompt + completion) used for truncation.
36+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
37+
Maximum length of the prompt used for truncation.
38+
max_completion_length (`int`, *optional*):
39+
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
40+
disable_dropout (`bool`, *optional*, defaults to `True`):
41+
Whether to disable dropout in the model.
42+
step_separator (`str`, *optional*, defaults to `"\n"`):
43+
Separator used to separate each step of the reasoning process.
44+
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
45+
Whether to train only on the last step.
46+
dataset_num_proc (`int`, *optional*):
47+
Number of processes to use for processing the dataset.
48+
"""
49+
50+
# Parameters whose default values are overridden from TrainingArguments
51+
learning_rate: float = field(
52+
default=1e-5,
53+
metadata={"help": "The initial learning rate for AdamW."},
54+
)
55+
logging_steps: float = field(
56+
default=10,
57+
metadata={
58+
"help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
59+
"will be interpreted as ratio of total training steps."
60+
},
61+
)
62+
gradient_checkpointing: bool = field(
63+
default=True,
64+
metadata={
65+
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
66+
},
67+
)
68+
bf16: bool | None = field(
69+
default=None,
70+
metadata={
71+
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
72+
"architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
73+
"`fp16` is not set."
74+
},
75+
)
76+
77+
max_length: int | None = field(
78+
default=1024,
79+
metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."},
80+
)
81+
max_prompt_length: int | None = field(
82+
default=512,
83+
metadata={"help": "Maximum length of the prompt used for truncation."},
84+
)
85+
max_completion_length: int | None = field(
86+
default=None,
87+
metadata={
88+
"help": "Maximum length of the completion used for truncation. The completion is the concatenation of the "
89+
"steps."
90+
},
91+
)
92+
disable_dropout: bool = field(
93+
default=True,
94+
metadata={"help": "Whether to disable dropout in the model and reference model."},
95+
)
96+
step_separator: str = field(
97+
default="\n",
98+
metadata={"help": "Separator used to separate each step of the reasoning process."},
99+
)
100+
train_on_last_step_only: bool = field(
101+
default=False,
102+
metadata={"help": "Whether to train only on the last step."},
103+
)
104+
dataset_num_proc: int | None = field(
105+
default=None,
106+
metadata={"help": "Number of processes to use for processing the dataset."},
107+
)
108+
109+
def __post_init__(self):
110+
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
111+
112+
super().__post_init__()

0 commit comments

Comments
 (0)