@@ -121,6 +121,12 @@ def on_train_end(self):
121121 return
122122 self ._teardown_already_run = True
123123
124+ # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
125+ # when a checkpoint was saved at the last step
126+ self .trainer .global_step -= 1
127+ self .check_checkpoint_callback (should_update = True , is_last = True )
128+ self .trainer .global_step += 1
129+
124130 # hook
125131 self .trainer .call_hook ("on_train_end" )
126132
@@ -139,6 +145,28 @@ def on_train_end(self):
139145 # reset bookkeeping
140146 self .trainer ._running_stage = None
141147
148+ def check_checkpoint_callback (self , should_update , is_last = False ):
149+ # TODO bake this logic into the ModelCheckpoint callback
150+ if should_update and self .trainer .checkpoint_connector .has_trained :
151+ callbacks = self .trainer .checkpoint_callbacks
152+
153+ if is_last and any (cb .save_last and cb .verbose for cb in callbacks ):
154+ rank_zero_info ("Saving latest checkpoint..." )
155+
156+ model = self .trainer .lightning_module
157+
158+ for cb in callbacks :
159+ cb .on_validation_end (self .trainer , model )
160+
161+ def check_early_stopping_callback (self , should_update ):
162+ # TODO bake this logic into the EarlyStopping callback
163+ if should_update and self .trainer .checkpoint_connector .has_trained :
164+ callbacks = [c for c in self .trainer .callbacks if isinstance (c , EarlyStopping )]
165+ model = self .trainer .lightning_module
166+
167+ for cb in callbacks :
168+ cb .on_validation_end (self .trainer , model )
169+
142170 def on_train_epoch_start (self , epoch ):
143171
144172 # update training progress in trainer
@@ -534,14 +562,15 @@ def run_training_epoch(self):
534562 if (val_loop_called and not should_check_val ) or should_train_only :
535563 self .trainer .optimizer_connector .update_learning_rates (interval = 'epoch' )
536564
565+ if should_train_only :
566+ self .check_checkpoint_callback (True )
567+ self .check_early_stopping_callback (True )
568+
537569 if should_check_val :
538570 self .trainer .validating = True
539571 self .trainer .run_evaluation (on_epoch = True )
540572 self .trainer .training = True
541573
542- if should_train_only :
543- self .trainer .call_hook ('on_train_epoch_final_end' )
544-
545574 # increment the global step once
546575 # progress global step according to grads progress
547576 self .increment_accumulated_grad_global_step ()
0 commit comments