diff --git a/CHANGELOG.md b/CHANGELOG.md index c87e37a8d0aaf..0483db32b3b34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -199,6 +199,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611)) +- Fixed `SignalConnector._has_already_handler` check for callable type ([#10483](https://github.com/PyTorchLightning/pytorch-lightning/pull/10483)) + + - Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815)) diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 0659f45e558fc..563cf4d45dfc0 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -5,7 +5,7 @@ import threading from signal import Signals from subprocess import call -from types import FrameType, FunctionType +from types import FrameType from typing import Any, Callable, Dict, List, Set, Union import pytorch_lightning as pl @@ -138,10 +138,7 @@ def _is_on_windows() -> bool: @staticmethod def _has_already_handler(signum: Signals) -> bool: - try: - return isinstance(signal.getsignal(signum), FunctionType) - except AttributeError: - return False + return signal.getsignal(signum) is not signal.SIG_DFL @staticmethod def _register_signal(signum: Signals, handlers: HandlersCompose) -> None: diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index 07cc0648f9c3d..c27806a2b9a88 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -109,3 +109,27 @@ def test_signal_connector_in_thread(): with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]): assert future.exception() is None + + +def signal_handler(): + pass + + +class SignalHandlers: + def signal_handler(self): + pass + + +@pytest.mark.parametrize( + ["handler", "expected_return"], + [ + (signal.Handlers.SIG_IGN, True), + (signal.Handlers.SIG_DFL, False), + (signal_handler, True), + (SignalHandlers().signal_handler, True), + ], +) +def test_has_already_handler(handler, expected_return): + """Test that the SignalConnector detects whether a signal handler is already attached.""" + with mock.patch("pytorch_lightning.trainer.connectors.signal_connector.signal.getsignal", return_value=handler): + assert SignalConnector._has_already_handler(signal.SIGTERM) is expected_return