Skip to content

Commit 691ae01

Browse files
authored
[LLM] valid loss before optimizer step (#9255) (#9705)
1 parent 7197b79 commit 691ae01

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,9 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
11331133
if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss():
11341134
tr_loss /= self.args.gradient_accumulation_steps
11351135

1136+
# assert if loss is invalid
1137+
self._check_loss_valid(tr_loss)
1138+
11361139
self.timers and self.timers("forward-backward").stop()
11371140
# Maunally collect gradients
11381141
# Case 1: Use recompute and dp
@@ -1431,13 +1434,17 @@ def _print_timer(self):
14311434
if timer_info or paddle_timer_info:
14321435
logger.info(f"[Profile global_step: {self.state.global_step}] {timer_info} {paddle_timer_info}")
14331436

1434-
def _get_item_from_loss(self, loss):
1437+
def _check_loss_valid(self, loss):
14351438
assert isinstance(loss, paddle.Tensor) and loss._is_initialized()
14361439
loss_value = loss.item()
14371440
if not self.args.fp16:
14381441
if not np.isfinite(loss_value).all():
14391442
err_msg = LOSS_NAN_ERROR if np.isnan(loss_value).any() else LOSS_INF_ERROR
14401443
raise ValueError(f"{err_msg}. Loss contains inf or nan values, its value is {loss_value}")
1444+
1445+
def _get_item_from_loss(self, loss):
1446+
assert isinstance(loss, paddle.Tensor) and loss._is_initialized()
1447+
loss_value = loss.item()
14411448
return loss_value
14421449

14431450
def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs):

0 commit comments

Comments
 (0)