Skip to content

📚 Accumulate completions for logging #3217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ class GRPOConfig(TrainingArguments):
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
num_completions_to_print (`int` or `None`, *optional*, defaults to `None`):
Number of completions to print with `rich`. If `None`, all completions are logged.
wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed from #3191

Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all
prompts are logged.
"""

# Parameters that control the model and reference model
Expand Down
81 changes: 54 additions & 27 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,36 +904,63 @@ def _generate_and_score_completions(
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())

# Log completions when we complete a full gradient accumulation cycle
# For logging across the gradient accumulation steps, we need to accumulate the data
is_last_step_in_grad_accum = (
self._step % self.args.gradient_accumulation_steps == self.args.gradient_accumulation_steps - 1
)
should_log_completions = (
self.log_completions
and self.state.global_step % self.args.logging_steps == 0
and is_last_step_in_grad_accum
)

# Collect data for logging throughout the accumulation steps
Comment on lines +907 to +918
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think self.control.should_log achieves the same goal. Let me check

Copy link
Member

@qgallouedec qgallouedec Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So yes it achieves the same goal but this flag is only updated after the loss.backward so we can't use it here. I'm starting to think that it would make more sense to have all this code in self.log. I'm checking if it's possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So yes It's possible, I'll do it in a follow-up PR

if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
prompts_to_log = gather_object(prompts_text)
completions_to_log = gather_object(completions_text)
rewards_to_log = {
reward_func_name: rewards_per_func[:, i] for i, reward_func_name in enumerate(reward_func_names)
}
if not hasattr(self, "_accumulated_prompts"):
self._accumulated_prompts = []
self._accumulated_completions = []
self._accumulated_rewards = {name: [] for name in reward_func_names}

# Gather and accumulate the data from this step
prompts_this_step = gather_object(prompts_text)
completions_this_step = gather_object(completions_text)

if self.accelerator.is_main_process:
if is_rich_available():
print_prompt_completions_sample(
prompts_to_log,
completions_to_log,
rewards_to_log,
self.state.global_step,
self.num_completions_to_print,
)
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
import pandas as pd

# For logging
table = {
"step": [str(self.state.global_step)] * len(rewards),
"prompt": prompts_to_log,
"completion": completions_to_log,
"reward": rewards.tolist(),
}
df = pd.DataFrame(table)
if self.args.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])
wandb.log({"completions": wandb.Table(dataframe=df)})
self._accumulated_prompts.extend(prompts_this_step)
self._accumulated_completions.extend(completions_this_step)
for i, name in enumerate(reward_func_names):
self._accumulated_rewards[name].extend(rewards_per_func[:, i].tolist())

# Log the accumulated data when we finish a grad accumulation cycle
if should_log_completions and self.accelerator.is_main_process:
if is_rich_available():
print_prompt_completions_sample(
self._accumulated_prompts,
self._accumulated_completions,
self._accumulated_rewards,
self.state.global_step,
self.num_completions_to_print,
)

if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
import pandas as pd

table = {
"step": [str(self.state.global_step)] * len(self._accumulated_prompts),
"prompt": self._accumulated_prompts,
"completion": self._accumulated_completions,
"reward": [sum(rewards) for rewards in zip(*self._accumulated_rewards.values())],
}
df = pd.DataFrame(table)
if self.args.wandb_log_unique_prompts:
df = df.drop_duplicates(subset=["prompt"])
wandb.log({"completions": wandb.Table(dataframe=df)})

# Reset the accumulated data after logging
self._accumulated_prompts = []
self._accumulated_completions = []
self._accumulated_rewards = {name: [] for name in reward_func_names}

return {
"prompt_ids": prompt_ids,
Expand Down
Loading