Skip to content

📏 Completion length logging fix + remainder logging fix #3482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 27, 2025
47 changes: 24 additions & 23 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,6 @@ def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler:
# In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
# second row shows the second sampled batch, and so on.
#
# | Accum step 0 |
# | GPU 0 | GPU 1 |
#
# global_step step <-───> num_generations=2
Expand Down Expand Up @@ -1110,6 +1109,9 @@ def _generate_and_score_completions(
[id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)
]

# Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
completion_lengths = completion_mask.sum(1)

# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
truncated_completions = ~is_eos.any(dim=1)
Expand Down Expand Up @@ -1213,26 +1215,25 @@ def _generate_and_score_completions(

# Log the metrics
if mode == "train":
self.state.num_input_tokens_seen += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item()
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]

# log completion lengths, mean, min, max
agg_completion_mask = self.accelerator.gather_for_metrics(completion_mask.sum(1))
self._metrics[mode]["completions/mean_length"].append(agg_completion_mask.float().mean().item())
self._metrics[mode]["completions/min_length"].append(agg_completion_mask.float().min().item())
self._metrics[mode]["completions/max_length"].append(agg_completion_mask.float().max().item())
# Log completion lengths, mean, min, max
agg_completion_lengths = self.accelerator.gather(completion_lengths)
self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())

# identify sequences that terminated with EOS and log their lengths
agg_terminated_with_eos = self.accelerator.gather_for_metrics(is_eos.any(dim=1))
term_completion_mask = agg_completion_mask[agg_terminated_with_eos]
clipped_completions_ratio = 1 - len(term_completion_mask) / len(agg_completion_mask)
# Identify sequences that terminated with EOS and log their lengths
agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
term_completion_lengths = completion_lengths[agg_terminated_with_eos]
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(completion_lengths)
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
if len(term_completion_mask) == 0:
# edge case where no completed sequences are found
term_completion_mask = torch.zeros(1, device=device)
self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_mask.float().mean().item())
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_mask.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_mask.float().max().item())
if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
term_completion_lengths = torch.zeros(1, device=device)
self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())

# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
Expand Down Expand Up @@ -1303,8 +1304,8 @@ def compute_liger_loss(self, unwrapped_model, inputs):

mode = "train" if self.model.training else "eval"
if self.beta != 0.0:
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item())
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item())
return loss

@profiling_decorator
Expand Down Expand Up @@ -1379,7 +1380,7 @@ def _compute_loss(self, model, inputs):

if self.beta != 0.0:
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item())

# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
Expand All @@ -1390,13 +1391,13 @@ def _compute_loss(self, model, inputs):
high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum()
clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum()

gathered_low_clip = self.accelerator.gather_for_metrics(low_clip)
gathered_low_clip = self.accelerator.gather(low_clip)
self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item())
self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item())
gathered_high_clip = self.accelerator.gather_for_metrics(high_clip)
gathered_high_clip = self.accelerator.gather(high_clip)
self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item())
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
return loss

Expand Down
Loading