-
Notifications
You must be signed in to change notification settings - Fork 2.1k
📚 Accumulate completions for logging #3217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -904,36 +904,63 @@ def _generate_and_score_completions( | |
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) | ||
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) | ||
|
||
# Log completions when we complete a full gradient accumulation cycle | ||
# For logging across the gradient accumulation steps, we need to accumulate the data | ||
is_last_step_in_grad_accum = ( | ||
self._step % self.args.gradient_accumulation_steps == self.args.gradient_accumulation_steps - 1 | ||
) | ||
should_log_completions = ( | ||
self.log_completions | ||
and self.state.global_step % self.args.logging_steps == 0 | ||
and is_last_step_in_grad_accum | ||
) | ||
|
||
# Collect data for logging throughout the accumulation steps | ||
Comment on lines
+907
to
+918
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So yes it achieves the same goal but this flag is only updated after the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So yes It's possible, I'll do it in a follow-up PR |
||
if self.log_completions and self.state.global_step % self.args.logging_steps == 0: | ||
prompts_to_log = gather_object(prompts_text) | ||
completions_to_log = gather_object(completions_text) | ||
rewards_to_log = { | ||
reward_func_name: rewards_per_func[:, i] for i, reward_func_name in enumerate(reward_func_names) | ||
} | ||
if not hasattr(self, "_accumulated_prompts"): | ||
self._accumulated_prompts = [] | ||
self._accumulated_completions = [] | ||
self._accumulated_rewards = {name: [] for name in reward_func_names} | ||
|
||
# Gather and accumulate the data from this step | ||
prompts_this_step = gather_object(prompts_text) | ||
completions_this_step = gather_object(completions_text) | ||
|
||
if self.accelerator.is_main_process: | ||
if is_rich_available(): | ||
print_prompt_completions_sample( | ||
prompts_to_log, | ||
completions_to_log, | ||
rewards_to_log, | ||
self.state.global_step, | ||
self.num_completions_to_print, | ||
) | ||
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: | ||
import pandas as pd | ||
|
||
# For logging | ||
table = { | ||
"step": [str(self.state.global_step)] * len(rewards), | ||
"prompt": prompts_to_log, | ||
"completion": completions_to_log, | ||
"reward": rewards.tolist(), | ||
} | ||
df = pd.DataFrame(table) | ||
if self.args.wandb_log_unique_prompts: | ||
df = df.drop_duplicates(subset=["prompt"]) | ||
wandb.log({"completions": wandb.Table(dataframe=df)}) | ||
self._accumulated_prompts.extend(prompts_this_step) | ||
self._accumulated_completions.extend(completions_this_step) | ||
for i, name in enumerate(reward_func_names): | ||
self._accumulated_rewards[name].extend(rewards_per_func[:, i].tolist()) | ||
|
||
# Log the accumulated data when we finish a grad accumulation cycle | ||
if should_log_completions and self.accelerator.is_main_process: | ||
if is_rich_available(): | ||
print_prompt_completions_sample( | ||
self._accumulated_prompts, | ||
self._accumulated_completions, | ||
self._accumulated_rewards, | ||
self.state.global_step, | ||
self.num_completions_to_print, | ||
) | ||
|
||
if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: | ||
import pandas as pd | ||
|
||
table = { | ||
"step": [str(self.state.global_step)] * len(self._accumulated_prompts), | ||
"prompt": self._accumulated_prompts, | ||
"completion": self._accumulated_completions, | ||
"reward": [sum(rewards) for rewards in zip(*self._accumulated_rewards.values())], | ||
} | ||
df = pd.DataFrame(table) | ||
if self.args.wandb_log_unique_prompts: | ||
df = df.drop_duplicates(subset=["prompt"]) | ||
wandb.log({"completions": wandb.Table(dataframe=df)}) | ||
|
||
# Reset the accumulated data after logging | ||
self._accumulated_prompts = [] | ||
self._accumulated_completions = [] | ||
self._accumulated_rewards = {name: [] for name in reward_func_names} | ||
|
||
return { | ||
"prompt_ids": prompt_ids, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed from #3191