|
1 | 1 | import logging
|
| 2 | +import os |
2 | 3 | from collections import defaultdict
|
3 | 4 | from typing import Any, DefaultDict, Dict, List, Optional
|
4 | 5 |
|
@@ -509,3 +510,75 @@ def _reset_losses(self) -> None:
|
509 | 510 | """Reset the loss counters."""
|
510 | 511 | self.running_losses = defaultdict(float)
|
511 | 512 | self.running_counts = defaultdict(int)
|
| 513 | + |
| 514 | + def save(self, trainer_path: str) -> None: |
| 515 | + """Save the trainer config to the specified file path in json format. |
| 516 | +
|
| 517 | + Parameters |
| 518 | + ---------- |
| 519 | + trainer_path |
| 520 | + The path where trainer config and optimizer state should be saved. |
| 521 | + """ |
| 522 | + |
| 523 | + head, tail = os.path.split(trainer_path) |
| 524 | + |
| 525 | + if not os.path.exists(head): |
| 526 | + os.makedirs(os.path.dirname(head)) |
| 527 | + try: |
| 528 | + torch.save( |
| 529 | + { |
| 530 | + "trainer_config": self.config._asdict(), |
| 531 | + "optimizer_state_dict": self.optimizer.state_dict(), |
| 532 | + }, |
| 533 | + trainer_path, |
| 534 | + ) |
| 535 | + except BaseException: # pragma: no cover |
| 536 | + logging.warning("Saving failed... continuing anyway.") |
| 537 | + |
| 538 | + logging.info(f"[{self.name}] Trainer config saved in {trainer_path}") |
| 539 | + |
| 540 | + def load(self, trainer_path: str, model: Optional[MultitaskClassifier]) -> None: |
| 541 | + """Load trainer config and optimizer state from the specified json file path to the trainer object. The optimizer state is stored, too. However, it only makes sense if loaded with the correct model again. |
| 542 | +
|
| 543 | + Parameters |
| 544 | + ---------- |
| 545 | + trainer_path |
| 546 | + The path to the saved trainer config to be loaded |
| 547 | + model |
| 548 | + MultitaskClassifier for which the optimizer has been set. Parameters of optimizer must fit to model parameters. This model |
| 549 | + shall be the model which was fit by the stored Trainer. |
| 550 | +
|
| 551 | + Example |
| 552 | + ------- |
| 553 | + Saving model and corresponding trainer: |
| 554 | + >>> model.save('./my_saved_model_file') # doctest: +SKIP |
| 555 | + >>> trainer.save('./my_saved_trainer_file') # doctest: +SKIP |
| 556 | + Now we can resume training and load the saved model and trainer into new model and trainer objects: |
| 557 | + >>> new_model.load('./my_saved_model_file') # doctest: +SKIP |
| 558 | + >>> new_trainer.load('./my_saved_trainer_file', model=new_model) # doctest: +SKIP |
| 559 | + >>> new_trainer.fit(...) # doctest: +SKIP |
| 560 | + """ |
| 561 | + |
| 562 | + try: |
| 563 | + saved_state = torch.load(trainer_path) |
| 564 | + except BaseException: |
| 565 | + if not os.path.exists(trainer_path): |
| 566 | + logging.error("Loading failed... Trainer config does not exist.") |
| 567 | + else: |
| 568 | + logging.error( |
| 569 | + f"Loading failed... Cannot load trainer config from {trainer_path}" |
| 570 | + ) |
| 571 | + raise |
| 572 | + |
| 573 | + self.config = TrainerConfig(**saved_state["trainer_config"]) |
| 574 | + logging.info(f"[{self.name}] Trainer config loaded from {trainer_path}") |
| 575 | + |
| 576 | + if model is not None: |
| 577 | + try: |
| 578 | + self._set_optimizer(model) |
| 579 | + self.optimizer.load_state_dict(saved_state["optimizer_state_dict"]) |
| 580 | + logging.info(f"[{self.name}] Optimizer loaded from {trainer_path}") |
| 581 | + except BaseException: |
| 582 | + logging.error( |
| 583 | + "Loading the optimizer for your model failed. Optimizer state NOT loaded." |
| 584 | + ) |
0 commit comments