diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index e48b51e975..2714620711 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -13,10 +13,11 @@ import unittest from typing import Any, Dict, Iterable, List from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import torch import torch.distributed as dist +from torch import nn from torch.distributed import launcher from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq @@ -34,6 +35,7 @@ _override_knobs, get_latest_checkpoint_path, KnobOptions, + RestoreOptions, TorchSnapshotSaver, ) from torchtnt.framework.train import train @@ -313,12 +315,44 @@ def test_save_restore_no_train_progress(self) -> None: end_num_steps_completed = my_unit.train_progress.num_steps_completed self.assertGreater(len(expected_paths), 0) snapshot_cb.restore( - expected_paths[0], my_unit, restore_train_progress=False + expected_paths[0], + my_unit, + restore_options=RestoreOptions(restore_train_progress=False), ) restored_num_steps_completed = my_unit.train_progress.num_steps_completed # no train progress was restored so the progress after restoration should be the same as the progress before restoration self.assertEqual(restored_num_steps_completed, end_num_steps_completed) + @patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot") + def test_save_restore_no_optimizer_restore( + self, mock_torchsnapshot: MagicMock + ) -> None: + my_unit = DummyTrainUnit(input_dim=2) + restore_options = RestoreOptions(restore_optimizers=False) + TorchSnapshotSaver.restore( + path="path/to/snapshot", unit=my_unit, restore_options=restore_options + ) + app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] + self.assertNotIn("optimizer", app_state) + TorchSnapshotSaver.restore(path="path/to/snapshot", unit=my_unit) + app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] + self.assertIn("optimizer", app_state) + + @patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot") + def test_save_restore_no_lr_scheduler_restore( + self, mock_torchsnapshot: MagicMock + ) -> None: + my_unit = DummyAutoUnit(module=nn.Linear(2, 3)) + restore_options = RestoreOptions(restore_lr_schedulers=False) + TorchSnapshotSaver.restore( + path="path/to/snapshot", unit=my_unit, restore_options=restore_options + ) + app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] + self.assertNotIn("lr_scheduler", app_state) + TorchSnapshotSaver.restore(path="path/to/snapshot", unit=my_unit) + app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] + self.assertIn("lr_scheduler", app_state) + def test_save_on_train_end(self) -> None: input_dim = 2 dataset_len = 10 diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 9ff78243ef..7f860c93f8 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -78,6 +78,24 @@ class KnobOptions: max_per_rank_io_concurrency: Optional[int] = None +@dataclass +class RestoreOptions: + """ + Options when restoring a snapshot. + + Args: + restore_train_progress: Whether to restore the training progress state. + restore_eval_progress: Whether to restore the evaluation progress state. + restore_optimizers: Whether to restore the optimizer states. + restore_lr_schedulers: Whether to restore the lr scheduler states. + """ + + restore_train_progress: bool = True + restore_eval_progress: bool = True + restore_optimizers: bool = True + restore_lr_schedulers: bool = True + + class TorchSnapshotSaver(Callback): """ A callback which periodically saves the application state during training using `TorchSnapshot `_. @@ -274,9 +292,8 @@ def restore( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, - restore_train_progress: bool = True, - restore_eval_progress: bool = True, process_group: Optional[dist.ProcessGroup] = None, + restore_options: Optional[RestoreOptions] = None, storage_options: Optional[Dict[str, Any]] = None, knob_options: Optional[KnobOptions] = None, ) -> None: @@ -289,9 +306,8 @@ def restore( path: Path of the snapshot to restore. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. train_dataloader: An optional train dataloader to restore. - restore_train_progress: Whether to restore the training progress state. - restore_eval_progress: Whether to restore the evaluation progress state. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) + restore_options: Controls what to filter when restoring the state. storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations. knob_options: Additional keyword options for the snapshot knobs """ @@ -312,12 +328,23 @@ def restore( rng_state = torchsnapshot.RNGState() app_state[_RNG_STATE_KEY] = rng_state - if not restore_train_progress: + restore_options = restore_options or RestoreOptions() + if not restore_options.restore_train_progress: app_state.pop(_TRAIN_PROGRESS_STATE_KEY, None) - if not restore_eval_progress: + if not restore_options.restore_eval_progress: app_state.pop(_EVAL_PROGRESS_STATE_KEY, None) + if not restore_options.restore_optimizers: + # remove all optimizer keys from app_state + for optim_keys in unit.tracked_optimizers().keys(): + app_state.pop(optim_keys, None) + + if not restore_options.restore_lr_schedulers: + # remove all lr scheduler keys from app_state + for lr_scheduler_keys in unit.tracked_lr_schedulers().keys(): + app_state.pop(lr_scheduler_keys, None) + if train_dataloader is not None: if not isinstance(train_dataloader, _TStateful): rank_zero_warn( @@ -346,9 +373,8 @@ def restore_from_latest( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, - restore_train_progress: bool = True, - restore_eval_progress: bool = True, process_group: Optional[dist.ProcessGroup] = None, + restore_options: Optional[RestoreOptions] = None, storage_options: Optional[Dict[str, Any]] = None, knob_options: Optional[KnobOptions] = None, ) -> bool: @@ -362,9 +388,8 @@ def restore_from_latest( dirpath: Parent directory from which to get the latest snapshot. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. train_dataloader: An optional train dataloader to restore. - restore_train_progress: Whether to restore the training progress state. - restore_eval_progress: Whether to restore the evaluation progress state. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) + restore_options: Controls what to filter when restoring the state. storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations. knob_options: Additional keyword options for the snapshot knobs @@ -379,9 +404,8 @@ def restore_from_latest( path, unit, train_dataloader=train_dataloader, - restore_train_progress=restore_train_progress, - restore_eval_progress=restore_eval_progress, process_group=process_group, + restore_options=restore_options, storage_options=storage_options, knob_options=knob_options, )