Skip to content

Commit

Permalink
Return bool from snapshot restore_from_latest (#498)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #498

As title, this indicates to callers if the states were restored or not. This allows callers to do logic like:

```
restored = TorchSnapshotSaver.restore_from_latest(...)

# no prior checkpoints, so initialize weights for the first attempt
if not restored:
    <initialization logic>
```

Reviewed By: daniellepintz

Differential Revision: D48207346

fbshipit-source-id: 063640f85d6d98aa46c8759610284c723cb8a91e
  • Loading branch information
ananthsub authored and facebook-github-bot committed Aug 10, 2023
1 parent 84219f0 commit ca175c5
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 44 deletions.
37 changes: 15 additions & 22 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader
from torchtnt.framework.auto_unit import AutoUnit
from torchtnt.framework.callbacks import Lambda, TorchSnapshotSaver
from torchtnt.framework.callbacks.lambda_callback import Lambda
from torchtnt.framework.callbacks.torchsnapshot_saver import (
_get_latest_checkpoint_path,
TorchSnapshotSaver,
)
from torchtnt.framework.state import State
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, PGWrapper
Expand Down Expand Up @@ -59,7 +63,6 @@ def test_save_every_n_train_steps(self) -> None:
snapshot = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=save_every_n_train_steps,
replicated=["**"],
)
# Artificially increase the step duration, otherwise torchsnapshot
# doesn't have the time to save all snapshots and will skip some.
Expand Down Expand Up @@ -91,7 +94,6 @@ def test_save_every_n_train_epochs(self) -> None:
snapshot = TorchSnapshotSaver(
temp_dir,
save_every_n_epochs=save_every_n_train_epochs,
replicated=["**"],
)
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot])
self.assertTrue(
Expand Down Expand Up @@ -124,7 +126,6 @@ def test_save_restore(self) -> None:
snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=save_every_n_train_steps,
replicated=["**"],
)
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

Expand Down Expand Up @@ -152,7 +153,6 @@ def test_save_restore_dataloader_state(self) -> None:
snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=save_every_n_train_steps,
replicated=["**"],
)
train(
my_unit,
Expand Down Expand Up @@ -204,18 +204,18 @@ def test_restore_from_latest(self) -> None:
snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=save_every_n_train_steps,
replicated=["**"],
)
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

with mock.patch(
"torchtnt.framework.callbacks.torchsnapshot_saver.TorchSnapshotSaver.restore"
) as mock_restore:
snapshot_cb.restore_from_latest(temp_dir, my_unit)
restored = snapshot_cb.restore_from_latest(temp_dir, my_unit)
self.assertIn(
temp_dir + f"/epoch_{max_epochs}_step_{expected_steps_per_epoch}",
mock_restore.call_args.args,
)
self.assertTrue(restored)

def test_restore_from_latest_empty_dir(self) -> None:
input_dim = 2
Expand All @@ -226,17 +226,17 @@ def test_restore_from_latest_empty_dir(self) -> None:
snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=save_every_n_train_steps,
replicated=["**"],
)

with self.assertLogs(level="WARNING") as log:
snapshot_cb.restore_from_latest(temp_dir, my_unit)
restored = snapshot_cb.restore_from_latest(temp_dir, my_unit)
self.assertEqual(
log.output,
[
f"WARNING:torchtnt.framework.callbacks.torchsnapshot_saver:Input dirpath doesn't contain any subdirectories: {temp_dir}"
],
)
self.assertFalse(restored)

def test_save_restore_no_train_progress(self) -> None:
input_dim = 2
Expand Down Expand Up @@ -264,7 +264,6 @@ def test_save_restore_no_train_progress(self) -> None:
snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=save_every_n_train_steps,
replicated=["**"],
)
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

Expand Down Expand Up @@ -293,7 +292,6 @@ def test_save_on_train_end(self) -> None:
self.assertFalse(os.path.exists(os.path.join(temp_dir, expected_path)))
snapshot_cb = TorchSnapshotSaver(
temp_dir,
replicated=["**"],
)
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

Expand Down Expand Up @@ -355,6 +353,7 @@ def _save_restore_fsdp() -> None:
snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
replicated=["**"],
)
temp_dir = snapshot_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
Expand Down Expand Up @@ -396,14 +395,12 @@ def test_saver_invalid_args(self) -> None:

