Skip to content

Commit

Permalink
encapsulate snapshot restoration params in a struct
Browse files Browse the repository at this point in the history
Summary:
# Context
Users may want to opt out of restoring certain parts of the app state (for example, like optimizer and lr scheduler states when finetuning). This is currently not supported in torchsnapshot saver

# This Diff
* Adds `RestoreOptions` dataclass to encapsulate all restoration params
* moves `restore_train_progress`, `restore_eval_progress` to the struct
* Replaces all `restore` apis to take `RestoreOptions` struct, replacing the `restore_train_progress`, `restore_eval_progress`
* Adds optimizer and lr_scheduler to `RestoreOptions`

Differential Revision: D50757494
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Oct 27, 2023
1 parent c8e8ab9 commit 1766dad
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 14 deletions.
38 changes: 36 additions & 2 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import unittest
from typing import Any, Dict, Iterable, List
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import torch
import torch.distributed as dist
from torch import nn
from torch.distributed import launcher
from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq
Expand All @@ -34,6 +35,7 @@
_override_knobs,
get_latest_checkpoint_path,
KnobOptions,
RestoreOptions,
TorchSnapshotSaver,
)
from torchtnt.framework.train import train
Expand Down Expand Up @@ -313,12 +315,44 @@ def test_save_restore_no_train_progress(self) -> None:
end_num_steps_completed = my_unit.train_progress.num_steps_completed
self.assertGreater(len(expected_paths), 0)
snapshot_cb.restore(
expected_paths[0], my_unit, restore_train_progress=False
expected_paths[0],
my_unit,
restore_options=RestoreOptions(restore_train_progress=False),
)
restored_num_steps_completed = my_unit.train_progress.num_steps_completed
# no train progress was restored so the progress after restoration should be the same as the progress before restoration
self.assertEqual(restored_num_steps_completed, end_num_steps_completed)

@patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot")
def test_save_restore_no_optimizer_restore(
self, mock_torchsnapshot: MagicMock
) -> None:
my_unit = DummyTrainUnit(input_dim=2)
restore_options = RestoreOptions(restore_optimizers=False)
TorchSnapshotSaver.restore(
path="path/to/snapshot", unit=my_unit, restore_options=restore_options
)
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
self.assertNotIn("optimizer", app_state)
TorchSnapshotSaver.restore(path="path/to/snapshot", unit=my_unit)
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
self.assertIn("optimizer", app_state)

@patch("torchtnt.framework.callbacks.torchsnapshot_saver.torchsnapshot")
def test_save_restore_no_lr_scheduler_restore(
self, mock_torchsnapshot: MagicMock
) -> None:
my_unit = DummyAutoUnit(module=nn.Linear(2, 3))
restore_options = RestoreOptions(restore_lr_schedulers=False)
TorchSnapshotSaver.restore(
path="path/to/snapshot", unit=my_unit, restore_options=restore_options
)
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
self.assertNotIn("lr_scheduler", app_state)
TorchSnapshotSaver.restore(path="path/to/snapshot", unit=my_unit)
app_state = mock_torchsnapshot.Snapshot().restore.call_args.args[0]
self.assertIn("lr_scheduler", app_state)

def test_save_on_train_end(self) -> None:
input_dim = 2
dataset_len = 10
Expand Down
48 changes: 36 additions & 12 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ class KnobOptions:
max_per_rank_io_concurrency: Optional[int] = None


@dataclass
class RestoreOptions:
"""
Options when restoring a snapshot.
Args:
restore_train_progress: Whether to restore the training progress state.
restore_eval_progress: Whether to restore the evaluation progress state.
restore_optimizers: Whether to restore the optimizer states.
restore_lr_schedulers: Whether to restore the lr scheduler states.
"""

restore_train_progress: bool = True
restore_eval_progress: bool = True
restore_optimizers: bool = True
restore_lr_schedulers: bool = True


class TorchSnapshotSaver(Callback):
"""
A callback which periodically saves the application state during training using `TorchSnapshot <https://pytorch.org/torchsnapshot/>`_.
Expand Down Expand Up @@ -274,9 +292,8 @@ def restore(
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
restore_train_progress: bool = True,
restore_eval_progress: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
storage_options: Optional[Dict[str, Any]] = None,
knob_options: Optional[KnobOptions] = None,
) -> None:
Expand All @@ -289,9 +306,8 @@ def restore(
path: Path of the snapshot to restore.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
restore_train_progress: Whether to restore the training progress state.
restore_eval_progress: Whether to restore the evaluation progress state.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
knob_options: Additional keyword options for the snapshot knobs
"""
Expand All @@ -312,12 +328,23 @@ def restore(
rng_state = torchsnapshot.RNGState()
app_state[_RNG_STATE_KEY] = rng_state

if not restore_train_progress:
restore_options = restore_options or RestoreOptions()
if not restore_options.restore_train_progress:
app_state.pop(_TRAIN_PROGRESS_STATE_KEY, None)

if not restore_eval_progress:
if not restore_options.restore_eval_progress:
app_state.pop(_EVAL_PROGRESS_STATE_KEY, None)

if not restore_options.restore_optimizers:
# remove all optimizer keys from app_state
for optim_keys in unit.tracked_optimizers().keys():
app_state.pop(optim_keys, None)

if not restore_options.restore_lr_schedulers:
# remove all lr scheduler keys from app_state
for lr_scheduler_keys in unit.tracked_lr_schedulers().keys():
app_state.pop(lr_scheduler_keys, None)

if train_dataloader is not None:
if not isinstance(train_dataloader, _TStateful):
rank_zero_warn(
Expand Down Expand Up @@ -346,9 +373,8 @@ def restore_from_latest(
unit: AppStateMixin,
*,
train_dataloader: Optional[Iterable[TTrainData]] = None,
restore_train_progress: bool = True,
restore_eval_progress: bool = True,
process_group: Optional[dist.ProcessGroup] = None,
restore_options: Optional[RestoreOptions] = None,
storage_options: Optional[Dict[str, Any]] = None,
knob_options: Optional[KnobOptions] = None,
) -> bool:
Expand All @@ -362,9 +388,8 @@ def restore_from_latest(
dirpath: Parent directory from which to get the latest snapshot.
unit: An instance of :class:`~torchtnt.framework.unit.TrainUnit`, :class:`~torchtnt.framework.unit.EvalUnit`, or :class:`~torchtnt.framework.unit.PredictUnit` containing states to restore.
train_dataloader: An optional train dataloader to restore.
restore_train_progress: Whether to restore the training progress state.
restore_eval_progress: Whether to restore the evaluation progress state.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
restore_options: Controls what to filter when restoring the state.
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
knob_options: Additional keyword options for the snapshot knobs
Expand All @@ -379,9 +404,8 @@ def restore_from_latest(
path,
unit,
train_dataloader=train_dataloader,
restore_train_progress=restore_train_progress,
restore_eval_progress=restore_eval_progress,
process_group=process_group,
restore_options=restore_options,
storage_options=storage_options,
knob_options=knob_options,
)
Expand Down

0 comments on commit 1766dad

Please sign in to comment.