@@ -446,7 +446,6 @@ def data_collator(features): # No data collation is needed in GRPO
446
446
447
447
# Initialize the metrics
448
448
self ._metrics = {"train" : defaultdict (list ), "eval" : defaultdict (list )}
449
- self ._total_train_tokens = 0
450
449
self .log_completions = args .log_completions
451
450
self .num_completions_to_print = args .num_completions_to_print
452
451
@@ -501,7 +500,7 @@ def data_collator(features): # No data collation is needed in GRPO
501
500
# vLLM specific sampling arguments
502
501
self .guided_decoding_regex = args .vllm_guided_decoding_regex
503
502
504
- self ._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
503
+ self ._last_loaded_step = - 1 # tag to avoid useless loading during grad accumulation
505
504
506
505
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
507
506
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
@@ -894,8 +893,8 @@ def _generate_and_score_completions(
894
893
mode = "eval" if self .control .should_evaluate else "train"
895
894
896
895
if mode == "train" :
897
- self ._total_train_tokens += self .accelerator .gather_for_metrics (attention_mask .sum ()).sum ().item ()
898
- self ._metrics [mode ]["num_tokens" ] = [self ._total_train_tokens ]
896
+ self .state . num_input_tokens_seen += self .accelerator .gather_for_metrics (attention_mask .sum ()).sum ().item ()
897
+ self ._metrics [mode ]["num_tokens" ] = [self .state . num_input_tokens_seen ]
899
898
900
899
# log completion lengths, mean, min, max
901
900
agg_completion_mask = self .accelerator .gather_for_metrics (completion_mask .sum (1 ))
0 commit comments