Skip to content

🎀 New default: beta=0.0 for GRPO #3516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ Note that compared to the original formulation in [DeepSeekMath: Pushing the Lim

</Tip>

<Tip>

Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [`GRPOConfig`] to a non-zero value.

</Tip>

In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:

$$
Expand All @@ -126,7 +132,7 @@ $$
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
$$

The DAPO paper highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
The [DAPO paper](https://huggingface.co/papers/2503.14476) highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:

$$
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
Expand Down
4 changes: 2 additions & 2 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,12 +809,12 @@ def test_training_with_sync_ref_model(self):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_beta_zero_no_ref_model_and_no_kl(self):
def test_training_beta_non_zero(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
beta=0.0, # set beta to 0 to test the case where the reference model is not used
beta=0.1, # set beta to non-zero value to test the case where the reference model is used
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
Expand Down
12 changes: 6 additions & 6 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ class GRPOConfig(TrainingArguments):
learning_rate (`float`, *optional*, defaults to `1e-6`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
speed, but may be numerically unstable for long training runs.
beta (`float`, *optional*, defaults to `0.0`):
KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and improving
training speed.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Expand Down Expand Up @@ -388,10 +388,10 @@ class GRPOConfig(TrainingArguments):
},
)
beta: float = field(
default=0.04,
default=0.0,
metadata={
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
"training speed, but may be numerically unstable for long training runs."
"help": "KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and "
"improving training speed."
},
)
num_iterations: int = field(
Expand Down
Loading