Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 14, 2023
1 parent 4f2287b commit b50e092
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
3 changes: 2 additions & 1 deletion nemo/collections/common/callbacks/nemomodelcheckpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from nemo.collections.common.callbacks import EMA
from nemo.utils import logging
from nemo.utils.app_state import AppState
from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank
from nemo.utils.get_rank import is_global_rank_zero
from nemo.utils.model_utils import inject_model_parallel_rank, uninject_model_parallel_rank


class NeMoModelCheckpoint(ModelCheckpoint):
""" Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end
Expand Down
1 change: 1 addition & 0 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ def configure_checkpointing(
preemption_callback = PreemptionCallback(torch.device('cuda'), checkpoint_callback)
trainer.callbacks.append(preemption_callback)


def check_slurm(trainer):
try:
return trainer.accelerator_connector.is_slurm_managing_tasks
Expand Down
10 changes: 6 additions & 4 deletions nemo/utils/preemption_callback.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import signal
import sys

import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from nemo.collections.common.callbacks.nemomodelcheckpoint import NeMoModelCheckpoint
import signal
import torch
import sys

class PreemptionCallback(Callback):

class PreemptionCallback(Callback):
def __init__(self, device, checkpoint_callback, sig=signal.SIGTERM):
self.sig = sig
self.device = device
Expand Down

0 comments on commit b50e092

Please sign in to comment.