Skip to content

Commit e878d48

Browse files
ptrcklvvincentschen
authored andcommitted
Save optimizer state in Trainer (#1533)
Addresses #1416
1 parent e711035 commit e878d48

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

snorkel/classification/training/trainer.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from collections import defaultdict
34
from typing import Any, DefaultDict, Dict, List, Optional
45

@@ -509,3 +510,75 @@ def _reset_losses(self) -> None:
509510
"""Reset the loss counters."""
510511
self.running_losses = defaultdict(float)
511512
self.running_counts = defaultdict(int)
513+
514+
def save(self, trainer_path: str) -> None:
515+
"""Save the trainer config to the specified file path in json format.
516+
517+
Parameters
518+
----------
519+
trainer_path
520+
The path where trainer config and optimizer state should be saved.
521+
"""
522+
523+
head, tail = os.path.split(trainer_path)
524+
525+
if not os.path.exists(head):
526+
os.makedirs(os.path.dirname(head))
527+
try:
528+
torch.save(
529+
{
530+
"trainer_config": self.config._asdict(),
531+
"optimizer_state_dict": self.optimizer.state_dict(),
532+
},
533+
trainer_path,
534+
)
535+
except BaseException: # pragma: no cover
536+
logging.warning("Saving failed... continuing anyway.")
537+
538+
logging.info(f"[{self.name}] Trainer config saved in {trainer_path}")
539+
540+
def load(self, trainer_path: str, model: Optional[MultitaskClassifier]) -> None:
541+
"""Load trainer config and optimizer state from the specified json file path to the trainer object. The optimizer state is stored, too. However, it only makes sense if loaded with the correct model again.
542+
543+
Parameters
544+
----------
545+
trainer_path
546+
The path to the saved trainer config to be loaded
547+
model
548+
MultitaskClassifier for which the optimizer has been set. Parameters of optimizer must fit to model parameters. This model
549+
shall be the model which was fit by the stored Trainer.
550+
551+
Example
552+
-------
553+
Saving model and corresponding trainer:
554+
>>> model.save('./my_saved_model_file') # doctest: +SKIP
555+
>>> trainer.save('./my_saved_trainer_file') # doctest: +SKIP
556+
Now we can resume training and load the saved model and trainer into new model and trainer objects:
557+
>>> new_model.load('./my_saved_model_file') # doctest: +SKIP
558+
>>> new_trainer.load('./my_saved_trainer_file', model=new_model) # doctest: +SKIP
559+
>>> new_trainer.fit(...) # doctest: +SKIP
560+
"""
561+
562+
try:
563+
saved_state = torch.load(trainer_path)
564+
except BaseException:
565+
if not os.path.exists(trainer_path):
566+
logging.error("Loading failed... Trainer config does not exist.")
567+
else:
568+
logging.error(
569+
f"Loading failed... Cannot load trainer config from {trainer_path}"
570+
)
571+
raise
572+
573+
self.config = TrainerConfig(**saved_state["trainer_config"])
574+
logging.info(f"[{self.name}] Trainer config loaded from {trainer_path}")
575+
576+
if model is not None:
577+
try:
578+
self._set_optimizer(model)
579+
self.optimizer.load_state_dict(saved_state["optimizer_state_dict"])
580+
logging.info(f"[{self.name}] Optimizer loaded from {trainer_path}")
581+
except BaseException:
582+
logging.error(
583+
"Loading the optimizer for your model failed. Optimizer state NOT loaded."
584+
)

test/classification/training/test_trainer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import copy
23
import json
34
import os
@@ -216,6 +217,43 @@ def test_warmup(self):
216217
trainer.fit(model, [dataloaders[0]])
217218
self.assertEqual(trainer.warmup_steps, 1)
218219

220+
def test_save_load(self):
221+
non_base_config = {"n_epochs": 2, "progress_bar": False}
222+
trainer1 = Trainer(**base_config, lr_scheduler="exponential")
223+
trainer1.fit(model, [dataloaders[0]])
224+
trainer2 = Trainer(**non_base_config, lr_scheduler="linear")
225+
trainer3 = Trainer(**non_base_config, lr_scheduler="linear")
226+
227+
with tempfile.NamedTemporaryFile() as fd:
228+
checkpoint_path = fd.name
229+
trainer1.save(checkpoint_path)
230+
trainer2.load(checkpoint_path, model=model)
231+
trainer3.load(checkpoint_path, None)
232+
233+
self.assertEqual(trainer1.config, trainer2.config)
234+
self.dict_check(
235+
trainer1.optimizer.state_dict(), trainer2.optimizer.state_dict()
236+
)
237+
238+
# continue training after load
239+
trainer2.fit(model, [dataloaders[0]])
240+
241+
# check that an inappropriate model does not load an optimizer state but a trainer config
242+
self.assertEqual(trainer1.config, trainer3.config)
243+
self.assertFalse(hasattr(trainer3, "optimizer"))
244+
trainer3.fit(model, [dataloaders[0]])
245+
246+
def dict_check(self, dict1, dict2):
247+
for k in dict1.keys():
248+
dict1_ = dict1[k]
249+
dict2_ = dict2[k]
250+
if isinstance(dict1_, collections.Mapping):
251+
self.dict_check(dict1_, dict2_)
252+
elif isinstance(dict1_, torch.Tensor):
253+
self.assertTrue(torch.eq(dict1_, dict2_,).all())
254+
else:
255+
self.assertEqual(dict1_, dict2_)
256+
219257

220258
if __name__ == "__main__":
221259
unittest.main()

0 commit comments

Comments
 (0)