Skip to content

Commit 5e2e9cb

Browse files
qgallouedeclewtun
andauthored
🩺 Dr. GRPO loss (#3256)
Co-authored-by: lewtun <[email protected]>
1 parent 227df82 commit 5e2e9cb

File tree

4 files changed

+119
-14
lines changed

4 files changed

+119
-14
lines changed

docs/source/grpo_trainer.md

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ This approach gives the method its name: **Group Relative Policy Optimization (G
7676

7777
<Tip>
7878

79-
It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
79+
It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.14476) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
8080

8181
</Tip>
8282

@@ -92,26 +92,55 @@ $$
9292
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
9393

9494
$$
95-
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
95+
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
9696
$$
9797

9898
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
9999

100100
<Tip>
101101

102-
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf) that this introduces a response-level length bias.
102+
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.14476) that this introduces a response-level length bias. More details in [loss types](#loss-types).
103103

104104
</Tip>
105105

106106
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
107107

108108
$$
109-
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
109+
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
110110
$$
111111

112112
where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\).
113113
When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective.
114114

115+
#### Loss Types
116+
117+
Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows:
118+
119+
$$
120+
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},
121+
$$
122+
123+
where
124+
125+
$$
126+
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
127+
$$
128+
129+
The DAPO paper highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
130+
131+
$$
132+
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
133+
$$
134+
135+
136+
Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.14476) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:
137+
138+
$$
139+
\mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
140+
$$
141+
142+
This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`].
143+
115144
## Logged metrics
116145

117146
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
@@ -121,7 +150,7 @@ When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifi
121150
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
122151
- `completions/min_terminated_length`: The minimun length of generated completions that terminate with EOS.
123152
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
124-
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
153+
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
125154
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
126155
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
127156
- `reward`: The overall average reward after applying reward weights.

tests/test_grpo_trainer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,38 @@ def test_training(self, config_name):
176176
new_param = trainer.model.get_parameter(n)
177177
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
178178

179+
@parameterized.expand([("bnpo",), ("dr_grpo",)])
180+
def test_training_loss_types(self, loss_type):
181+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
182+
183+
with tempfile.TemporaryDirectory() as tmp_dir:
184+
training_args = GRPOConfig(
185+
output_dir=tmp_dir,
186+
learning_rate=0.1, # increase the learning rate to speed up the test
187+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
188+
num_generations=3, # reduce the number of generations to reduce memory usage
189+
max_completion_length=32, # reduce the completion length to reduce memory usage
190+
loss_type=loss_type,
191+
report_to="none",
192+
)
193+
trainer = GRPOTrainer(
194+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
195+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
196+
args=training_args,
197+
train_dataset=dataset,
198+
)
199+
200+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
201+
202+
trainer.train()
203+
204+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
205+
206+
# Check that the params have changed
207+
for n, param in previous_trainable_params.items():
208+
new_param = trainer.model.get_parameter(n)
209+
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
210+
179211
def test_training_with_eval(self):
180212
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
181213

trl/trainer/grpo_config.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,21 @@ class GRPOConfig(TrainingArguments):
116116
scale_rewards (`bool`, *optional*, defaults to `True`):
117117
Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), the rewards
118118
are normalized by the standard deviation, ensuring they have unit variance. If `False`, no scaling is
119-
applied. The [Dr. GRPO](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf)
120-
paper recommends not scaling the rewards, as scaling by the standard deviation introduces a question-level
121-
difficulty bias.
119+
applied. The [Dr. GRPO paper](https://huggingface.co/papers/2503.14476) recommends not scaling the rewards,
120+
as scaling by the standard deviation introduces a question-level difficulty bias.
121+
loss_type (`str`, *optional*, defaults to `"bnpo"`):
122+
Specifies the loss formulation to use. Supported values are:
123+
124+
- `"grpo"`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to
125+
length bias—this approach tends to prefer shorter completions with positive advantages and longer ones
126+
with negative advantages.
127+
- `"bnpo"`: Aggregates token-level losses by normalizing number of active token in the local batch.
128+
Note that normalization is performed over the local batch only, so results may slightly vary depending
129+
on the local batch size, despite a constant effective batch size. When using
130+
`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
131+
- `"dr_grpo"`: Aggregates token-level losses by normalizing with a global constant. This method was
132+
introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.14476) to eliminate length bias.
133+
The value of the constant corresponds to `max_completion_length`.
122134
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
123135
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
124136
incorrectly penalized and introducing noise during training. According to the
@@ -324,6 +336,22 @@ class GRPOConfig(TrainingArguments):
324336
"deviation introduces a question-level difficulty bias."
325337
},
326338
)
339+
loss_type: str = field(
340+
default="bnpo",
341+
metadata={
342+
"help": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`. "
343+
"`'grpo'`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to "
344+
"length bias—this approach tends to prefer shorter completions with positive advantages and longer ones "
345+
"with negative advantages. "
346+
"`'bnpo'`: Aggregates token-level losses by normalizing number of active token in the local batch. "
347+
"Note that normalization is performed over the local batch only, so results may slightly vary depending "
348+
"on the local batch size, despite a constant effective batch size. When using "
349+
"`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. "
350+
"`'dr_grpo'`: Aggregates token-level losses by normalizing with a global constant. This method was "
351+
"introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to "
352+
"`max_completion_length`."
353+
},
354+
)
327355
mask_truncated_completions: bool = field(
328356
default=False,
329357
metadata={

trl/trainer/grpo_trainer.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,8 @@ def data_collator(features): # No data collation is needed in GRPO
416416
self.repetition_penalty = args.repetition_penalty
417417
self.use_vllm = args.use_vllm
418418
self.use_liger_loss = args.use_liger_loss
419+
self.loss_type = args.loss_type
420+
self.scale_rewards = args.scale_rewards
419421
self.mask_truncated_completions = args.mask_truncated_completions
420422

421423
# Datasets
@@ -455,7 +457,13 @@ def data_collator(features): # No data collation is needed in GRPO
455457
"Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`."
456458
)
457459
if is_peft_model(model):
458-
raise ValueError("Liger loss is not supported with a PEFT model.")
460+
raise TypeError("Liger loss is not supported with a PEFT model.")
461+
462+
if self.loss_type != "bnpo":
463+
raise ValueError(
464+
f"The provided loss type (`{self.loss_type}`) is not supported with `use_liger_loss`. Liger loss "
465+
"only supports `bnpo` for now."
466+
)
459467

460468
self.liger_grpo_loss = LigerFusedLinearGRPOLoss(
461469
beta=self.beta,
@@ -480,6 +488,7 @@ def data_collator(features): # No data collation is needed in GRPO
480488
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
481489
self._total_train_tokens = 0
482490
self.log_completions = args.log_completions
491+
self.wandb_log_unique_prompts = args.wandb_log_unique_prompts
483492
self.num_completions_to_print = args.num_completions_to_print
484493
# maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the
485494
# final optimization step.
@@ -757,7 +766,7 @@ def _generate_and_score_completions(
757766
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
758767

759768
# Generate completions using either vLLM or regular generation
760-
if self.args.use_vllm:
769+
if self.use_vllm:
761770
# First, have main process load weights if needed
762771
if self.state.global_step != self._last_loaded_step:
763772
self._move_model_to_vllm()
@@ -919,7 +928,7 @@ def _generate_and_score_completions(
919928
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
920929
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
921930
advantages = rewards - mean_grouped_rewards
922-
if self.args.scale_rewards:
931+
if self.scale_rewards:
923932
advantages = advantages / (std_grouped_rewards + 1e-4)
924933

925934
# Slice to keep only the local part of the data
@@ -1061,7 +1070,14 @@ def _compute_loss(self, model, inputs):
10611070
if self.beta != 0.0:
10621071
per_token_loss = per_token_loss + self.beta * per_token_kl
10631072

1064-
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
1073+
if self.loss_type == "grpo":
1074+
loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
1075+
elif self.loss_type == "bnpo":
1076+
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
1077+
elif self.loss_type == "dr_grpo":
1078+
loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length)
1079+
else:
1080+
raise ValueError(f"Unknown loss type: {self.loss_type}")
10651081

10661082
# Log the metrics
10671083
mode = "eval" if self.control.should_evaluate else "train"
@@ -1102,7 +1118,7 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
11021118
super().log(logs)
11031119
self._metrics[mode].clear()
11041120

1105-
if self.accelerator.is_main_process:
1121+
if self.accelerator.is_main_process and self.log_completions:
11061122
if is_rich_available():
11071123
print_prompt_completions_sample(
11081124
self._textual_logs["prompt"],
@@ -1122,7 +1138,7 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
11221138
**self._textual_logs["rewards"],
11231139
}
11241140
df = pd.DataFrame(table)
1125-
if self.args.wandb_log_unique_prompts:
1141+
if self.wandb_log_unique_prompts:
11261142
df = df.drop_duplicates(subset=["prompt"])
11271143
wandb.log({"completions": wandb.Table(dataframe=df)})
11281144

0 commit comments

Comments
 (0)