From 5a655057fafe10d0a05ce06de550772c2ec3589f Mon Sep 17 00:00:00 2001 From: jbieniusiewi <152396322+jbieniusiewi@users.noreply.github.com> Date: Thu, 8 Feb 2024 07:11:27 +0100 Subject: [PATCH] Unfinished checkpoints handling (#7952) * Unfinished checkpoints handling + tests Signed-off-by: Jacek Bieniusiewicz * Fixed EMA checkpoint tests Signed-off-by: Jacek Bieniusiewicz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes after review Signed-off-by: Jacek Bieniusiewicz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Merged latest main Signed-off-by: Jacek Bieniusiewicz * Removed not used barrier_before and barrier_after params Signed-off-by: Jacek Bieniusiewicz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Cosmetic change: removed redundant commas Signed-off-by: Jacek Bieniusiewicz --------- Signed-off-by: Jacek Bieniusiewicz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com> --- nemo/utils/callbacks/nemo_model_checkpoint.py | 163 ++++++++- nemo/utils/exp_manager.py | 16 +- tests/core/test_exp_manager.py | 323 ++++++++++++++++++ 3 files changed, 498 insertions(+), 4 deletions(-) diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index a290152907db..2c3325c56f82 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -38,6 +38,8 @@ class NeMoModelCheckpoint(ModelCheckpoint): Also contains func to save the EMA copy of the model. """ + UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished" + def __init__( self, always_save_nemo: bool = False, @@ -139,6 +141,44 @@ def nemo_topk_check_previous_run(self): self.best_model_path = best_k_models[0] self.best_model_score = self.best_k_models[self.best_model_path] + def _remove_invalid_entries_from_topk(self): + # Removes invalid (incomplete or not existing) checkpoints from topk checkpoints. + # This might be needed if the checkpointing was abruptly terminated. + def __is_ckpt_ok(ckpt_path: str) -> bool: + exists = ( + os.path.isfile(ckpt_path) + or os.path.isfile(inject_model_parallel_rank(ckpt_path)) + or os.path.isdir(ckpt_path.removesuffix('.ckpt')) + ) + return exists and not self.is_checkpoint_unfinished(ckpt_path) + + self.best_k_models = {k: v for k, v in self.best_k_models.items() if __is_ckpt_ok(k)} + if len(self.best_k_models) > 0: + reverse_arr = self.mode != "min" + best_k_models_arr = sorted(self.best_k_models, key=self.best_k_models.get, reverse=reverse_arr) + self.kth_best_model_path = best_k_models_arr[-1] + self.kth_value = self.best_k_models[self.kth_best_model_path] + self.best_model_path = best_k_models_arr[0] + self.best_model_score = self.best_k_models[self.best_model_path] + else: + self.kth_best_model_path = "" + self.kth_value = None + self.best_model_path = "" + self.best_model_score = None + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + super().load_state_dict(state_dict) + self._remove_invalid_entries_from_topk() + + def setup(self, *args, **kwargs) -> None: + if is_global_rank_zero(): + logging.debug("Removing unfinished checkpoints if any...") + NeMoModelCheckpoint._remove_unfinished_checkpoints(self.dirpath) + # Ensure that all ranks continue with unfinished checkpoints removed + if torch.distributed.is_initialized(): + torch.distributed.barrier() + super().setup(*args, **kwargs) + def on_save_checkpoint(self, trainer, pl_module, checkpoint): output = super().on_save_checkpoint(trainer, pl_module, checkpoint) if not self.always_save_nemo: @@ -257,7 +297,77 @@ def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]: ema_callback = callback return ema_callback + @staticmethod + def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) -> Path: + """ Format the path to the unfinished checkpoint marker file. + + If the marker file exists, corresponding checkpoint is considered unfinished/incomplete. + NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint. + + Args: + checkpoint_path: Path to the checkpoint file or dir. + Does not need to exist. + + Returns: + Path to the unfinished checkpoint marker file. + """ + marker_filepath = str(uninject_model_parallel_rank(checkpoint_path)) + marker_filepath = marker_filepath.removesuffix(".nemo") + marker_filepath = marker_filepath.removesuffix(".ckpt") + marker_filepath = marker_filepath.removesuffix("-EMA") + return Path(marker_filepath + NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX) + + @staticmethod + def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool: + """ Check if the checkpoint is unfinished. + + Args: + checkpoint_path: Path to the checkpoint file or dir. + Does not need to exist. + + Returns: + True if the checkpoint is unfinished, False otherwise. + """ + return NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path).exists() + + @staticmethod + def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_after=False) -> None: + """ Marks given checkpoint as unfinished. + + Args: + checkpoint_filepath: Path to the checkpoint file or dir. + Does not need to exist. + barrier_after: Synchronize ranks after writing the marker file. + Defaults to False. + """ + if is_global_rank_zero(): + marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) + marker_path.parent.mkdir(parents=True, exist_ok=True) + marker_path.touch() + if barrier_after and torch.distributed.is_initialized(): + torch.distributed.barrier() + + @staticmethod + def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_before=False) -> None: + """Clear unfinished marker for given checkpoint. + + Args: + checkpoint_path: Path to the checkpoint file or dir. + Does not need to exist. + barrier_before: Synchronize ranks before removing the marker file. + Defaults to False. + """ + if barrier_before and torch.distributed.is_initialized(): + torch.distributed.barrier() + if is_global_rank_zero(): + marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path) + if marker_path.exists(): + marker_path.unlink() + def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. + # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. + self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) if ema_callback is not None: with ema_callback.save_original_optimizer_state(trainer): @@ -271,14 +381,23 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) super()._save_checkpoint(trainer, filepath) else: super()._save_checkpoint(trainer, filepath) + # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker + # we don't want to remove the marker until all checkpointing is done. + self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str) -> None: + # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. + # if anything goes wrong during removal, we should be able to detect that data is incomplete. + self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) super()._remove_checkpoint(trainer, filepath) ema_callback = self._ema_callback(trainer) if ema_callback is not None: # remove EMA copy of the state dict as well. filepath = self._ema_format_filepath(filepath) super()._remove_checkpoint(trainer, filepath) + # barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker + # we don't want to remove the marker until the checkpoint is actually removed. + self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) def _ema_format_filepath(self, filepath: str) -> str: return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') @@ -292,8 +411,46 @@ def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: @property def _saved_checkpoint_paths(self) -> Iterable[Path]: # distributed checkpoints are directories so we check for them here - dist_checkpoints = [d for d in list(Path(self.dirpath).glob("*")) if d.is_dir()] + # we filter out unfinished checkpoints, these should be deleted during next cleanup + dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()] if dist_checkpoints: - return dist_checkpoints + return filter(lambda p: not self.is_checkpoint_unfinished(p), dist_checkpoints) else: - return Path(self.dirpath).rglob("*.ckpt") + checkpoint_files = [f for f in Path(self.dirpath).rglob("*.ckpt")] + return filter(lambda p: not self.is_checkpoint_unfinished(p), checkpoint_files) + + @staticmethod + def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None: + + # Delete unfinished checkpoints from the filesystems. + # "Unfinished marker" files are removed as well. + + if not is_global_rank_zero(): + raise AssertionError("_remove_unfinished_checkpoints should run only on rank 0") + + checkpoint_dir = Path(checkpoint_dir) + + existing_marker_filepaths = { + f.resolve() + for f in checkpoint_dir.glob(f"*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}") + if f.is_file() + } + + checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")} + for ckpt_filepath in checkpoint_filepaths: + possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_filepath) + if possible_marker_path in existing_marker_filepaths: + logging.warning(f'Removing unfinished checkpoint: {ckpt_filepath}') + os.remove(ckpt_filepath) + + # some directories might be distributed checkpoints, we remove these if they have a unfinished marker + all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()} + for ckpt_dirpath in all_dirpaths: + possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath) + if possible_marker_path in existing_marker_filepaths: + logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}') + shutil.rmtree(ckpt_dirpath) + + # delete markers + for marker_path in existing_marker_filepaths: + os.remove(marker_path) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 4bde204f2976..db45701385e8 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -22,7 +22,7 @@ from datetime import timedelta from pathlib import Path from shutil import copy, move -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Collection, Dict, List, Optional, Tuple, Union import pytorch_lightning import torch @@ -564,6 +564,18 @@ def error_checks(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictC ) +def _filter_out_unfinished_checkpoints(checkpoint_paths: Collection[Union[Path, str]]) -> Collection[Union[Path, str]]: + res = [] + for chkpt_path in checkpoint_paths: + if NeMoModelCheckpoint.is_checkpoint_unfinished(chkpt_path): + logging.warning( + f'Checkpoint {chkpt_path} has the unfinished marker set - skipped while looking for the last one.' + ) + else: + res.append(chkpt_path) + return res + + def check_resume( trainer: 'pytorch_lightning.Trainer', log_dir: str, @@ -604,7 +616,9 @@ def check_resume( last_dist_checkpoints = [d for d in dist_checkpoints if d.match("*last")] end_checkpoints = end_dist_checkpoints if end_dist_checkpoints else list(checkpoint_dir.rglob("*end.ckpt")) + end_checkpoints = _filter_out_unfinished_checkpoints(end_checkpoints) last_checkpoints = last_dist_checkpoints if last_dist_checkpoints else list(checkpoint_dir.rglob("*last.ckpt")) + last_checkpoints = _filter_out_unfinished_checkpoints(last_checkpoints) if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0): if resume_ignore_no_checkpoint: diff --git a/tests/core/test_exp_manager.py b/tests/core/test_exp_manager.py index 7a8eec669d40..8073a75e14ca 100644 --- a/tests/core/test_exp_manager.py +++ b/tests/core/test_exp_manager.py @@ -27,6 +27,7 @@ from nemo.constants import NEMO_ENV_VARNAME_VERSION from nemo.core.classes import ModelPT +from nemo.utils.callbacks import NeMoModelCheckpoint from nemo.utils.exp_manager import ( CheckpointMisconfigurationError, LoggerMisconfigurationError, @@ -628,3 +629,325 @@ class CustomLoop(_TrainingEpochLoop): trainer.fit_loop.epoch_loop = loop with pytest.warns(UserWarning, match="Detected custom epoch loop"): exp_manager(trainer, {"explicit_log_dir": str(tmp_path)}) + + def _write_fake_checkpoint(self, path, isdir, add_unfinished_marker): + path = Path(path) + if isdir: + # fake distributed checkpoint + path.mkdir(parents=True, exist_ok=True) + (path / "dummy.txt").touch() + else: + # fake checkpoint file + path.parent.mkdir(parents=True, exist_ok=True) + path.touch() + if add_unfinished_marker: + NeMoModelCheckpoint.set_checkpoint_unfinished_marker(path) + + @pytest.mark.unit + def test_skipped_unfinished_checkpoints_when_restoring(self, tmp_path): + """ + Check if unfinished checkpoints are skipped during last checkpoint lookup. + Logic of the test: + - write multiple last checkpoints, some of them incomplete + - ensure that the last complete checkpoint is found + """ + + test_dir = tmp_path / "test" + checkpoints_dir = test_dir / "checkpoints" + + self._write_fake_checkpoint( + checkpoints_dir / "megatron_gpt--val_loss=5.01-step=900-consumed_samples=1000.0.ckpt", + isdir=False, + add_unfinished_marker=False, + ) # not last + self._write_fake_checkpoint( + checkpoints_dir / "megatron_gpt--val_loss=5.01-step=900-consumed_samples=1000.0-last.ckpt", + isdir=False, + add_unfinished_marker=True, + ) # incomplete + self._write_fake_checkpoint( + checkpoints_dir + / "mp_rank_00" + / "megatron_gpt--val_loss=5.01-step=1100-consumed_samples=17600.0-last.ckpt", + isdir=False, + add_unfinished_marker=True, + ) # incomplete + self._write_fake_checkpoint( + checkpoints_dir + / "mp_rank_01" + / "megatron_gpt--val_loss=5.01-step=1100-consumed_samples=17600.0-last.ckpt", + isdir=False, + add_unfinished_marker=True, + ) # incomplete + self._write_fake_checkpoint( + checkpoints_dir + / "mp_rank_00" + / "megatron_gpt--val_loss=5.01-step=1000-consumed_samples=16000.0-last.ckpt", + isdir=False, + add_unfinished_marker=False, + ) # ok + self._write_fake_checkpoint( + checkpoints_dir + / "mp_rank_01" + / "megatron_gpt--val_loss=5.01-step=1000-consumed_samples=16000.0-last.ckpt", + isdir=False, + add_unfinished_marker=False, + ) # ok + + restored_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False) + exp_manager( + restored_trainer, {"resume_if_exists": True, "explicit_log_dir": str(test_dir)}, + ) + + # Check that last complete (w/o unifinished marker) checkpoint was found + assert ( + Path(restored_trainer.ckpt_path).name + == 'megatron_gpt--val_loss=5.01-step=1000-consumed_samples=16000.0-last.ckpt' + ) + + @pytest.mark.unit + def test_skipped_unfinished_dist_checkpoints_when_restoring(self, tmp_path): + """ + Check if unfinished distributed checkpoints are skipped during last checkpoint lookup. + Logic of the test: + - write multiple last checkpoints, some of them incomplete + - ensure that the last complete checkpoint is found + """ + + test_dir = tmp_path / "test" + checkpoints_dir = test_dir / "checkpoints" + + self._write_fake_checkpoint( + checkpoints_dir / "megatron_gpt--val_loss=5.01-step=1000-consumed_samples=16000.0", + isdir=True, + add_unfinished_marker=False, + ) + self._write_fake_checkpoint( + checkpoints_dir / "megatron_gpt--val_loss=5.01-step=1000-consumed_samples=16000.0-last", + isdir=True, + add_unfinished_marker=False, + ) + self._write_fake_checkpoint( + checkpoints_dir / "megatron_gpt--val_loss=5.01-step=1100-consumed_samples=17600.0", + isdir=True, + add_unfinished_marker=False, + ) + self._write_fake_checkpoint( + checkpoints_dir / "megatron_gpt--val_loss=5.01-step=1100-consumed_samples=17600.0-last", + isdir=True, + add_unfinished_marker=True, + ) + + restored_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False) + exp_manager( + restored_trainer, {"resume_if_exists": True, "explicit_log_dir": str(test_dir)}, + ) + + # Check that last complete (w/o unifinished marker) checkpoint was found + assert ( + Path(restored_trainer.ckpt_path).name + == 'megatron_gpt--val_loss=5.01-step=1000-consumed_samples=16000.0-last' + ) + + @pytest.mark.unit + def test_incomplete_checkpoints_cleanup(self, tmp_path): + """ + Check if unfinished checkpoints are cleaned up when training starts + Complete checkpoints should be left intact. + """ + test_dir = tmp_path / "test" + checkpoints_dir = test_dir / "checkpoints" + + complete_ckpts = { + checkpoints_dir / "step=1-epoch=0.ckpt", + checkpoints_dir / "step=2-epoch=0-last.ckpt", + checkpoints_dir / "mp_rank_00" / "step=3-epoch=0-last.ckpt", + checkpoints_dir / "tp_rank_00_pp_rank_000" / "step=4-epoch=0-last.ckpt", + checkpoints_dir / "tp_rank_00_pp_rank_001" / "step=4-epoch=0-last.ckpt", + } + for ckpt_filepath in complete_ckpts: + self._write_fake_checkpoint(ckpt_filepath, isdir=False, add_unfinished_marker=False) + + incomplete_ckpts = { + checkpoints_dir / "step=11-epoch=1.ckpt", + checkpoints_dir / "step=12-epoch=1-last.ckpt", + checkpoints_dir / "mp_rank_00" / "step=13-epoch=1-last.ckpt", + checkpoints_dir / "tp_rank_00_pp_rank_000" / "step=14-epoch=1-last.ckpt", + checkpoints_dir / "tp_rank_00_pp_rank_001" / "step=14-epoch=1-last.ckpt", + } + for ckpt_filepath in incomplete_ckpts: + self._write_fake_checkpoint(ckpt_filepath, isdir=False, add_unfinished_marker=True) + + # sanity check + remaining_ckpts = {f for f in (test_dir / "checkpoints").rglob("*.ckpt") if f.is_file()} + assert remaining_ckpts == (complete_ckpts | incomplete_ckpts) + + # marker without corresponding checkpoint should be removed during cleanup in exp_manager + (checkpoints_dir / f"orphan-marker001-{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}").touch() + + # unfinished checkpoint with EMA part, both parts should be removed + self._write_fake_checkpoint( + checkpoints_dir / "incomplete01-EMA.ckpt", isdir=False, add_unfinished_marker=False, + ) + self._write_fake_checkpoint(checkpoints_dir / "incomplete01.ckpt", isdir=False, add_unfinished_marker=True) + + # just EMA part - should be removed. NOTE marker path is the same for base part and for EMA part + self._write_fake_checkpoint( + checkpoints_dir / "incomplete02-EMA.ckpt", isdir=False, add_unfinished_marker=False, + ) + (checkpoints_dir / f"incomplete02{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}").touch() + + test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=1) + + exp_manager( + test_trainer, + {"checkpoint_callback_params": {"save_top_k": 0, "save_last": False}, "explicit_log_dir": str(test_dir),}, + ) + + model = ExampleModel() + test_trainer.fit(model) + + remaining_ckpts = {f for f in (test_dir / "checkpoints").rglob("*.ckpt") if f.is_file()} + assert remaining_ckpts == complete_ckpts + remaining_markers = list(checkpoints_dir.rglob(f"*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}")) + assert remaining_markers == [] + + @pytest.mark.unit + def test_incomplete_dist_checkpoints_cleanup(self, tmp_path): + """ + Check if unfinished distributed checkpoints are cleaned up when training starts. + Complete distributed checkpoints should be left intact. + """ + + test_dir = tmp_path / "test" + checkpoints_dir = test_dir / "checkpoints" + + complete_dist_ckpts = { + checkpoints_dir / "step=5-epoch=0", + checkpoints_dir / "step=6-epoch=0-last", + } + for ckpt_dirpath in complete_dist_ckpts: + self._write_fake_checkpoint(ckpt_dirpath, isdir=True, add_unfinished_marker=False) + + incomplete_dist_ckpts = { + checkpoints_dir / "step=15-epoch=1", + checkpoints_dir / "step=16-epoch=1-last", + } + for ckpt_dirpath in incomplete_dist_ckpts: + self._write_fake_checkpoint(ckpt_dirpath, isdir=True, add_unfinished_marker=True) + + # marker without corresponding checkpoint should be removed during cleanup in exp_manager + (checkpoints_dir / f"orphan-marker001-{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}").touch() + + remaining_dist_ckpts = {f for f in (test_dir / "checkpoints").glob("*") if f.is_dir()} + assert remaining_dist_ckpts == (complete_dist_ckpts | incomplete_dist_ckpts) + + test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=1) + + exp_manager( + test_trainer, + {"checkpoint_callback_params": {"save_top_k": 0, "save_last": False}, "explicit_log_dir": str(test_dir),}, + ) + + model = ExampleModel() + test_trainer.fit(model) + + remaining_dist_ckpts = {f for f in (test_dir / "checkpoints").glob("*") if f.is_dir()} + assert remaining_dist_ckpts == complete_dist_ckpts + remaining_markers = list(checkpoints_dir.rglob(f"*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}")) + assert remaining_markers == [] + + _chkpt_path_and_marker_path_pairs = [ + ('a=1_b=1.c.d.e', f'a=1_b=1.c.d.e{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}'), + ('a=1_b=1.c.d.e-last', f'a=1_b=1.c.d.e-last{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}'), + ('.ckpt/a=1_b=1.c.d.e.ckpt', f'.ckpt/a=1_b=1.c.d.e{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}'), + ('.ckpt/a=1_b=1.c.d.e-EMA.ckpt', f'.ckpt/a=1_b=1.c.d.e{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}'), + ( + '.ckpt/a=1_b=1.c.d.e-last.ckpt', + f'.ckpt/a=1_b=1.c.d.e-last{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}', + ), + ( + '/tmp/mp_rank_00/a=1_b=1.c.d.e.ckpt', + f'/tmp/a=1_b=1.c.d.e{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}', + ), + ( + '/tmp/tp_rank_00_pp_rank_000/a=1_b=1.c.d.e.ckpt', + f'/tmp/a=1_b=1.c.d.e{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}', + ), + ('nemo/a=1_b=1.c.d.e.nemo', f'nemo/a=1_b=1.c.d.e{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}'), + ('nemo/a=1_b=1.c.d.e-last.nemo', f'nemo/a=1_b=1.c.d.e-last{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}'), + ] + + @pytest.mark.unit + @pytest.mark.parametrize("chkpt_path, expected_marker_path", _chkpt_path_and_marker_path_pairs) + def test_incomplete_checkpoints_marker_path(self, chkpt_path, expected_marker_path): + """ + Ensure that unfinished checkpoint marker path is correctly formed. + """ + marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(chkpt_path) + assert str(marker_path) == str(expected_marker_path) + + @pytest.mark.unit + def test_invalid_checkpoints_removed_from_topk(self, tmp_path): + """ + Ensure that invalid (unfinished, deleted) checkpoints are removed from topk when resuming. + - Do few training steps and save checkpoints + - Delete some checkpoints, mark some as unfinished + - Resume training and verify that topk checkpoints are correct + """ + test_dir = tmp_path / "test" + checkpoints_dir = test_dir / "checkpoints" + + test_trainer = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=7) + exp_manager( + test_trainer, + { + "checkpoint_callback_params": { + "save_top_k": 3, + "save_last": True, + "mode": 'max', + "monitor": 'epoch', + "filename": f"{{epoch}}", + }, + "explicit_log_dir": str(tmp_path / "test"), + }, + ) + model = ExampleModel() + test_trainer.fit(model) + + ckpt_filenames = {f.name for f in checkpoints_dir.rglob("*.ckpt") if f.is_file()} + assert len(ckpt_filenames) == 4 # 3 top + 1 last + assert 'epoch=7-last.ckpt' in ckpt_filenames + assert 'epoch=6.ckpt' in ckpt_filenames + assert 'epoch=5.ckpt' in ckpt_filenames + assert 'epoch=4.ckpt' in ckpt_filenames + + # Mark 6th epoch checkpoint as unfinished and remove 5th epoch checkpoint, + # so last valid candidate for topk is 4th epoch checkpoint + NeMoModelCheckpoint.set_checkpoint_unfinished_marker(checkpoints_dir / 'epoch=6.ckpt') + (checkpoints_dir / 'epoch=5.ckpt').unlink() + + test_trainer2 = pl.Trainer(accelerator='cpu', enable_checkpointing=False, logger=False, max_epochs=9) + exp_manager( + test_trainer2, + { + "resume_if_exists": True, + "checkpoint_callback_params": { + "save_top_k": 3, + "save_last": True, + "mode": 'max', + "monitor": 'epoch', + "filename": f"{{epoch}}", + }, + "explicit_log_dir": str(tmp_path / "test"), + }, + ) + model = ExampleModel() + test_trainer2.fit(model) + + ckpt_filenames = {f.name for f in checkpoints_dir.rglob("*.ckpt") if f.is_file()} + assert len(ckpt_filenames) == 4 # 3 top + 1 last + assert 'epoch=9-last.ckpt' in ckpt_filenames + assert 'epoch=8.ckpt' in ckpt_filenames + assert 'epoch=7.ckpt' in ckpt_filenames + assert 'epoch=4.ckpt' in ckpt_filenames