From 9e32e8ebade4cf8253c224628fae9adb895ce2a7 Mon Sep 17 00:00:00 2001 From: Max Fischer Date: Tue, 19 Mar 2024 11:51:51 +0100 Subject: [PATCH] Cooperative signal handling (#1600) * test desired signal behaviour * capture and restore signal handlers * ruff * checks * test asyncio handlers * add note on signal handler handling * remove legacy signal raising * test SIGBREAK on windows * remove test guard * include convered branch * Update docs/index.md * Update docs/index.md --------- Co-authored-by: Marcelo Trylesinski --- tests/test_server.py | 67 ++++++++++++++++++++++++++++++++++++++++++++ uvicorn/server.py | 39 +++++++++++++++++--------- 2 files changed, 92 insertions(+), 14 deletions(-) create mode 100644 tests/test_server.py diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 000000000..dac8fb026 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import asyncio +import contextlib +import signal +import sys +from typing import Callable, ContextManager, Generator + +import pytest + +from uvicorn.config import Config +from uvicorn.server import Server + + +# asyncio does NOT allow raising in signal handlers, so to detect +# raised signals raised a mutable `witness` receives the signal +@contextlib.contextmanager +def capture_signal_sync(sig: signal.Signals) -> Generator[list[int], None, None]: + """Replace `sig` handling with a normal exception via `signal""" + witness: list[int] = [] + original_handler = signal.signal(sig, lambda signum, frame: witness.append(signum)) + yield witness + signal.signal(sig, original_handler) + + +@contextlib.contextmanager +def capture_signal_async(sig: signal.Signals) -> Generator[list[int], None, None]: # pragma: py-win32 + """Replace `sig` handling with a normal exception via `asyncio""" + witness: list[int] = [] + original_handler = signal.getsignal(sig) + asyncio.get_running_loop().add_signal_handler(sig, witness.append, sig) + yield witness + signal.signal(sig, original_handler) + + +async def dummy_app(scope, receive, send): # pragma: py-win32 + pass + + +if sys.platform == "win32": + signals = [signal.SIGBREAK] + signal_captures = [capture_signal_sync] +else: + signals = [signal.SIGTERM, signal.SIGINT] + signal_captures = [capture_signal_sync, capture_signal_async] + + +@pytest.mark.anyio +@pytest.mark.parametrize("exception_signal", signals) +@pytest.mark.parametrize("capture_signal", signal_captures) +async def test_server_interrupt( + exception_signal: signal.Signals, capture_signal: Callable[[signal.Signals], ContextManager[None]] +): # pragma: py-win32 + """Test interrupting a Server that is run explicitly inside asyncio""" + + async def interrupt_running(srv: Server): + while not srv.started: + await asyncio.sleep(0.01) + signal.raise_signal(exception_signal) + + server = Server(Config(app=dummy_app, loop="asyncio")) + asyncio.create_task(interrupt_running(server)) + with capture_signal(exception_signal) as witness: + await server.serve() + assert witness + # set by the server's graceful exit handler + assert server.should_exit diff --git a/uvicorn/server.py b/uvicorn/server.py index c7645f3ce..bfce1b1b1 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import logging import os import platform @@ -11,7 +12,7 @@ import time from email.utils import formatdate from types import FrameType -from typing import TYPE_CHECKING, Sequence, Union +from typing import TYPE_CHECKING, Generator, Sequence, Union import click @@ -57,11 +58,17 @@ def __init__(self, config: Config) -> None: self.force_exit = False self.last_notified = 0.0 + self._captured_signals: list[int] = [] + def run(self, sockets: list[socket.socket] | None = None) -> None: self.config.setup_event_loop() return asyncio.run(self.serve(sockets=sockets)) async def serve(self, sockets: list[socket.socket] | None = None) -> None: + with self.capture_signals(): + await self._serve(sockets) + + async def _serve(self, sockets: list[socket.socket] | None = None) -> None: process_id = os.getpid() config = self.config @@ -70,8 +77,6 @@ async def serve(self, sockets: list[socket.socket] | None = None) -> None: self.lifespan = config.lifespan_class(config) - self.install_signal_handlers() - message = "Started server process [%d]" color_message = "Started server process [" + click.style("%d", fg="cyan") + "]" logger.info(message, process_id, extra={"color_message": color_message}) @@ -302,22 +307,28 @@ async def _wait_tasks_to_complete(self) -> None: for server in self.servers: await server.wait_closed() - def install_signal_handlers(self) -> None: + @contextlib.contextmanager + def capture_signals(self) -> Generator[None, None, None]: + # Signals can only be listened to from the main thread. if threading.current_thread() is not threading.main_thread(): - # Signals can only be listened to from the main thread. + yield return - - loop = asyncio.get_event_loop() - + # always use signal.signal, even if loop.add_signal_handler is available + # this allows to restore previous signal handlers later on + original_handlers = {sig: signal.signal(sig, self.handle_exit) for sig in HANDLED_SIGNALS} try: - for sig in HANDLED_SIGNALS: - loop.add_signal_handler(sig, self.handle_exit, sig, None) - except NotImplementedError: # pragma: no cover - # Windows - for sig in HANDLED_SIGNALS: - signal.signal(sig, self.handle_exit) + yield + finally: + for sig, handler in original_handlers.items(): + signal.signal(sig, handler) + # If we did gracefully shut down due to a signal, try to + # trigger the expected behaviour now; multiple signals would be + # done LIFO, see https://stackoverflow.com/questions/48434964 + for captured_signal in reversed(self._captured_signals): + signal.raise_signal(captured_signal) def handle_exit(self, sig: int, frame: FrameType | None) -> None: + self._captured_signals.append(sig) if self.should_exit and sig == signal.SIGINT: self.force_exit = True else: