diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py
index e48b51e975..2714620711 100644
--- a/tests/framework/callbacks/test_torchsnapshot_saver.py
+++ b/tests/framework/callbacks/test_torchsnapshot_saver.py
@@ -13,10 +13,11 @@
import unittest
from typing import Any, Dict, Iterable, List
from unittest import mock
-from unittest.mock import Mock, patch
+from unittest.mock import MagicMock, Mock, patch
import torch
import torch.distributed as dist
+from torch import nn
from torch.distributed import launcher
from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
@@ -34,6 +35,7 @@
_override_knobs,
get_latest_checkpoint_path,
KnobOptions,
+ RestoreOptions,
TorchSnapshotSaver,
)
from torchtnt.framework.train import train
@@ -313,12 +315,44 @@ def test_save_restore_no_train_progress(self) -> None:
end_num_steps_completed = my_unit.train_progress.num_steps_completed
self.assertGreater(len(expected_paths), 0)
snapshot_cb.restore(
- expected_paths[0], my_unit, restore_train_progress=False
+ expected_paths[0],
+ my_unit,
+ restore_options=RestoreOptions(restore_train_progress=False),
)
restored_num_steps_completed = my_unit.train_progress.num_steps_completed
# no train progress was restored so the progress after restoration should be the same as the progress before restoration
self.assertEqual(restored_num_steps_completed, end_num_steps_completed)
+ @patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot")
+ def test_save_restore_no_optimizer_restore(
+ self, mock_torchsnapshot: MagicMock
+ ) -> None:
+ my_unit = DummyTrainUnit(input_dim=2)
+ restore_options = RestoreOptions(restore_optimizers=False)
+ TorchSnapshotSaver.restore(
+ path="path/to/snapshot", unit=my_unit, restore_options=restore_options
+ )
+ app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
+ self.assertNotIn("optimizer", app_state)
+ TorchSnapshotSaver.restore(path="path/to/snapshot", unit=my_unit)
+ app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
+ self.assertIn("optimizer", app_state)
+
+ @patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot")
+ def test_save_restore_no_lr_scheduler_restore(
+ self, mock_torchsnapshot: MagicMock
+ ) -> None:
+ my_unit = DummyAutoUnit(module=nn.Linear(2, 3))
+ restore_options = RestoreOptions(restore_lr_schedulers=False)
+ TorchSnapshotSaver.restore(
+ path="path/to/snapshot", unit=my_unit, restore_options=restore_options
+ )
+ app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
+ self.assertNotIn("lr_scheduler", app_state)
+ TorchSnapshotSaver.restore(path="path/to/snapshot", unit=my_unit)
+ app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
+ self.assertIn("lr_scheduler", app_state)
+
def test_save_on_train_end(self) -> None:
input_dim = 2
dataset_len = 10
diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py
index 9ff78243ef..7f860c93f8 100644
--- a/torchtnt/framework/callbacks/torchsnapshot_saver.py
+++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py
@@ -78,6 +78,24 @@ class KnobOptions:
max_per_rank_io_concurrency: Optional[int] = None
+@dataclass
+class RestoreOptions:
+ """
+ Options when restoring a snapshot.
+
+ Args:
+ restore_train_progress: Whether to restore the training progress state.
+ restore_eval_progress: Whether to restore the evaluation progress state.
+ restore_optimizers: Whether to restore the optimizer states.
+ restore_lr_schedulers: Whether to restore the lr scheduler states.
+ """
+
+ restore_train_progress: bool = True
+ restore_eval_progress: bool = True
+ restore_optimizers: bool = True
+ restore_lr_schedulers: bool = True
+
+
class TorchSnapshotSaver(Callback):
"""
A callback which periodically saves the application state during training using `TorchSnapshot `_.
@@ -274,9 +292,8 @@ def restore(
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
- restore_train_progress: bool = True,
- restore_eval_progress: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
+ restore_options: Optional[RestoreOptions] = None,
storage_options: Optional[Dict[str, Any]] = None,
knob_options: Optional[KnobOptions] = None,
) -> None:
@@ -289,9 +306,8 @@ def restore(
path: Path of the snapshot to restore.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
- restore_train_progress: Whether to restore the training progress state.
- restore_eval_progress: Whether to restore the evaluation progress state.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
+ restore_options: Controls what to filter when restoring the state.
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations.
knob_options: Additional keyword options for the snapshot knobs
"""
@@ -312,12 +328,23 @@ def restore(
rng_state = torchsnapshot.RNGState()
app_state[_RNG_STATE_KEY] = rng_state
- if not restore_train_progress:
+ restore_options = restore_options or RestoreOptions()
+ if not restore_options.restore_train_progress:
app_state.pop(_TRAIN_PROGRESS_STATE_KEY, None)
- if not restore_eval_progress:
+ if not restore_options.restore_eval_progress:
app_state.pop(_EVAL_PROGRESS_STATE_KEY, None)
+ if not restore_options.restore_optimizers:
+ # remove all optimizer keys from app_state
+ for optim_keys in unit.tracked_optimizers().keys():
+ app_state.pop(optim_keys, None)
+
+ if not restore_options.restore_lr_schedulers:
+ # remove all lr scheduler keys from app_state
+ for lr_scheduler_keys in unit.tracked_lr_schedulers().keys():
+ app_state.pop(lr_scheduler_keys, None)
+
if train_dataloader is not None:
if not isinstance(train_dataloader, _TStateful):
rank_zero_warn(
@@ -346,9 +373,8 @@ def restore_from_latest(
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
- restore_train_progress: bool = True,
- restore_eval_progress: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
+ restore_options: Optional[RestoreOptions] = None,
storage_options: Optional[Dict[str, Any]] = None,
knob_options: Optional[KnobOptions] = None,
) -> bool:
@@ -362,9 +388,8 @@ def restore_from_latest(
dirpath: Parent directory from which to get the latest snapshot.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
- restore_train_progress: Whether to restore the training progress state.
- restore_eval_progress: Whether to restore the evaluation progress state.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
+ restore_options: Controls what to filter when restoring the state.
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations.
knob_options: Additional keyword options for the snapshot knobs
@@ -379,9 +404,8 @@ def restore_from_latest(
path,
unit,
train_dataloader=train_dataloader,
- restore_train_progress=restore_train_progress,
- restore_eval_progress=restore_eval_progress,
process_group=process_group,
+ restore_options=restore_options,
storage_options=storage_options,
knob_options=knob_options,
)