Skip to content

📚 Accumulate completions for logging #3217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Apr 2, 2025

What does this PR do?

This PR fixes a bug where the rich and wandb tables were not logging the complete set of unique prompts for gradient_accumulation_steps > 1. Previously what happened is that the tables we logged for each gradient accumulation step, which is confusing when inspecting the logs for unique prompts.

Here's two runs with/without the fix for:

These runs have:

  • gradient_accumulation_steps=4
  • num_generations=16
  • per_device_train_batch_size=4
  • num_gpus = 8

so we expect num_unique_prompts=8:

num_unique_prompts = num_gpus * per_device_train_batch_size * gradient_accumulation_steps / num_generations = 8

We see that without the fix, wandb only displays 2 unique prompts in the table (i.e. once per gradient accumulation step). With the fix, the prompts are grouped together per step, which incidentally makes the table scrolling faster in the UI as we have fewer steps to query.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -136,6 +136,9 @@ class GRPOConfig(TrainingArguments):
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
num_completions_to_print (`int` or `None`, *optional*, defaults to `None`):
Number of completions to print with `rich`. If `None`, all completions are logged.
wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed from #3191

@lewtun lewtun requested a review from qgallouedec April 2, 2025 21:23
Comment on lines +907 to +918
# Log completions when we complete a full gradient accumulation cycle
# For logging across the gradient accumulation steps, we need to accumulate the data
is_last_step_in_grad_accum = (
self._step % self.args.gradient_accumulation_steps == self.args.gradient_accumulation_steps - 1
)
should_log_completions = (
self.log_completions
and self.state.global_step % self.args.logging_steps == 0
and is_last_step_in_grad_accum
)

# Collect data for logging throughout the accumulation steps
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think self.control.should_log achieves the same goal. Let me check

Copy link
Member

@qgallouedec qgallouedec Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So yes it achieves the same goal but this flag is only updated after the loss.backward so we can't use it here. I'm starting to think that it would make more sense to have all this code in self.log. I'm checking if it's possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So yes It's possible, I'll do it in a follow-up PR

@qgallouedec qgallouedec changed the title Accumulate completions for logging 📚 Accumulate completions for logging Apr 3, 2025
@qgallouedec qgallouedec merged commit 7eaca76 into main Apr 3, 2025
7 of 10 checks passed
@qgallouedec qgallouedec deleted the fix-completions-logging branch April 3, 2025 00:00
yxliu-TAMU pushed a commit to mincheolseong/ECEN743-GRPO-Project-Proposal that referenced this pull request Apr 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants