From 6ebcfaa38f2c1a7962e9490d0419c2efbf03d07b Mon Sep 17 00:00:00 2001 From: Abhishree Date: Tue, 7 Mar 2023 20:06:28 +0000 Subject: [PATCH 01/11] Adding basic preemption code Add preemption functionality in preemption_callback.py under utils Refactor the code to move NemoModelCheckpoint callback under callbacks Signed-off-by: Abhishree --- nemo/collections/common/callbacks/__init__.py | 1 + .../common/callbacks/nemomodelcheckpoint.py | 214 ++++++++++++++++ nemo/utils/exp_manager.py | 239 +----------------- nemo/utils/preemption_callback.py | 64 +++++ 4 files changed, 284 insertions(+), 234 deletions(-) create mode 100644 nemo/collections/common/callbacks/nemomodelcheckpoint.py create mode 100644 nemo/utils/preemption_callback.py diff --git a/nemo/collections/common/callbacks/__init__.py b/nemo/collections/common/callbacks/__init__.py index 0cf495d94696..13d956a937ab 100644 --- a/nemo/collections/common/callbacks/__init__.py +++ b/nemo/collections/common/callbacks/__init__.py @@ -14,3 +14,4 @@ from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback from nemo.collections.common.callbacks.ema import EMA +from nemo.collections.common.callbacks.nemomodelcheckpoint import NeMoModelCheckpoint diff --git a/nemo/collections/common/callbacks/nemomodelcheckpoint.py b/nemo/collections/common/callbacks/nemomodelcheckpoint.py new file mode 100644 index 000000000000..fa0151422e3c --- /dev/null +++ b/nemo/collections/common/callbacks/nemomodelcheckpoint.py @@ -0,0 +1,214 @@ +import os +import re +from copy import deepcopy +from pathlib import Path +from typing import Optional + +import pytorch_lightning +import torch +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_info + +from nemo.collections.common.callbacks import EMA +from nemo.utils import logging +from nemo.utils.app_state import AppState +from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank +from nemo.utils.get_rank import is_global_rank_zero + +class NeMoModelCheckpoint(ModelCheckpoint): + """ Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end + """ + + def __init__( + self, + always_save_nemo: bool = False, + save_nemo_on_train_end: bool = True, + save_best_model: bool = False, + postfix: str = ".nemo", + n_resume: bool = False, + model_parallel_size: int = None, + **kwargs, + ): + # Parse and store "extended" parameters: save_best model and postfix. + self.always_save_nemo = always_save_nemo + self.save_nemo_on_train_end = save_nemo_on_train_end + self.save_best_model = save_best_model + if self.save_best_model and not self.save_nemo_on_train_end: + logging.warning( + ( + "Found save_best_model is True and save_nemo_on_train_end is False. " + "Set save_nemo_on_train_end to True to automatically save the best model." + ) + ) + self.postfix = postfix + self.previous_best_path = "" + self.model_parallel_size = model_parallel_size + + # `prefix` is deprecated + if 'prefix' in kwargs: + self.prefix = kwargs.pop('prefix') + else: + self.prefix = "" + + # Call the parent class constructor with the remaining kwargs. + super().__init__(**kwargs) + + if self.save_top_k != -1 and n_resume: + logging.debug("Checking previous runs") + self.nemo_topk_check_previous_run() + + def nemo_topk_check_previous_run(self): + try: + self.best_k_models + self.kth_best_model_path + self.best_model_score + self.best_model_path + except AttributeError: + raise AttributeError("Lightning's ModelCheckpoint was updated. NeMoModelCheckpoint will need an update.") + self.best_k_models = {} + self.kth_best_model_path = "" + self.best_model_score = None + self.best_model_path = "" + + checkpoints = list(Path(self.dirpath).rglob("*.ckpt")) + for checkpoint in checkpoints: + if 'mp_rank' in str(checkpoint) or 'tp_rank' in str(checkpoint): + checkpoint = uninject_model_parallel_rank(checkpoint) + checkpoint = str(checkpoint) + if checkpoint[-10:] == '-last.ckpt': + continue + index = checkpoint.find(self.monitor) + len(self.monitor) + 1 # Find monitor in str + 1 for '=' + if index != -1: + match = re.search('[A-z]', checkpoint[index:]) + if match: + value = checkpoint[index : index + match.start() - 1] # -1 due to separator hypen + self.best_k_models[checkpoint] = float(value) + if len(self.best_k_models) < 1: + return # No saved checkpoints yet + + _reverse = False if self.mode == "min" else True + + best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse) + + ### This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are + ### instantiated after rank zero. models_to_delete should be 0 for all other ranks. + if self.model_parallel_size is not None: + models_to_delete = len(best_k_models) - self.model_parallel_size * self.save_top_k + else: + models_to_delete = len(best_k_models) - self.save_top_k + logging.debug(f'Number of models to delete: {models_to_delete}') + for _ in range(models_to_delete): + model = best_k_models.pop(-1) + self.best_k_models.pop(model) + self._del_model_without_trainer(model) + logging.debug(f"Removed checkpoint: {model}") + + self.kth_best_model_path = best_k_models[-1] + self.best_model_path = best_k_models[0] + self.best_model_score = self.best_k_models[self.best_model_path] + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + # output = None + output = super().on_save_checkpoint(trainer, pl_module, checkpoint) + if not self.always_save_nemo: + return output + else: + # Load the best model and then re-save it + app_state = AppState() + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + raise ValueError(f'always_save_nemo is not implemented for model parallel models.') + # since we are creating tarfile artifacts we need to update .nemo path + app_state.model_restore_path = os.path.abspath( + os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix)) + ) + if self.save_best_model: + if not os.path.exists(self.best_model_path): + return output + + if self.best_model_path == self.previous_best_path: + return output + + self.previous_model_path = self.best_model_path + old_state_dict = deepcopy(pl_module.state_dict()) + checkpoint = torch.load(self.best_model_path, map_location='cpu') + if 'state_dict' in checkpoint: + checkpoint = checkpoint['state_dict'] + # get a new instanace of the model + pl_module.load_state_dict(checkpoint, strict=True) + pl_module.save_to(save_path=app_state.model_restore_path) + pl_module.load_state_dict(old_state_dict, strict=True) + else: + pl_module.save_to(save_path=app_state.model_restore_path) + return output + + def on_train_end(self, trainer, pl_module): + if trainer.fast_dev_run: + return None + + # check if we need to save a last checkpoint manually as validation isn't always run based on the interval + if self.save_last and trainer.val_check_interval != 0: + should_save_last_checkpoint = False + if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0: + should_save_last_checkpoint = True + if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0: + should_save_last_checkpoint = True + if should_save_last_checkpoint: + monitor_candidates = self._monitor_candidates(trainer) + super()._save_last_checkpoint(trainer, monitor_candidates) + # Call parent on_train_end() to save the -last checkpoint + super().on_train_end(trainer, pl_module) + + # Load the best model and then re-save it + if self.save_best_model: + # wait for all processes + trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end") + if self.best_model_path == "": + logging.warning( + f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints " + "were found. Saving latest model instead." + ) + else: + self.best_model_path = trainer.strategy.broadcast(self.best_model_path) + trainer._checkpoint_connector.restore(self.best_model_path) + + if self.save_nemo_on_train_end: + pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix)) + + def _del_model_without_trainer(self, filepath: str) -> None: + app_state = AppState() + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + # filepath needs to be updated to include mp_rank + filepath = inject_model_parallel_rank(filepath) + + # each model parallel rank needs to remove its model + if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0): + try: + self._fs.rm(filepath) + logging.info(f"Removed checkpoint: {filepath}") + except: + logging.info(f"Tried to remove checkpoint: {filepath} but failed.") + + def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]: + ema_callback = None + for callback in trainer.callbacks: + if isinstance(callback, EMA): + ema_callback = callback + return ema_callback + + def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + ema_callback = self._ema_callback(trainer) + if ema_callback is not None: + with ema_callback.save_original_optimizer_state(trainer): + super()._save_checkpoint(trainer, filepath) + + # save EMA copy of the model as well. + with ema_callback.save_ema_model(trainer): + filepath = self._ema_format_filepath(filepath) + if self.verbose: + rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") + super()._save_checkpoint(trainer, filepath) + else: + super()._save_checkpoint(trainer, filepath) + + def _ema_format_filepath(self, filepath: str) -> str: + return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 64ee66fb5d80..bfedb712014a 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -39,7 +39,7 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.utilities import rank_zero_info -from nemo.collections.common.callbacks import EMA +from nemo.collections.common.callbacks import EMA, NeMoModelCheckpoint from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION from nemo.utils import logging, timers from nemo.utils.app_state import AppState @@ -49,6 +49,7 @@ from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger from nemo.utils.loggers import ClearMLLogger, ClearMLParams, DLLogger, DLLoggerParams, MLFlowParams from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank +from nemo.utils.preemption_callback import PreemptionCallback class NotFoundError(NeMoBaseException): @@ -831,238 +832,6 @@ def configure_loggers( trainer._logger_connector.configure_logger(logger_list) -class NeMoModelCheckpoint(ModelCheckpoint): - """ Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end - """ - - def __init__( - self, - always_save_nemo: bool = False, - save_nemo_on_train_end: bool = True, - save_best_model: bool = False, - postfix: str = ".nemo", - n_resume: bool = False, - model_parallel_size: int = None, - **kwargs, - ): - # Parse and store "extended" parameters: save_best model and postfix. - self.always_save_nemo = always_save_nemo - self.save_nemo_on_train_end = save_nemo_on_train_end - self.save_best_model = save_best_model - if self.save_best_model and not self.save_nemo_on_train_end: - logging.warning( - ( - "Found save_best_model is True and save_nemo_on_train_end is False. " - "Set save_nemo_on_train_end to True to automatically save the best model." - ) - ) - self.postfix = postfix - self.previous_best_path = "" - self.model_parallel_size = model_parallel_size - - # `prefix` is deprecated - if 'prefix' in kwargs: - self.prefix = kwargs.pop('prefix') - else: - self.prefix = "" - - # Call the parent class constructor with the remaining kwargs. - super().__init__(**kwargs) - - if self.save_top_k != -1 and n_resume: - logging.debug("Checking previous runs") - self.nemo_topk_check_previous_run() - - def nemo_topk_check_previous_run(self): - try: - self.best_k_models - self.kth_best_model_path - self.best_model_score - self.best_model_path - except AttributeError: - raise AttributeError("Lightning's ModelCheckpoint was updated. NeMoModelCheckpoint will need an update.") - self.best_k_models = {} - self.kth_best_model_path = "" - self.best_model_score = None - self.best_model_path = "" - - checkpoints = list(path for path in self._saved_checkpoint_paths if not self._is_ema_filepath(path)) - for checkpoint in checkpoints: - if 'mp_rank' in str(checkpoint) or 'tp_rank' in str(checkpoint): - checkpoint = uninject_model_parallel_rank(checkpoint) - checkpoint = str(checkpoint) - if checkpoint[-10:] == '-last.ckpt': - continue - index = checkpoint.find(self.monitor) + len(self.monitor) + 1 # Find monitor in str + 1 for '=' - if index != -1: - match = re.search('[A-z]', checkpoint[index:]) - if match: - value = checkpoint[index : index + match.start() - 1] # -1 due to separator hypen - self.best_k_models[checkpoint] = float(value) - if len(self.best_k_models) < 1: - return # No saved checkpoints yet - - _reverse = False if self.mode == "min" else True - - best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse) - - ### This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are - ### instantiated after rank zero. models_to_delete should be 0 for all other ranks. - if self.model_parallel_size is not None: - models_to_delete = len(best_k_models) - self.model_parallel_size * self.save_top_k - else: - models_to_delete = len(best_k_models) - self.save_top_k - logging.debug(f'Number of models to delete: {models_to_delete}') - - # If EMA enabled, delete the additional EMA weights - ema_enabled = self._has_ema_ckpts(self._saved_checkpoint_paths) - - for _ in range(models_to_delete): - model = best_k_models.pop(-1) - self.best_k_models.pop(model) - self._del_model_without_trainer(model) - if ema_enabled and self._fs.exists(self._ema_format_filepath(model)): - self._del_model_without_trainer(self._ema_format_filepath(model)) - logging.debug(f"Removed checkpoint: {model}") - - self.kth_best_model_path = best_k_models[-1] - self.best_model_path = best_k_models[0] - self.best_model_score = self.best_k_models[self.best_model_path] - - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - output = super().on_save_checkpoint(trainer, pl_module, checkpoint) - if not self.always_save_nemo: - return output - # Load the best model and then re-save it - app_state = AppState() - if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: - logging.warning(f'always_save_nemo will slow down training for model_parallel > 1.') - # since we are creating tarfile artifacts we need to update .nemo path - app_state.model_restore_path = os.path.abspath( - os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix)) - ) - if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: - maybe_injected_best_model_path = inject_model_parallel_rank(self.best_model_path) - else: - maybe_injected_best_model_path = self.best_model_path - - if self.save_best_model: - if not os.path.exists(maybe_injected_best_model_path): - return - - if self.best_model_path == self.previous_best_path: - return output - - self.previous_model_path = self.best_model_path - old_state_dict = deepcopy(pl_module.state_dict()) - checkpoint = torch.load(maybe_injected_best_model_path, map_location='cpu') - if 'state_dict' in checkpoint: - checkpoint = checkpoint['state_dict'] - # get a new instanace of the model - pl_module.load_state_dict(checkpoint, strict=True) - if torch.distributed.is_initialized(): - torch.distributed.barrier() - pl_module.save_to(save_path=app_state.model_restore_path) - logging.info(f"New best .nemo model saved to: {app_state.model_restore_path}") - pl_module.load_state_dict(old_state_dict, strict=True) - else: - if torch.distributed.is_initialized(): - torch.distributed.barrier() - pl_module.save_to(save_path=app_state.model_restore_path) - logging.info(f"New .nemo model saved to: {app_state.model_restore_path}") - return output - - def on_train_end(self, trainer, pl_module): - if trainer.fast_dev_run: - return None - - # check if we need to save a last checkpoint manually as validation isn't always run based on the interval - if self.save_last and trainer.val_check_interval != 0: - should_save_last_checkpoint = False - if isinstance(trainer.val_check_interval, float) and trainer.val_check_interval % trainer.global_step != 0: - should_save_last_checkpoint = True - if isinstance(trainer.val_check_interval, int) and trainer.global_step % trainer.val_check_interval != 0: - should_save_last_checkpoint = True - if should_save_last_checkpoint: - monitor_candidates = self._monitor_candidates(trainer) - super()._save_last_checkpoint(trainer, monitor_candidates) - # Call parent on_train_end() to save the -last checkpoint - super().on_train_end(trainer, pl_module) - - # Load the best model and then re-save it - if self.save_best_model: - # wait for all processes - trainer.strategy.barrier("SaveBestCheckpointConnector.resume_end") - if self.best_model_path == "": - logging.warning( - f"{self} was told to save the best checkpoint at the end of training, but no saved checkpoints " - "were found. Saving latest model instead." - ) - else: - self.best_model_path = trainer.strategy.broadcast(self.best_model_path) - trainer._checkpoint_connector.restore(self.best_model_path) - - if self.save_nemo_on_train_end: - pl_module.save_to(save_path=os.path.join(self.dirpath, self.prefix + self.postfix)) - - def _del_model_without_trainer(self, filepath: str) -> None: - app_state = AppState() - if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: - # filepath needs to be updated to include mp_rank - filepath = inject_model_parallel_rank(filepath) - - # each model parallel rank needs to remove its model - if is_global_rank_zero() or (app_state.model_parallel_size is not None and app_state.data_parallel_rank == 0): - try: - self._fs.rm(filepath) - logging.info(f"Removed checkpoint: {filepath}") - except: - logging.info(f"Tried to remove checkpoint: {filepath} but failed.") - - def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]: - ema_callback = None - for callback in trainer.callbacks: - if isinstance(callback, EMA): - ema_callback = callback - return ema_callback - - def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: - ema_callback = self._ema_callback(trainer) - if ema_callback is not None: - with ema_callback.save_original_optimizer_state(trainer): - super()._save_checkpoint(trainer, filepath) - - # save EMA copy of the model as well. - with ema_callback.save_ema_model(trainer): - filepath = self._ema_format_filepath(filepath) - if self.verbose: - rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") - super()._save_checkpoint(trainer, filepath) - else: - super()._save_checkpoint(trainer, filepath) - - def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str) -> None: - super()._remove_checkpoint(trainer, filepath) - ema_callback = self._ema_callback(trainer) - if ema_callback is not None: - # remove EMA copy of the state dict as well. - filepath = self._ema_format_filepath(filepath) - super()._remove_checkpoint(trainer, filepath) - - def _ema_format_filepath(self, filepath: str) -> str: - return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') - - def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool: - return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints) - - def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: - return str(filepath).endswith(f'-EMA{self.FILE_EXTENSION}') - - @property - def _saved_checkpoint_paths(self) -> Iterable[Path]: - return Path(self.dirpath).rglob("*.ckpt") - - def configure_checkpointing( trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', ): @@ -1122,7 +891,9 @@ def configure_checkpointing( if 'mp_rank' in checkpoint_callback.last_model_path or 'tp_rank' in checkpoint_callback.last_model_path: checkpoint_callback.last_model_path = uninject_model_parallel_rank(checkpoint_callback.last_model_path) trainer.callbacks.append(checkpoint_callback) - + # TODO create support_preemption arg, if True do the below + preemption_callback = PreemptionCallback(torch.device('cuda'), checkpoint_callback) + trainer.callbacks.append(preemption_callback) def check_slurm(trainer): try: diff --git a/nemo/utils/preemption_callback.py b/nemo/utils/preemption_callback.py new file mode 100644 index 000000000000..924728b3fcd7 --- /dev/null +++ b/nemo/utils/preemption_callback.py @@ -0,0 +1,64 @@ +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback, ModelCheckpoint +from nemo.collections.common.callbacks.nemomodelcheckpoint import NeMoModelCheckpoint +import signal +import torch +import sys + +class PreemptionCallback(Callback): + + def __init__(self, device, checkpoint_callback, sig=signal.SIGTERM): + self.sig = sig + self.device = device + self.checkpoint_callback = checkpoint_callback + + @property + def interrupted(self): + interrupted = torch.tensor(self._interrupted).int().to(self.device) + torch.distributed.broadcast(interrupted, 0) + interrupted = bool(interrupted.item()) + return interrupted + + def on_train_start(self, trainer, pl_module): + self._interrupted = False + self.released = False + self.original_handler = signal.getsignal(self.sig) + + def master_handler(signum, frame): + self.release() + self._interrupted = True + + def ignoring_handler(signum, frame): + self.release() + + self.private_rank = torch.distributed.get_rank() + if self.private_rank == 0: + signal.signal(self.sig, master_handler) + else: + signal.signal(self.sig, ignoring_handler) + + return self + + def on_train_end(self, trainer, pl_module): + self.release() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int): + # check if the job was preempted + # NOTE: "timeout_handler.interrupted" is a property which triggers a + # distributed broadcast of "_interrupted" flag from rank 0 to all other + # ranks, to avoid performance overheads it's best to store the result in + # a regular local variable + interrupted = self.interrupted + if interrupted: + print("Received SIGTERM, exiting") + monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer) + self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) + sys.exit(0) + + def release(self): + if self.released: + return False + + signal.signal(self.sig, self.original_handler) + self.released = True + return True From a80c4eb4f4e7a88b5645b5833814d02027a944c2 Mon Sep 17 00:00:00 2001 From: Abhishree Date: Thu, 16 Mar 2023 03:35:25 +0000 Subject: [PATCH 02/11] Adding the following modifications 1) Rename nemo/collections/common/callbacks/nemomodelcheckpoint.py to nemo/utils/callbacks/nemo_model_checkpoint.py 2) Rename nemo/utils/preemption_callback.py to nemo/utils/callbacks/preemption.py 3) Add docstrings, headers, logging and check for torch distributed Signed-off-by: Abhishree --- nemo/collections/common/callbacks/__init__.py | 1 - nemo/utils/callbacks/__init__.py | 16 ++++++++ .../callbacks/nemo_model_checkpoint.py} | 14 +++++++ .../preemption.py} | 38 ++++++++++++++++--- nemo/utils/exp_manager.py | 5 +-- 5 files changed, 64 insertions(+), 10 deletions(-) create mode 100644 nemo/utils/callbacks/__init__.py rename nemo/{collections/common/callbacks/nemomodelcheckpoint.py => utils/callbacks/nemo_model_checkpoint.py} (94%) rename nemo/utils/{preemption_callback.py => callbacks/preemption.py} (56%) diff --git a/nemo/collections/common/callbacks/__init__.py b/nemo/collections/common/callbacks/__init__.py index 13d956a937ab..0cf495d94696 100644 --- a/nemo/collections/common/callbacks/__init__.py +++ b/nemo/collections/common/callbacks/__init__.py @@ -14,4 +14,3 @@ from nemo.collections.common.callbacks.callbacks import LogEpochTimeCallback from nemo.collections.common.callbacks.ema import EMA -from nemo.collections.common.callbacks.nemomodelcheckpoint import NeMoModelCheckpoint diff --git a/nemo/utils/callbacks/__init__.py b/nemo/utils/callbacks/__init__.py new file mode 100644 index 000000000000..011e72a01fdf --- /dev/null +++ b/nemo/utils/callbacks/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.utils.callbacks.preemption import PreemptionCallback +from nemo.utils.callbacks.nemo_model_checkpoint import NeMoModelCheckpoint \ No newline at end of file diff --git a/nemo/collections/common/callbacks/nemomodelcheckpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py similarity index 94% rename from nemo/collections/common/callbacks/nemomodelcheckpoint.py rename to nemo/utils/callbacks/nemo_model_checkpoint.py index fa0151422e3c..5b6dc78eda2b 100644 --- a/nemo/collections/common/callbacks/nemomodelcheckpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -1,3 +1,17 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import re from copy import deepcopy diff --git a/nemo/utils/preemption_callback.py b/nemo/utils/callbacks/preemption.py similarity index 56% rename from nemo/utils/preemption_callback.py rename to nemo/utils/callbacks/preemption.py index 924728b3fcd7..6b903771410d 100644 --- a/nemo/utils/preemption_callback.py +++ b/nemo/utils/callbacks/preemption.py @@ -1,11 +1,31 @@ -import pytorch_lightning as pl -from pytorch_lightning.callbacks import Callback, ModelCheckpoint -from nemo.collections.common.callbacks.nemomodelcheckpoint import NeMoModelCheckpoint +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import signal import torch import sys +from pytorch_lightning.callbacks import Callback +from nemo.utils import logging class PreemptionCallback(Callback): + """ + PreemptionCallback class creates a callback that checks for preemption during training at the end of every step. + Upon preemption the callback provides function to gracefully exit the training immediately and also saves the state of training (to be able to start from the + same step without wasting any compute while resuming the next time). + + PreemptionCallback is always enabled. + """ def __init__(self, device, checkpoint_callback, sig=signal.SIGTERM): self.sig = sig @@ -24,6 +44,11 @@ def on_train_start(self, trainer, pl_module): self.released = False self.original_handler = signal.getsignal(self.sig) + if pl_module.device.type == 'cuda': + assert torch.distributed.is_available() and torch.distributed.is_initialized(), "Preemption requires torch distributed to be initialized" + else: + logging.info("Preemption is supported only on GPUs") + def master_handler(signum, frame): self.release() self._interrupted = True @@ -48,13 +73,14 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) # distributed broadcast of "_interrupted" flag from rank 0 to all other # ranks, to avoid performance overheads it's best to store the result in # a regular local variable - interrupted = self.interrupted + #interrupted = self.interrupted(pl_module.device) + interrupted = self.interrupted() if interrupted: - print("Received SIGTERM, exiting") + logging.info("Received SIGTERM, exiting") monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer) self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) sys.exit(0) - + def release(self): if self.released: return False diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index bfedb712014a..d8295a94d3ed 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -39,7 +39,7 @@ from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.utilities import rank_zero_info -from nemo.collections.common.callbacks import EMA, NeMoModelCheckpoint +from nemo.collections.common.callbacks import EMA from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION from nemo.utils import logging, timers from nemo.utils.app_state import AppState @@ -49,7 +49,7 @@ from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger from nemo.utils.loggers import ClearMLLogger, ClearMLParams, DLLogger, DLLoggerParams, MLFlowParams from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank -from nemo.utils.preemption_callback import PreemptionCallback +from nemo.utils.callbacks import PreemptionCallback, NeMoModelCheckpoint class NotFoundError(NeMoBaseException): @@ -891,7 +891,6 @@ def configure_checkpointing( if 'mp_rank' in checkpoint_callback.last_model_path or 'tp_rank' in checkpoint_callback.last_model_path: checkpoint_callback.last_model_path = uninject_model_parallel_rank(checkpoint_callback.last_model_path) trainer.callbacks.append(checkpoint_callback) - # TODO create support_preemption arg, if True do the below preemption_callback = PreemptionCallback(torch.device('cuda'), checkpoint_callback) trainer.callbacks.append(preemption_callback) From 604b8c3c5d6320b4eab2eff2e0ba5951073be56a Mon Sep 17 00:00:00 2001 From: Abhishree Date: Thu, 16 Mar 2023 19:53:45 +0000 Subject: [PATCH 03/11] Minor edit in preemption.py Signed-off-by: Abhishree --- nemo/utils/callbacks/preemption.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index 6b903771410d..a25e70f1c85d 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -73,8 +73,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) # distributed broadcast of "_interrupted" flag from rank 0 to all other # ranks, to avoid performance overheads it's best to store the result in # a regular local variable - #interrupted = self.interrupted(pl_module.device) - interrupted = self.interrupted() + interrupted = self.interrupted if interrupted: logging.info("Received SIGTERM, exiting") monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer) From e418e165bfe1a5e26302b899cb8c41d9a57e507a Mon Sep 17 00:00:00 2001 From: Abhishree Date: Thu, 30 Mar 2023 16:21:26 +0000 Subject: [PATCH 04/11] Removing unused imports Signed-off-by: Abhishree --- nemo/utils/exp_manager.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index d8295a94d3ed..6abb63135ecf 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -14,17 +14,15 @@ import glob import os -import re import subprocess import sys import time import warnings -from copy import deepcopy from dataclasses import dataclass from datetime import timedelta from pathlib import Path from shutil import copy, move -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import pytorch_lightning import torch @@ -37,7 +35,6 @@ from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger, WandbLogger from pytorch_lightning.loops import TrainingEpochLoop from pytorch_lightning.strategies.ddp import DDPStrategy -from pytorch_lightning.utilities import rank_zero_info from nemo.collections.common.callbacks import EMA from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION @@ -48,7 +45,7 @@ from nemo.utils.get_rank import is_global_rank_zero from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger from nemo.utils.loggers import ClearMLLogger, ClearMLParams, DLLogger, DLLoggerParams, MLFlowParams -from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank +from nemo.utils.model_utils import uninject_model_parallel_rank from nemo.utils.callbacks import PreemptionCallback, NeMoModelCheckpoint From e632930b18e26bff8bb02a98667c3e2118f00a61 Mon Sep 17 00:00:00 2001 From: Abhishree Date: Thu, 30 Mar 2023 17:08:56 +0000 Subject: [PATCH 05/11] Remove device arg from PreemptionCallback class Signed-off-by: Abhishree --- nemo/utils/callbacks/preemption.py | 5 ++--- nemo/utils/exp_manager.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index a25e70f1c85d..b51d58f033b4 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -27,14 +27,13 @@ class PreemptionCallback(Callback): PreemptionCallback is always enabled. """ - def __init__(self, device, checkpoint_callback, sig=signal.SIGTERM): + def __init__(self, checkpoint_callback, sig=signal.SIGTERM): self.sig = sig - self.device = device self.checkpoint_callback = checkpoint_callback @property def interrupted(self): - interrupted = torch.tensor(self._interrupted).int().to(self.device) + interrupted = torch.tensor(self._interrupted).int().to(torch.device('cuda')) torch.distributed.broadcast(interrupted, 0) interrupted = bool(interrupted.item()) return interrupted diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 6abb63135ecf..35c99537fb2d 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -888,7 +888,7 @@ def configure_checkpointing( if 'mp_rank' in checkpoint_callback.last_model_path or 'tp_rank' in checkpoint_callback.last_model_path: checkpoint_callback.last_model_path = uninject_model_parallel_rank(checkpoint_callback.last_model_path) trainer.callbacks.append(checkpoint_callback) - preemption_callback = PreemptionCallback(torch.device('cuda'), checkpoint_callback) + preemption_callback = PreemptionCallback(checkpoint_callback) trainer.callbacks.append(preemption_callback) def check_slurm(trainer): From 69a0c97a90ff6885266acdaf21b627ee18fa2ab8 Mon Sep 17 00:00:00 2001 From: Abhishree Date: Thu, 30 Mar 2023 21:55:10 +0000 Subject: [PATCH 06/11] Add more details in the NeMoModelCheckpointdocstring Signed-off-by: Abhishree --- nemo/utils/callbacks/nemo_model_checkpoint.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index 5b6dc78eda2b..50fd21b16ca8 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -30,7 +30,10 @@ from nemo.utils.get_rank import is_global_rank_zero class NeMoModelCheckpoint(ModelCheckpoint): - """ Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end + """ Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. + Extends Lightning's on_save_checkpoint func to save the .nemo file. Saves the .nemo file based + on the best checkpoint saved (according to the monitor value). + Also contains func to save the EMA copy of the model. """ def __init__( From 50cd904d008bfc7d96c6b0333bf947a1d4892da4 Mon Sep 17 00:00:00 2001 From: Abhishree Date: Mon, 3 Apr 2023 19:16:05 +0000 Subject: [PATCH 07/11] Add the following modifications: 1) Add boolean flag for createing preemption callback 2) Make sig arg in PreemptionCallback as None 3) Other minor modifications and code comments Signed-off-by: Abhishree --- nemo/utils/callbacks/preemption.py | 27 +++++++++++++++++++-------- nemo/utils/exp_manager.py | 14 ++++++++++---- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index b51d58f033b4..b81df7079b56 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -21,37 +21,48 @@ class PreemptionCallback(Callback): """ PreemptionCallback class creates a callback that checks for preemption during training at the end of every step. - Upon preemption the callback provides function to gracefully exit the training immediately and also saves the state of training (to be able to start from the - same step without wasting any compute while resuming the next time). + Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the state of training + (to be able to start from the same step without wasting any compute while resuming the next time). - PreemptionCallback is always enabled. + PreemptionCallback is always enabled by default via the arg create_preemption_callback under ExpManagerConfig. To disable please pass + create_preemption_callback: False in your config file. """ - def __init__(self, checkpoint_callback, sig=signal.SIGTERM): + def __init__(self, checkpoint_callback, sig=None): self.sig = sig + if self.sig is None: + self.sig = signal.SIGTERM self.checkpoint_callback = checkpoint_callback @property def interrupted(self): - interrupted = torch.tensor(self._interrupted).int().to(torch.device('cuda')) + interrupted = torch.tensor(self._interrupted, device=torch.cuda.current_device(), dtype=torch.int32) torch.distributed.broadcast(interrupted, 0) interrupted = bool(interrupted.item()) return interrupted def on_train_start(self, trainer, pl_module): + """ + Defines custom handlers at the beginning of training to be executed when the + preemption signal is received. + """ + # Bool var that's initialized to false and made True upon receving the preemption signal self._interrupted = False self.released = False self.original_handler = signal.getsignal(self.sig) + # Check if torch distributed is initialised, as its needed for broadcasting the preemption signal to all the ranks if pl_module.device.type == 'cuda': assert torch.distributed.is_available() and torch.distributed.is_initialized(), "Preemption requires torch distributed to be initialized" else: logging.info("Preemption is supported only on GPUs") + # Master handler executed only by rank 0 when the preemption siganal is received, to avoid deadlock conditions def master_handler(signum, frame): self.release() self._interrupted = True - + + # Handler executed by the non zero ranks def ignoring_handler(signum, frame): self.release() @@ -67,8 +78,8 @@ def on_train_end(self, trainer, pl_module): self.release() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int): - # check if the job was preempted - # NOTE: "timeout_handler.interrupted" is a property which triggers a + # check if the job was preempted at the end of every training step/iteration + # NOTE: "self.interrupted" is a property which triggers a # distributed broadcast of "_interrupted" flag from rank 0 to all other # ranks, to avoid performance overheads it's best to store the result in # a regular local variable diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 35c99537fb2d..0bfffac52b1e 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -155,6 +155,7 @@ class ExpManagerConfig: checkpoint_callback_params: Optional[CallbackParams] = CallbackParams() create_early_stopping_callback: Optional[bool] = False early_stopping_callback_params: Optional[EarlyStoppingParams] = EarlyStoppingParams() + create_preemption_callback: Optional[bool] = True # Additional exp_manager arguments files_to_copy: Optional[List[str]] = None # logs timing of train/val/test steps @@ -280,6 +281,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo Defaults to True. - create_early_stopping_callback (bool): Flag to decide if early stopping should be used to stop training. Default is False. See EarlyStoppingParams dataclass above. + - create_preemption_callback (bool): Flag to decide whether to enable preemption callback to save checkpoints and exit training + immediately upon preemption. Default is True. - files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which copies no files. - log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False. @@ -439,7 +442,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo if cfg.create_checkpoint_callback: configure_checkpointing( - trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params + trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params, cfg.create_preemption_callback ) if cfg.disable_validation_on_resume: @@ -830,7 +833,7 @@ def configure_loggers( def configure_checkpointing( - trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', + trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', create_preemption_callback: bool ): """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback @@ -888,8 +891,11 @@ def configure_checkpointing( if 'mp_rank' in checkpoint_callback.last_model_path or 'tp_rank' in checkpoint_callback.last_model_path: checkpoint_callback.last_model_path = uninject_model_parallel_rank(checkpoint_callback.last_model_path) trainer.callbacks.append(checkpoint_callback) - preemption_callback = PreemptionCallback(checkpoint_callback) - trainer.callbacks.append(preemption_callback) + if create_preemption_callback: + ## By default PreemptionCallback handles SIGTERM. To handle other signals pass the signal in the call as below: + ## PreemptionCallback(checkpoint_callback, signal.SIGCHLD) + preemption_callback = PreemptionCallback(checkpoint_callback) + trainer.callbacks.append(preemption_callback) def check_slurm(trainer): try: From 3b56d55a821a5690b03f11f4af44ae6f88613c7b Mon Sep 17 00:00:00 2001 From: Abhishree Date: Wed, 5 Apr 2023 18:26:21 +0000 Subject: [PATCH 08/11] Modify torch cuda and distributed available checks to skip preemption if unavailable Signed-off-by: Abhishree --- nemo/utils/callbacks/preemption.py | 18 ++++++++++-------- nemo/utils/exp_manager.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index b81df7079b56..eb03be87805f 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -21,7 +21,7 @@ class PreemptionCallback(Callback): """ PreemptionCallback class creates a callback that checks for preemption during training at the end of every step. - Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the state of training + Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the current state in a checkpoint as *last.ckpt. (to be able to start from the same step without wasting any compute while resuming the next time). PreemptionCallback is always enabled by default via the arg create_preemption_callback under ExpManagerConfig. To disable please pass @@ -46,17 +46,19 @@ def on_train_start(self, trainer, pl_module): Defines custom handlers at the beginning of training to be executed when the preemption signal is received. """ + + # Check if torch distributed is initialised, as its needed for broadcasting the preemption signal to all the ranks + if not (torch.distributed.is_available() and torch.distributed.is_initialized()): + logging.info("Preemption requires torch distributed to be initialized, disabling preemption") + #Remove the callback from the callbacks list + trainer.callbacks.remove(self) + return + # Bool var that's initialized to false and made True upon receving the preemption signal self._interrupted = False self.released = False self.original_handler = signal.getsignal(self.sig) - # Check if torch distributed is initialised, as its needed for broadcasting the preemption signal to all the ranks - if pl_module.device.type == 'cuda': - assert torch.distributed.is_available() and torch.distributed.is_initialized(), "Preemption requires torch distributed to be initialized" - else: - logging.info("Preemption is supported only on GPUs") - # Master handler executed only by rank 0 when the preemption siganal is received, to avoid deadlock conditions def master_handler(signum, frame): self.release() @@ -85,7 +87,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) # a regular local variable interrupted = self.interrupted if interrupted: - logging.info("Received SIGTERM, exiting") + logging.info("Received SIGTERM, saving checkpoint and exiting") monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer) self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) sys.exit(0) diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 0bfffac52b1e..2efa8c7eb71d 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -892,10 +892,15 @@ def configure_checkpointing( checkpoint_callback.last_model_path = uninject_model_parallel_rank(checkpoint_callback.last_model_path) trainer.callbacks.append(checkpoint_callback) if create_preemption_callback: - ## By default PreemptionCallback handles SIGTERM. To handle other signals pass the signal in the call as below: - ## PreemptionCallback(checkpoint_callback, signal.SIGCHLD) - preemption_callback = PreemptionCallback(checkpoint_callback) - trainer.callbacks.append(preemption_callback) + # Check if cuda is avialable as preemption is supported only on GPUs + if torch.cuda.is_available(): + ## By default PreemptionCallback handles SIGTERM. To handle other signals pass the signal in the call as below: + ## PreemptionCallback(checkpoint_callback, signal.SIGCHLD) + preemption_callback = PreemptionCallback(checkpoint_callback) + trainer.callbacks.append(preemption_callback) + else: + logging.info("Preemption is supported only on GPUs, disabling preemption") + def check_slurm(trainer): try: From 772239d382655d9dc73c4a76e1d70f64f59b9e62 Mon Sep 17 00:00:00 2001 From: Abhishree Date: Thu, 6 Apr 2023 06:21:58 +0000 Subject: [PATCH 09/11] Add preemption_enabled flag to preemption.py Signed-off-by: Abhishree --- nemo/utils/callbacks/preemption.py | 69 +++++++++++++++--------------- 1 file changed, 35 insertions(+), 34 deletions(-) diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index eb03be87805f..07ab893c7b2d 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -33,6 +33,7 @@ def __init__(self, checkpoint_callback, sig=None): if self.sig is None: self.sig = signal.SIGTERM self.checkpoint_callback = checkpoint_callback + self.preemption_enabled = False @property def interrupted(self): @@ -46,51 +47,51 @@ def on_train_start(self, trainer, pl_module): Defines custom handlers at the beginning of training to be executed when the preemption signal is received. """ - + # Check if torch distributed is initialised, as its needed for broadcasting the preemption signal to all the ranks if not (torch.distributed.is_available() and torch.distributed.is_initialized()): logging.info("Preemption requires torch distributed to be initialized, disabling preemption") - #Remove the callback from the callbacks list - trainer.callbacks.remove(self) - return - - # Bool var that's initialized to false and made True upon receving the preemption signal - self._interrupted = False - self.released = False - self.original_handler = signal.getsignal(self.sig) + else: + self.preemption_enabled = True + # Bool var that's initialized to false and made True upon receving the preemption signal + self._interrupted = False + self.released = False + self.original_handler = signal.getsignal(self.sig) - # Master handler executed only by rank 0 when the preemption siganal is received, to avoid deadlock conditions - def master_handler(signum, frame): - self.release() - self._interrupted = True - - # Handler executed by the non zero ranks - def ignoring_handler(signum, frame): - self.release() + # Master handler executed only by rank 0 when the preemption siganal is received, to avoid deadlock conditions + def master_handler(signum, frame): + self.release() + self._interrupted = True + + # Handler executed by the non zero ranks + def ignoring_handler(signum, frame): + self.release() - self.private_rank = torch.distributed.get_rank() - if self.private_rank == 0: - signal.signal(self.sig, master_handler) - else: - signal.signal(self.sig, ignoring_handler) + self.private_rank = torch.distributed.get_rank() + if self.private_rank == 0: + signal.signal(self.sig, master_handler) + else: + signal.signal(self.sig, ignoring_handler) return self def on_train_end(self, trainer, pl_module): - self.release() + if self.preemption_enabled: + self.release() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int): - # check if the job was preempted at the end of every training step/iteration - # NOTE: "self.interrupted" is a property which triggers a - # distributed broadcast of "_interrupted" flag from rank 0 to all other - # ranks, to avoid performance overheads it's best to store the result in - # a regular local variable - interrupted = self.interrupted - if interrupted: - logging.info("Received SIGTERM, saving checkpoint and exiting") - monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer) - self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) - sys.exit(0) + if self.preemption_enabled: + # check if the job was preempted at the end of every training step/iteration + # NOTE: "self.interrupted" is a property which triggers a + # distributed broadcast of "_interrupted" flag from rank 0 to all other + # ranks, to avoid performance overheads it's best to store the result in + # a regular local variable + interrupted = self.interrupted + if interrupted: + logging.info("Received SIGTERM, saving checkpoint and exiting") + monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer) + self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) + sys.exit(0) def release(self): if self.released: From 0a329da3ece0722d4c230a81fbf78b89a447c478 Mon Sep 17 00:00:00 2001 From: Abhishree Date: Thu, 6 Apr 2023 19:11:01 +0000 Subject: [PATCH 10/11] Update nemo_model_checkpoint.py with the latest NemoModelCheckpoint class from exp_manager.py Signed-off-by: Abhishree --- nemo/utils/callbacks/nemo_model_checkpoint.py | 93 +++++++++++++------ nemo/utils/callbacks/preemption.py | 2 +- 2 files changed, 64 insertions(+), 31 deletions(-) diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index 50fd21b16ca8..b00b52bf648b 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -16,7 +16,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import pytorch_lightning import torch @@ -87,7 +87,7 @@ def nemo_topk_check_previous_run(self): self.best_model_score = None self.best_model_path = "" - checkpoints = list(Path(self.dirpath).rglob("*.ckpt")) + checkpoints = list(path for path in self._saved_checkpoint_paths if not self._is_ema_filepath(path)) for checkpoint in checkpoints: if 'mp_rank' in str(checkpoint) or 'tp_rank' in str(checkpoint): checkpoint = uninject_model_parallel_rank(checkpoint) @@ -114,10 +114,16 @@ def nemo_topk_check_previous_run(self): else: models_to_delete = len(best_k_models) - self.save_top_k logging.debug(f'Number of models to delete: {models_to_delete}') + + # If EMA enabled, delete the additional EMA weights + ema_enabled = self._has_ema_ckpts(self._saved_checkpoint_paths) + for _ in range(models_to_delete): model = best_k_models.pop(-1) self.best_k_models.pop(model) self._del_model_without_trainer(model) + if ema_enabled and self._fs.exists(self._ema_format_filepath(model)): + self._del_model_without_trainer(self._ema_format_filepath(model)) logging.debug(f"Removed checkpoint: {model}") self.kth_best_model_path = best_k_models[-1] @@ -125,38 +131,47 @@ def nemo_topk_check_previous_run(self): self.best_model_score = self.best_k_models[self.best_model_path] def on_save_checkpoint(self, trainer, pl_module, checkpoint): - # output = None output = super().on_save_checkpoint(trainer, pl_module, checkpoint) if not self.always_save_nemo: return output + # Load the best model and then re-save it + app_state = AppState() + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + logging.warning(f'always_save_nemo will slow down training for model_parallel > 1.') + # since we are creating tarfile artifacts we need to update .nemo path + app_state.model_restore_path = os.path.abspath( + os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix)) + ) + if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: + maybe_injected_best_model_path = inject_model_parallel_rank(self.best_model_path) else: - # Load the best model and then re-save it - app_state = AppState() - if app_state.model_parallel_size is not None and app_state.model_parallel_size > 1: - raise ValueError(f'always_save_nemo is not implemented for model parallel models.') - # since we are creating tarfile artifacts we need to update .nemo path - app_state.model_restore_path = os.path.abspath( - os.path.expanduser(os.path.join(self.dirpath, self.prefix + self.postfix)) - ) - if self.save_best_model: - if not os.path.exists(self.best_model_path): - return output - - if self.best_model_path == self.previous_best_path: - return output - - self.previous_model_path = self.best_model_path - old_state_dict = deepcopy(pl_module.state_dict()) - checkpoint = torch.load(self.best_model_path, map_location='cpu') - if 'state_dict' in checkpoint: - checkpoint = checkpoint['state_dict'] - # get a new instanace of the model - pl_module.load_state_dict(checkpoint, strict=True) - pl_module.save_to(save_path=app_state.model_restore_path) - pl_module.load_state_dict(old_state_dict, strict=True) - else: - pl_module.save_to(save_path=app_state.model_restore_path) - return output + maybe_injected_best_model_path = self.best_model_path + + if self.save_best_model: + if not os.path.exists(maybe_injected_best_model_path): + return + + if self.best_model_path == self.previous_best_path: + return output + + self.previous_model_path = self.best_model_path + old_state_dict = deepcopy(pl_module.state_dict()) + checkpoint = torch.load(maybe_injected_best_model_path, map_location='cpu') + if 'state_dict' in checkpoint: + checkpoint = checkpoint['state_dict'] + # get a new instanace of the model + pl_module.load_state_dict(checkpoint, strict=True) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + pl_module.save_to(save_path=app_state.model_restore_path) + logging.info(f"New best .nemo model saved to: {app_state.model_restore_path}") + pl_module.load_state_dict(old_state_dict, strict=True) + else: + if torch.distributed.is_initialized(): + torch.distributed.barrier() + pl_module.save_to(save_path=app_state.model_restore_path) + logging.info(f"New .nemo model saved to: {app_state.model_restore_path}") + return output def on_train_end(self, trainer, pl_module): if trainer.fast_dev_run: @@ -227,5 +242,23 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) else: super()._save_checkpoint(trainer, filepath) + def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str) -> None: + super()._remove_checkpoint(trainer, filepath) + ema_callback = self._ema_callback(trainer) + if ema_callback is not None: + # remove EMA copy of the state dict as well. + filepath = self._ema_format_filepath(filepath) + super()._remove_checkpoint(trainer, filepath) + def _ema_format_filepath(self, filepath: str) -> str: return filepath.replace(self.FILE_EXTENSION, f'-EMA{self.FILE_EXTENSION}') + + def _has_ema_ckpts(self, checkpoints: Iterable[Path]) -> bool: + return any(self._is_ema_filepath(checkpoint_path) for checkpoint_path in checkpoints) + + def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: + return str(filepath).endswith(f'-EMA{self.FILE_EXTENSION}') + + @property + def _saved_checkpoint_paths(self) -> Iterable[Path]: + return Path(self.dirpath).rglob("*.ckpt") \ No newline at end of file diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index 07ab893c7b2d..d241668ea09c 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -47,7 +47,7 @@ def on_train_start(self, trainer, pl_module): Defines custom handlers at the beginning of training to be executed when the preemption signal is received. """ - + # Check if torch distributed is initialised, as its needed for broadcasting the preemption signal to all the ranks if not (torch.distributed.is_available() and torch.distributed.is_initialized()): logging.info("Preemption requires torch distributed to be initialized, disabling preemption") From ae9217c5fdb38c1a558b87851d3b4888cc0130c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Apr 2023 19:13:48 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/utils/callbacks/__init__.py | 2 +- nemo/utils/callbacks/nemo_model_checkpoint.py | 5 +++-- nemo/utils/callbacks/preemption.py | 11 +++++++---- nemo/utils/exp_manager.py | 16 +++++++++++++--- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/nemo/utils/callbacks/__init__.py b/nemo/utils/callbacks/__init__.py index 011e72a01fdf..6623657a2dc2 100644 --- a/nemo/utils/callbacks/__init__.py +++ b/nemo/utils/callbacks/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.utils.callbacks.nemo_model_checkpoint import NeMoModelCheckpoint from nemo.utils.callbacks.preemption import PreemptionCallback -from nemo.utils.callbacks.nemo_model_checkpoint import NeMoModelCheckpoint \ No newline at end of file diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index b00b52bf648b..a9292b7bb765 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -26,8 +26,9 @@ from nemo.collections.common.callbacks import EMA from nemo.utils import logging from nemo.utils.app_state import AppState -from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank from nemo.utils.get_rank import is_global_rank_zero +from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank + class NeMoModelCheckpoint(ModelCheckpoint): """ Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end. @@ -261,4 +262,4 @@ def _is_ema_filepath(self, filepath: Union[Path, str]) -> bool: @property def _saved_checkpoint_paths(self) -> Iterable[Path]: - return Path(self.dirpath).rglob("*.ckpt") \ No newline at end of file + return Path(self.dirpath).rglob("*.ckpt") diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index d241668ea09c..e9b5f95022f3 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import signal -import torch +import signal import sys + +import torch from pytorch_lightning.callbacks import Callback + from nemo.utils import logging + class PreemptionCallback(Callback): """ PreemptionCallback class creates a callback that checks for preemption during training at the end of every step. @@ -62,7 +65,7 @@ def on_train_start(self, trainer, pl_module): def master_handler(signum, frame): self.release() self._interrupted = True - + # Handler executed by the non zero ranks def ignoring_handler(signum, frame): self.release() @@ -92,7 +95,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer) self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) sys.exit(0) - + def release(self): if self.released: return False diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 2efa8c7eb71d..af3b25eb73bd 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -40,13 +40,13 @@ from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION from nemo.utils import logging, timers from nemo.utils.app_state import AppState +from nemo.utils.callbacks import NeMoModelCheckpoint, PreemptionCallback from nemo.utils.env_var_parsing import get_envbool from nemo.utils.exceptions import NeMoBaseException from nemo.utils.get_rank import is_global_rank_zero from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger from nemo.utils.loggers import ClearMLLogger, ClearMLParams, DLLogger, DLLoggerParams, MLFlowParams from nemo.utils.model_utils import uninject_model_parallel_rank -from nemo.utils.callbacks import PreemptionCallback, NeMoModelCheckpoint class NotFoundError(NeMoBaseException): @@ -442,7 +442,12 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo if cfg.create_checkpoint_callback: configure_checkpointing( - trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params, cfg.create_preemption_callback + trainer, + log_dir, + checkpoint_name, + cfg.resume_if_exists, + cfg.checkpoint_callback_params, + cfg.create_preemption_callback, ) if cfg.disable_validation_on_resume: @@ -833,7 +838,12 @@ def configure_loggers( def configure_checkpointing( - trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', create_preemption_callback: bool + trainer: 'pytorch_lightning.Trainer', + log_dir: Path, + name: str, + resume: bool, + params: 'DictConfig', + create_preemption_callback: bool, ): """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback