Skip to content

Commit

Permalink
Enhanced Signal Management (#535)
Browse files Browse the repository at this point in the history
Fixes unfalsifiable test that tests SmartSim's custom SIGINT signal
handler. Adds infrastructure to make the test pass again.

[ committed by @MattToast ]
[ reviewed by @ashao ]
  • Loading branch information
MattToast authored Apr 5, 2024
1 parent 619d64b commit 505de50
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 32 deletions.
21 changes: 21 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import os
import pathlib
import shutil
import signal
import sys
import tempfile
import typing as t
Expand Down Expand Up @@ -206,6 +207,26 @@ def alloc_specs() -> t.Dict[str, t.Any]:
return specs


def _reset_signal(signalnum: int):
"""SmartSim will set/overwrite signals on occasion. This function will
return a generator that can be used as a fixture to automatically reset the
signal handler to what it was at the beginning of the test suite to keep
tests atomic.
"""
original = signal.getsignal(signalnum)

def _reset():
yield
signal.signal(signalnum, original)

return _reset


_reset_signal_interrupt = pytest.fixture(
_reset_signal(signal.SIGINT), autouse=True, scope="function"
)


@pytest.fixture
def wlmutils() -> t.Type[WLMUtils]:
return WLMUtils
Expand Down
7 changes: 7 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Description
- Fix publishing of development docs
- Update Experiment API typing
- Minor enhancements to test suite
- Improve SmartSim experiment signal handlers

Detailed Notes

Expand Down Expand Up @@ -82,12 +83,18 @@ Detailed Notes
undefined. (SmartSim-PR521_)
- Remove previously deprecated behavior present in test suite on machines with
Slurm and Open MPI. (SmartSim-PR520_)
- When calling ``Experiment.start`` SmartSim would register a signal handler
that would capture an interrupt signal (^C) to kill any jobs launched through
its ``JobManager``. This would replace the default (or user defined) signal
handler. SmartSim will now attempt to kill any launched jobs before calling
the previously registered signal handler. (SmartSim-PR535_)

.. _SmartSim-PR538: https://github.com/CrayLabs/SmartSim/pull/538
.. _SmartSim-PR537: https://github.com/CrayLabs/SmartSim/pull/537
.. _SmartSim-PR498: https://github.com/CrayLabs/SmartSim/pull/498
.. _SmartSim-PR460: https://github.com/CrayLabs/SmartSim/pull/460
.. _SmartSim-PR512: https://github.com/CrayLabs/SmartSim/pull/512
.. _SmartSim-PR535: https://github.com/CrayLabs/SmartSim/pull/535
.. _SmartSim-PR529: https://github.com/CrayLabs/SmartSim/pull/529
.. _SmartSim-PR522: https://github.com/CrayLabs/SmartSim/pull/522
.. _SmartSim-PR521: https://github.com/CrayLabs/SmartSim/pull/521
Expand Down
15 changes: 12 additions & 3 deletions smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
from smartsim._core.utils.network import get_ip_from_host

from ..._core.launcher.step import Step
from ..._core.utils.helpers import unpack_colo_db_identifier, unpack_db_identifier
from ..._core.utils.helpers import (
SignalInterceptionStack,
unpack_colo_db_identifier,
unpack_db_identifier,
)
from ..._core.utils.redis import (
db_is_active,
set_ml_model,
Expand Down Expand Up @@ -71,6 +75,8 @@
from .manifest import LaunchedManifest, LaunchedManifestBuilder, Manifest

if t.TYPE_CHECKING:
from types import FrameType

from ..utils.serialize import TStepLaunchMetaData


Expand Down Expand Up @@ -113,8 +119,11 @@ def start(
execution of all jobs.
"""
self._jobs.kill_on_interrupt = kill_on_interrupt

# register custom signal handler for ^C (SIGINT)
signal.signal(signal.SIGINT, self._jobs.signal_interrupt)
SignalInterceptionStack.get(signal.SIGINT).push_unique(
self._jobs.signal_interrupt
)
launched = self._launch(exp_name, exp_path, manifest)

# start the job manager thread if not already started
Expand All @@ -132,7 +141,7 @@ def start(
# block until all non-database jobs are complete
if block:
# poll handles its own keyboard interrupt as
# it may be called seperately
# it may be called separately
self.poll(5, True, kill_on_interrupt=kill_on_interrupt)

@property
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/control/jobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ def set_db_hosts(self, orchestrator: Orchestrator) -> None:
self.db_jobs[dbnode.name].hosts = dbnode.hosts

def signal_interrupt(self, signo: int, _frame: t.Optional[FrameType]) -> None:
"""Custom handler for whenever SIGINT is received"""
if not signo:
logger.warning("Received SIGINT with no signal number")
"""Custom handler for whenever SIGINT is received"""
if self.actively_monitoring and len(self) > 0:
if self.kill_on_interrupt:
for _, job in self().items():
Expand Down
98 changes: 98 additions & 0 deletions smartsim/_core/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
A file of helper functions for SmartSim
"""
import base64
import collections.abc
import os
import signal
import typing as t
import uuid
from datetime import datetime
Expand All @@ -38,6 +40,12 @@

from smartsim._core._install.builder import TRedisAIBackendStr as _TRedisAIBackendStr

if t.TYPE_CHECKING:
from types import FrameType


_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object]


def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]:
"""Unpack the unformatted database identifier
Expand Down Expand Up @@ -302,3 +310,93 @@ def decode_cmd(encoded_cmd: str) -> t.List[str]:
cleaned_cmd = decoded_cmd.decode("ascii").split("|")

return cleaned_cmd


# TODO: Remove the ``type: ignore`` comment here when Python 3.8 support is dropped
# ``collections.abc.Collection`` is not subscriptable until Python 3.9
@t.final
class SignalInterceptionStack(collections.abc.Collection): # type: ignore[type-arg]
"""Registers a stack of unique callables to be called when a signal is
received before calling the original signal handler.
"""

def __init__(
self,
signalnum: int,
callbacks: t.Optional[t.Iterable[_TSignalHandlerFn]] = None,
) -> None:
"""Set up a ``SignalInterceptionStack`` for particular signal number.
.. note::
This class typically should not be instanced directly as it will
change the registered signal handler regardless of if a signal
interception stack is already present. Instead, it is generally
best to create or get a signal interception stack for a particular
signal number via the `get` factory method.
:param signalnum: The signal number to intercept
:type signalnum: int
:param callbacks: A iterable of functions to call upon receiving the signal
:type callbacks: t.Iterable[_TSignalHandlerFn] | None
"""
self._callbacks = list(callbacks) if callbacks else []
self._original = signal.signal(signalnum, self)

def __call__(self, signalnum: int, frame: t.Optional["FrameType"]) -> None:
"""Handle the signal on which the interception stack was registered.
End by calling the originally registered signal hander (if present).
:param frame: The current stack frame
:type frame: FrameType | None
"""
for fn in self:
fn(signalnum, frame)
if callable(self._original):
self._original(signalnum, frame)

def __contains__(self, obj: object) -> bool:
return obj in self._callbacks

def __iter__(self) -> t.Iterator[_TSignalHandlerFn]:
return reversed(self._callbacks)

def __len__(self) -> int:
return len(self._callbacks)

@classmethod
def get(cls, signalnum: int) -> "SignalInterceptionStack":
"""Fetch an existing ``SignalInterceptionStack`` or create a new one
for a particular signal number.
:param signalnum: The singal number of the signal interception stack
should be registered
:type signalnum: int
:returns: The existing or created signal interception stack
:rtype: SignalInterceptionStack
"""
handler = signal.getsignal(signalnum)
if isinstance(handler, cls):
return handler
return cls(signalnum, [])

def push(self, fn: _TSignalHandlerFn) -> None:
"""Add a callback to the signal interception stack.
:param fn: A callable to add to the unique signal stack
:type fn: _TSignalHandlerFn
"""
self._callbacks.append(fn)

def push_unique(self, fn: _TSignalHandlerFn) -> bool:
"""Add a callback to the signal interception stack if and only if the
callback is not already present.
:param fn: A callable to add to the unique signal stack
:type fn: _TSignalHandlerFn
:returns: True if the callback was added, False if the callback was
already present
:rtype: bool
"""
if did_push := fn not in self:
self.push(fn)
return did_push
115 changes: 115 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import collections
import signal

import pytest

from smartsim._core.utils import helpers
Expand Down Expand Up @@ -68,3 +71,115 @@ def test_encode_raises_on_empty():
def test_decode_raises_on_empty():
with pytest.raises(ValueError):
helpers.decode_cmd("")


class MockSignal:
def __init__(self):
self.signal_handlers = collections.defaultdict(lambda: signal.SIG_IGN)

def signal(self, signalnum, handler):
orig = self.getsignal(signalnum)
self.signal_handlers[signalnum] = handler
return orig

def getsignal(self, signalnum):
return self.signal_handlers[signalnum]


@pytest.fixture
def mock_signal(monkeypatch):
mock_signal = MockSignal()
monkeypatch.setattr(helpers, "signal", mock_signal)
yield mock_signal


def test_signal_intercept_stack_will_register_itself_with_callback_fn(mock_signal):
callback = lambda num, frame: ...
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
stack.push(callback)
assert isinstance(stack, helpers.SignalInterceptionStack)
assert stack is mock_signal.signal_handlers[signal.NSIG]
assert len(stack) == 1
assert list(stack)[0] == callback


def test_signal_intercept_stack_keeps_track_of_previous_handlers(mock_signal):
default_handler = lambda num, frame: ...
mock_signal.signal_handlers[signal.NSIG] = default_handler
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
stack.push(lambda n, f: ...)
assert stack._original is default_handler


def test_signal_intercept_stacks_are_registered_per_signal_number(mock_signal):
handler = lambda num, frame: ...
stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG)
stack_1.push(handler)
stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG + 1)
stack_2.push(handler)

assert mock_signal.signal_handlers[signal.NSIG] is stack_1
assert mock_signal.signal_handlers[signal.NSIG + 1] is stack_2
assert stack_1 is not stack_2
assert list(stack_1) == list(stack_2) == [handler]


def test_signal_intercept_handlers_will_not_overwrite_if_handler_already_exists(
mock_signal,
):
handler_1 = lambda num, frame: ...
handler_2 = lambda num, frame: ...
stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG)
stack_1.push(handler_1)
stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG)
stack_2.push(handler_2)
assert stack_1 is stack_2 is mock_signal.signal_handlers[signal.NSIG]
assert list(stack_1) == [handler_2, handler_1]


def test_signal_intercept_stack_can_add_multiple_instances_of_the_same_handler(
mock_signal,
):
handler = lambda num, frame: ...
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
stack.push(handler)
stack.push(handler)
assert list(stack) == [handler, handler]


def test_signal_intercept_stack_enforces_that_unique_push_handlers_are_unique(
mock_signal,
):
handler = lambda num, frame: ...
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
assert stack.push_unique(handler)
assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(handler)
assert list(stack) == [handler]


def test_signal_intercept_stack_enforces_that_unique_push_method_handlers_are_unique(
mock_signal,
):
class C:
def fn(num, frame): ...

c1 = C()
c2 = C()
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
stack.push_unique(c1.fn)
assert helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c2.fn)
assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c1.fn)
assert list(stack) == [c2.fn, c1.fn]


def test_signal_handler_calls_functions_in_reverse_order(mock_signal):
called_list = []
default = lambda num, frame: called_list.append("default")
handler_1 = lambda num, frame: called_list.append("handler_1")
handler_2 = lambda num, frame: called_list.append("handler_2")

mock_signal.signal_handlers[signal.NSIG] = default
helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_1)
helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_2)
mock_signal.signal_handlers[signal.NSIG](signal.NSIG, None)
assert called_list == ["handler_2", "handler_1", "default"]
Loading

0 comments on commit 505de50

Please sign in to comment.