Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions docs/source/grpo_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,18 @@ When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifi

## Logged metrics

The GRPO Trainer logs the following metrics:

- `completion_length`: The average completion length.
- `reward/{reward_func_name}`: The reward computed by each reward function.
- `reward`: The average reward.
- `reward_std` : The average standard deviation within reward groups.
- `kl` : The average KL divergence between the model and the reference model calculated on completions.
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
- `completion_length`: The average length of generated completions.
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
- `reward`: The overall average reward after applying reward weights.
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
- `clip_ratio`: The fraction of tokens where the PPO objective is clipped to stay within the trust region:
$$
\text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right)
$$
A higher value means more tokens were affected by clipping, limiting how much the policy can change.

## Customization

Expand Down
6 changes: 4 additions & 2 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,10 @@ def reward_func2(completions, **kwargs):

# Check that training logs contain both reward metrics
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
self.assertIn("rewards/reward_func1", trainer.state.log_history[-1])
self.assertIn("rewards/reward_func2", trainer.state.log_history[-1])
self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1])
self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1])
self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1])
self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1])

# Check that the params have changed
for n, param in previous_trainable_params.items():
Expand Down
27 changes: 24 additions & 3 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,25 @@ def __len__(self) -> int:
return self.num_samples * self.mini_repeat_count * self.repeat_count


# torch.nanstd doesn't exist, so we define it here
def nanstd(tensor: torch.Tensor) -> torch.Tensor:
"""
Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors.

Args:
tensor (`torch.Tensor`):
Input tensor of shape `(N,)`.

Returns:
`torch.Tensor`:
Standard deviation of the tensor, ignoring NaNs.
"""
variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) # Compute variance ignoring NaNs
count = torch.sum(~torch.isnan(tensor)) # Count of non-NaN values
variance *= count / (count - 1) # Bessel's correction
return torch.sqrt(variance)


class GRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
Expand Down Expand Up @@ -856,8 +875,10 @@ def _generate_and_score_completions(
reward_func_name = reward_func.__name__
# Only calculate mean for samples where this reward function was applied (non-NaN values)
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
self._metrics[mode]["reward"].append(rewards.mean().item())
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
std_rewards = nanstd(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards)
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())

if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
Expand Down Expand Up @@ -938,7 +959,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())

is_clipped = (per_token_loss1 < per_token_loss2).float()
is_clipped = (coef_1 < (1 - self.epsilon_low)) | (coef_1 > (1 + self.epsilon_high))
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
return loss
Expand Down
Loading