diff --git a/src/isolate/connections/grpc/_base.py b/src/isolate/connections/grpc/_base.py index a4d3bfc..1b8331d 100644 --- a/src/isolate/connections/grpc/_base.py +++ b/src/isolate/connections/grpc/_base.py @@ -1,4 +1,5 @@ import socket +import subprocess from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path @@ -27,7 +28,9 @@ class AgentError(Exception): class GRPCExecutionBase(EnvironmentConnection): """A customizable gRPC-based execution backend.""" - def start_agent(self) -> ContextManager[Tuple[str, grpc.ChannelCredentials]]: + def start_agent( + self, shutdown_grace_period: float + ) -> ContextManager[Tuple[str, grpc.ChannelCredentials]]: """Starts the gRPC agent and returns the address it is listening on and the required credentials to connect to it.""" raise NotImplementedError @@ -37,8 +40,9 @@ def _establish_bridge( self, *, max_wait_timeout: float = 20.0, + shutdown_grace_period: float = 0.1, ) -> Iterator[definitions.AgentStub]: - with self.start_agent() as (address, credentials): + with self.start_agent(shutdown_grace_period) as (address, credentials): with grpc.secure_channel( address, credentials, @@ -113,8 +117,13 @@ def run( class LocalPythonGRPC(PythonExecutionBase[str], GRPCExecutionBase): + """A gRPC-based execution backend that runs Python code in a local + environment.""" + @contextmanager - def start_agent(self) -> Iterator[Tuple[str, grpc.ChannelCredentials]]: + def start_agent( + self, shutdown_grace_period: float + ) -> Iterator[Tuple[str, grpc.ChannelCredentials]]: def find_free_port() -> Tuple[str, int]: """Find a free port in the system.""" with socket.socket() as _temp_socket: @@ -129,8 +138,21 @@ def find_free_port() -> Tuple[str, int]: yield address, grpc.local_channel_credentials() finally: if process is not None: - # TODO: should we check the status code here? - process.terminate() + self.terminate_proc(process, shutdown_grace_period) + + def terminate_proc( + self, proc: subprocess.Popen, shutdown_grace_period: float + ) -> None: + if not proc or proc.poll() is not None: + return + try: + print(f"Terminating agent PID {proc.pid}") + proc.terminate() + proc.wait(timeout=shutdown_grace_period) + except subprocess.TimeoutExpired: + # Process didn't die within timeout + print(f"killing agent PID {proc.pid}") + proc.kill() def get_python_cmd( self, diff --git a/src/isolate/server/server.py b/src/isolate/server/server.py index 12a9931..bb486b9 100644 --- a/src/isolate/server/server.py +++ b/src/isolate/server/server.py @@ -2,6 +2,7 @@ import functools import os +import signal import threading import time import traceback @@ -39,6 +40,9 @@ SKIP_EMPTY_LOGS = os.getenv("ISOLATE_SKIP_EMPTY_LOGS") == "1" MAX_GRPC_WAIT_TIMEOUT = float(os.getenv("ISOLATE_MAX_GRPC_WAIT_TIMEOUT", "10.0")) +# Graceful shutdown timeout in seconds (default: 60 minutes = 3600 seconds) +SHUTDOWN_GRACE_PERIOD = int(os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "3600")) + # Whether to inherit all the packages from the current environment or not. INHERIT_FROM_LOCAL = os.getenv("ISOLATE_INHERIT_FROM_LOCAL") == "1" @@ -107,6 +111,8 @@ def terminate(self) -> None: @dataclass class BridgeManager: + """Manages a pool of reusable gRPC agents for isolated Python environments.""" + _agent_access_lock: threading.Lock = field(default_factory=threading.Lock) _agents: dict[tuple[Any, ...], list[RunnerAgent]] = field( default_factory=lambda: defaultdict(list) @@ -150,8 +156,12 @@ def _allocate_new_agent( bound_context = ExitStack() stub = bound_context.enter_context( - connection._establish_bridge(max_wait_timeout=MAX_GRPC_WAIT_TIMEOUT) + connection._establish_bridge( + max_wait_timeout=MAX_GRPC_WAIT_TIMEOUT, + shutdown_grace_period=SHUTDOWN_GRACE_PERIOD, + ) ) + return RunnerAgent(stub, queue, bound_context) def _identify(self, connection: LocalPythonGRPC) -> tuple[Any, ...]: @@ -163,7 +173,7 @@ def _identify(self, connection: LocalPythonGRPC) -> tuple[Any, ...]: def __enter__(self) -> BridgeManager: return self - def __exit__(self, *exc_info: Any) -> None: + def __exit__(self, *_: Any) -> None: for agents in self._agents.values(): for agent in agents: agent.terminate() @@ -194,6 +204,12 @@ def stream_logs(self) -> bool: @dataclass class IsolateServicer(definitions.IsolateServicer): + """gRPC service handler for executing Python functions in isolated environments. + + Orchestrates task execution by creating environments, managing agent connections + via BridgeManager, and streaming real-time results back to clients. Handles + graceful shutdown with signal propagation and agent cleanup.""" + bridge_manager: BridgeManager default_settings: IsolateSettings = field(default_factory=IsolateSettings) background_tasks: dict[str, RunTask] = field(default_factory=dict) @@ -201,6 +217,7 @@ class IsolateServicer(definitions.IsolateServicer): _thread_pool: futures.ThreadPoolExecutor = field( default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS) ) + _shutting_down: bool = field(default=False) def _run_task(self, task: RunTask) -> Iterator[definitions.PartialRunResult]: messages: Queue[definitions.PartialRunResult] = Queue() @@ -472,10 +489,30 @@ def abort_with_msg( return None def cancel_tasks(self): + """Cancel all tasks with optional graceful shutdown""" tasks_copy = self.background_tasks.copy() for task in tasks_copy.values(): task.cancel() + def initiate_shutdown(self) -> None: + if self._shutting_down: + return + self._shutting_down = True + + print(f"Initiating shutdown with grace period: {SHUTDOWN_GRACE_PERIOD}s") + # Collect all active agents from running tasks + shutdown_threads = [] + for task in self.background_tasks.values(): + if task.agent is not None: + thread = threading.Thread(target=task.agent.terminate) + thread.start() + shutdown_threads.append(thread) + + # Wait for all agents to terminate + for thread in shutdown_threads: + thread.join() + print("All active agents have been shut down") + def _proxy_to_queue( queue: Queue, @@ -583,9 +620,10 @@ def wrapper(method_impl): def _wrapper(request: Any, context: grpc.ServicerContext) -> Any: def termination() -> None: if is_run: - print("Stopping server since run is finished") - # Stop the server after the Run task is finished - self.server.stop(grace=0.1) + # Trigger graceful shutdown instead of immediate stop + self.servicer.initiate_shutdown() + if self._server: + self._server.stop(grace=0.1) elif is_submit: # Wait until the task_id is assigned @@ -606,11 +644,11 @@ def termination() -> None: "Task future was not assigned in time." ) - def _stop(*args): + def _stop(*_): # Small sleep to make sure the cancellation is processed time.sleep(0.1) print("Stopping server since the task is finished") - self.server.stop(grace=0.1) + self.servicer.initiate_shutdown() # Add a callback which will stop the server # after the task is finished @@ -629,6 +667,16 @@ def _stop(*args): return wrap_server_method_handler(wrapper, handler) +def register_signal_handlers(servicer: IsolateServicer, server: grpc.Server) -> None: + def handle_signal(signum, frame): + print(f"Received signal {signum}, initiating shutdown...") + servicer.initiate_shutdown() + server.stop(grace=0.1) + + signal.signal(signal.SIGINT, handle_signal) + signal.signal(signal.SIGTERM, handle_signal) + + def main(argv: list[str] | None = None) -> None: parser = ArgumentParser() parser.add_argument("--host", default="0.0.0.0") @@ -664,18 +712,18 @@ def main(argv: list[str] | None = None) -> None: with BridgeManager() as bridge_manager: servicer = IsolateServicer(bridge_manager) - + register_signal_handlers(servicer, server) for interceptor in interceptors: interceptor.register_servicer(servicer) - definitions.register_isolate(servicer, server) health.register_health(HealthServicer(), server) - server.add_insecure_port("[::]:50001") - print("Started listening at localhost:50001") + server.add_insecure_port(f"[::]:{options.port}") + print(f"Started listening at localhost:{options.port}") server.start() server.wait_for_termination() + print("Server shutdown complete") if __name__ == "__main__": diff --git a/tests/test_connections.py b/tests/test_connections.py index e81aa65..ba3b4eb 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -1,9 +1,15 @@ import operator +import subprocess +import sys +import threading +import time import traceback +from contextlib import ExitStack from dataclasses import replace from functools import partial from pathlib import Path from typing import Any, List +from unittest.mock import Mock import pytest from isolate.backends import BaseEnvironment, EnvironmentConnection @@ -209,3 +215,79 @@ def make_venv( env = VirtualPythonEnvironment(requirements + [f"{REPO_DIR}[grpc]"]) env.apply_settings(IsolateSettings(Path(tmp_path))) return env + + def test_process_termination(self): + local_env = LocalPythonEnvironment() + connection = LocalPythonGRPC(local_env, local_env.create()) + + # Mock terminate_proc on the connection instance before starting agent + connection.terminate_proc = Mock(wraps=connection.terminate_proc) + + # Use ExitStack to manage the start_agent context manager + bound_context = ExitStack() + bound_context.enter_context(connection.start_agent(0.1)) + + # Confirm with active context terminate_proc is not called yet + connection.terminate_proc.assert_not_called() + + # Explicitly close the context to trigger cleanup + bound_context.close() + + # After closing context, terminate_proc should have been called + connection.terminate_proc.assert_called_once() + + def test_terminate_proc(self): + """Test that LocalPythonGRPC.terminate_proc() successfully kills + a process with SIGTERM.""" + + local_env = LocalPythonEnvironment() + connection = LocalPythonGRPC(local_env, local_env.create()) + # Create a process that responds to SIGTERM + code = """import signal, time, sys +while True: + time.sleep(0.1)""" + + proc = subprocess.Popen([sys.executable, "-c", code]) + # Verify process is running initially + assert proc.poll() is None, "Process should be running initially" + + # terminate_proc() should send SIGTERM and wait for process to die + connection.terminate_proc(proc, shutdown_grace_period=3.0) + + # Process should be terminated after terminate_proc() returns + assert proc.poll() is not None, "Process should be terminated by SIGTERM" + + def test_force_terminate(self): + """Test that LocalPythonGRPC.force_terminate() kills a process immediately.""" + local_env = LocalPythonEnvironment() + connection = LocalPythonGRPC(local_env, local_env.create()) + + # Start a process that ignores SIGTERM + code = """import signal, time, os +# Set up signal handler to ignore SIGTERM +signal.signal(signal.SIGTERM, signal.SIG_IGN) +# Signal that setup is complete +print("R", flush=True) +while True: + time.sleep(0.1)""" + + proc = subprocess.Popen( + [sys.executable, "-c", code], + stdout=subprocess.PIPE, + ) + assert proc.poll() is None, "Process should be running initially" + + # Wait for "READY" signal from process because signal handling requires it + assert proc.stdout is not None + ready_line = proc.stdout.read(1) + assert ready_line == b"R" + + # run terminate_proc in background thread since it will block + # waiting for the process to terminate (which it won't since it ignores SIGTERM) + threading.Thread(target=connection.terminate_proc, args=(proc, 0.5)).start() + time.sleep(0.35) + assert proc.poll() is None, "Process should ignore SIGTERM initially" + time.sleep(0.35) + assert ( + proc.poll() is not None + ), "Process should be terminated by force_terminate" diff --git a/tests/test_server.py b/tests/test_server.py index d41b55e..17e1078 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -20,6 +20,7 @@ IsolateServicer, ServerBoundInterceptor, SingleTaskInterceptor, + register_signal_handlers, ) REPO_DIR = Path(__file__).parent.parent @@ -62,6 +63,9 @@ def make_server( with BridgeManager() as bridge: servicer = IsolateServicer(bridge, test_settings) + # Set up signal handlers (needed for graceful shutdown) + register_signal_handlers(servicer, server) + for interceptor in interceptors: interceptor.register_servicer(servicer) diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py new file mode 100644 index 0000000..3204aae --- /dev/null +++ b/tests/test_shutdown.py @@ -0,0 +1,147 @@ +"""End-to-end tests for graceful shutdown behavior of IsolateServicer.""" + +import functools +import os +import signal +import subprocess +import sys +import threading +import time +from unittest.mock import Mock + +import grpc +import pytest +from isolate.server.definitions.server_pb2 import BoundFunction, EnvironmentDefinition +from isolate.server.definitions.server_pb2_grpc import IsolateStub +from isolate.server.interface import to_serialized_object +from isolate.server.server import BridgeManager, IsolateServicer, RunnerAgent, RunTask + + +def create_run_request(func, stream_logs=True): + """Convert a Python function into a BoundFunction request for stub.Run().""" + bound_function = functools.partial(func) + serialized_function = to_serialized_object(bound_function, method="cloudpickle") + + env_def = EnvironmentDefinition() + env_def.kind = "local" + + request = BoundFunction() + request.function.CopyFrom(serialized_function) + request.environments.append(env_def) + request.stream_logs = stream_logs + + return request + + +@pytest.fixture +def servicer(): + """Create a real IsolateServicer instance for testing.""" + with BridgeManager() as bridge_manager: + servicer = IsolateServicer(bridge_manager) + yield servicer + + +@pytest.fixture +def isolate_server_subprocess(): + """Set up a gRPC server with the IsolateServicer for testing.""" + + os.environ["ISOLATE_SHUTDOWN_GRACE_PERIOD"] = "2" + + # Find a free port + import socket + + with socket.socket() as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + process = subprocess.Popen( + [ + sys.executable, + "-m", + "isolate.server.server", + "--single-use", + "--port", + str(port), + ] + ) + + time.sleep(0.1) # Wait for server to start + yield process, port + + # Cleanup + if process.poll() is None: + process.terminate() + process.wait(timeout=5) + + +def test_shutdown_with_terminate(servicer): + """Test shutdown confirms that terminate is called + on servicer background tasks by initiate_shutdown""" + task = RunTask(request=Mock()) + servicer.background_tasks["TEST_BLOCKING"] = task + task.agent = RunnerAgent(Mock(), Mock(), Mock(), Mock()) + task.agent.terminate = Mock(wraps=task.agent.terminate) + servicer.initiate_shutdown() # default grace period + # force_terminate is tested within runner_agent testing + task.agent.terminate.assert_called_once() # agent should be terminated + + +def test_exit_on_client_close(isolate_server_subprocess): + """Connect with grpc client, run a task and then close the client.""" + process, port = isolate_server_subprocess + channel = grpc.insecure_channel(f"localhost:{port}") + stub = IsolateStub(channel) + + def fn(): + while True: + pass + + responses = stub.Run(create_run_request(fn)) + + def consume_responses(): + try: + for response in responses: + print("Received response:", response) + except grpc.RpcError: + # Expected when connection is closed + pass + + response_thread = threading.Thread(target=consume_responses, daemon=True) + response_thread.start() + + # Give task time to start + time.sleep(0.5) + + # there is a running grpc client connected to an isolate servicer which is + # emitting responses from an agent running a infinite loop + assert process.poll() is None, "Server should be running while client is connected" + + # Close the channel to simulate client disconnect + channel.close() + + # Give time for the channel close to propagate and trigger termination + time.sleep(1.0) + + try: + # Wait for server process to exit + process.wait(timeout=3) + except subprocess.TimeoutExpired: + raise AssertionError("Server did not shut down after client disconnect") + + assert ( + process.poll() is not None + ), "Server should have shut down after client disconnect" + + +def test_sigterm_termination(isolate_server_subprocess): + """Test that the server shuts down gracefully on SIGTERM.""" + process, port = isolate_server_subprocess + # Send SIGTERM to the current process + assert process.poll() is None, "Server should be running initially" + os.kill(process.pid, signal.SIGTERM) + process.wait(timeout=5) + assert process.poll() is not None, "Server should have shut down after SIGTERM" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])