-
Notifications
You must be signed in to change notification settings - Fork 8
feat: add graceful shutdown with signal handling #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import functools | ||
| import os | ||
| import signal | ||
| import threading | ||
| import time | ||
| import traceback | ||
|
|
@@ -39,6 +40,9 @@ | |
| SKIP_EMPTY_LOGS = os.getenv("ISOLATE_SKIP_EMPTY_LOGS") == "1" | ||
| MAX_GRPC_WAIT_TIMEOUT = float(os.getenv("ISOLATE_MAX_GRPC_WAIT_TIMEOUT", "10.0")) | ||
|
|
||
| # Graceful shutdown timeout in seconds (default: 60 minutes = 3600 seconds) | ||
| SHUTDOWN_GRACE_PERIOD = int(os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "3600")) | ||
|
|
||
| # Whether to inherit all the packages from the current environment or not. | ||
| INHERIT_FROM_LOCAL = os.getenv("ISOLATE_INHERIT_FROM_LOCAL") == "1" | ||
|
|
||
|
|
@@ -107,6 +111,8 @@ def terminate(self) -> None: | |
|
|
||
| @dataclass | ||
| class BridgeManager: | ||
| """Manages a pool of reusable gRPC agents for isolated Python environments.""" | ||
|
|
||
| _agent_access_lock: threading.Lock = field(default_factory=threading.Lock) | ||
| _agents: dict[tuple[Any, ...], list[RunnerAgent]] = field( | ||
| default_factory=lambda: defaultdict(list) | ||
|
|
@@ -150,8 +156,12 @@ def _allocate_new_agent( | |
|
|
||
| bound_context = ExitStack() | ||
| stub = bound_context.enter_context( | ||
| connection._establish_bridge(max_wait_timeout=MAX_GRPC_WAIT_TIMEOUT) | ||
| connection._establish_bridge( | ||
| max_wait_timeout=MAX_GRPC_WAIT_TIMEOUT, | ||
| shutdown_grace_period=SHUTDOWN_GRACE_PERIOD, | ||
| ) | ||
| ) | ||
|
|
||
| return RunnerAgent(stub, queue, bound_context) | ||
|
|
||
| def _identify(self, connection: LocalPythonGRPC) -> tuple[Any, ...]: | ||
|
|
@@ -163,7 +173,7 @@ def _identify(self, connection: LocalPythonGRPC) -> tuple[Any, ...]: | |
| def __enter__(self) -> BridgeManager: | ||
| return self | ||
|
|
||
| def __exit__(self, *exc_info: Any) -> None: | ||
| def __exit__(self, *_: Any) -> None: | ||
| for agents in self._agents.values(): | ||
| for agent in agents: | ||
| agent.terminate() | ||
|
|
@@ -194,13 +204,20 @@ def stream_logs(self) -> bool: | |
|
|
||
| @dataclass | ||
| class IsolateServicer(definitions.IsolateServicer): | ||
| """gRPC service handler for executing Python functions in isolated environments. | ||
|
|
||
| Orchestrates task execution by creating environments, managing agent connections | ||
| via BridgeManager, and streaming real-time results back to clients. Handles | ||
| graceful shutdown with signal propagation and agent cleanup.""" | ||
|
|
||
| bridge_manager: BridgeManager | ||
| default_settings: IsolateSettings = field(default_factory=IsolateSettings) | ||
| background_tasks: dict[str, RunTask] = field(default_factory=dict) | ||
|
|
||
| _thread_pool: futures.ThreadPoolExecutor = field( | ||
| default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS) | ||
| ) | ||
| _shutting_down: bool = field(default=False) | ||
|
|
||
| def _run_task(self, task: RunTask) -> Iterator[definitions.PartialRunResult]: | ||
| messages: Queue[definitions.PartialRunResult] = Queue() | ||
|
|
@@ -472,10 +489,30 @@ def abort_with_msg( | |
| return None | ||
|
|
||
| def cancel_tasks(self): | ||
| """Cancel all tasks with optional graceful shutdown""" | ||
| tasks_copy = self.background_tasks.copy() | ||
| for task in tasks_copy.values(): | ||
| task.cancel() | ||
|
|
||
| def initiate_shutdown(self) -> None: | ||
| if self._shutting_down: | ||
| return | ||
| self._shutting_down = True | ||
|
|
||
| print(f"Initiating shutdown with grace period: {SHUTDOWN_GRACE_PERIOD}s") | ||
| # Collect all active agents from running tasks | ||
| shutdown_threads = [] | ||
| for task in self.background_tasks.values(): | ||
| if task.agent is not None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and if it is None you should probably cancel it directly |
||
| thread = threading.Thread(target=task.agent.terminate) | ||
| thread.start() | ||
| shutdown_threads.append(thread) | ||
|
|
||
| # Wait for all agents to terminate | ||
| for thread in shutdown_threads: | ||
| thread.join() | ||
| print("All active agents have been shut down") | ||
|
|
||
|
|
||
| def _proxy_to_queue( | ||
| queue: Queue, | ||
|
|
@@ -583,9 +620,10 @@ def wrapper(method_impl): | |
| def _wrapper(request: Any, context: grpc.ServicerContext) -> Any: | ||
| def termination() -> None: | ||
| if is_run: | ||
| print("Stopping server since run is finished") | ||
| # Stop the server after the Run task is finished | ||
| self.server.stop(grace=0.1) | ||
| # Trigger graceful shutdown instead of immediate stop | ||
| self.servicer.initiate_shutdown() | ||
| if self._server: | ||
| self._server.stop(grace=0.1) | ||
|
|
||
| elif is_submit: | ||
| # Wait until the task_id is assigned | ||
|
|
@@ -606,11 +644,11 @@ def termination() -> None: | |
| "Task future was not assigned in time." | ||
| ) | ||
|
|
||
| def _stop(*args): | ||
| def _stop(*_): | ||
| # Small sleep to make sure the cancellation is processed | ||
| time.sleep(0.1) | ||
| print("Stopping server since the task is finished") | ||
| self.server.stop(grace=0.1) | ||
| self.servicer.initiate_shutdown() | ||
|
|
||
| # Add a callback which will stop the server | ||
| # after the task is finished | ||
|
|
@@ -629,6 +667,16 @@ def _stop(*args): | |
| return wrap_server_method_handler(wrapper, handler) | ||
|
|
||
|
|
||
| def register_signal_handlers(servicer: IsolateServicer, server: grpc.Server) -> None: | ||
| def handle_signal(signum, frame): | ||
| print(f"Received signal {signum}, initiating shutdown...") | ||
| servicer.initiate_shutdown() | ||
| server.stop(grace=0.1) | ||
|
|
||
| signal.signal(signal.SIGINT, handle_signal) | ||
| signal.signal(signal.SIGTERM, handle_signal) | ||
|
Comment on lines
+670
to
+677
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @chamini2 Is this what you're suggesting here #174 (comment)? |
||
|
|
||
|
|
||
| def main(argv: list[str] | None = None) -> None: | ||
| parser = ArgumentParser() | ||
| parser.add_argument("--host", default="0.0.0.0") | ||
|
|
@@ -664,18 +712,18 @@ def main(argv: list[str] | None = None) -> None: | |
|
|
||
| with BridgeManager() as bridge_manager: | ||
| servicer = IsolateServicer(bridge_manager) | ||
|
|
||
| register_signal_handlers(servicer, server) | ||
| for interceptor in interceptors: | ||
| interceptor.register_servicer(servicer) | ||
|
|
||
| definitions.register_isolate(servicer, server) | ||
| health.register_health(HealthServicer(), server) | ||
|
|
||
| server.add_insecure_port("[::]:50001") | ||
| print("Started listening at localhost:50001") | ||
| server.add_insecure_port(f"[::]:{options.port}") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change was necessary to support end to end testing concurrently. |
||
| print(f"Started listening at localhost:{options.port}") | ||
|
|
||
| server.start() | ||
| server.wait_for_termination() | ||
| print("Server shutdown complete") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,15 @@ | ||
| import operator | ||
| import subprocess | ||
| import sys | ||
| import threading | ||
| import time | ||
| import traceback | ||
| from contextlib import ExitStack | ||
| from dataclasses import replace | ||
| from functools import partial | ||
| from pathlib import Path | ||
| from typing import Any, List | ||
| from unittest.mock import Mock | ||
|
|
||
| import pytest | ||
| from isolate.backends import BaseEnvironment, EnvironmentConnection | ||
|
|
@@ -209,3 +215,79 @@ def make_venv( | |
| env = VirtualPythonEnvironment(requirements + [f"{REPO_DIR}[grpc]"]) | ||
| env.apply_settings(IsolateSettings(Path(tmp_path))) | ||
| return env | ||
|
|
||
| def test_process_termination(self): | ||
| local_env = LocalPythonEnvironment() | ||
| connection = LocalPythonGRPC(local_env, local_env.create()) | ||
|
|
||
| # Mock terminate_proc on the connection instance before starting agent | ||
| connection.terminate_proc = Mock(wraps=connection.terminate_proc) | ||
|
|
||
| # Use ExitStack to manage the start_agent context manager | ||
| bound_context = ExitStack() | ||
| bound_context.enter_context(connection.start_agent(0.1)) | ||
|
|
||
| # Confirm with active context terminate_proc is not called yet | ||
| connection.terminate_proc.assert_not_called() | ||
|
|
||
| # Explicitly close the context to trigger cleanup | ||
| bound_context.close() | ||
|
|
||
| # After closing context, terminate_proc should have been called | ||
| connection.terminate_proc.assert_called_once() | ||
|
|
||
| def test_terminate_proc(self): | ||
| """Test that LocalPythonGRPC.terminate_proc() successfully kills | ||
| a process with SIGTERM.""" | ||
|
|
||
| local_env = LocalPythonEnvironment() | ||
| connection = LocalPythonGRPC(local_env, local_env.create()) | ||
| # Create a process that responds to SIGTERM | ||
| code = """import signal, time, sys | ||
| while True: | ||
| time.sleep(0.1)""" | ||
|
|
||
| proc = subprocess.Popen([sys.executable, "-c", code]) | ||
| # Verify process is running initially | ||
| assert proc.poll() is None, "Process should be running initially" | ||
|
|
||
| # terminate_proc() should send SIGTERM and wait for process to die | ||
| connection.terminate_proc(proc, shutdown_grace_period=3.0) | ||
|
|
||
| # Process should be terminated after terminate_proc() returns | ||
| assert proc.poll() is not None, "Process should be terminated by SIGTERM" | ||
|
|
||
| def test_force_terminate(self): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good one |
||
| """Test that LocalPythonGRPC.force_terminate() kills a process immediately.""" | ||
| local_env = LocalPythonEnvironment() | ||
| connection = LocalPythonGRPC(local_env, local_env.create()) | ||
|
|
||
| # Start a process that ignores SIGTERM | ||
| code = """import signal, time, os | ||
| # Set up signal handler to ignore SIGTERM | ||
| signal.signal(signal.SIGTERM, signal.SIG_IGN) | ||
| # Signal that setup is complete | ||
| print("R", flush=True) | ||
| while True: | ||
| time.sleep(0.1)""" | ||
|
|
||
| proc = subprocess.Popen( | ||
| [sys.executable, "-c", code], | ||
| stdout=subprocess.PIPE, | ||
| ) | ||
| assert proc.poll() is None, "Process should be running initially" | ||
|
|
||
| # Wait for "READY" signal from process because signal handling requires it | ||
| assert proc.stdout is not None | ||
| ready_line = proc.stdout.read(1) | ||
| assert ready_line == b"R" | ||
|
|
||
| # run terminate_proc in background thread since it will block | ||
| # waiting for the process to terminate (which it won't since it ignores SIGTERM) | ||
| threading.Thread(target=connection.terminate_proc, args=(proc, 0.5)).start() | ||
| time.sleep(0.35) | ||
| assert proc.poll() is None, "Process should ignore SIGTERM initially" | ||
| time.sleep(0.35) | ||
| assert ( | ||
| proc.poll() is not None | ||
| ), "Process should be terminated by force_terminate" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good