Skip to content

Commit 56dba6d

Browse files
[Bug fix] fix skip consumed_samples twice bug (#8980)
1 parent cf3a672 commit 56dba6d

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

paddlenlp/trainer/auto_trainer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from paddlenlp.trainer import Trainer
2828

29+
from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler
2930
from ..utils.log import logger
3031
from .argparser import strtobool
3132
from .trainer import SCALER_NAME, SCHEDULER_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME
@@ -309,12 +310,23 @@ def _inner_training_loop(
309310

310311
# Skip past any already trained steps if resuming training
311312
# We use consumed_samples to reset the status
312-
if steps_trained_in_current_epoch > 0:
313+
if isinstance(train_dataloader._dataloader, paddle.io.DataLoader) and isinstance(
314+
train_dataloader._dataloader.batch_sampler, NlpDistributedBatchSampler
315+
):
316+
if step == 0:
317+
if steps_trained_progress_bar is not None:
318+
steps_trained_progress_bar.update(steps_trained_in_current_epoch)
319+
steps_trained_progress_bar.close()
320+
steps_trained_progress_bar = None
321+
self._load_rng_state(resume_from_checkpoint)
322+
step += steps_trained_in_current_epoch
323+
elif steps_trained_in_current_epoch > 0:
313324
steps_trained_in_current_epoch -= 1
314325
if steps_trained_progress_bar is not None:
315326
steps_trained_progress_bar.update(1)
316327
if steps_trained_in_current_epoch == 0:
317328
self._load_rng_state(resume_from_checkpoint)
329+
self.timers and self.timers("read-data").start()
318330
continue
319331
elif steps_trained_progress_bar is not None:
320332
steps_trained_progress_bar.close()

0 commit comments

Comments
 (0)