diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index 2e277b1a05..0d54fed8e2 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -21,7 +21,11 @@ from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader from torchtnt.framework.auto_unit import AutoUnit -from torchtnt.framework.callbacks import Lambda, TorchSnapshotSaver +from torchtnt.framework.callbacks.lambda_callback import Lambda +from torchtnt.framework.callbacks.torchsnapshot_saver import ( + _get_latest_checkpoint_path, + TorchSnapshotSaver, +) from torchtnt.framework.state import State from torchtnt.framework.train import train from torchtnt.utils.distributed import get_global_rank, PGWrapper @@ -59,7 +63,6 @@ def test_save_every_n_train_steps(self) -> None: snapshot = TorchSnapshotSaver( temp_dir, save_every_n_train_steps=save_every_n_train_steps, - replicated=["**"], ) # Artificially increase the step duration, otherwise torchsnapshot # doesn't have the time to save all snapshots and will skip some. @@ -91,7 +94,6 @@ def test_save_every_n_train_epochs(self) -> None: snapshot = TorchSnapshotSaver( temp_dir, save_every_n_epochs=save_every_n_train_epochs, - replicated=["**"], ) train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot]) self.assertTrue( @@ -124,7 +126,6 @@ def test_save_restore(self) -> None: snapshot_cb = TorchSnapshotSaver( temp_dir, save_every_n_train_steps=save_every_n_train_steps, - replicated=["**"], ) train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]) @@ -152,7 +153,6 @@ def test_save_restore_dataloader_state(self) -> None: snapshot_cb = TorchSnapshotSaver( temp_dir, save_every_n_train_steps=save_every_n_train_steps, - replicated=["**"], ) train( my_unit, @@ -204,18 +204,18 @@ def test_restore_from_latest(self) -> None: snapshot_cb = TorchSnapshotSaver( temp_dir, save_every_n_train_steps=save_every_n_train_steps, - replicated=["**"], ) train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]) with mock.patch( "torchtnt.framework.callbacks.torchsnapshot_saver.TorchSnapshotSaver.restore" ) as mock_restore: - snapshot_cb.restore_from_latest(temp_dir, my_unit) + restored = snapshot_cb.restore_from_latest(temp_dir, my_unit) self.assertIn( temp_dir + f"/epoch_{max_epochs}_step_{expected_steps_per_epoch}", mock_restore.call_args.args, ) + self.assertTrue(restored) def test_restore_from_latest_empty_dir(self) -> None: input_dim = 2 @@ -226,17 +226,17 @@ def test_restore_from_latest_empty_dir(self) -> None: snapshot_cb = TorchSnapshotSaver( temp_dir, save_every_n_train_steps=save_every_n_train_steps, - replicated=["**"], ) with self.assertLogs(level="WARNING") as log: - snapshot_cb.restore_from_latest(temp_dir, my_unit) + restored = snapshot_cb.restore_from_latest(temp_dir, my_unit) self.assertEqual( log.output, [ f"WARNING:torchtnt.framework.callbacks.torchsnapshot_saver:Input dirpath doesn't contain any subdirectories: {temp_dir}" ], ) + self.assertFalse(restored) def test_save_restore_no_train_progress(self) -> None: input_dim = 2 @@ -264,7 +264,6 @@ def test_save_restore_no_train_progress(self) -> None: snapshot_cb = TorchSnapshotSaver( temp_dir, save_every_n_train_steps=save_every_n_train_steps, - replicated=["**"], ) train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]) @@ -293,7 +292,6 @@ def test_save_on_train_end(self) -> None: self.assertFalse(os.path.exists(os.path.join(temp_dir, expected_path))) snapshot_cb = TorchSnapshotSaver( temp_dir, - replicated=["**"], ) train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]) @@ -355,6 +353,7 @@ def _save_restore_fsdp() -> None: snapshot_cb = TorchSnapshotSaver( temp_dir, save_every_n_epochs=save_every_n_epochs, + replicated=["**"], ) temp_dir = snapshot_cb.dirpath train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb]) @@ -396,14 +395,12 @@ def test_saver_invalid_args(self) -> None: def test_latest_checkpoint_path(self) -> None: with tempfile.TemporaryDirectory() as temp_dir: - self.assertIsNone(TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir)) + self.assertIsNone(_get_latest_checkpoint_path(temp_dir)) with tempfile.TemporaryDirectory() as temp_dir: latest_path = os.path.join(temp_dir, "epoch_0_step_0") os.mkdir(latest_path) - self.assertEqual( - TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), latest_path - ) + self.assertEqual(_get_latest_checkpoint_path(temp_dir), latest_path) with tempfile.TemporaryDirectory() as temp_dir: path_1 = os.path.join(temp_dir, "epoch_0_step_0") @@ -414,9 +411,7 @@ def test_latest_checkpoint_path(self) -> None: os.mkdir(path_3) path_4 = os.path.join(temp_dir, "epoch_700") os.mkdir(path_4) - self.assertEqual( - TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), path_3 - ) + self.assertEqual(_get_latest_checkpoint_path(temp_dir), path_3) @unittest.skipUnless( torch.distributed.is_available(), reason="Torch distributed is needed to run" @@ -436,7 +431,7 @@ def _latest_checkpoint_path_distributed() -> None: temp_dir = tempfile.mkdtemp() else: temp_dir = "" - tc.assertIsNone(TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir)) + tc.assertIsNone(_get_latest_checkpoint_path(temp_dir)) if is_rank0: shutil.rmtree(temp_dir) # delete temp directory @@ -458,9 +453,7 @@ def _latest_checkpoint_path_distributed() -> None: path_container = [path_3] if is_rank0 else [None] pg.broadcast_object_list(path_container, 0) expected_path = path_container[0] - tc.assertEqual( - TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), expected_path - ) + tc.assertEqual(_get_latest_checkpoint_path(temp_dir), expected_path) if is_rank0: shutil.rmtree(temp_dir) # delete temp directory diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 7b3d1c61a9..9af602dc7f 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -226,7 +226,6 @@ def restore( restore_train_progress: Whether to restore the training progress state. restore_eval_progress: Whether to restore the evaluation progress state. storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations. - """ _validate_snapshot_available() @@ -268,7 +267,7 @@ def restore_from_latest( restore_train_progress: bool = True, restore_eval_progress: bool = True, storage_options: Optional[Dict[str, Any]] = None, - ) -> None: + ) -> bool: """ Given a parent directory where checkpoints are saved, restore the snapshot state from the latest checkpoint in the directory. @@ -282,10 +281,13 @@ def restore_from_latest( restore_train_progress: Whether to restore the training progress state. restore_eval_progress: Whether to restore the evaluation progress state. storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot `_. See each storage plugin's documentation for customizations. + + Returns: + True if the latest snapshot directory was found and successfully restored, otherwise False. """ - path = TorchSnapshotSaver.get_latest_checkpoint_path(dirpath) + path = _get_latest_checkpoint_path(dirpath) if path is None: - return + return False TorchSnapshotSaver.restore( path, unit, @@ -294,27 +296,28 @@ def restore_from_latest( restore_eval_progress=restore_eval_progress, storage_options=storage_options, ) + return True - @staticmethod - def get_latest_checkpoint_path(dirpath: str) -> Optional[str]: - """Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory.""" - ret = None - rank = get_global_rank() - # Do all filesystem reads from rank 0 only - if rank == 0: - ret = _latest_checkpoint_path(dirpath) +def _get_latest_checkpoint_path(dirpath: str) -> Optional[str]: + """Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory.""" - # If not running in a distributed setting, return as is - if not (dist.is_available() and dist.is_initialized()): - return ret - - # Otherwise, broadcast result from rank 0 to all ranks - pg = PGWrapper(dist.group.WORLD) - path_container = [ret] if rank == 0 else [None] - pg.broadcast_object_list(path_container, 0) - val = path_container[0] - return val + ret = None + rank = get_global_rank() + # Do all filesystem reads from rank 0 only + if rank == 0: + ret = _latest_checkpoint_path(dirpath) + + # If not running in a distributed setting, return as is + if not (dist.is_available() and dist.is_initialized()): + return ret + + # Otherwise, broadcast result from rank 0 to all ranks + pg = PGWrapper(dist.group.WORLD) + path_container = [ret] if rank == 0 else [None] + pg.broadcast_object_list(path_container, 0) + val = path_container[0] + return val def _latest_checkpoint_path(dirpath: str) -> Optional[str]: