Skip to content

Commit 732fbd6

Browse files
ananthsubfacebook-github-bot
authored andcommitted
Return bool from snapshot restore_from_latest
Summary: As title, this indicates to callers if the states were restored or not. This allows callers to do logic like: ``` restored = TorchSnapshotSaver.restore_from_latest(...) # no prior checkpoints, so initialize weights for the first attempt if not restored: <initialization logic> ``` Reviewed By: daniellepintz Differential Revision: D48207346 fbshipit-source-id: 6ea20dced7ee433aa9ee27f52998e3855a9f217e
1 parent a9ad674 commit 732fbd6

File tree

2 files changed

+40
-44
lines changed

2 files changed

+40
-44
lines changed

tests/framework/callbacks/test_torchsnapshot_saver.py

+15-22
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121

2222
from torchtnt.framework._test_utils import DummyTrainUnit, generate_random_dataloader
2323
from torchtnt.framework.auto_unit import AutoUnit
24-
from torchtnt.framework.callbacks import Lambda, TorchSnapshotSaver
24+
from torchtnt.framework.callbacks.lambda_callback import Lambda
25+
from torchtnt.framework.callbacks.torchsnapshot_saver import (
26+
_get_latest_checkpoint_path,
27+
TorchSnapshotSaver,
28+
)
2529
from torchtnt.framework.state import State
2630
from torchtnt.framework.train import train
2731
from torchtnt.utils.distributed import get_global_rank, PGWrapper
@@ -59,7 +63,6 @@ def test_save_every_n_train_steps(self) -> None:
5963
snapshot = TorchSnapshotSaver(
6064
temp_dir,
6165
save_every_n_train_steps=save_every_n_train_steps,
62-
replicated=["**"],
6366
)
6467
# Artificially increase the step duration, otherwise torchsnapshot
6568
# doesn't have the time to save all snapshots and will skip some.
@@ -91,7 +94,6 @@ def test_save_every_n_train_epochs(self) -> None:
9194
snapshot = TorchSnapshotSaver(
9295
temp_dir,
9396
save_every_n_epochs=save_every_n_train_epochs,
94-
replicated=["**"],
9597
)
9698
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot])
9799
self.assertTrue(
@@ -124,7 +126,6 @@ def test_save_restore(self) -> None:
124126
snapshot_cb = TorchSnapshotSaver(
125127
temp_dir,
126128
save_every_n_train_steps=save_every_n_train_steps,
127-
replicated=["**"],
128129
)
129130
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
130131

@@ -152,7 +153,6 @@ def test_save_restore_dataloader_state(self) -> None:
152153
snapshot_cb = TorchSnapshotSaver(
153154
temp_dir,
154155
save_every_n_train_steps=save_every_n_train_steps,
155-
replicated=["**"],
156156
)
157157
train(
158158
my_unit,
@@ -204,18 +204,18 @@ def test_restore_from_latest(self) -> None:
204204
snapshot_cb = TorchSnapshotSaver(
205205
temp_dir,
206206
save_every_n_train_steps=save_every_n_train_steps,
207-
replicated=["**"],
208207
)
209208
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
210209

211210
with mock.patch(
212211
"torchtnt.framework.callbacks.torchsnapshot_saver.TorchSnapshotSaver.restore"
213212
) as mock_restore:
214-
snapshot_cb.restore_from_latest(temp_dir, my_unit)
213+
restored = snapshot_cb.restore_from_latest(temp_dir, my_unit)
215214
self.assertIn(
216215
temp_dir + f"/epoch_{max_epochs}_step_{expected_steps_per_epoch}",
217216
mock_restore.call_args.args,
218217
)
218+
self.assertTrue(restored)
219219

220220
def test_restore_from_latest_empty_dir(self) -> None:
221221
input_dim = 2
@@ -226,17 +226,17 @@ def test_restore_from_latest_empty_dir(self) -> None:
226226
snapshot_cb = TorchSnapshotSaver(
227227
temp_dir,
228228
save_every_n_train_steps=save_every_n_train_steps,
229-
replicated=["**"],
230229
)
231230

232231
with self.assertLogs(level="WARNING") as log:
233-
snapshot_cb.restore_from_latest(temp_dir, my_unit)
232+
restored = snapshot_cb.restore_from_latest(temp_dir, my_unit)
234233
self.assertEqual(
235234
log.output,
236235
[
237236
f"WARNING:torchtnt.framework.callbacks.torchsnapshot_saver:Input dirpath doesn't contain any subdirectories: {temp_dir}"
238237
],
239238
)
239+
self.assertFalse(restored)
240240

241241
def test_save_restore_no_train_progress(self) -> None:
242242
input_dim = 2
@@ -264,7 +264,6 @@ def test_save_restore_no_train_progress(self) -> None:
264264
snapshot_cb = TorchSnapshotSaver(
265265
temp_dir,
266266
save_every_n_train_steps=save_every_n_train_steps,
267-
replicated=["**"],
268267
)
269268
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
270269

@@ -293,7 +292,6 @@ def test_save_on_train_end(self) -> None:
293292
self.assertFalse(os.path.exists(os.path.join(temp_dir, expected_path)))
294293
snapshot_cb = TorchSnapshotSaver(
295294
temp_dir,
296-
replicated=["**"],
297295
)
298296
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
299297

@@ -355,6 +353,7 @@ def _save_restore_fsdp() -> None:
355353
snapshot_cb = TorchSnapshotSaver(
356354
temp_dir,
357355
save_every_n_epochs=save_every_n_epochs,
356+
replicated=["**"],
358357
)
359358
temp_dir = snapshot_cb.dirpath
360359
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])
@@ -396,14 +395,12 @@ def test_saver_invalid_args(self) -> None:
396395

