Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 5, 2023
1 parent c04f57a commit dbb20f4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 10 deletions.
2 changes: 1 addition & 1 deletion nemo/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.utils.callbacks.nemo_model_checkpoint import NeMoModelCheckpoint
from nemo.utils.callbacks.preemption import PreemptionCallback
from nemo.utils.callbacks.nemo_model_checkpoint import NeMoModelCheckpoint
3 changes: 2 additions & 1 deletion nemo/utils/callbacks/nemo_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
from nemo.collections.common.callbacks import EMA
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank


class NeMoModelCheckpoint(ModelCheckpoint):
""" Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end.
Expand Down
13 changes: 8 additions & 5 deletions nemo/utils/callbacks/preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import signal
import torch
import signal
import sys

import torch
from pytorch_lightning.callbacks import Callback

from nemo.utils import logging


class PreemptionCallback(Callback):
"""
PreemptionCallback class creates a callback that checks for preemption during training at the end of every step.
Expand Down Expand Up @@ -50,7 +53,7 @@ def on_train_start(self, trainer, pl_module):
# Check if torch distributed is initialised, as its needed for broadcasting the preemption signal to all the ranks
if not (torch.distributed.is_available() and torch.distributed.is_initialized()):
logging.info("Preemption requires torch distributed to be initialized, disabling preemption")
#Remove the callback from the callbacks list
# Remove the callback from the callbacks list
trainer.callbacks.remove(self)
return

Expand All @@ -63,7 +66,7 @@ def on_train_start(self, trainer, pl_module):
def master_handler(signum, frame):
self.release()
self._interrupted = True

# Handler executed by the non zero ranks
def ignoring_handler(signum, frame):
self.release()
Expand Down Expand Up @@ -91,7 +94,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int)
monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer)
self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates)
sys.exit(0)

def release(self):
if self.released:
return False
Expand Down
16 changes: 13 additions & 3 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION
from nemo.utils import logging, timers
from nemo.utils.app_state import AppState
from nemo.utils.callbacks import NeMoModelCheckpoint, PreemptionCallback
from nemo.utils.env_var_parsing import get_envbool
from nemo.utils.exceptions import NeMoBaseException
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.lightning_logger_patch import add_filehandlers_to_pl_logger
from nemo.utils.loggers import ClearMLLogger, ClearMLParams, DLLogger, DLLoggerParams, MLFlowParams
from nemo.utils.model_utils import uninject_model_parallel_rank
from nemo.utils.callbacks import PreemptionCallback, NeMoModelCheckpoint


class NotFoundError(NeMoBaseException):
Expand Down Expand Up @@ -442,7 +442,12 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo

if cfg.create_checkpoint_callback:
configure_checkpointing(
trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params, cfg.create_preemption_callback
trainer,
log_dir,
checkpoint_name,
cfg.resume_if_exists,
cfg.checkpoint_callback_params,
cfg.create_preemption_callback,
)

if cfg.disable_validation_on_resume:
Expand Down Expand Up @@ -833,7 +838,12 @@ def configure_loggers(


def configure_checkpointing(
trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', create_preemption_callback: bool
trainer: 'pytorch_lightning.Trainer',
log_dir: Path,
name: str,
resume: bool,
params: 'DictConfig',
create_preemption_callback: bool,
):
""" Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
callback
Expand Down

0 comments on commit dbb20f4

Please sign in to comment.