Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add keep_latest_n checkpoint arg in TSS #619

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch import nn
from torch.distributed import launcher
from torch.utils.data import DataLoader
from torchsnapshot import Snapshot
from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME
from torchsnapshot.test_utils import assert_state_dict_eq, check_state_dict_eq

Expand All @@ -33,7 +34,9 @@
)
from torchtnt.framework.callbacks.lambda_callback import Lambda
from torchtnt.framework.callbacks.torchsnapshot_saver import (
_delete_snapshot,
_override_knobs,
_retrieve_checkpoint_dirpaths,
get_latest_checkpoint_path,
KnobOptions,
RestoreOptions,
Expand Down Expand Up @@ -669,6 +672,174 @@ def test_knob_override(self) -> None:
with _override_knobs(KnobOptions(max_per_rank_io_concurrency=None)):
self.assertNotIn(env_var, os.environ)

@patch("torchtnt.framework.callbacks.torchsnapshot_saver.get_filesystem")
def test_retrieve_checkpoint_dirpaths(self, mock_get_filesystem: MagicMock) -> None:
"""
Tests retrieving checkpoint directories from a given root directory
"""
paths = [
{"name": "tmp/epoch_0_step_10", "type": "directory"},
{"name": "tmp/epoch_1_step_10", "type": "directory"},
{"name": "tmp/epoch_2_step_10", "type": "directory"},
{"name": "tmp/epoch_0_step_5", "type": "directory"},
{"name": "tmp/epoch_0_step_3", "type": "file"},
]

mock_get_filesystem.return_value.ls.return_value = paths
returned_paths = _retrieve_checkpoint_dirpaths("foo")
self.assertEqual(
returned_paths,
[
"tmp/epoch_0_step_5",
"tmp/epoch_0_step_10",
"tmp/epoch_1_step_10",
"tmp/epoch_2_step_10",
],
)

def test_delete_snapshot(self) -> None:
"""
Tests removing checkpoint directories
"""
app_state = {"module": nn.Linear(2, 2)}
with tempfile.TemporaryDirectory() as temp_dir:
dirpath = os.path.join(temp_dir, "checkpoint")
Snapshot.take(dirpath, app_state=app_state)
self.assertTrue(os.path.exists(dirpath))
# check that error is thrown if .snapshot_metadata is not found in the directory when deleting
with self.assertRaisesRegex(
RuntimeError, f"{temp_dir} does not contain .snapshot_metadata"
):
_delete_snapshot(temp_dir)
_delete_snapshot(dirpath)
self.assertFalse(os.path.exists(dirpath))

def test_should_remove_snapshot(self) -> None:
"""
Tests the helper function that checks if snapshot should be removed or not
"""
tss = TorchSnapshotSaver("temp")

# keep_last_n_checkpoints is toggled off
self.assertFalse(tss._should_remove_snapshot())

# not enough checkpoints are saved yet to be removed
tss._keep_last_n_checkpoints = 2
tss._ckpt_dirpaths = ["bar"]
self.assertFalse(tss._should_remove_snapshot())

# enough checkpoints are there to remove
tss._keep_last_n_checkpoints = 2
tss._ckpt_dirpaths = ["foo", "bar"]
self.assertTrue(tss._should_remove_snapshot())

@patch("torchtnt.framework.callbacks.torchsnapshot_saver._delete_snapshot")
def test_remove_snapshot(self, mock_delete_snapshot: MagicMock) -> None:
"""
Tests the helper function that removes snapshots and updates the checkpoint paths
"""
state = get_dummy_train_state()
tss = TorchSnapshotSaver("temp")
tss._ckpt_dirpaths = ["foo", "bar"]
tss._remove_snapshot(state)

mock_delete_snapshot.assert_called_once()
self.assertEqual(len(tss._ckpt_dirpaths), 1)
self.assertEqual(tss._ckpt_dirpaths[0], "bar")

@patch("torchtnt.framework.callbacks.torchsnapshot_saver._delete_snapshot")
def test_cleanup_surplus(self, mock_delete_snapshot: MagicMock) -> None:
"""
Tests surplus of checkpoints being cleaned up
"""
state = get_dummy_train_state()
unit = DummyTrainUnit(input_dim=2)
warning_messages = []
with tempfile.TemporaryDirectory() as temp_dir:
tss = TorchSnapshotSaver(temp_dir, keep_last_n_checkpoints=1)
tss._ckpt_dirpaths = ["foo", "bar", "baz"]

expected_warning_msg = " ".join(
[
f"3 checkpoints found in {temp_dir}.",
f"Deleting {2} oldest",
"checkpoints to enforce ``keep_last_n_checkpoints`` argument.",
]
)

with patch(
"torchtnt.framework.callbacks.torchsnapshot_saver.logging.Logger.warning",
warning_messages.append,
):
tss.on_train_start(state, unit)
self.assertEqual(tss._ckpt_dirpaths, ["baz"])
self.assertEqual(warning_messages[0], expected_warning_msg)

tss = TorchSnapshotSaver(temp_dir)
tss._ckpt_dirpaths = ["foo", "bar", "baz"]

tss.on_train_start(state, unit)
self.assertEqual(tss._ckpt_dirpaths, ["foo", "bar", "baz"])

def test_keep_last_n_checkpoints(self) -> None:
"""
Tests removing checkpoint directories
"""
unit = DummyTrainUnit(input_dim=2)
state = get_dummy_train_state()
with tempfile.TemporaryDirectory() as temp_dir:
tss = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=1,
keep_last_n_checkpoints=2,
)

# take 10 steps
for _ in range(10):
unit.train_progress.increment_step()
tss.on_train_step_end(state, unit)
# TODO remove time.sleep to avoid potential flaky test
time.sleep(0.1) # sleep to ensure enough time to checkpoint

dirs = os.listdir(temp_dir)
self.assertEqual(len(dirs), 2)
self.assertIn("epoch_0_step_9", dirs)
self.assertIn("epoch_0_step_10", dirs)

def test_keep_last_n_checkpoints_e2e(self) -> None:
"""
Tests removing checkpoint directories e2e
"""
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2

my_unit = DummyTrainUnit(input_dim=input_dim)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
with tempfile.TemporaryDirectory() as temp_dir:
snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=2,
keep_last_n_checkpoints=1,
)
# Artificially increase the step duration, otherwise torchsnapshot
# doesn't have the time to save all snapshots and will skip some.
slowdown = Lambda(on_train_step_end=lambda *_: time.sleep(0.1))

train(
my_unit,
dataloader,
max_epochs=max_epochs,
callbacks=[snapshot_cb, slowdown],
)
dirs = os.listdir(temp_dir)
self.assertEqual(len(dirs), 1)
self.assertIn(
f"epoch_{max_epochs}_step_{dataset_len // batch_size * max_epochs}",
os.listdir(temp_dir),
)