397396
def test_latest_checkpoint_path(self) -> None:
398397
with tempfile.TemporaryDirectory() as temp_dir:
399-
self.assertIsNone(TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir))
398+
self.assertIsNone(_get_latest_checkpoint_path(temp_dir))
400399

401400
with tempfile.TemporaryDirectory() as temp_dir:
402401
latest_path = os.path.join(temp_dir, "epoch_0_step_0")
403402
os.mkdir(latest_path)
404-
self.assertEqual(
405-
TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), latest_path
406-
)
403+
self.assertEqual(_get_latest_checkpoint_path(temp_dir), latest_path)
407404

408405
with tempfile.TemporaryDirectory() as temp_dir:
409406
path_1 = os.path.join(temp_dir, "epoch_0_step_0")
@@ -414,9 +411,7 @@ def test_latest_checkpoint_path(self) -> None:
414411
os.mkdir(path_3)
415412
path_4 = os.path.join(temp_dir, "epoch_700")
416413
os.mkdir(path_4)
417-
self.assertEqual(
418-
TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), path_3
419-
)
414+
self.assertEqual(_get_latest_checkpoint_path(temp_dir), path_3)
420415

421416
@unittest.skipUnless(
422417
torch.distributed.is_available(), reason="Torch distributed is needed to run"
@@ -436,7 +431,7 @@ def _latest_checkpoint_path_distributed() -> None:
436431
temp_dir = tempfile.mkdtemp()
437432
else:
438433
temp_dir = ""
439-
tc.assertIsNone(TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir))
434+
tc.assertIsNone(_get_latest_checkpoint_path(temp_dir))
440435
if is_rank0:
441436
shutil.rmtree(temp_dir) # delete temp directory
442437

@@ -458,9 +453,7 @@ def _latest_checkpoint_path_distributed() -> None:
458453
path_container = [path_3] if is_rank0 else [None]
459454
pg.broadcast_object_list(path_container, 0)
460455
expected_path = path_container[0]
461-
tc.assertEqual(
462-
TorchSnapshotSaver.get_latest_checkpoint_path(temp_dir), expected_path
463-
)
456+
tc.assertEqual(_get_latest_checkpoint_path(temp_dir), expected_path)
464457

465458
if is_rank0:
466459
shutil.rmtree(temp_dir) # delete temp directory

torchtnt/framework/callbacks/torchsnapshot_saver.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def restore(
226226
restore_train_progress: Whether to restore the training progress state.
227227
restore_eval_progress: Whether to restore the evaluation progress state.
228228
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
229-
230229
"""
231230

232231
_validate_snapshot_available()
@@ -268,7 +267,7 @@ def restore_from_latest(
268267
restore_train_progress: bool = True,
269268
restore_eval_progress: bool = True,
270269
storage_options: Optional[Dict[str, Any]] = None,
271-
) -> None:
270+
) -> bool:
272271
"""
273272
Given a parent directory where checkpoints are saved, restore the snapshot state from the latest checkpoint in the directory.
274273
@@ -282,10 +281,13 @@ def restore_from_latest(
282281
restore_train_progress: Whether to restore the training progress state.
283282
restore_eval_progress: Whether to restore the evaluation progress state.
284283
storage_options: Additional keyword options for the storage plugin to use, to be passed to `torchsnapshot.Snapshot <https://pytorch.org/torchsnapshot/stable/api_reference.html#torchsnapshot.Snapshot>`_. See each storage plugin's documentation for customizations.
284+
285+
Returns:
286+
True if the latest snapshot directory was found and successfully restored, otherwise False.
285287
"""
286-
path = TorchSnapshotSaver.get_latest_checkpoint_path(dirpath)
288+
path = _get_latest_checkpoint_path(dirpath)
287289
if path is None:
288-
return
290+
return False
289291
TorchSnapshotSaver.restore(
290292
path,
291293
unit,
@@ -294,27 +296,28 @@ def restore_from_latest(
294296
restore_eval_progress=restore_eval_progress,
295297
storage_options=storage_options,
296298
)
299+
return True
297300

298-
@staticmethod
299-
def get_latest_checkpoint_path(dirpath: str) -> Optional[str]:
300-
"""Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory."""
301301

302-
ret = None
303-
rank = get_global_rank()
304-
# Do all filesystem reads from rank 0 only
305-
if rank == 0:
306-
ret = _latest_checkpoint_path(dirpath)
302+
def _get_latest_checkpoint_path(dirpath: str) -> Optional[str]:
303+
"""Given a parent directory where checkpoints are saved, return the latest checkpoint subdirectory."""
307304

308-
# If not running in a distributed setting, return as is
309-
if not (dist.is_available() and dist.is_initialized()):
310-
return ret
311-
312-
# Otherwise, broadcast result from rank 0 to all ranks
313-
pg = PGWrapper(dist.group.WORLD)
314-
path_container = [ret] if rank == 0 else [None]
315-
pg.broadcast_object_list(path_container, 0)
316-
val = path_container[0]
317-
return val
305+
ret = None
306+
rank = get_global_rank()
307+
# Do all filesystem reads from rank 0 only
308+
if rank == 0:
309+
ret = _latest_checkpoint_path(dirpath)
310+
311+
# If not running in a distributed setting, return as is
312+
if not (dist.is_available() and dist.is_initialized()):
313+
return ret
314+
315+
# Otherwise, broadcast result from rank 0 to all ranks
316+
pg = PGWrapper(dist.group.WORLD)
317+
path_container = [ret] if rank == 0 else [None]
318+
pg.broadcast_object_list(path_container, 0)
319+
val = path_container[0]
320+
return val
318321

319322

320323
def _latest_checkpoint_path(dirpath: str) -> Optional[str]:

0 commit comments

Comments
 (0)