Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,9 @@ def fused_allreduce_gradients_no_sync(paramlist, hcg):
if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss():
tr_loss /= self.args.gradient_accumulation_steps

# assert if loss is invalid
self._check_loss_valid(tr_loss)

self.timers and self.timers("forward-backward").stop()
# Maunally collect gradients
# Case 1: Use recompute and dp
Expand Down Expand Up @@ -1297,13 +1300,17 @@ def _print_timer(self):
if timer_info or paddle_timer_info:
logger.info(f"[Profile global_step: {self.state.global_step}] {timer_info} {paddle_timer_info}")

def _get_item_from_loss(self, loss):
def _check_loss_valid(self, loss):
assert isinstance(loss, paddle.Tensor) and loss._is_initialized()
loss_value = loss.item()
if not self.args.fp16:
if not np.isfinite(loss_value).all():
err_msg = LOSS_NAN_ERROR if np.isnan(loss_value).any() else LOSS_INF_ERROR
raise ValueError(f"{err_msg}. Loss contains inf or nan values, its value is {loss_value}")

def _get_item_from_loss(self, loss):
assert isinstance(loss, paddle.Tensor) and loss._is_initialized()
loss_value = loss.item()
return loss_value

def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs):
Expand Down
Loading