Skip to content

Commit 8c3dda3

Browse files
SunMarcsoghomon-b
authored andcommitted
Log the correct learning rate (huggingface#36973)
* fix learning rate log * fix lr log * add lr
1 parent ab26547 commit 8c3dda3

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

src/transformers/trainer.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,6 +2460,7 @@ def _inner_training_loop(
24602460
self._globalstep_last_logged = self.state.global_step
24612461
model.zero_grad()
24622462
grad_norm: Optional[float] = None
2463+
learning_rate = None
24632464
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
24642465

24652466
if args.eval_on_start:
@@ -2608,6 +2609,9 @@ def _inner_training_loop(
26082609

26092610
self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
26102611

2612+
# get leaning rate before update
2613+
learning_rate = self._get_learning_rate()
2614+
26112615
if not self.accelerator.optimizer_step_was_skipped:
26122616
# Delay optimizer scheduling until metrics are generated
26132617
if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
@@ -2618,7 +2622,14 @@ def _inner_training_loop(
26182622
self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
26192623
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
26202624
self._maybe_log_save_evaluate(
2621-
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
2625+
tr_loss,
2626+
grad_norm,
2627+
model,
2628+
trial,
2629+
epoch,
2630+
ignore_keys_for_eval,
2631+
start_time,
2632+
learning_rate=learning_rate,
26222633
)
26232634
else:
26242635
self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
@@ -2644,7 +2655,9 @@ def _inner_training_loop(
26442655
self.control.should_training_stop = True
26452656

26462657
self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
2647-
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
2658+
self._maybe_log_save_evaluate(
2659+
tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
2660+
)
26482661

26492662
if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
26502663
if is_torch_xla_available():
@@ -3064,7 +3077,9 @@ def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False, limit_eva
30643077
) from exc
30653078
return metrics
30663079

3067-
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time):
3080+
def _maybe_log_save_evaluate(
3081+
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
3082+
):
30683083
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
30693084
if is_torch_xla_available():
30703085
xm.mark_step()
@@ -3080,7 +3095,10 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno
30803095
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
30813096
if grad_norm is not None:
30823097
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
3083-
logs["learning_rate"] = self._get_learning_rate()
3098+
if learning_rate is not None:
3099+
logs["learning_rate"] = learning_rate
3100+
else:
3101+
logs["learning_rate"] = self._get_learning_rate()
30843102

30853103
self._total_loss_scalar += tr_loss_scalar
30863104
self._globalstep_last_logged = self.state.global_step

0 commit comments

Comments
 (0)