Skip to content

Commit dee3734

Browse files
authored
📊 Fix clip_ratio logging and better document logged values (#3145)
1 parent 8037f18 commit dee3734

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

docs/source/grpo_trainer.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,18 @@ When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifi
102102

103103
## Logged metrics
104104

105-
The GRPO Trainer logs the following metrics:
106-
107-
- `completion_length`: The average completion length.
108-
- `reward/{reward_func_name}`: The reward computed by each reward function.
109-
- `reward`: The average reward.
110-
- `reward_std` : The average standard deviation within reward groups.
111-
- `kl` : The average KL divergence between the model and the reference model calculated on completions.
105+
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
106+
- `completion_length`: The average length of generated completions.
107+
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
108+
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
109+
- `reward`: The overall average reward after applying reward weights.
110+
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
111+
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
112+
- `clip_ratio`: The fraction of tokens where the PPO objective is clipped to stay within the trust region:
113+
$$
114+
\text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right)
115+
$$
116+
A higher value means more tokens were affected by clipping, limiting how much the policy can change.
112117

113118
## Customization
114119

tests/test_grpo_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,10 @@ def reward_func2(completions, **kwargs):
548548

549549
# Check that training logs contain both reward metrics
550550
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
551-
self.assertIn("rewards/reward_func1", trainer.state.log_history[-1])
552-
self.assertIn("rewards/reward_func2", trainer.state.log_history[-1])
551+
self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1])
552+
self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1])
553+
self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1])
554+
self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1])
553555

554556
# Check that the params have changed
555557
for n, param in previous_trainable_params.items():

trl/trainer/grpo_trainer.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,25 @@ def __len__(self) -> int:
160160
return self.num_samples * self.mini_repeat_count * self.repeat_count
161161

162162

163+
# torch.nanstd doesn't exist, so we define it here
164+
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
165+
"""
166+
Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.
167+
168+
Args:
169+
tensor (`torch.Tensor`):
170+
Input tensor of shape `(N,)`.
171+
172+
Returns:
173+
`torch.Tensor`:
174+
Standard deviation of the tensor, ignoring NaNs.
175+
"""
176+
variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) # Compute variance ignoring NaNs
177+
count = torch.sum(~torch.isnan(tensor)) # Count of non-NaN values
178+
variance *= count / (count - 1) # Bessel's correction
179+
return torch.sqrt(variance)
180+
181+
163182
class GRPOTrainer(Trainer):
164183
"""
165184
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
@@ -856,8 +875,10 @@ def _generate_and_score_completions(
856875
reward_func_name = reward_func.__name__
857876
# Only calculate mean for samples where this reward function was applied (non-NaN values)
858877
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
859-
self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
860-
self._metrics[mode]["reward"].append(rewards.mean().item())
878+
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
879+
std_rewards = nanstd(rewards_per_func[:, i]).item()
880+
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
881+
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
861882
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
862883

863884
if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
@@ -938,7 +959,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
938959
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
939960
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
940961

941-
is_clipped = (per_token_loss1 < per_token_loss2).float()
962+
is_clipped = (coef_1 < (1 - self.epsilon_low)) | (coef_1 > (1 + self.epsilon_high))
942963
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
943964
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
944965
return loss

0 commit comments

Comments
 (0)