Skip to content

Commit

Permalink
Unfinished checkpoints handling (#7952)
Browse files Browse the repository at this point in the history
* Unfinished checkpoints handling + tests

Signed-off-by: Jacek Bieniusiewicz <[email protected]>

* Fixed EMA checkpoint tests

Signed-off-by: Jacek Bieniusiewicz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes after review

Signed-off-by: Jacek Bieniusiewicz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Merged latest main

Signed-off-by: Jacek Bieniusiewicz <[email protected]>

* Removed not used barrier_before and barrier_after params

Signed-off-by: Jacek Bieniusiewicz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cosmetic change: removed redundant commas

Signed-off-by: Jacek Bieniusiewicz <[email protected]>

---------

Signed-off-by: Jacek Bieniusiewicz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Abhishree Thittenamane <[email protected]>
  • Loading branch information
3 people authored Feb 8, 2024
1 parent 2bc2e97 commit 5a65505
Show file tree
Hide file tree
Showing 3 changed files with 498 additions and 4 deletions.
163 changes: 160 additions & 3 deletions nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class NeMoModelCheckpoint(ModelCheckpoint):
Also contains func to save the EMA copy of the model.
"""

UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"

def __init__(
self,
always_save_nemo: bool = False,
Expand Down Expand Up @@ -139,6 +141,44 @@ def nemo_topk_check_previous_run(self):
self.best_model_path = best_k_models[0]
self.best_model_score = self.best_k_models[self.best_model_path]

def _remove_invalid_entries_from_topk(self):
# Removes invalid (incomplete or not existing) checkpoints from topk checkpoints.
# This might be needed if the checkpointing was abruptly terminated.
def __is_ckpt_ok(ckpt_path: str) -> bool:
exists = (
os.path.isfile(ckpt_path)
or os.path.isfile(inject_model_parallel_rank(ckpt_path))
or os.path.isdir(ckpt_path.removesuffix('.ckpt'))
)
return exists and not self.is_checkpoint_unfinished(ckpt_path)

self.best_k_models = {k: v for k, v in self.best_k_models.items() if __is_ckpt_ok(k)}
if len(self.best_k_models) > 0:
reverse_arr = self.mode != "min"
best_k_models_arr = sorted(self.best_k_models, key=self.best_k_models.get, reverse=reverse_arr)
self.kth_best_model_path = best_k_models_arr[-1]
self.kth_value = self.best_k_models[self.kth_best_model_path]
self.best_model_path = best_k_models_arr[0]
self.best_model_score = self.best_k_models[self.best_model_path]
else:
self.kth_best_model_path = ""
self.kth_value = None
self.best_model_path = ""
self.best_model_score = None

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
super().load_state_dict(state_dict)
self._remove_invalid_entries_from_topk()

def setup(self, *args, **kwargs) -> None:
if is_global_rank_zero():
logging.debug("Removing unfinished checkpoints if any...")
NeMoModelCheckpoint._remove_unfinished_checkpoints(self.dirpath)
# Ensure that all ranks continue with unfinished checkpoints removed
if torch.distributed.is_initialized():
torch.distributed.barrier()
super().setup(*args, **kwargs)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
if not self.always_save_nemo:
Expand Down Expand Up @@ -257,7 +297,77 @@ def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]:
ema_callback = callback
return ema_callback

@staticmethod
def format_checkpoint_unfinished_marker_path(checkpoint_path: Union[Path, str]) -> Path:
""" Format the path to the unfinished checkpoint marker file.
If the marker file exists, corresponding checkpoint is considered unfinished/incomplete.
NOTE: Marker path for the EMA checkpoint part is the same as for the original checkpoint.
Args:
checkpoint_path: Path to the checkpoint file or dir.
Does not need to exist.
Returns:
Path to the unfinished checkpoint marker file.
"""
marker_filepath = str(uninject_model_parallel_rank(checkpoint_path))
marker_filepath = marker_filepath.removesuffix(".nemo")
marker_filepath = marker_filepath.removesuffix(".ckpt")
marker_filepath = marker_filepath.removesuffix("-EMA")
return Path(marker_filepath + NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX)

@staticmethod
def is_checkpoint_unfinished(checkpoint_path: Union[Path, str]) -> bool:
""" Check if the checkpoint is unfinished.
Args:
checkpoint_path: Path to the checkpoint file or dir.
Does not need to exist.
Returns:
True if the checkpoint is unfinished, False otherwise.
"""
return NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path).exists()

@staticmethod
def set_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_after=False) -> None:
""" Marks given checkpoint as unfinished.
Args:
checkpoint_filepath: Path to the checkpoint file or dir.
Does not need to exist.
barrier_after: Synchronize ranks after writing the marker file.
Defaults to False.
"""
if is_global_rank_zero():
marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path)
marker_path.parent.mkdir(parents=True, exist_ok=True)
marker_path.touch()
if barrier_after and torch.distributed.is_initialized():
torch.distributed.barrier()

@staticmethod
def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barrier_before=False) -> None:
"""Clear unfinished marker for given checkpoint.
Args:
checkpoint_path: Path to the checkpoint file or dir.
Does not need to exist.
barrier_before: Synchronize ranks before removing the marker file.
Defaults to False.
"""
if barrier_before and torch.distributed.is_initialized():
torch.distributed.barrier()
if is_global_rank_zero():
marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(checkpoint_path)
if marker_path.exists():
marker_path.unlink()

def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None:
# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during checkpointing, we should be able to detect that data is incomplete.
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
ema_callback = self._ema_callback(trainer)
if ema_callback is not None:
with ema_callback.save_original_optimizer_state(trainer):
Expand All @@ -271,14 +381,23 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
super()._save_checkpoint(trainer, filepath)
else:
super()._save_checkpoint(trainer, filepath)
# barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker
# we don't want to remove the marker until all checkpointing is done.
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)

def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str) -> None:
# barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed.
# if anything goes wrong during removal, we should be able to detect that data is incomplete.
self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
super()._remove_checkpoint(trainer, filepath)
ema_callback = self._ema_callback(trainer)
if ema_callback is not None:
# remove EMA copy of the state dict as well.
filepath = self._ema_format_filepath(filepath)
super()._remove_checkpoint(trainer, filepath)
# barrier_before=True, so all ranks synchronize before removing the unfinished checkpoint marker
# we don't want to remove the marker until the checkpoint is actually removed.
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)

def _ema_format_filepath(self, filepath: str) -> str:
return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}')
Expand All @@ -292,8 +411,46 @@ def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool:
@property
def _saved_checkpoint_paths(self) -> Iterable[Path]:
# distributed checkpoints are directories so we check for them here
dist_checkpoints = [d for d in list(Path(self.dirpath).glob("*")) if d.is_dir()]
# we filter out unfinished checkpoints, these should be deleted during next cleanup
dist_checkpoints = [d for d in Path(self.dirpath).glob("*") if d.is_dir()]
if dist_checkpoints:
return dist_checkpoints
return filter(lambda p: not self.is_checkpoint_unfinished(p), dist_checkpoints)
else:
return Path(self.dirpath).rglob("*.ckpt")
checkpoint_files = [f for f in Path(self.dirpath).rglob("*.ckpt")]
return filter(lambda p: not self.is_checkpoint_unfinished(p), checkpoint_files)

@staticmethod
def _remove_unfinished_checkpoints(checkpoint_dir: Union[Path, str]) -> None:

# Delete unfinished checkpoints from the filesystems.
# "Unfinished marker" files are removed as well.

if not is_global_rank_zero():
raise AssertionError("_remove_unfinished_checkpoints should run only on rank 0")

checkpoint_dir = Path(checkpoint_dir)

existing_marker_filepaths = {
f.resolve()
for f in checkpoint_dir.glob(f"*{NeMoModelCheckpoint.UNFINISHED_CHECKPOINT_SUFFIX}")
if f.is_file()
}

checkpoint_filepaths = {f.resolve() for f in checkpoint_dir.rglob("*.ckpt")}
for ckpt_filepath in checkpoint_filepaths:
possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_filepath)
if possible_marker_path in existing_marker_filepaths:
logging.warning(f'Removing unfinished checkpoint: {ckpt_filepath}')
os.remove(ckpt_filepath)

# some directories might be distributed checkpoints, we remove these if they have a unfinished marker
all_dirpaths = {d.resolve() for d in checkpoint_dir.glob("*") if d.is_dir()}
for ckpt_dirpath in all_dirpaths:
possible_marker_path = NeMoModelCheckpoint.format_checkpoint_unfinished_marker_path(ckpt_dirpath)
if possible_marker_path in existing_marker_filepaths:
logging.warning(f'Removing unfinished dist checkpoint: {ckpt_dirpath}')
shutil.rmtree(ckpt_dirpath)

# delete markers
for marker_path in existing_marker_filepaths:
os.remove(marker_path)
16 changes: 15 additions & 1 deletion nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from datetime import timedelta
from pathlib import Path
from shutil import copy, move
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Collection, Dict, List, Optional, Tuple, Union

import pytorch_lightning
import torch
Expand Down Expand Up @@ -564,6 +564,18 @@ def error_checks(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictC
)


def _filter_out_unfinished_checkpoints(checkpoint_paths: Collection[Union[Path, str]]) -> Collection[Union[Path, str]]:
res = []
for chkpt_path in checkpoint_paths:
if NeMoModelCheckpoint.is_checkpoint_unfinished(chkpt_path):
logging.warning(
f'Checkpoint {chkpt_path} has the unfinished marker set - skipped while looking for the last one.'
)
else:
res.append(chkpt_path)
return res


def check_resume(
trainer: 'pytorch_lightning.Trainer',
log_dir: str,
Expand Down Expand Up @@ -604,7 +616,9 @@ def check_resume(
last_dist_checkpoints = [d for d in dist_checkpoints if d.match("*last")]

end_checkpoints = end_dist_checkpoints if end_dist_checkpoints else list(checkpoint_dir.rglob("*end.ckpt"))
end_checkpoints = _filter_out_unfinished_checkpoints(end_checkpoints)
last_checkpoints = last_dist_checkpoints if last_dist_checkpoints else list(checkpoint_dir.rglob("*last.ckpt"))
last_checkpoints = _filter_out_unfinished_checkpoints(last_checkpoints)

if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0):
if resume_ignore_no_checkpoint:
Expand Down
Loading

0 comments on commit 5a65505

Please sign in to comment.