Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhanced Signal Management #535

Merged
merged 11 commits into from
Apr 5, 2024
21 changes: 21 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import os
import pathlib
import shutil
import signal
import sys
import tempfile
import typing as t
Expand Down Expand Up @@ -206,6 +207,26 @@ def alloc_specs() -> t.Dict[str, t.Any]:
return specs


def _reset_signal(signalnum: int):
"""SmartSim will set/overwrite signals on occasion. This function will
return a generator that can be used as a fixture to automatically reset the
signal handler to what it was at the beginning of the test suite to keep
tests atomic.
"""
original = signal.getsignal(signalnum)

def _reset():
yield
signal.signal(signalnum, original)

return _reset


_reset_signal_interrupt = pytest.fixture(
_reset_signal(signal.SIGINT), autouse=True, scope="function"
)


@pytest.fixture
def wlmutils() -> t.Type[WLMUtils]:
return WLMUtils
Expand Down
7 changes: 7 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Description
- Fix publishing of development docs
- Update Experiment API typing
- Minor enhancements to test suite
- Improve SmartSim experiment signal handlers

Detailed Notes

Expand Down Expand Up @@ -76,11 +77,17 @@ Detailed Notes
undefined. (SmartSim-PR521_)
- Remove previously deprecated behavior present in test suite on machines with
Slurm and Open MPI. (SmartSim-PR520_)
- When calling ``Experiment.start`` SmartSim would register a signal handler
that would capture an interrupt signal (^C) to kill any jobs launched through
its ``JobManager``. This would replace the default (or user defined) signal
handler. SmartSim will now attempt to kill any launched jobs before calling
the previously registered signal handler. (SmartSim-PR535_)

.. _SmartSim-PR537: https://github.com/CrayLabs/SmartSim/pull/537
.. _SmartSim-PR498: https://github.com/CrayLabs/SmartSim/pull/498
.. _SmartSim-PR460: https://github.com/CrayLabs/SmartSim/pull/460
.. _SmartSim-PR512: https://github.com/CrayLabs/SmartSim/pull/512
.. _SmartSim-PR535: https://github.com/CrayLabs/SmartSim/pull/535
.. _SmartSim-PR529: https://github.com/CrayLabs/SmartSim/pull/529
.. _SmartSim-PR522: https://github.com/CrayLabs/SmartSim/pull/522
.. _SmartSim-PR521: https://github.com/CrayLabs/SmartSim/pull/521
Expand Down
15 changes: 12 additions & 3 deletions smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
from smartsim._core.utils.network import get_ip_from_host

from ..._core.launcher.step import Step
from ..._core.utils.helpers import unpack_colo_db_identifier, unpack_db_identifier
from ..._core.utils.helpers import (
SignalInterceptionStack,
unpack_colo_db_identifier,
unpack_db_identifier,
)
from ..._core.utils.redis import (
db_is_active,
set_ml_model,
Expand Down Expand Up @@ -71,6 +75,8 @@
from .manifest import LaunchedManifest, LaunchedManifestBuilder, Manifest

if t.TYPE_CHECKING:
from types import FrameType

from ..utils.serialize import TStepLaunchMetaData


Expand Down Expand Up @@ -113,8 +119,11 @@ def start(
execution of all jobs.
"""
self._jobs.kill_on_interrupt = kill_on_interrupt

# register custom signal handler for ^C (SIGINT)
signal.signal(signal.SIGINT, self._jobs.signal_interrupt)
SignalInterceptionStack.get(signal.SIGINT).push_unique(
self._jobs.signal_interrupt
)
launched = self._launch(exp_name, exp_path, manifest)

# start the job manager thread if not already started
Expand All @@ -132,7 +141,7 @@ def start(
# block until all non-database jobs are complete
if block:
# poll handles its own keyboard interrupt as
# it may be called seperately
# it may be called separately
self.poll(5, True, kill_on_interrupt=kill_on_interrupt)

@property
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/control/jobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ def set_db_hosts(self, orchestrator: Orchestrator) -> None:
self.db_jobs[dbnode.name].hosts = dbnode.hosts

def signal_interrupt(self, signo: int, _frame: t.Optional[FrameType]) -> None:
"""Custom handler for whenever SIGINT is received"""
if not signo:
logger.warning("Received SIGINT with no signal number")
"""Custom handler for whenever SIGINT is received"""
if self.actively_monitoring and len(self) > 0:
if self.kill_on_interrupt:
for _, job in self().items():
Expand Down
98 changes: 98 additions & 0 deletions smartsim/_core/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
A file of helper functions for SmartSim
"""
import base64
import collections.abc
import os
import signal
import typing as t
import uuid
from datetime import datetime
Expand All @@ -38,6 +40,12 @@

from smartsim._core._install.builder import TRedisAIBackendStr as _TRedisAIBackendStr

if t.TYPE_CHECKING:
from types import FrameType


_TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object]


def unpack_db_identifier(db_id: str, token: str) -> t.Tuple[str, str]:
"""Unpack the unformatted database identifier
Expand Down Expand Up @@ -302,3 +310,93 @@ def decode_cmd(encoded_cmd: str) -> t.List[str]:
cleaned_cmd = decoded_cmd.decode("ascii").split("|")

return cleaned_cmd


# TODO: Remove the ``type: ignore`` comment here when Python 3.8 support is dropped
# ``collections.abc.Collection`` is not subscriptable until Python 3.9
@t.final
class SignalInterceptionStack(collections.abc.Collection): # type: ignore[type-arg]
"""Registers a stack of unique callables to be called when a signal is
received before calling the original signal handler.
"""

def __init__(
self,
signalnum: int,
callbacks: t.Optional[t.Iterable[_TSignalHandlerFn]] = None,
) -> None:
"""Set up a ``SignalInterceptionStack`` for particular signal number.
.. note::
This class typically should not be instanced directly as it will
change the registered signal handler regardless of if a signal
interception stack is already present. Instead, it is generally
best to create or get a signal interception stack for a particular
signal number via the `get` factory method.
:param signalnum: The signal number to intercept
:type signalnum: int
:param callbacks: A iterable of functions to call upon receiving the signal
:type callbacks: t.Iterable[_TSignalHandlerFn] | None
"""
self._callbacks = list(callbacks) if callbacks else []
self._original = signal.signal(signalnum, self)

def __call__(self, signalnum: int, frame: t.Optional["FrameType"]) -> None:
"""Handle the signal on which the interception stack was registered.
End by calling the originally registered signal hander (if present).
:param frame: The current stack frame
:type frame: FrameType | None
"""
for fn in self:
fn(signalnum, frame)
if callable(self._original):
self._original(signalnum, frame)

def __contains__(self, obj: object) -> bool:
return obj in self._callbacks

def __iter__(self) -> t.Iterator[_TSignalHandlerFn]:
return reversed(self._callbacks)

def __len__(self) -> int:
return len(self._callbacks)

@classmethod
def get(cls, signalnum: int) -> "SignalInterceptionStack":
"""Fetch an existing ``SignalInterceptionStack`` or create a new one
for a particular signal number.
:param signalnum: The singal number of the signal interception stack
should be registered
:type signalnum: int
:returns: The existing or created signal interception stack
:rtype: SignalInterceptionStack
"""
handler = signal.getsignal(signalnum)
if isinstance(handler, cls):
return handler
return cls(signalnum, [])

def push(self, fn: _TSignalHandlerFn) -> None:
"""Add a callback to the signal interception stack.
:param fn: A callable to add to the unique signal stack
:type fn: _TSignalHandlerFn
"""
self._callbacks.append(fn)

def push_unique(self, fn: _TSignalHandlerFn) -> bool:
"""Add a callback to the signal interception stack if and only if the
callback is not already present.
:param fn: A callable to add to the unique signal stack
:type fn: _TSignalHandlerFn
:returns: True if the callback was added, False if the callback was
already present
:rtype: bool
"""
if did_push := fn not in self:
self.push(fn)
return did_push
115 changes: 115 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import collections
import signal

import pytest

from smartsim._core.utils import helpers
Expand Down Expand Up @@ -68,3 +71,115 @@ def test_encode_raises_on_empty():
def test_decode_raises_on_empty():
with pytest.raises(ValueError):
helpers.decode_cmd("")


class MockSignal:
def __init__(self):
self.signal_handlers = collections.defaultdict(lambda: signal.SIG_IGN)

def signal(self, signalnum, handler):
orig = self.getsignal(signalnum)
self.signal_handlers[signalnum] = handler
return orig

def getsignal(self, signalnum):
return self.signal_handlers[signalnum]


@pytest.fixture
def mock_signal(monkeypatch):
mock_signal = MockSignal()
monkeypatch.setattr(helpers, "signal", mock_signal)
yield mock_signal


def test_signal_intercept_stack_will_register_itself_with_callback_fn(mock_signal):
callback = lambda num, frame: ...
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
stack.push(callback)
assert isinstance(stack, helpers.SignalInterceptionStack)
assert stack is mock_signal.signal_handlers[signal.NSIG]
assert len(stack) == 1
assert list(stack)[0] == callback


def test_signal_intercept_stack_keeps_track_of_previous_handlers(mock_signal):
default_handler = lambda num, frame: ...
mock_signal.signal_handlers[signal.NSIG] = default_handler
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
stack.push(lambda n, f: ...)
assert stack._original is default_handler


def test_signal_intercept_stacks_are_registered_per_signal_number(mock_signal):
handler = lambda num, frame: ...
stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG)
stack_1.push(handler)
stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG + 1)
stack_2.push(handler)

assert mock_signal.signal_handlers[signal.NSIG] is stack_1
assert mock_signal.signal_handlers[signal.NSIG + 1] is stack_2
assert stack_1 is not stack_2
assert list(stack_1) == list(stack_2) == [handler]


def test_signal_intercept_handlers_will_not_overwrite_if_handler_already_exists(
mock_signal,
):
handler_1 = lambda num, frame: ...
handler_2 = lambda num, frame: ...
stack_1 = helpers.SignalInterceptionStack.get(signal.NSIG)
stack_1.push(handler_1)
stack_2 = helpers.SignalInterceptionStack.get(signal.NSIG)
stack_2.push(handler_2)
assert stack_1 is stack_2 is mock_signal.signal_handlers[signal.NSIG]
assert list(stack_1) == [handler_2, handler_1]


def test_signal_intercept_stack_can_add_multiple_instances_of_the_same_handler(
mock_signal,
):
handler = lambda num, frame: ...
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
stack.push(handler)
stack.push(handler)
assert list(stack) == [handler, handler]


def test_signal_intercept_stack_enforces_that_unique_push_handlers_are_unique(
mock_signal,
):
handler = lambda num, frame: ...
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
assert stack.push_unique(handler)
assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(handler)
assert list(stack) == [handler]


def test_signal_intercept_stack_enforces_that_unique_push_method_handlers_are_unique(
mock_signal,
):
class C:
def fn(num, frame): ...

c1 = C()
c2 = C()
stack = helpers.SignalInterceptionStack.get(signal.NSIG)
stack.push_unique(c1.fn)
assert helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c2.fn)
assert not helpers.SignalInterceptionStack.get(signal.NSIG).push_unique(c1.fn)
assert list(stack) == [c2.fn, c1.fn]


def test_signal_handler_calls_functions_in_reverse_order(mock_signal):
called_list = []
default = lambda num, frame: called_list.append("default")
handler_1 = lambda num, frame: called_list.append("handler_1")
handler_2 = lambda num, frame: called_list.append("handler_2")

mock_signal.signal_handlers[signal.NSIG] = default
helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_1)
helpers.SignalInterceptionStack.get(signal.NSIG).push(handler_2)
mock_signal.signal_handlers[signal.NSIG](signal.NSIG, None)
assert called_list == ["handler_2", "handler_1", "default"]
Loading
Loading