From c7451b3ccf742b0e8971332caf2e041ceabd9fe8 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 17 Sep 2021 20:13:59 +0100 Subject: [PATCH] [Feat] Add graceful detection of signal to exit + SignalConnector and merge SlurmConnector. (#9566) Co-authored-by: Sean Naren --- CHANGELOG.md | 1 + pyproject.toml | 1 + .../connectors/checkpoint_connector.py | 5 +- .../trainer/connectors/signal_connector.py | 110 ++++++++++++++++++ .../trainer/connectors/slurm_connector.py | 60 ---------- pytorch_lightning/trainer/trainer.py | 8 +- .../connectors/test_signal_connector.py | 53 +++++++++ 7 files changed, 172 insertions(+), 66 deletions(-) create mode 100644 pytorch_lightning/trainer/connectors/signal_connector.py delete mode 100644 pytorch_lightning/trainer/connectors/slurm_connector.py create mode 100644 tests/trainer/connectors/test_signal_connector.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c66330725d6a1..ce7d562d35c0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added partial support for global random state fault-tolerance in map-style datasets ([#8950](https://github.com/PyTorchLightning/pytorch-lightning/pull/8950)) * Converted state to tuple explicitly when setting Python random state ([#9401](https://github.com/PyTorchLightning/pytorch-lightning/pull/9401)) * Added support for restarting an optimizer loop (multiple optimizers) ([#9537](https://github.com/PyTorchLightning/pytorch-lightning/pull/9537)) + * Added mechanism to detect a signal has been sent so the Trainer can gracefully exit ([#9566](https://github.com/PyTorchLightning/pytorch-lightning/pull/9566)) - Checkpoint saving & loading extensibility: diff --git a/pyproject.toml b/pyproject.toml index 9981d6827e33d..ed22f853107bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ module = [ "pytorch_lightning.loops.evaluation_loop", "pytorch_lightning.trainer.connectors.checkpoint_connector", "pytorch_lightning.trainer.connectors.logger_connector.*", + "pytorch_lightning.trainer.connectors.signal_connector", "pytorch_lightning.trainer.progress", "pytorch_lightning.tuner.auto_gpu_select", "pytorch_lightning.utilities.apply_func", diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5376a81b658d6..b750b0f81b26f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -274,14 +274,15 @@ def restore_lr_schedulers(self) -> None: # PRIVATE OPS # ---------------------------------- - def hpc_save(self, folderpath: str, logger: LightningLoggerBase) -> str: + def hpc_save(self, folderpath: str, logger: Optional[LightningLoggerBase]) -> str: # make sure the checkpoint folder exists folderpath = str(folderpath) # because the tests pass a path object fs = get_filesystem(folderpath) fs.makedirs(folderpath, exist_ok=True) # save logger to make sure we get all the metrics - logger.save() + if logger: + logger.finalize("finished") max_suffix = self.max_ckpt_version_in_folder(folderpath) ckpt_number = (max_suffix if max_suffix is not None else 0) + 1 diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py new file mode 100644 index 0000000000000..8e21ffc6dd44c --- /dev/null +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -0,0 +1,110 @@ +import logging +import os +import signal +import sys +from signal import Signals +from subprocess import call +from types import FrameType, FunctionType +from typing import Callable, List, Union + +import pytorch_lightning as pl +from pytorch_lightning.utilities.imports import _fault_tolerant_training + +log = logging.getLogger(__name__) + + +class HandlersCompose: + def __init__(self, signal_handlers: Union[List[Callable], Callable]): + if not isinstance(signal_handlers, list): + signal_handlers = [signal_handlers] + self.signal_handlers = signal_handlers + + def __call__(self, signum: Signals, frame: FrameType) -> None: + for signal_handler in self.signal_handlers: + signal_handler(signum, frame) + + +class SignalConnector: + def __init__(self, trainer: "pl.Trainer"): + self.trainer = trainer + self.trainer._terminate_gracefully = False + + def register_signal_handlers(self) -> None: + sigusr1_handlers: List[Callable] = [] + sigterm_handlers: List[Callable] = [] + + if _fault_tolerant_training(): + sigusr1_handlers.append(self.fault_tolerant_sigusr1_handler_fn) + + if self._is_on_slurm(): + log.info("Set SLURM handle signals.") + sigusr1_handlers.append(self.slurm_sigusr1_handler_fn) + + sigterm_handlers.append(self.sigterm_handler_fn) + + # signal.SIGUSR1 doesn't seem available on windows + if not self._is_on_windows(): + if not self._has_already_handler(signal.SIGUSR1): + signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) + + if not self._has_already_handler(signal.SIGTERM): + signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) + + def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: + if self.trainer.is_global_zero: + # save weights + log.info("handling SIGUSR1") + self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger) + + # find job id + job_id = os.environ["SLURM_JOB_ID"] + cmd = ["scontrol", "requeue", job_id] + + # requeue job + log.info(f"requeing job {job_id}...") + try: + result = call(cmd) + except FileNotFoundError: + # This can occur if a subprocess call to `scontrol` is run outside a shell context + # Re-attempt call (now with shell context). If any error is raised, propagate to user. + # When running a shell command, it should be passed as a single string. + joint_cmd = [str(x) for x in cmd] + result = call(" ".join(joint_cmd), shell=True) + + # print result text + if result == 0: + log.info(f"requeued exp {job_id}") + else: + log.warning("requeue failed...") + + # close experiment to avoid issues + if self.trainer.logger: + self.trainer.logger.finalize("finished") + + def fault_tolerant_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: + self.trainer._terminate_gracefully = True + + def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None: + log.info("bypassing sigterm") + + def _is_on_slurm(self) -> bool: + # see if we're using slurm (not interactive) + on_slurm = False + try: + job_name = os.environ["SLURM_JOB_NAME"] + if job_name != "bash": + on_slurm = True + # todo: specify the possible exception + except Exception: + pass + + return on_slurm + + def _is_on_windows(self) -> bool: + return sys.platform == "win32" + + def _has_already_handler(self, signum: Signals) -> bool: + try: + return isinstance(signal.getsignal(signum), FunctionType) + except AttributeError: + return False diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py deleted file mode 100644 index 053e1397ba2a2..0000000000000 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -import os -import signal -from subprocess import call - -log = logging.getLogger(__name__) - - -class SLURMConnector: - def __init__(self, trainer): - self.trainer = trainer - - def register_slurm_signal_handlers(self): - # see if we're using slurm (not interactive) - on_slurm = False - try: - job_name = os.environ["SLURM_JOB_NAME"] - if job_name != "bash": - on_slurm = True - # todo: specify the possible exception - except Exception: - pass - - if on_slurm: - log.info("Set SLURM handle signals.") - signal.signal(signal.SIGUSR1, self.sig_handler) - signal.signal(signal.SIGTERM, self.term_handler) - - def sig_handler(self, signum, frame): # pragma: no-cover - if self.trainer.is_global_zero: - # save weights - log.info("handling SIGUSR1") - self.trainer.checkpoint_connector.hpc_save(self.trainer.weights_save_path, self.trainer.logger) - - # find job id - job_id = os.environ["SLURM_JOB_ID"] - cmd = ["scontrol", "requeue", job_id] - - # requeue job - log.info(f"requeing job {job_id}...") - try: - result = call(cmd) - except FileNotFoundError: - # This can occur if a subprocess call to `scontrol` is run outside a shell context - # Re-attempt call (now with shell context). If any error is raised, propagate to user. - # When running a shell command, it should be passed as a single string. - joint_cmd = [str(x) for x in cmd] - result = call(" ".join(joint_cmd), shell=True) - - # print result text - if result == 0: - log.info(f"requeued exp {job_id}") - else: - log.warning("requeue failed...") - - # close experiment to avoid issues - self.trainer.logger.close() - - def term_handler(self, signum, frame): # pragma: no-cover - log.info("bypassing sigterm") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 35dc481a39146..4846b7c117921 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -52,7 +52,7 @@ from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector -from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector +from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.deprecated_api import DeprecatedTrainerAttributes @@ -383,7 +383,7 @@ def __init__( self.debugging_connector = DebuggingConnector(self) self.training_tricks_connector = TrainingTricksConnector(self) self.checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint) - self.slurm_connector = SLURMConnector(self) + self.signal_connector = SignalConnector(self) self.tuner = Tuner(self) # max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1). @@ -1104,8 +1104,8 @@ def _pre_training_routine(self): # wait for all to join if on distributed self.accelerator.barrier("setup_training") - # register auto-resubmit when on SLURM - self.slurm_connector.register_slurm_signal_handlers() + # register signals + self.signal_connector.register_signal_handlers() self.checkpoint_connector.resume_end() diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py new file mode 100644 index 0000000000000..1ec62f3d3082f --- /dev/null +++ b/tests/trainer/connectors/test_signal_connector.py @@ -0,0 +1,53 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import signal +from time import sleep +from unittest import mock + +import pytest + +from pytorch_lightning import Trainer +from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + + +@pytest.mark.parametrize("register_handler", [False, True]) +@pytest.mark.parametrize("terminate_gracefully", [False, True]) +@RunIf(min_torch="1.7.0", skip_windows=True) +def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir): + + # hack to reset the signal + signal.signal(signal.SIGUSR1, 0) + + if register_handler: + + def handler(*_): + pass + + signal.signal(signal.SIGUSR1, handler) + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": str(int(terminate_gracefully))}): + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + if terminate_gracefully or register_handler: + os.kill(os.getpid(), signal.SIGUSR1) + sleep(0.1) + return super().training_step(batch, batch_idx) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0) + trainer.fit(model) + assert trainer._terminate_gracefully == (False if register_handler else terminate_gracefully)