|
26 | 26 |
|
27 | 27 | from paddlenlp.trainer import Trainer |
28 | 28 |
|
| 29 | +from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler |
29 | 30 | from ..utils.log import logger |
30 | 31 | from .argparser import strtobool |
31 | 32 | from .trainer import SCALER_NAME, SCHEDULER_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME |
@@ -309,12 +310,23 @@ def _inner_training_loop( |
309 | 310 |
|
310 | 311 | # Skip past any already trained steps if resuming training |
311 | 312 | # We use consumed_samples to reset the status |
312 | | - if steps_trained_in_current_epoch > 0: |
| 313 | + if isinstance(train_dataloader._dataloader, paddle.io.DataLoader) and isinstance( |
| 314 | + train_dataloader._dataloader.batch_sampler, NlpDistributedBatchSampler |
| 315 | + ): |
| 316 | + if step == 0: |
| 317 | + if steps_trained_progress_bar is not None: |
| 318 | + steps_trained_progress_bar.update(steps_trained_in_current_epoch) |
| 319 | + steps_trained_progress_bar.close() |
| 320 | + steps_trained_progress_bar = None |
| 321 | + self._load_rng_state(resume_from_checkpoint) |
| 322 | + step += steps_trained_in_current_epoch |
| 323 | + elif steps_trained_in_current_epoch > 0: |
313 | 324 | steps_trained_in_current_epoch -= 1 |
314 | 325 | if steps_trained_progress_bar is not None: |
315 | 326 | steps_trained_progress_bar.update(1) |
316 | 327 | if steps_trained_in_current_epoch == 0: |
317 | 328 | self._load_rng_state(resume_from_checkpoint) |
| 329 | + self.timers and self.timers("read-data").start() |
318 | 330 | continue |
319 | 331 | elif steps_trained_progress_bar is not None: |
320 | 332 | steps_trained_progress_bar.close() |
|
0 commit comments