Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround asyncio signal handling on Unix #479

Merged
merged 4 commits into from
Jan 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 82 additions & 15 deletions launch/launch/utilities/signal_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,19 @@ class AsyncSafeSignalManager:
:func:`signal.signal`.
All signals received are forwarded to the previously setup file
descriptor, if any.

..warning::
Within (potentially nested) contexts, :func:`signal.set_wakeup_fd`
calls are intercepted such that the given file descriptor overrides
the previously setup file descriptor for the outermost manager.
This ensures the manager's chain of signal wakeup file descriptors
is not broken by third-party code or by asyncio itself in some platforms.
"""

__current = None # type: AsyncSafeSignalManager

__set_wakeup_fd = signal.set_wakeup_fd # type: Callable[[int], int]

def __init__(
self,
loop: asyncio.AbstractEventLoop
Expand All @@ -77,6 +88,7 @@ def __init__(

:param loop: event loop that will handle the signals.
"""
self.__parent = None # type: AsyncSafeSignalManager
self.__loop = loop # type: asyncio.AbstractEventLoop
self.__background_loop = None # type: Optional[asyncio.AbstractEventLoop]
self.__handlers = {} # type: dict
Expand All @@ -86,12 +98,31 @@ def __init__(
self.__rsock.setblocking(False)

def __enter__(self):
self.__add_signal_readers()
try:
self.__install_signal_writers()
except Exception:
self.__remove_signal_readers()
raise
self.__chain()
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
try:
try:
self.__uninstall_signal_writers()
finally:
self.__remove_signal_readers()
finally:
self.__unchain()

def __add_signal_readers(self):
try:
self.__loop.add_reader(self.__rsock.fileno(), self.__handle_signal)
except NotImplementedError:
# Some event loops, like the asyncio.ProactorEventLoop
# on Windows, do not support asynchronous socket reads.
# So we emulate it.
# Emulate it.
self.__background_loop = asyncio.SelectorEventLoop()
self.__background_loop.add_reader(
self.__rsock.fileno(),
Expand All @@ -102,29 +133,65 @@ def run_background_loop():
asyncio.set_event_loop(self.__background_loop)
self.__background_loop.run_forever()

self.__background_thread = threading.Thread(target=run_background_loop)
self.__background_thread = threading.Thread(
target=run_background_loop, daemon=True)
self.__background_thread.start()
self.__prev_wakeup_handle = signal.set_wakeup_fd(self.__wsock.fileno())
if self.__prev_wakeup_handle != -1 and is_winsock_handle(self.__prev_wakeup_handle):
# On Windows, os.write will fail on a WinSock handle. There is no WinSock API
# in the standard library either. Thus we wrap it in a socket.socket instance.
self.__prev_wakeup_handle = socket.socket(fileno=self.__prev_wakeup_handle)
return self

def __exit__(self, type_, value, traceback):
if isinstance(self.__prev_wakeup_handle, socket.socket):
# Detach (Windows) socket and retrieve the raw OS handle.
prev_wakeup_handle = self.__prev_wakeup_handle.fileno()
self.__prev_wakeup_handle.detach()
self.__prev_wakeup_handle = prev_wakeup_handle
assert self.__wsock.fileno() == signal.set_wakeup_fd(self.__prev_wakeup_handle)
def __remove_signal_readers(self):
if self.__background_loop:
self.__background_loop.call_soon_threadsafe(self.__background_loop.stop)
self.__background_thread.join()
self.__background_loop.close()
self.__background_loop = None
else:
self.__loop.remove_reader(self.__rsock.fileno())

def __install_signal_writers(self):
prev_wakeup_handle = self.__set_wakeup_fd(self.__wsock.fileno())
try:
self.__chain_wakeup_handle(prev_wakeup_handle)
except Exception:
own_wakeup_handle = self.__set_wakeup_fd(prev_wakeup_handle)
assert self.__wsock.fileno() == own_wakeup_handle
raise

def __uninstall_signal_writers(self):
prev_wakeup_handle = self.__chain_wakeup_handle(-1)
own_wakeup_handle = self.__set_wakeup_fd(prev_wakeup_handle)
assert self.__wsock.fileno() == own_wakeup_handle

def __chain(self):
self.__parent = AsyncSafeSignalManager.__current
AsyncSafeSignalManager.__current = self
if self.__parent is None:
# Do not trust signal.set_wakeup_fd calls within context.
# Overwrite handle at the start of the managers' chain.
def modified_set_wakeup_fd(signum):
if threading.current_thread() is not threading.main_thread():
raise ValueError(
'set_wakeup_fd only works in main'
' thread of the main interpreter'
)
return self.__chain_wakeup_handle(signum)
signal.set_wakeup_fd = modified_set_wakeup_fd

def __unchain(self):
if self.__parent is None:
signal.set_wakeup_fd = self.__set_wakeup_fd
AsyncSafeSignalManager.__current = self.__parent

def __chain_wakeup_handle(self, wakeup_handle):
prev_wakeup_handle = self.__prev_wakeup_handle
if isinstance(prev_wakeup_handle, socket.socket):
# Detach (Windows) socket and retrieve the raw OS handle.
prev_wakeup_handle = prev_wakeup_handle.detach()
if wakeup_handle != -1 and is_winsock_handle(wakeup_handle):
# On Windows, os.write will fail on a WinSock handle. There is no WinSock API
# in the standard library either. Thus we wrap it in a socket.socket instance.
wakeup_handle = socket.socket(fileno=wakeup_handle)
self.__prev_wakeup_handle = wakeup_handle
return prev_wakeup_handle

def __handle_signal(self):
while True:
try:
Expand Down
16 changes: 12 additions & 4 deletions launch/test/launch/utilities/test_signal_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ def _wrapper(*args, **kwargs):
SIGNAL = signal.SIGUSR1
ANOTHER_SIGNAL = signal.SIGUSR2

if not hasattr(signal, 'raise_signal'):
# Only available for Python 3.8+
def raise_signal(signum):
import os
os.kill(os.getpid(), signum)
else:
raise_signal = signal.raise_signal


@cap_signals(SIGNAL, ANOTHER_SIGNAL)
def test_async_safe_signal_manager():
Expand All @@ -70,7 +78,7 @@ def test_async_safe_signal_manager():
manager.handle(ANOTHER_SIGNAL, got_another_signal.set_result)

# Verify signal handling is working
loop.call_soon(signal.raise_signal, SIGNAL)
loop.call_soon(raise_signal, SIGNAL)
loop.run_until_complete(asyncio.wait(
[got_signal, got_another_signal],
return_when=asyncio.FIRST_COMPLETED,
Expand All @@ -84,22 +92,22 @@ def test_async_safe_signal_manager():
manager.handle(SIGNAL, None)

# Verify signal handler is no longer there
loop.call_soon(signal.raise_signal, SIGNAL)
loop.call_soon(raise_signal, SIGNAL)
loop.run_until_complete(asyncio.wait(
[got_another_signal], timeout=1.0
))
assert not got_another_signal.done()

# Signal handling is (now) inactive outside context
loop.call_soon(signal.raise_signal, ANOTHER_SIGNAL)
loop.call_soon(raise_signal, ANOTHER_SIGNAL)
loop.run_until_complete(asyncio.wait(
[got_another_signal], timeout=1.0
))
assert not got_another_signal.done()

# Managers' context may be re-entered
with manager:
loop.call_soon(signal.raise_signal, ANOTHER_SIGNAL)
loop.call_soon(raise_signal, ANOTHER_SIGNAL)
loop.run_until_complete(asyncio.wait(
[got_another_signal], timeout=1.0
))
Expand Down