@@ -2460,6 +2460,7 @@ def _inner_training_loop(
2460
2460
self ._globalstep_last_logged = self .state .global_step
2461
2461
model .zero_grad ()
2462
2462
grad_norm : Optional [float ] = None
2463
+ learning_rate = None
2463
2464
self .control = self .callback_handler .on_train_begin (args , self .state , self .control )
2464
2465
2465
2466
if args .eval_on_start :
@@ -2608,6 +2609,9 @@ def _inner_training_loop(
2608
2609
2609
2610
self .control = self .callback_handler .on_optimizer_step (args , self .state , self .control )
2610
2611
2612
+ # get leaning rate before update
2613
+ learning_rate = self ._get_learning_rate ()
2614
+
2611
2615
if not self .accelerator .optimizer_step_was_skipped :
2612
2616
# Delay optimizer scheduling until metrics are generated
2613
2617
if not isinstance (self .lr_scheduler , torch .optim .lr_scheduler .ReduceLROnPlateau ):
@@ -2618,7 +2622,14 @@ def _inner_training_loop(
2618
2622
self .state .epoch = epoch + (step + 1 + steps_skipped ) / steps_in_epoch
2619
2623
self .control = self .callback_handler .on_step_end (args , self .state , self .control )
2620
2624
self ._maybe_log_save_evaluate (
2621
- tr_loss , grad_norm , model , trial , epoch , ignore_keys_for_eval , start_time
2625
+ tr_loss ,
2626
+ grad_norm ,
2627
+ model ,
2628
+ trial ,
2629
+ epoch ,
2630
+ ignore_keys_for_eval ,
2631
+ start_time ,
2632
+ learning_rate = learning_rate ,
2622
2633
)
2623
2634
else :
2624
2635
self .control = self .callback_handler .on_substep_end (args , self .state , self .control )
@@ -2644,7 +2655,9 @@ def _inner_training_loop(
2644
2655
self .control .should_training_stop = True
2645
2656
2646
2657
self .control = self .callback_handler .on_epoch_end (args , self .state , self .control )
2647
- self ._maybe_log_save_evaluate (tr_loss , grad_norm , model , trial , epoch , ignore_keys_for_eval , start_time )
2658
+ self ._maybe_log_save_evaluate (
2659
+ tr_loss , grad_norm , model , trial , epoch , ignore_keys_for_eval , start_time , learning_rate = learning_rate
2660
+ )
2648
2661
2649
2662
if DebugOption .TPU_METRICS_DEBUG in self .args .debug :
2650
2663
if is_torch_xla_available ():
@@ -3064,7 +3077,9 @@ def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False, limit_eva
3064
3077
) from exc
3065
3078
return metrics
3066
3079
3067
- def _maybe_log_save_evaluate (self , tr_loss , grad_norm , model , trial , epoch , ignore_keys_for_eval , start_time ):
3080
+ def _maybe_log_save_evaluate (
3081
+ self , tr_loss , grad_norm , model , trial , epoch , ignore_keys_for_eval , start_time , learning_rate = None
3082
+ ):
3068
3083
if self .control .should_log and self .state .global_step > self ._globalstep_last_logged :
3069
3084
if is_torch_xla_available ():
3070
3085
xm .mark_step ()
@@ -3080,7 +3095,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
3080
3095
logs ["loss" ] = round (tr_loss_scalar / (self .state .global_step - self ._globalstep_last_logged ), 4 )
3081
3096
if grad_norm is not None :
3082
3097
logs ["grad_norm" ] = grad_norm .detach ().item () if isinstance (grad_norm , torch .Tensor ) else grad_norm
3083
- logs ["learning_rate" ] = self ._get_learning_rate ()
3098
+ if learning_rate is not None :
3099
+ logs ["learning_rate" ] = learning_rate
3100
+ else :
3101
+ logs ["learning_rate" ] = self ._get_learning_rate ()
3084
3102
3085
3103
self ._total_loss_scalar += tr_loss_scalar
3086
3104
self ._globalstep_last_logged = self .state .global_step
0 commit comments