From 0f8bd4e635f64e0f1b359f75efd36eb65c92a3f6 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Mon, 13 Nov 2023 11:51:48 -0800 Subject: [PATCH] add keep_latest_n checkpoint arg in TSS Reviewed By: galrotem Differential Revision: D51048523 --- .../callbacks/test_torchsnapshot_saver.py | 171 ++++++++++++++++++ .../callbacks/torchsnapshot_saver.py | 110 ++++++++++- 2 files changed, 277 insertions(+), 4 deletions(-) diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index 18c871e439..7f214bdaf0 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -20,6 +20,7 @@ from torch import nn from torch.distributed import launcher from torch.utils.data import DataLoader +from torchsnapshot import Snapshot from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq @@ -33,7 +34,9 @@ ) from torchtnt.framework.callbacks.lambda_callback import Lambda from torchtnt.framework.callbacks.torchsnapshot_saver import ( + _delete_snapshot, _override_knobs, + _retrieve_checkpoint_dirpaths, get_latest_checkpoint_path, KnobOptions, RestoreOptions, @@ -669,6 +672,174 @@ def test_knob_override(self) -> None: with _override_knobs(KnobOptions(max_per_rank_io_concurrency=None)): self.assertNotIn(env_var, os.environ) + @patch("torchtnt.framework.callbacks.torchsnapshot_saver.get_filesystem") + def test_retrieve_checkpoint_dirpaths(self, mock_get_filesystem: MagicMock) -> None: + """ + Tests retrieving checkpoint directories from a given root directory + """ + paths = [ + {"name": "tmp/epoch_0_step_10", "type": "directory"}, + {"name": "tmp/epoch_1_step_10", "type": "directory"}, + {"name": "tmp/epoch_2_step_10", "type": "directory"}, + {"name": "tmp/epoch_0_step_5", "type": "directory"}, + {"name": "tmp/epoch_0_step_3", "type": "file"}, + ] + + mock_get_filesystem.return_value.ls.return_value = paths + returned_paths = _retrieve_checkpoint_dirpaths("foo") + self.assertEqual( + returned_paths, + [ + "tmp/epoch_0_step_5", + "tmp/epoch_0_step_10", + "tmp/epoch_1_step_10", + "tmp/epoch_2_step_10", + ], + ) + + def test_delete_snapshot(self) -> None: + """ + Tests removing checkpoint directories + """ + app_state = {"module": nn.Linear(2, 2)} + with tempfile.TemporaryDirectory() as temp_dir: + dirpath = os.path.join(temp_dir, "checkpoint") + Snapshot.take(dirpath, app_state=app_state) + self.assertTrue(os.path.exists(dirpath)) + # check that error is thrown if .snapshot_metadata is not found in the directory when deleting + with self.assertRaisesRegex( + RuntimeError, f"{temp_dir} does not contain .snapshot_metadata" + ): + _delete_snapshot(temp_dir) + _delete_snapshot(dirpath) + self.assertFalse(os.path.exists(dirpath)) + + def test_should_remove_snapshot(self) -> None: + """ + Tests the helper function that checks if snapshot should be removed or not + """ + tss = TorchSnapshotSaver("temp") + + # keep_last_n_checkpoints is toggled off + self.assertFalse(tss._should_remove_snapshot()) + + # not enough checkpoints are saved yet to be removed + tss._keep_last_n_checkpoints = 2 + tss._ckpt_dirpaths = ["bar"] + self.assertFalse(tss._should_remove_snapshot()) + + # enough checkpoints are there to remove + tss._keep_last_n_checkpoints = 2 + tss._ckpt_dirpaths = ["foo", "bar"] + self.assertTrue(tss._should_remove_snapshot()) + + @patch("torchtnt.framework.callbacks.torchsnapshot_saver._delete_snapshot") + def test_remove_snapshot(self, mock_delete_snapshot: MagicMock) -> None: + """ + Tests the helper function that removes snapshots and updates the checkpoint paths + """ + state = get_dummy_train_state() + tss = TorchSnapshotSaver("temp") + tss._ckpt_dirpaths = ["foo", "bar"] + tss._remove_snapshot(state) + + mock_delete_snapshot.assert_called_once() + self.assertEqual(len(tss._ckpt_dirpaths), 1) + self.assertEqual(tss._ckpt_dirpaths[0], "bar") + + @patch("torchtnt.framework.callbacks.torchsnapshot_saver._delete_snapshot") + def test_cleanup_surplus(self, mock_delete_snapshot: MagicMock) -> None: + """ + Tests surplus of checkpoints being cleaned up + """ + state = get_dummy_train_state() + unit = DummyTrainUnit(input_dim=2) + warning_messages = [] + with tempfile.TemporaryDirectory() as temp_dir: + tss = TorchSnapshotSaver(temp_dir, keep_last_n_checkpoints=1) + tss._ckpt_dirpaths = ["foo", "bar", "baz"] + + expected_warning_msg = " ".join( + [ + f"3 checkpoints found in {temp_dir}.", + f"Deleting {2} oldest", + "checkpoints to enforce ``keep_last_n_checkpoints`` argument.", + ] + ) + + with patch( + "torchtnt.framework.callbacks.torchsnapshot_saver.logging.Logger.warning", + warning_messages.append, + ): + tss.on_train_start(state, unit) + self.assertEqual(tss._ckpt_dirpaths, ["baz"]) + self.assertEqual(warning_messages[0], expected_warning_msg) + + tss = TorchSnapshotSaver(temp_dir) + tss._ckpt_dirpaths = ["foo", "bar", "baz"] + + tss.on_train_start(state, unit) + self.assertEqual(tss._ckpt_dirpaths, ["foo", "bar", "baz"]) + + def test_keep_last_n_checkpoints(self) -> None: + """ + Tests removing checkpoint directories + """ + unit = DummyTrainUnit(input_dim=2) + state = get_dummy_train_state() + with tempfile.TemporaryDirectory() as temp_dir: + tss = TorchSnapshotSaver( + temp_dir, + save_every_n_train_steps=1, + keep_last_n_checkpoints=2, + ) + + # take 10 steps + for _ in range(10): + unit.train_progress.increment_step() + tss.on_train_step_end(state, unit) + # TODO remove time.sleep to avoid potential flaky test + time.sleep(0.1) # sleep to ensure enough time to checkpoint + + dirs = os.listdir(temp_dir) + self.assertEqual(len(dirs), 2) + self.assertIn("epoch_0_step_9", dirs) + self.assertIn("epoch_0_step_10", dirs) + + def test_keep_last_n_checkpoints_e2e(self) -> None: + """ + Tests removing checkpoint directories e2e + """ + input_dim = 2 + dataset_len = 10 + batch_size = 2 + max_epochs = 2 + + my_unit = DummyTrainUnit(input_dim=input_dim) + dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) + with tempfile.TemporaryDirectory() as temp_dir: + snapshot_cb = TorchSnapshotSaver( + temp_dir, + save_every_n_train_steps=2, + keep_last_n_checkpoints=1, + ) + # Artificially increase the step duration, otherwise torchsnapshot + # doesn't have the time to save all snapshots and will skip some. + slowdown = Lambda(on_train_step_end=lambda *_: time.sleep(0.1)) + + train( + my_unit, + dataloader, + max_epochs=max_epochs, + callbacks=[snapshot_cb, slowdown], + ) + dirs = os.listdir(temp_dir) + self.assertEqual(len(dirs), 1) + self.assertIn( + f"epoch_{max_epochs}_step_{dataset_len // batch_size * max_epochs}", + os.listdir(temp_dir), + ) + class DummyStatefulDataLoader: def __init__(self, dataloader: DataLoader) -> None: diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 2caab8064f..89f4857b1d 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -83,6 +83,7 @@ class TorchSnapshotSaver(Callback): dirpath: Parent directory to save snapshots to. save_every_n_train_steps: Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated. save_every_n_epochs: Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated. + keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference. process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world) replicated: A glob-pattern of replicated key names that indicate which application state entries have the same state across all processes. For more information, see https://pytorch.org/torchsnapshot/main/api_reference.html#torchsnapshot.Snapshot.take. @@ -108,6 +109,7 @@ def __init__( *, save_every_n_train_steps: Optional[int] = None, save_every_n_epochs: Optional[int] = None, + keep_last_n_checkpoints: Optional[int] = None, process_group: Optional[dist.ProcessGroup] = None, replicated: Optional[List[str]] = None, storage_options: Optional[Dict[str, Any]] = None, @@ -122,9 +124,20 @@ def __init__( raise ValueError( f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}" ) + if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0: + raise ValueError( + f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}" + ) self._save_every_n_epochs = save_every_n_epochs self._save_every_n_train_steps = save_every_n_train_steps + + self._keep_last_n_checkpoints = keep_last_n_checkpoints + self._ckpt_dirpaths: List[str] = [] + if self._keep_last_n_checkpoints: + self._ckpt_dirpaths = _retrieve_checkpoint_dirpaths(dirpath) + self._process_group = process_group + self._pg_wrapper = PGWrapper(process_group) self._sync_dirpath_to_all_ranks(dirpath) self._replicated: Set[str] = set(replicated or []) @@ -156,6 +169,24 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None: app_state = _app_state(unit) _check_app_state_collision(app_state) + # clean up the difference if surplus of checkpoints exist + keep_last_n_checkpoints = self._keep_last_n_checkpoints + if ( + keep_last_n_checkpoints + and len(self._ckpt_dirpaths) > keep_last_n_checkpoints + ): + logger.warning( + " ".join( + [ + f"{len(self._ckpt_dirpaths)} checkpoints found in {self._dirpath}.", + f"Deleting {len(self._ckpt_dirpaths) - keep_last_n_checkpoints} oldest", + "checkpoints to enforce ``keep_last_n_checkpoints`` argument.", + ] + ) + ) + for _ in range(len(self._ckpt_dirpaths) - keep_last_n_checkpoints): + self._remove_snapshot(state) + def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: num_steps_completed = unit.train_progress.num_steps_completed save_every_n_train_steps = self._save_every_n_train_steps @@ -177,7 +208,14 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: with get_timing_context( state, f"{self.__class__.__name__}.take_async_snapshot" ): - self._async_snapshot(snapshot_path, app_state, wait=False) + checkpoint_success = self._async_snapshot( + snapshot_path, app_state, wait=False + ) + + if checkpoint_success: + if self._should_remove_snapshot(): + self._remove_snapshot(state) + self._ckpt_dirpaths.append(snapshot_path) def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: epoch = unit.train_progress.num_epochs_completed @@ -197,7 +235,14 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: with get_timing_context( state, f"{self.__class__.__name__}.take_async_snapshot" ): - self._async_snapshot(snapshot_path, app_state, wait=True) + checkpoint_success = self._async_snapshot( + snapshot_path, app_state, wait=True + ) + + if checkpoint_success: + if self._should_remove_snapshot(): + self._remove_snapshot(state) + self._ckpt_dirpaths.append(snapshot_path) def on_train_end(self, state: State, unit: TTrainUnit) -> None: app_state = _get_app_state(state, unit, self._replicated, intra_epoch=False) @@ -213,9 +258,20 @@ def on_train_end(self, state: State, unit: TTrainUnit) -> None: with get_timing_context( state, f"{self.__class__.__name__}.take_async_snapshot" ): - self._async_snapshot(snapshot_path, app_state, wait=True) + # TODO checkpoint is not truly successful + # since this is async checkpointed, so in + # future, add logic to set successful flag + # only when checkpoint is fully written + checkpoint_success = self._async_snapshot( + snapshot_path, app_state, wait=True + ) self._wait() + if checkpoint_success: + if self._should_remove_snapshot(): + self._remove_snapshot(state) + self._ckpt_dirpaths.append(snapshot_path) + def on_exception( self, state: State, @@ -224,6 +280,22 @@ def on_exception( ) -> None: self._wait() + def _should_remove_snapshot(self) -> bool: + keep_last_n_checkpoints = self._keep_last_n_checkpoints + return ( + keep_last_n_checkpoints is not None + and len(self._ckpt_dirpaths) >= keep_last_n_checkpoints + ) + + def _remove_snapshot(self, state: State) -> None: + # remove oldest snapshot directory + oldest_ckpt_path = self._ckpt_dirpaths.pop(0) + with get_timing_context(state, f"{self.__class__.__name__}.delete_snapshot"): + if self._pg_wrapper.get_rank() == 0: + # only delete on rank 0 + _delete_snapshot(oldest_ckpt_path) + self._pg_wrapper.barrier() + def _wait(self) -> None: if self._prev_snapshot is not None: self._prev_snapshot.wait() @@ -236,7 +308,7 @@ def _async_snapshot( if prev_snapshot.path == snapshot_path: # Snapshot for this step already has been saved. # This can happen if we call _async_snapshot twice at the same step. - return True + return False still_pending = not prev_snapshot.done() if still_pending and wait: prev_snapshot.wait() @@ -563,3 +635,33 @@ def _override_knobs( for mgr in knobs: stack.enter_context(mgr) yield + + +def _retrieve_checkpoint_dirpaths(dirpath: str) -> List[str]: + """ + Given a parent directory where checkpoints are saved, return the sorted checkpoint subdirectories + from oldest to newest. + + Args: + dirpath: parent directory where checkpoints are saved. + """ + fs = get_filesystem(dirpath) + + contents = fs.ls(dirpath, detail=True) + contents = [item["name"] for item in contents if item["type"] == "directory"] + ckpt_dirpaths = [] + for path in contents: + match = re.search(r"epoch_(\d+)_step_(\d+)", path) + if match: + ckpt_dirpaths.append(path) + + # sorts by epoch, then step + ckpt_dirpaths.sort(key=lambda x: (int(x.split("_")[1]), int(x.split("_")[3]))) + return ckpt_dirpaths + + +def _delete_snapshot(dirpath: str) -> None: + fs = get_filesystem(dirpath) + if not fs.exists(os.path.join(dirpath, ".snapshot_metadata")): + raise RuntimeError(f"{dirpath} does not contain .snapshot_metadata") + fs.rm(dirpath, recursive=True)