diff --git a/conftest.py b/conftest.py index c1e9ba4a9..1e9b5a141 100644 --- a/conftest.py +++ b/conftest.py @@ -31,6 +31,7 @@ import os import pathlib import shutil +import signal import sys import tempfile import typing as t @@ -206,6 +207,26 @@ def alloc_specs() -> t.Dict[str, t.Any]: return specs +def _reset_signal(signalnum: int): + """SmartSim will set/overwrite signals on occasion. This function will + return a generator that can be used as a fixture to automatically reset the + signal handler to what it was at the beginning of the test suite to keep + tests atomic. + """ + original = signal.getsignal(signalnum) + + def _reset(): + yield + signal.signal(signalnum, original) + + return _reset + + +_reset_signal_interrupt = pytest.fixture( + _reset_signal(signal.SIGINT), autouse=True, scope="function" +) + + @pytest.fixture def wlmutils() -> t.Type[WLMUtils]: return WLMUtils diff --git a/doc/changelog.rst b/doc/changelog.rst index 3e73101e1..411740e15 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -32,6 +32,7 @@ Description - Fix publishing of development docs - Update Experiment API typing - Minor enhancements to test suite +- Improve SmartSim experiment signal handlers Detailed Notes @@ -76,11 +77,17 @@ Detailed Notes undefined. (SmartSim-PR521_) - Remove previously deprecated behavior present in test suite on machines with Slurm and Open MPI. (SmartSim-PR520_) +- When calling ``Experiment.start`` SmartSim would register a signal handler + that would capture an interrupt signal (^C) to kill any jobs launched through + its ``JobManager``. This would replace the default (or user defined) signal + handler. SmartSim will now attempt to kill any launched jobs before calling + the previously registered signal handler. (SmartSim-PR535_) .. _SmartSim-PR537: https://github.com/CrayLabs/SmartSim/pull/537 .. _SmartSim-PR498: https://github.com/CrayLabs/SmartSim/pull/498 .. _SmartSim-PR460: https://github.com/CrayLabs/SmartSim/pull/460 .. _SmartSim-PR512: https://github.com/CrayLabs/SmartSim/pull/512 +.. _SmartSim-PR535: https://github.com/CrayLabs/SmartSim/pull/535 .. _SmartSim-PR529: https://github.com/CrayLabs/SmartSim/pull/529 .. _SmartSim-PR522: https://github.com/CrayLabs/SmartSim/pull/522 .. _SmartSim-PR521: https://github.com/CrayLabs/SmartSim/pull/521 diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 5c1de5cc2..989d66d2c 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -43,7 +43,11 @@ from smartsim._core.utils.network import get_ip_from_host from ..._core.launcher.step import Step -from ..._core.utils.helpers import unpack_colo_db_identifier, unpack_db_identifier +from ..._core.utils.helpers import ( + SignalInterceptionStack, + unpack_colo_db_identifier, + unpack_db_identifier, +) from ..._core.utils.redis import ( db_is_active, set_ml_model, @@ -71,6 +75,8 @@ from .manifest import LaunchedManifest, LaunchedManifestBuilder, Manifest if t.TYPE_CHECKING: + from types import FrameType + from ..utils.serialize import TStepLaunchMetaData @@ -113,8 +119,11 @@ def start( execution of all jobs. """ self._jobs.kill_on_interrupt = kill_on_interrupt + # register custom signal handler for ^C (SIGINT) - signal.signal(signal.SIGINT, self._jobs.signal_interrupt) + SignalInterceptionStack.get(signal.SIGINT).push_unique( + self._jobs.signal_interrupt + ) launched = self._launch(exp_name, exp_path, manifest) # start the job manager thread if not already started @@ -132,7 +141,7 @@ def start( # block until all non-database jobs are complete if block: # poll handles its own keyboard interrupt as - # it may be called seperately + # it may be called separately self.poll(5, True, kill_on_interrupt=kill_on_interrupt) @property diff --git a/smartsim/_core/control/jobmanager.py b/smartsim/_core/control/jobmanager.py index 89363d520..4910b8311 100644 --- a/smartsim/_core/control/jobmanager.py +++ b/smartsim/_core/control/jobmanager.py @@ -350,9 +350,9 @@ def set_db_hosts(self, orchestrator: Orchestrator) -> None: self.db_jobs[dbnode.name].hosts = dbnode.hosts def signal_interrupt(self, signo: int, _frame: t.Optional[FrameType]) -> None: + """Custom handler for whenever SIGINT is received""" if not signo: logger.warning("Received SIGINT with no signal number") - """Custom handler for whenever SIGINT is received""" if self.actively_monitoring and len(self) > 0: if self.kill_on_interrupt: for _, job in self().items(): diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index 9ae319883..b9e79e250 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -28,7 +28,9 @@ A file of helper functions for SmartSim """ import base64 +import collections.abc import os +import signal import typing as t import uuid from datetime import datetime @@ -38,6 +40,12 @@ from smartsim._core._install.builder import TRedisAIBackendStr as _TRedisAIBackendStr +if t.TYPE_CHECKING: + from types import FrameType + + +_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object] + def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]: """Unpack the unformatted database identifier @@ -302,3 +310,93 @@ def decode_cmd(encoded_cmd: str) -> t.List[str]: cleaned_cmd = decoded_cmd.decode("ascii").split("|") return cleaned_cmd + + +# TODO: Remove the ``type: ignore`` comment here when Python 3.8 support is dropped +# ``collections.abc.Collection`` is not subscriptable until Python 3.9 +@t.final +class SignalInterceptionStack(collections.abc.Collection): # type: ignore[type-arg] + """Registers a stack of unique callables to be called when a signal is + received before calling the original signal handler. + """ + + def __init__( + self, + signalnum: int, + callbacks: t.Optional[t.Iterable[_TSignalHandlerFn]] = None, + ) -> None: + """Set up a ``SignalInterceptionStack`` for particular signal number. + + .. note:: + This class typically should not be instanced directly as it will + change the registered signal handler regardless of if a signal + interception stack is already present. Instead, it is generally + best to create or get a signal interception stack for a particular + signal number via the `get` factory method. + + :param signalnum: The signal number to intercept + :type signalnum: int + :param callbacks: A iterable of functions to call upon receiving the signal + :type callbacks: t.Iterable[_TSignalHandlerFn] | None + """ + self._callbacks = list(callbacks) if callbacks else [] + self._original = signal.signal(signalnum, self) + + def __call__(self, signalnum: int, frame: t.Optional["FrameType"]) -> None: + """Handle the signal on which the interception stack was registered. + End by calling the originally registered signal hander (if present). + + :param frame: The current stack frame + :type frame: FrameType | None + """ + for fn in self: + fn(signalnum, frame) + if callable(self._original): + self._original(signalnum, frame) + + def __contains__(self, obj: object) -> bool: + return obj in self._callbacks + + def __iter__(self) -> t.Iterator[_TSignalHandlerFn]: + return reversed(self._callbacks) + + def __len__(self) -> int: + return len(self._callbacks) + + @classmethod + def get(cls, signalnum: int) -> "SignalInterceptionStack": + """Fetch an existing ``SignalInterceptionStack`` or create a new one + for a particular signal number. + + :param signalnum: The singal number of the signal interception stack + should be registered + :type signalnum: int + :returns: The existing or created signal interception stack + :rtype: SignalInterceptionStack + """ + handler = signal.getsignal(signalnum) + if isinstance(handler, cls): + return handler + return cls(signalnum, []) + + def push(self, fn: _TSignalHandlerFn) -> None: + """Add a callback to the signal interception stack. + + :param fn: A callable to add to the unique signal stack + :type fn: _TSignalHandlerFn + """ + self._callbacks.append(fn) + + def push_unique(self, fn: _TSignalHandlerFn) -> bool: + """Add a callback to the signal interception stack if and only if the + callback is not already present. + + :param fn: A callable to add to the unique signal stack + :type fn: _TSignalHandlerFn + :returns: True if the callback was added, False if the callback was + already present + :rtype: bool + """ + if did_push := fn not in self: + self.push(fn) + return did_push diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 025f53d32..523ed7191 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -24,6 +24,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import collections +import signal + import pytest from smartsim._core.utils import helpers @@ -68,3 +71,115 @@ def test_encode_raises_on_empty(): def test_decode_raises_on_empty(): with pytest.raises(ValueError): helpers.decode_cmd("") + + +class MockSignal: + def __init__(self): + self.signal_handlers = collections.defaultdict(lambda: signal.SIG_IGN) + + def signal(self, signalnum, handler): + orig = self.getsignal(signalnum) + self.signal_handlers[signalnum] = handler + return orig + + def getsignal(self, signalnum): + return self.signal_handlers[signalnum] + + +@pytest.fixture +def mock_signal(monkeypatch): + mock_signal = MockSignal() + monkeypatch.setattr(helpers, "signal", mock_signal) + yield mock_signal + + +def test_signal_intercept_stack_will_register_itself_with_callback_fn(mock_signal): + callback = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(callback) + assert isinstance(stack, helpers.SignalInterceptionStack) + assert stack is mock_signal.signal_handlers[signal.NSIG] + assert len(stack) == 1 + assert list(stack)[0] == callback + + +def test_signal_intercept_stack_keeps_track_of_previous_handlers(mock_signal): + default_handler = lambda num, frame: ... + mock_signal.signal_handlers[signal.NSIG] = default_handler + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(lambda n, f: ...) + assert stack._original is default_handler + + +def test_signal_intercept_stacks_are_registered_per_signal_number(mock_signal): + handler = lambda num, frame: ... + stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_1.push(handler) + stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG + 1) + stack_2.push(handler) + + assert mock_signal.signal_handlers[signal.NSIG] is stack_1 + assert mock_signal.signal_handlers[signal.NSIG + 1] is stack_2 + assert stack_1 is not stack_2 + assert list(stack_1) == list(stack_2) == [handler] + + +def test_signal_intercept_handlers_will_not_overwrite_if_handler_already_exists( + mock_signal, +): + handler_1 = lambda num, frame: ... + handler_2 = lambda num, frame: ... + stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_1.push(handler_1) + stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG) + stack_2.push(handler_2) + assert stack_1 is stack_2 is mock_signal.signal_handlers[signal.NSIG] + assert list(stack_1) == [handler_2, handler_1] + + +def test_signal_intercept_stack_can_add_multiple_instances_of_the_same_handler( + mock_signal, +): + handler = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push(handler) + stack.push(handler) + assert list(stack) == [handler, handler] + + +def test_signal_intercept_stack_enforces_that_unique_push_handlers_are_unique( + mock_signal, +): + handler = lambda num, frame: ... + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + assert stack.push_unique(handler) + assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(handler) + assert list(stack) == [handler] + + +def test_signal_intercept_stack_enforces_that_unique_push_method_handlers_are_unique( + mock_signal, +): + class C: + def fn(num, frame): ... + + c1 = C() + c2 = C() + stack = helpers.SignalInterceptionStack.get(signal.NSIG) + stack.push_unique(c1.fn) + assert helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c2.fn) + assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c1.fn) + assert list(stack) == [c2.fn, c1.fn] + + +def test_signal_handler_calls_functions_in_reverse_order(mock_signal): + called_list = [] + default = lambda num, frame: called_list.append("default") + handler_1 = lambda num, frame: called_list.append("handler_1") + handler_2 = lambda num, frame: called_list.append("handler_2") + + mock_signal.signal_handlers[signal.NSIG] = default + helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_1) + helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_2) + mock_signal.signal_handlers[signal.NSIG](signal.NSIG, None) + assert called_list == ["handler_2", "handler_1", "default"] diff --git a/tests/test_interrupt.py b/tests/test_interrupt.py index 28c48e0db..61dc5b8c0 100644 --- a/tests/test_interrupt.py +++ b/tests/test_interrupt.py @@ -65,20 +65,21 @@ def test_interrupt_blocked_jobs(test_dir): ) ensemble.set_path(test_dir) num_jobs = 1 + len(ensemble) - try: - pid = os.getpid() - keyboard_interrupt_thread = Thread( - name="sigint_thread", target=keyboard_interrupt, args=(pid,) - ) - keyboard_interrupt_thread.start() + pid = os.getpid() + keyboard_interrupt_thread = Thread( + name="sigint_thread", target=keyboard_interrupt, args=(pid,) + ) + keyboard_interrupt_thread.start() + + with pytest.raises(KeyboardInterrupt): exp.start(model, ensemble, block=True, kill_on_interrupt=True) - except KeyboardInterrupt: - time.sleep(2) # allow time for jobs to be stopped - active_jobs = exp._control._jobs.jobs - active_db_jobs = exp._control._jobs.db_jobs - completed_jobs = exp._control._jobs.completed - assert len(active_jobs) + len(active_db_jobs) == 0 - assert len(completed_jobs) == num_jobs + + time.sleep(2) # allow time for jobs to be stopped + active_jobs = exp._control._jobs.jobs + active_db_jobs = exp._control._jobs.db_jobs + completed_jobs = exp._control._jobs.completed + assert len(active_jobs) + len(active_db_jobs) == 0 + assert len(completed_jobs) == num_jobs def test_interrupt_multi_experiment_unblocked_jobs(test_dir): @@ -106,20 +107,22 @@ def test_interrupt_multi_experiment_unblocked_jobs(test_dir): ) ensemble.set_path(test_dir) jobs_per_experiment[i] = 1 + len(ensemble) - try: - pid = os.getpid() - keyboard_interrupt_thread = Thread( - name="sigint_thread", target=keyboard_interrupt, args=(pid,) - ) - keyboard_interrupt_thread.start() + + pid = os.getpid() + keyboard_interrupt_thread = Thread( + name="sigint_thread", target=keyboard_interrupt, args=(pid,) + ) + keyboard_interrupt_thread.start() + + with pytest.raises(KeyboardInterrupt): for experiment in experiments: experiment.start(model, ensemble, block=False, kill_on_interrupt=True) - time.sleep(9) # since jobs aren't blocked, wait for SIGINT - except KeyboardInterrupt: - time.sleep(2) # allow time for jobs to be stopped - for i, experiment in enumerate(experiments): - active_jobs = experiment._control._jobs.jobs - active_db_jobs = experiment._control._jobs.db_jobs - completed_jobs = experiment._control._jobs.completed - assert len(active_jobs) + len(active_db_jobs) == 0 - assert len(completed_jobs) == jobs_per_experiment[i] + keyboard_interrupt_thread.join() # since jobs aren't blocked, wait for SIGINT + + time.sleep(2) # allow time for jobs to be stopped + for i, experiment in enumerate(experiments): + active_jobs = experiment._control._jobs.jobs + active_db_jobs = experiment._control._jobs.db_jobs + completed_jobs = experiment._control._jobs.completed + assert len(active_jobs) + len(active_db_jobs) == 0 + assert len(completed_jobs) == jobs_per_experiment[i]