Skip to content

Commit 7359ddc

Browse files
authored
🎀 New default: beta=0.0 for GRPO (#3516)
1 parent 0844936 commit 7359ddc

File tree

3 files changed

+15
-9
lines changed

3 files changed

+15
-9
lines changed

docs/source/grpo_trainer.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ Note that compared to the original formulation in [DeepSeekMath: Pushing the Lim
103103

104104
</Tip>
105105

106+
<Tip>
107+
108+
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [`GRPOConfig`] to a non-zero value.
109+
110+
</Tip>
111+
106112
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
107113

108114
$$
@@ -126,7 +132,7 @@ $$
126132
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
127133
$$
128134

129-
The DAPO paper highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
135+
The [DAPO paper](https://huggingface.co/papers/2503.14476) highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
130136

131137
$$
132138
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},

tests/test_grpo_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,12 +809,12 @@ def test_training_with_sync_ref_model(self):
809809
new_param = trainer.model.get_parameter(n)
810810
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
811811

812-
def test_beta_zero_no_ref_model_and_no_kl(self):
812+
def test_training_beta_non_zero(self):
813813
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
814814
with tempfile.TemporaryDirectory() as tmp_dir:
815815
training_args = GRPOConfig(
816816
output_dir=tmp_dir,
817-
beta=0.0, # set beta to 0 to test the case where the reference model is not used
817+
beta=0.1, # set beta to non-zero value to test the case where the reference model is used
818818
learning_rate=0.1, # increase the learning rate to speed up the test
819819
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
820820
num_generations=3, # reduce the number of generations to reduce memory usage

trl/trainer/grpo_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ class GRPOConfig(TrainingArguments):
131131
learning_rate (`float`, *optional*, defaults to `1e-6`):
132132
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
133133
[`~transformers.TrainingArguments`].
134-
beta (`float`, *optional*, defaults to `0.04`):
135-
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
136-
speed, but may be numerically unstable for long training runs.
134+
beta (`float`, *optional*, defaults to `0.0`):
135+
KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and improving
136+
training speed.
137137
num_iterations (`int`, *optional*, defaults to `1`):
138138
Number of iterations per batch (denoted as μ in the algorithm).
139139
epsilon (`float`, *optional*, defaults to `0.2`):
@@ -388,10 +388,10 @@ class GRPOConfig(TrainingArguments):
388388
},
389389
)
390390
beta: float = field(
391-
default=0.04,
391+
default=0.0,
392392
metadata={
393-
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
394-
"training speed, but may be numerically unstable for long training runs."
393+
"help": "KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and "
394+
"improving training speed."
395395
},
396396
)
397397
num_iterations: int = field(

0 commit comments

Comments
 (0)