class DummyStatefulDataLoader:
def __init__(self, dataloader: DataLoader) -> None:
Expand Down
110 changes: 106 additions & 4 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class TorchSnapshotSaver(Callback):
dirpath: Parent directory to save snapshots to.
save_every_n_train_steps: Frequency of steps with which to save snapshots during the train epoch. If None, no intra-epoch snapshots are generated.
save_every_n_epochs: Frequency of epochs with which to save snapshots during training. If None, no end-of-epoch snapshots are generated.
keep_last_n_checkpoints: Number of most recent checkpoints to keep. If None, all checkpoints are kept. If an excess of existing checkpoints are present, the oldest ones will be deleted to clean the difference.
process_group: The process group on which the ranks will communicate on. default: ``None`` (the entire world)
replicated: A glob-pattern of replicated key names that indicate which application state entries have the same state across all processes.
For more information, see https://pytorch.org/torchsnapshot/main/api_reference.html#torchsnapshot.Snapshot.take.
Expand All @@ -108,6 +109,7 @@ def __init__(
*,
save_every_n_train_steps: Optional[int] = None,
save_every_n_epochs: Optional[int] = None,
keep_last_n_checkpoints: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
replicated: Optional[List[str]] = None,
storage_options: Optional[Dict[str, Any]] = None,
Expand All @@ -122,9 +124,20 @@ def __init__(
raise ValueError(
f"Invalid value passed for save_every_n_epochs. Expected to receive either None or positive number, but received {save_every_n_epochs}"
)
if keep_last_n_checkpoints is not None and keep_last_n_checkpoints <= 0:
raise ValueError(
f"Invalid value passed for keep_last_n_checkpoints. Expected to receive either None or positive number, but received {keep_last_n_checkpoints}"
)
self._save_every_n_epochs = save_every_n_epochs
self._save_every_n_train_steps = save_every_n_train_steps

self._keep_last_n_checkpoints = keep_last_n_checkpoints
self._ckpt_dirpaths: List[str] = []
if self._keep_last_n_checkpoints:
self._ckpt_dirpaths = _retrieve_checkpoint_dirpaths(dirpath)

self._process_group = process_group
self._pg_wrapper = PGWrapper(process_group)
self._sync_dirpath_to_all_ranks(dirpath)
self._replicated: Set[str] = set(replicated or [])

Expand Down Expand Up @@ -156,6 +169,24 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None:
app_state = _app_state(unit)
_check_app_state_collision(app_state)

# clean up the difference if surplus of checkpoints exist
keep_last_n_checkpoints = self._keep_last_n_checkpoints
if (
keep_last_n_checkpoints
and len(self._ckpt_dirpaths) > keep_last_n_checkpoints
):
logger.warning(
" ".join(
[
f"{len(self._ckpt_dirpaths)} checkpoints found in {self._dirpath}.",
f"Deleting {len(self._ckpt_dirpaths) - keep_last_n_checkpoints} oldest",
"checkpoints to enforce ``keep_last_n_checkpoints`` argument.",
]
)
)
for _ in range(len(self._ckpt_dirpaths) - keep_last_n_checkpoints):
self._remove_snapshot(state)

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
num_steps_completed = unit.train_progress.num_steps_completed
save_every_n_train_steps = self._save_every_n_train_steps
Expand All @@ -177,7 +208,14 @@ def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
with get_timing_context(
state, f"{self.__class__.__name__}.take_async_snapshot"
):
self._async_snapshot(snapshot_path, app_state, wait=False)
checkpoint_success = self._async_snapshot(
snapshot_path, app_state, wait=False
)

if checkpoint_success:
if self._should_remove_snapshot():
self._remove_snapshot(state)
self._ckpt_dirpaths.append(snapshot_path)

def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
epoch = unit.train_progress.num_epochs_completed
Expand All @@ -197,7 +235,14 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
with get_timing_context(
state, f"{self.__class__.__name__}.take_async_snapshot"
):
self._async_snapshot(snapshot_path, app_state, wait=True)
checkpoint_success = self._async_snapshot(
snapshot_path, app_state, wait=True
)

if checkpoint_success:
if self._should_remove_snapshot():
self._remove_snapshot(state)
self._ckpt_dirpaths.append(snapshot_path)

def on_train_end(self, state: State, unit: TTrainUnit) -> None:
app_state = _get_app_state(state, unit, self._replicated, intra_epoch=False)
Expand All @@ -213,9 +258,20 @@ def on_train_end(self, state: State, unit: TTrainUnit) -> None:
with get_timing_context(
state, f"{self.__class__.__name__}.take_async_snapshot"
):
self._async_snapshot(snapshot_path, app_state, wait=True)
# TODO checkpoint is not truly successful
# since this is async checkpointed, so in
# future, add logic to set successful flag
# only when checkpoint is fully written
checkpoint_success = self._async_snapshot(
snapshot_path, app_state, wait=True
)
self._wait()

if checkpoint_success:
if self._should_remove_snapshot():
self._remove_snapshot(state)
self._ckpt_dirpaths.append(snapshot_path)

def on_exception(
self,
state: State,
Expand All @@ -224,6 +280,22 @@ def on_exception(
) -> None:
self._wait()

def _should_remove_snapshot(self) -> bool:
keep_last_n_checkpoints = self._keep_last_n_checkpoints
return (
keep_last_n_checkpoints is not None
and len(self._ckpt_dirpaths) >= keep_last_n_checkpoints
)

def _remove_snapshot(self, state: State) -> None:
# remove oldest snapshot directory
oldest_ckpt_path = self._ckpt_dirpaths.pop(0)
with get_timing_context(state, f"{self.__class__.__name__}.delete_snapshot"):
if self._pg_wrapper.get_rank() == 0:
# only delete on rank 0
_delete_snapshot(oldest_ckpt_path)
self._pg_wrapper.barrier()

def _wait(self) -> None:
if self._prev_snapshot is not None:
self._prev_snapshot.wait()
Expand All @@ -236,7 +308,7 @@ def _async_snapshot(
if prev_snapshot.path == snapshot_path:
# Snapshot for this step already has been saved.
# This can happen if we call _async_snapshot twice at the same step.
return True
return False
still_pending = not prev_snapshot.done()
if still_pending and wait:
prev_snapshot.wait()
Expand Down Expand Up @@ -563,3 +635,33 @@ def _override_knobs(
for mgr in knobs:
stack.enter_context(mgr)
yield


def _retrieve_checkpoint_dirpaths(dirpath: str) -> List[str]:
"""
Given a parent directory where checkpoints are saved, return the sorted checkpoint subdirectories
from oldest to newest.

Args:
dirpath: parent directory where checkpoints are saved.
"""
fs = get_filesystem(dirpath)

contents = fs.ls(dirpath, detail=True)
contents = [item["name"] for item in contents if item["type"] == "directory"]
ckpt_dirpaths = []
for path in contents:
match = re.search(r"epoch_(\d+)_step_(\d+)", path)
if match:
ckpt_dirpaths.append(path)

# sorts by epoch, then step
ckpt_dirpaths.sort(key=lambda x: (int(x.split("_")[1]), int(x.split("_")[3])))
return ckpt_dirpaths


def _delete_snapshot(dirpath: str) -> None:
fs = get_filesystem(dirpath)
if not fs.exists(os.path.join(dirpath, ".snapshot_metadata")):
raise RuntimeError(f"{dirpath} does not contain .snapshot_metadata")
fs.rm(dirpath, recursive=True)
Loading