Skip to content

Commit

Permalink
Add support to track any/list of user specified signals for preemption
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishree <[email protected]>
  • Loading branch information
athitten committed Apr 5, 2023
1 parent f3992c6 commit 4f3a280
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 31 deletions.
64 changes: 38 additions & 26 deletions nemo/utils/callbacks/preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ class PreemptionCallback(Callback):
PreemptionCallback is always enabled by default via the arg create_preemption_callback under ExpManagerConfig. To disable please pass
create_preemption_callback: False in your config file.
Args:
checkpoint_callback: NeMoModelCheckpoint object created in the exp_manager.py
sig: A list of signals that need to be tracked. Preemption along with saving the ckpt is carried out upon receving any of the signals in this list.
Its a parameter defined in the exp_manager and if its None, by default SIGTERM is tracked.
To handle specific signals, pass the sigal names as a list in the config yaml file.
"""

def __init__(self, checkpoint_callback, sig=None):
self.sig = sig
if self.sig is None:
self.sig = signal.SIGTERM
def __init__(self, checkpoint_callback, sig):
if sig is None:
sig = ['SIGTERM']
self.sig = [getattr(signal, key) for key in sig]
self.checkpoint_callback = checkpoint_callback
self.preempted_signum = 1

@property
def interrupted(self):
Expand All @@ -46,36 +53,41 @@ def on_train_start(self, trainer, pl_module):
Defines custom handlers at the beginning of training to be executed when the
preemption signal is received.
"""
# Bool var that's initialized to false and made True upon receving the preemption signal
self._interrupted = False
self.released = False
self.original_handler = signal.getsignal(self.sig)

# Check if torch distributed is initialised, as its needed for broadcasting the preemption signal to all the ranks
if pl_module.device.type == 'cuda':
assert torch.distributed.is_available() and torch.distributed.is_initialized(), "Preemption requires torch distributed to be initialized"
else:
logging.info("Preemption is supported only on GPUs")

# Master handler executed only by rank 0 when the preemption siganal is received, to avoid deadlock conditions
def master_handler(signum, frame):
self.release()
self._interrupted = True

# Handler executed by the non zero ranks
def ignoring_handler(signum, frame):
self.release()
# Loop through all signals and set up handlers
for sig in self.sig:
# Bool var that's initialized to false and made True upon receving the preemption signal
self._interrupted = False
self.released = False
self.original_handler = signal.getsignal(sig)

self.private_rank = torch.distributed.get_rank()
if self.private_rank == 0:
signal.signal(self.sig, master_handler)
else:
signal.signal(self.sig, ignoring_handler)
# Master handler executed only by rank 0 when the preemption signal is received, to avoid deadlock conditions
def master_handler(signum, frame):
self.release(sig)
self._interrupted = True
# Get the signum to print the signal name while logging
self.preempted_signum = signum

# Handler executed by the non zero ranks
def ignoring_handler(signum, frame):
self.release(sig)

self.private_rank = torch.distributed.get_rank()
if self.private_rank == 0:
signal.signal(sig, master_handler)
else:
signal.signal(sig, ignoring_handler)

return self

def on_train_end(self, trainer, pl_module):
self.release()
for sig in self.sig:
self.release(sig)

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int):
# check if the job was preempted at the end of every training step/iteration
Expand All @@ -85,15 +97,15 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int)
# a regular local variable
interrupted = self.interrupted
if interrupted:
logging.info("Received SIGTERM, exiting")
logging.info(f"Received {signal.Signals(self.preempted_signum).name}, saving checkpoint and exiting")
monitor_candidates = self.checkpoint_callback._monitor_candidates(trainer)
self.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates)
sys.exit(0)

def release(self):
def release(self,sig):
if self.released:
return False

signal.signal(self.sig, self.original_handler)
signal.signal(sig, self.original_handler)
self.released = True
return True
11 changes: 6 additions & 5 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class ExpManagerConfig:
create_early_stopping_callback: Optional[bool] = False
early_stopping_callback_params: Optional[EarlyStoppingParams] = EarlyStoppingParams()
create_preemption_callback: Optional[bool] = True
preemption_signal: Optional[List[str]] = None
# Additional exp_manager arguments
files_to_copy: Optional[List[str]] = None
# logs timing of train/val/test steps
Expand Down Expand Up @@ -283,6 +284,8 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
See EarlyStoppingParams dataclass above.
- create_preemption_callback (bool): Flag to decide whether to enable preemption callback to save checkpoints and exit training
immediately upon preemption. Default is True.
- preemption_signal (list): A list of signals to be tracked to preempt the code and save checkpoints. Defaults to None. If None,
tracks signal.SIGTERM by default. To track other signals pass preemption_signal: ['SIGINT', 'SIGUSR1'] in the config file.
- files_to_copy (list): A list of files to copy to the experiment logging directory. Defaults to None which
copies no files.
- log_local_rank_0_only (bool): Whether to only create log files for local rank 0. Defaults to False.
Expand Down Expand Up @@ -442,7 +445,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo

if cfg.create_checkpoint_callback:
configure_checkpointing(
trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params, cfg.create_preemption_callback
trainer, log_dir, checkpoint_name, cfg.resume_if_exists, cfg.checkpoint_callback_params, cfg.create_preemption_callback, cfg.preemption_signal
)

if cfg.disable_validation_on_resume:
Expand Down Expand Up @@ -833,7 +836,7 @@ def configure_loggers(


def configure_checkpointing(
trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', create_preemption_callback: bool
trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, resume: bool, params: 'DictConfig', create_preemption_callback: bool, preemption_signal: List
):
""" Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
callback
Expand Down Expand Up @@ -892,9 +895,7 @@ def configure_checkpointing(
checkpoint_callback.last_model_path = uninject_model_parallel_rank(checkpoint_callback.last_model_path)
trainer.callbacks.append(checkpoint_callback)
if create_preemption_callback:
## By default PreemptionCallback handles SIGTERM. To handle other signals pass the signal in the call as below:
## PreemptionCallback(checkpoint_callback, signal.SIGCHLD)
preemption_callback = PreemptionCallback(checkpoint_callback)
preemption_callback = PreemptionCallback(checkpoint_callback, preemption_signal)
trainer.callbacks.append(preemption_callback)

def check_slurm(trainer):
Expand Down

0 comments on commit 4f3a280

Please sign in to comment.