Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/isolate/connections/grpc/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import socket
import subprocess
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -27,7 +28,9 @@ class AgentError(Exception):
class GRPCExecutionBase(EnvironmentConnection):
"""A customizable gRPC-based execution backend."""

def start_agent(self) -> ContextManager[Tuple[str, grpc.ChannelCredentials]]:
def start_agent(
self, shutdown_grace_period: float
) -> ContextManager[Tuple[str, grpc.ChannelCredentials]]:
"""Starts the gRPC agent and returns the address it is listening on and
the required credentials to connect to it."""
raise NotImplementedError
Expand All @@ -37,8 +40,9 @@ def _establish_bridge(
self,
*,
max_wait_timeout: float = 20.0,
shutdown_grace_period: float = 0.1,
) -> Iterator[definitions.AgentStub]:
with self.start_agent() as (address, credentials):
with self.start_agent(shutdown_grace_period) as (address, credentials):
with grpc.secure_channel(
address,
credentials,
Expand Down Expand Up @@ -113,8 +117,13 @@ def run(


class LocalPythonGRPC(PythonExecutionBase[str], GRPCExecutionBase):
"""A gRPC-based execution backend that runs Python code in a local
environment."""

@contextmanager
def start_agent(self) -> Iterator[Tuple[str, grpc.ChannelCredentials]]:
def start_agent(
self, shutdown_grace_period: float
) -> Iterator[Tuple[str, grpc.ChannelCredentials]]:
def find_free_port() -> Tuple[str, int]:
"""Find a free port in the system."""
with socket.socket() as _temp_socket:
Expand All @@ -129,8 +138,21 @@ def find_free_port() -> Tuple[str, int]:
yield address, grpc.local_channel_credentials()
finally:
if process is not None:
# TODO: should we check the status code here?
process.terminate()
self.terminate_proc(process, shutdown_grace_period)

def terminate_proc(
self, proc: subprocess.Popen, shutdown_grace_period: float
) -> None:
if not proc or proc.poll() is not None:
return
try:
print(f"Terminating agent PID {proc.pid}")
proc.terminate()
proc.wait(timeout=shutdown_grace_period)
except subprocess.TimeoutExpired:
# Process didn't die within timeout
print(f"killing agent PID {proc.pid}")
proc.kill()
Comment on lines +148 to +155
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good


def get_python_cmd(
self,
Expand Down
70 changes: 59 additions & 11 deletions src/isolate/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import functools
import os
import signal
import threading
import time
import traceback
Expand Down Expand Up @@ -39,6 +40,9 @@
SKIP_EMPTY_LOGS = os.getenv("ISOLATE_SKIP_EMPTY_LOGS") == "1"
MAX_GRPC_WAIT_TIMEOUT = float(os.getenv("ISOLATE_MAX_GRPC_WAIT_TIMEOUT", "10.0"))

# Graceful shutdown timeout in seconds (default: 60 minutes = 3600 seconds)
SHUTDOWN_GRACE_PERIOD = int(os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "3600"))

# Whether to inherit all the packages from the current environment or not.
INHERIT_FROM_LOCAL = os.getenv("ISOLATE_INHERIT_FROM_LOCAL") == "1"

Expand Down Expand Up @@ -107,6 +111,8 @@ def terminate(self) -> None:

@dataclass
class BridgeManager:
"""Manages a pool of reusable gRPC agents for isolated Python environments."""

_agent_access_lock: threading.Lock = field(default_factory=threading.Lock)
_agents: dict[tuple[Any, ...], list[RunnerAgent]] = field(
default_factory=lambda: defaultdict(list)
Expand Down Expand Up @@ -150,8 +156,12 @@ def _allocate_new_agent(

bound_context = ExitStack()
stub = bound_context.enter_context(
connection._establish_bridge(max_wait_timeout=MAX_GRPC_WAIT_TIMEOUT)
connection._establish_bridge(
max_wait_timeout=MAX_GRPC_WAIT_TIMEOUT,
shutdown_grace_period=SHUTDOWN_GRACE_PERIOD,
)
)

return RunnerAgent(stub, queue, bound_context)

def _identify(self, connection: LocalPythonGRPC) -> tuple[Any, ...]:
Expand All @@ -163,7 +173,7 @@ def _identify(self, connection: LocalPythonGRPC) -> tuple[Any, ...]:
def __enter__(self) -> BridgeManager:
return self

def __exit__(self, *exc_info: Any) -> None:
def __exit__(self, *_: Any) -> None:
for agents in self._agents.values():
for agent in agents:
agent.terminate()
Expand Down Expand Up @@ -194,13 +204,20 @@ def stream_logs(self) -> bool:

@dataclass
class IsolateServicer(definitions.IsolateServicer):
"""gRPC service handler for executing Python functions in isolated environments.

Orchestrates task execution by creating environments, managing agent connections
via BridgeManager, and streaming real-time results back to clients. Handles
graceful shutdown with signal propagation and agent cleanup."""

bridge_manager: BridgeManager
default_settings: IsolateSettings = field(default_factory=IsolateSettings)
background_tasks: dict[str, RunTask] = field(default_factory=dict)

_thread_pool: futures.ThreadPoolExecutor = field(
default_factory=lambda: futures.ThreadPoolExecutor(max_workers=MAX_THREADS)
)
_shutting_down: bool = field(default=False)

def _run_task(self, task: RunTask) -> Iterator[definitions.PartialRunResult]:
messages: Queue[definitions.PartialRunResult] = Queue()
Expand Down Expand Up @@ -472,10 +489,30 @@ def abort_with_msg(
return None

def cancel_tasks(self):
"""Cancel all tasks with optional graceful shutdown"""
tasks_copy = self.background_tasks.copy()
for task in tasks_copy.values():
task.cancel()

def initiate_shutdown(self) -> None:
if self._shutting_down:
return
self._shutting_down = True

print(f"Initiating shutdown with grace period: {SHUTDOWN_GRACE_PERIOD}s")
# Collect all active agents from running tasks
shutdown_threads = []
for task in self.background_tasks.values():
if task.agent is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if it is None you should probably cancel it directly

thread = threading.Thread(target=task.agent.terminate)
thread.start()
shutdown_threads.append(thread)

# Wait for all agents to terminate
for thread in shutdown_threads:
thread.join()
print("All active agents have been shut down")


def _proxy_to_queue(
queue: Queue,
Expand Down Expand Up @@ -583,9 +620,10 @@ def wrapper(method_impl):
def _wrapper(request: Any, context: grpc.ServicerContext) -> Any:
def termination() -> None:
if is_run:
print("Stopping server since run is finished")
# Stop the server after the Run task is finished
self.server.stop(grace=0.1)
# Trigger graceful shutdown instead of immediate stop
self.servicer.initiate_shutdown()
if self._server:
self._server.stop(grace=0.1)

elif is_submit:
# Wait until the task_id is assigned
Expand All @@ -606,11 +644,11 @@ def termination() -> None:
"Task future was not assigned in time."
)

def _stop(*args):
def _stop(*_):
# Small sleep to make sure the cancellation is processed
time.sleep(0.1)
print("Stopping server since the task is finished")
self.server.stop(grace=0.1)
self.servicer.initiate_shutdown()

# Add a callback which will stop the server
# after the task is finished
Expand All @@ -629,6 +667,16 @@ def _stop(*args):
return wrap_server_method_handler(wrapper, handler)


def register_signal_handlers(servicer: IsolateServicer, server: grpc.Server) -> None:
def handle_signal(signum, frame):
print(f"Received signal {signum}, initiating shutdown...")
servicer.initiate_shutdown()
server.stop(grace=0.1)

signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
Comment on lines +670 to +677
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chamini2 Is this what you're suggesting here #174 (comment)?



def main(argv: list[str] | None = None) -> None:
parser = ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
Expand Down Expand Up @@ -664,18 +712,18 @@ def main(argv: list[str] | None = None) -> None:

with BridgeManager() as bridge_manager:
servicer = IsolateServicer(bridge_manager)

register_signal_handlers(servicer, server)
for interceptor in interceptors:
interceptor.register_servicer(servicer)

definitions.register_isolate(servicer, server)
health.register_health(HealthServicer(), server)

server.add_insecure_port("[::]:50001")
print("Started listening at localhost:50001")
server.add_insecure_port(f"[::]:{options.port}")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was necessary to support end to end testing concurrently.

print(f"Started listening at localhost:{options.port}")

server.start()
server.wait_for_termination()
print("Server shutdown complete")


if __name__ == "__main__":
Expand Down
82 changes: 82 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import operator
import subprocess
import sys
import threading
import time
import traceback
from contextlib import ExitStack
from dataclasses import replace
from functools import partial
from pathlib import Path
from typing import Any, List
from unittest.mock import Mock

import pytest
from isolate.backends import BaseEnvironment, EnvironmentConnection
Expand Down Expand Up @@ -209,3 +215,79 @@ def make_venv(
env = VirtualPythonEnvironment(requirements + [f"{REPO_DIR}[grpc]"])
env.apply_settings(IsolateSettings(Path(tmp_path)))
return env

def test_process_termination(self):
local_env = LocalPythonEnvironment()
connection = LocalPythonGRPC(local_env, local_env.create())

# Mock terminate_proc on the connection instance before starting agent
connection.terminate_proc = Mock(wraps=connection.terminate_proc)

# Use ExitStack to manage the start_agent context manager
bound_context = ExitStack()
bound_context.enter_context(connection.start_agent(0.1))

# Confirm with active context terminate_proc is not called yet
connection.terminate_proc.assert_not_called()

# Explicitly close the context to trigger cleanup
bound_context.close()

# After closing context, terminate_proc should have been called
connection.terminate_proc.assert_called_once()

def test_terminate_proc(self):
"""Test that LocalPythonGRPC.terminate_proc() successfully kills
a process with SIGTERM."""

local_env = LocalPythonEnvironment()
connection = LocalPythonGRPC(local_env, local_env.create())
# Create a process that responds to SIGTERM
code = """import signal, time, sys
while True:
time.sleep(0.1)"""

proc = subprocess.Popen([sys.executable, "-c", code])
# Verify process is running initially
assert proc.poll() is None, "Process should be running initially"

# terminate_proc() should send SIGTERM and wait for process to die
connection.terminate_proc(proc, shutdown_grace_period=3.0)

# Process should be terminated after terminate_proc() returns
assert proc.poll() is not None, "Process should be terminated by SIGTERM"

def test_force_terminate(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good one

"""Test that LocalPythonGRPC.force_terminate() kills a process immediately."""
local_env = LocalPythonEnvironment()
connection = LocalPythonGRPC(local_env, local_env.create())

# Start a process that ignores SIGTERM
code = """import signal, time, os
# Set up signal handler to ignore SIGTERM
signal.signal(signal.SIGTERM, signal.SIG_IGN)
# Signal that setup is complete
print("R", flush=True)
while True:
time.sleep(0.1)"""

proc = subprocess.Popen(
[sys.executable, "-c", code],
stdout=subprocess.PIPE,
)
assert proc.poll() is None, "Process should be running initially"

# Wait for "READY" signal from process because signal handling requires it
assert proc.stdout is not None
ready_line = proc.stdout.read(1)
assert ready_line == b"R"

# run terminate_proc in background thread since it will block
# waiting for the process to terminate (which it won't since it ignores SIGTERM)
threading.Thread(target=connection.terminate_proc, args=(proc, 0.5)).start()
time.sleep(0.35)
assert proc.poll() is None, "Process should ignore SIGTERM initially"
time.sleep(0.35)
assert (
proc.poll() is not None
), "Process should be terminated by force_terminate"
4 changes: 4 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
IsolateServicer,
ServerBoundInterceptor,
SingleTaskInterceptor,
register_signal_handlers,
)

REPO_DIR = Path(__file__).parent.parent
Expand Down Expand Up @@ -62,6 +63,9 @@ def make_server(
with BridgeManager() as bridge:
servicer = IsolateServicer(bridge, test_settings)

# Set up signal handlers (needed for graceful shutdown)
register_signal_handlers(servicer, server)

for interceptor in interceptors:
interceptor.register_servicer(servicer)

Expand Down
Loading