def test_latest_checkpoint_path(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
self.assertIsNone(TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir))
self.assertIsNone(_get_latest_checkpoint_path(temp_dir))

with tempfile.TemporaryDirectory() as temp_dir:
latest_path = os.path.join(temp_dir, "epoch_0_step_0")
os.mkdir(latest_path)
self.assertEqual(
TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), latest_path
)
self.assertEqual(_get_latest_checkpoint_path(temp_dir), latest_path)

with tempfile.TemporaryDirectory() as temp_dir:
path_1 = os.path.join(temp_dir, "epoch_0_step_0")
Expand All @@ -414,9 +411,7 @@ def test_latest_checkpoint_path(self) -> None:
os.mkdir(path_3)
path_4 = os.path.join(temp_dir, "epoch_700")
os.mkdir(path_4)
self.assertEqual(
TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), path_3
)
self.assertEqual(_get_latest_checkpoint_path(temp_dir), path_3)

@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
Expand All @@ -436,7 +431,7 @@ def _latest_checkpoint_path_distributed() -> None:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""
tc.assertIsNone(TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir))
tc.assertIsNone(_get_latest_checkpoint_path(temp_dir))
if is_rank0:
shutil.rmtree(temp_dir) # delete temp directory

Expand All @@ -458,9 +453,7 @@ def _latest_checkpoint_path_distributed() -> None:
path_container = [path_3] if is_rank0 else [None]
pg.broadcast_object_list(path_container, 0)
expected_path = path_container[0]
tc.assertEqual(
TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), expected_path
)
tc.assertEqual(_get_latest_checkpoint_path(temp_dir), expected_path)

if is_rank0:
shutil.rmtree(temp_dir) # delete temp directory
Expand Down
47 changes: 25 additions & 22 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def restore(
restore_train_progress: Whether to restore the training progress state.
restore_eval_progress: Whether to restore the evaluation progress state.
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
"""

_validate_snapshot_available()
Expand Down Expand Up @@ -268,7 +267,7 @@ def restore_from_latest(
restore_train_progress: bool = True,
restore_eval_progress: bool = True,
storage_options: Optional[Dict[str, Any]] = None,
) -> None:
) -> bool:
"""
Given a parent directory where checkpoints are saved, restore the snapshot state from the latest checkpoint in the directory.
Expand All @@ -282,10 +281,13 @@ def restore_from_latest(
restore_train_progress: Whether to restore the training progress state.
restore_eval_progress: Whether to restore the evaluation progress state.
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
Returns:
True if the latest snapshot directory was found and successfully restored, otherwise False.
"""
path = TorchSnapshotSaver.get_latest_checkpoint_path(dirpath)
path = _get_latest_checkpoint_path(dirpath)
if path is None:
return
return False
TorchSnapshotSaver.restore(
path,
unit,
Expand All @@ -294,27 +296,28 @@ def restore_from_latest(
restore_eval_progress=restore_eval_progress,
storage_options=storage_options,
)
return True

@staticmethod
def get_latest_checkpoint_path(dirpath: str) -> Optional[str]:
"""Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory."""

ret = None
rank = get_global_rank()
# Do all filesystem reads from rank 0 only
if rank == 0:
ret = _latest_checkpoint_path(dirpath)
def _get_latest_checkpoint_path(dirpath: str) -> Optional[str]:
"""Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory."""

# If not running in a distributed setting, return as is
if not (dist.is_available() and dist.is_initialized()):
return ret

# Otherwise, broadcast result from rank 0 to all ranks
pg = PGWrapper(dist.group.WORLD)
path_container = [ret] if rank == 0 else [None]
pg.broadcast_object_list(path_container, 0)
val = path_container[0]
return val
ret = None
rank = get_global_rank()
# Do all filesystem reads from rank 0 only
if rank == 0:
ret = _latest_checkpoint_path(dirpath)

# If not running in a distributed setting, return as is
if not (dist.is_available() and dist.is_initialized()):
return ret

# Otherwise, broadcast result from rank 0 to all ranks
pg = PGWrapper(dist.group.WORLD)
path_container = [ret] if rank == 0 else [None]
pg.broadcast_object_list(path_container, 0)
val = path_container[0]
return val


def _latest_checkpoint_path(dirpath: str) -> Optional[str]:
Expand Down

0 comments on commit ca175c5

Please sign in to comment.