From 1766dadfc4ee329ca226839bf9f1c619cd667f6b Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Fri, 27 Oct 2023 15:11:24 -0700 Subject: [PATCH] encapsulate snapshot restoration params in a struct Summary: # Context Users may want to opt out of restoring certain parts of the app state (for example, like optimizer and lr scheduler states when finetuning). This is currently not supported in torchsnapshot saver # This Diff * Adds `RestoreOptions` dataclass to encapsulate all restoration params * moves `restore_train_progress`, `restore_eval_progress` to the struct * Replaces all `restore` apis to take `RestoreOptions` struct, replacing the `restore_train_progress`, `restore_eval_progress` * Adds optimizer and lr_scheduler to `RestoreOptions` Differential Revision: D50757494 --- .../callbacks/test_torchsnapshot_saver.py | 38 ++++++++++++++- .../callbacks/torchsnapshot_saver.py | 48 ++++++++++++++----- 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index e48b51e975..2714620711 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -13,10 +13,11 @@ import unittest from typing import Any, Dict, Iterable, List from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import torch import torch.distributed as dist +from torch import nn from torch.distributed import launcher from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq @@ -34,6 +35,7 @@ _override_knobs, get_latest_checkpoint_path, KnobOptions, + RestoreOptions, TorchSnapshotSaver, ) from torchtnt.framework.train import train @@ -313,12 +315,44 @@ def test_save_restore_no_train_progress(self) -> None: end_num_steps_completed = my_unit.train_progress.num_steps_completed self.assertGreater(len(expected_paths), 0) snapshot_cb.restore( - expected_paths[0], my_unit, restore_train_progress=False + expected_paths[0], + my_unit, + restore_options=RestoreOptions(restore_train_progress=False), ) restored_num_steps_completed = my_unit.train_progress.num_steps_completed # no train progress was restored so the progress after restoration should be the same as the progress before restoration self.assertEqual(restored_num_steps_completed, end_num_steps_completed) + @patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot") + def test_save_restore_no_optimizer_restore( + self, mock_torchsnapshot: MagicMock + ) -> None: + my_unit = DummyTrainUnit(input_dim=2) + restore_options = RestoreOptions(restore_optimizers=False) + TorchSnapshotSaver.restore( + path="path/to/snapshot", unit=my_unit, restore_options=restore_options + ) + app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] + self.assertNotIn("optimizer", app_state) + TorchSnapshotSaver.restore(path="path/to/snapshot", unit=my_unit) + app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] + self.assertIn("optimizer", app_state) + + @patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot") + def test_save_restore_no_lr_scheduler_restore( + self, mock_torchsnapshot: MagicMock + ) -> None: + my_unit = DummyAutoUnit(module=nn.Linear(2, 3)) + restore_options = RestoreOptions(restore_lr_schedulers=False) + TorchSnapshotSaver.restore( + path="path/to/snapshot", unit=my_unit, restore_options=restore_options + ) + app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] + self.assertNotIn("lr_scheduler", app_state) + TorchSnapshotSaver.restore(path="path/to/snapshot", unit=my_unit) + app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0] + self.assertIn("lr_scheduler", app_state) + def test_save_on_train_end(self) -> None: input_dim = 2 dataset_len = 10 diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 9ff78243ef..7f860c93f8 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -78,6 +78,24 @@ class KnobOptions: max_per_rank_io_concurrency: Optional[int] = None +@dataclass +class RestoreOptions: + """ + Options when restoring a snapshot. + + Args: + restore_train_progress: Whether to restore the training progress state. + restore_eval_progress: Whether to restore the evaluation progress state. + restore_optimizers: Whether to restore the optimizer states. + restore_lr_schedulers: Whether to restore the lr scheduler states. + """ + + restore_train_progress: bool = True + restore_eval_progress: bool = True + restore_optimizers: bool = True + restore_lr_schedulers: bool = True + + class TorchSnapshotSaver(Callback): """ A callback which periodically saves the application state during training using `TorchSnapshot `_. @@ -274,9 +292,8 @@ def restore( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, - restore_train_progress: bool = True, - restore_eval_progress: bool = True, process_group: Optional[dist.ProcessGroup] = None, + restore_options: Optional[RestoreOptions] = None, storage_options: Optional[Dict[str, Any]] = None, knob_options: Optional[KnobOptions] = None, ) -> None: @@ -289,9 +306,8 @@ def restore( path: Path of the snapshot to restore. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. train_dataloader: An optional train dataloader to restore. - restore_train_progress: Whether to restore the training progress state. - restore_eval_progress: Whether to restore the evaluation progress state. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) + restore_options: Controls what to filter when restoring the state. storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations. knob_options: Additional keyword options for the snapshot knobs """ @@ -312,12 +328,23 @@ def restore( rng_state = torchsnapshot.RNGState() app_state[_RNG_STATE_KEY] = rng_state - if not restore_train_progress: + restore_options = restore_options or RestoreOptions() + if not restore_options.restore_train_progress: app_state.pop(_TRAIN_PROGRESS_STATE_KEY, None) - if not restore_eval_progress: + if not restore_options.restore_eval_progress: app_state.pop(_EVAL_PROGRESS_STATE_KEY, None) + if not restore_options.restore_optimizers: + # remove all optimizer keys from app_state + for optim_keys in unit.tracked_optimizers().keys(): + app_state.pop(optim_keys, None) + + if not restore_options.restore_lr_schedulers: + # remove all lr scheduler keys from app_state + for lr_scheduler_keys in unit.tracked_lr_schedulers().keys(): + app_state.pop(lr_scheduler_keys, None) + if train_dataloader is not None: if not isinstance(train_dataloader, _TStateful): rank_zero_warn( @@ -346,9 +373,8 @@ def restore_from_latest( unit: AppStateMixin, *, train_dataloader: Optional[Iterable[TTrainData]] = None, - restore_train_progress: bool = True, - restore_eval_progress: bool = True, process_group: Optional[dist.ProcessGroup] = None, + restore_options: Optional[RestoreOptions] = None, storage_options: Optional[Dict[str, Any]] = None, knob_options: Optional[KnobOptions] = None, ) -> bool: @@ -362,9 +388,8 @@ def restore_from_latest( dirpath: Parent directory from which to get the latest snapshot. unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore. train_dataloader: An optional train dataloader to restore. - restore_train_progress: Whether to restore the training progress state. - restore_eval_progress: Whether to restore the evaluation progress state. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) + restore_options: Controls what to filter when restoring the state. storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations. knob_options: Additional keyword options for the snapshot knobs @@ -379,9 +404,8 @@ def restore_from_latest( path, unit, train_dataloader=train_dataloader, - restore_train_progress=restore_train_progress, - restore_eval_progress=restore_eval_progress, process_group=process_group, + restore_options=restore_options, storage_options=storage_options, knob_options=knob_options, )