@@ -1133,6 +1133,9 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
11331133 if self .args .pipeline_parallel_degree <= 1 and self ._enable_delay_scale_loss ():
11341134 tr_loss /= self .args .gradient_accumulation_steps
11351135
1136+ # assert if loss is invalid
1137+ self ._check_loss_valid (tr_loss )
1138+
11361139 self .timers and self .timers ("forward-backward" ).stop ()
11371140 # Maunally collect gradients
11381141 # Case 1: Use recompute and dp
@@ -1431,13 +1434,17 @@ def _print_timer(self):
14311434 if timer_info or paddle_timer_info :
14321435 logger .info (f"[Profile global_step: { self .state .global_step } ] { timer_info } { paddle_timer_info } " )
14331436
1434- def _get_item_from_loss (self , loss ):
1437+ def _check_loss_valid (self , loss ):
14351438 assert isinstance (loss , paddle .Tensor ) and loss ._is_initialized ()
14361439 loss_value = loss .item ()
14371440 if not self .args .fp16 :
14381441 if not np .isfinite (loss_value ).all ():
14391442 err_msg = LOSS_NAN_ERROR if np .isnan (loss_value ).any () else LOSS_INF_ERROR
14401443 raise ValueError (f"{ err_msg } . Loss contains inf or nan values, its value is { loss_value } " )
1444+
1445+ def _get_item_from_loss (self , loss ):
1446+ assert isinstance (loss , paddle .Tensor ) and loss ._is_initialized ()
1447+ loss_value = loss .item ()
14411448 return loss_value
14421449
14431450 def _maybe_log_save_evaluate (self , tr_loss , model , epoch , ignore_keys_for_eval , ** kwargs ):
0 commit comments