Skip to content

Commit 65308cf

Browse files
authored
⏯️ Fix logging when resuming from checkpoint GRPO (#3185)
1 parent 1755e03 commit 65308cf

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,6 @@ def data_collator(features): # No data collation is needed in GRPO
446446

447447
# Initialize the metrics
448448
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
449-
self._total_train_tokens = 0
450449
self.log_completions = args.log_completions
451450
self.num_completions_to_print = args.num_completions_to_print
452451

@@ -501,7 +500,7 @@ def data_collator(features): # No data collation is needed in GRPO
501500
# vLLM specific sampling arguments
502501
self.guided_decoding_regex = args.vllm_guided_decoding_regex
503502

504-
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
503+
self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
505504

506505
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
507506
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
@@ -894,8 +893,8 @@ def _generate_and_score_completions(
894893
mode = "eval" if self.control.should_evaluate else "train"
895894

896895
if mode == "train":
897-
self._total_train_tokens += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
898-
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
896+
self.state.num_input_tokens_seen += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
897+
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
899898

900899
# log completion lengths, mean, min, max
901900
agg_completion_mask = self.accelerator.gather_for_metrics(completion_mask.sum(1))

0 commit comments

Comments
 (0)