Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions src/isolate/connections/grpc/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@

from __future__ import annotations

import asyncio
import os
import signal
import sys
import traceback
from argparse import ArgumentParser
from concurrent import futures
from dataclasses import dataclass
from typing import (
Any,
AsyncIterator,
Iterable,
Iterator,
)

import grpc
import grpc.aio
from grpc import ServicerContext, StatusCode

try:
Expand All @@ -33,7 +35,6 @@
from isolate.backends.common import sha256_digest_of
from isolate.connections.common import SerializationError, serialize_object
from isolate.connections.grpc import definitions
from isolate.connections.grpc.configuration import get_default_options
from isolate.connections.grpc.interface import from_grpc


Expand All @@ -49,11 +50,11 @@ def __init__(self, log_fd: int | None = None):
self._run_cache: dict[str, Any] = {}
self._log = sys.stdout if log_fd is None else os.fdopen(log_fd, "w")

def Run(
async def Run(
self,
request: definitions.FunctionCall,
context: ServicerContext,
) -> Iterator[definitions.PartialRunResult]:
) -> AsyncIterator[definitions.PartialRunResult]:
self.log(f"A connection has been established: {context.peer()}!")
server_version = os.getenv("ISOLATE_SERVER_VERSION") or "unknown"
self.log(f"Isolate info: server {server_version}, agent {agent_version}")
Expand Down Expand Up @@ -87,7 +88,8 @@ def Run(
)
raise AbortException("The setup function has thrown an error.")
except AbortException as exc:
return self.abort_with_msg(context, exc.message)
self.abort_with_msg(context, exc.message)
return
else:
assert not was_it_raised
self._run_cache[cache_key] = result
Expand All @@ -107,7 +109,8 @@ def Run(
stringized_tb,
)
except AbortException as exc:
return self.abort_with_msg(context, exc.message)
self.abort_with_msg(context, exc.message)
return

def execute_function(
self,
Expand Down Expand Up @@ -205,14 +208,10 @@ def abort_with_msg(
return None


def create_server(address: str) -> grpc.Server:
def create_server(address: str) -> grpc.aio.Server:
"""Create a new (temporary) gRPC server listening on the given
address."""
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=1),
maximum_concurrent_rpcs=1,
options=get_default_options(),
)
server = grpc.aio.server()

# Local server credentials allow us to ensure that the
# connection is established by a local process.
Expand All @@ -221,29 +220,80 @@ def create_server(address: str) -> grpc.Server:
return server


def run_agent(address: str, log_fd: int | None = None) -> int:
def _setup_signal_handler(
shutdown_event: asyncio.Event, servicer: AgentServicer
) -> None:
"""Set up SIGTERM signal handler for graceful shutdown."""

def signal_handler():
servicer.log("Received SIGTERM, initiating graceful shutdown")
shutdown_event.set()

loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGTERM, signal_handler)


async def _handle_shutdown(
server: grpc.aio.Server, servicer: AgentServicer, grace_period: int
) -> None:
"""Handle graceful server shutdown."""
shutdown_event = asyncio.Event()
_setup_signal_handler(shutdown_event, servicer)

# Wait for either server termination or shutdown signal
termination_task = asyncio.create_task(server.wait_for_termination())
shutdown_task = asyncio.create_task(shutdown_event.wait())

try:
done, pending = await asyncio.wait(
[termination_task, shutdown_task], return_when=asyncio.FIRST_COMPLETED
)

# If shutdown signal received, stop the server gracefully
if shutdown_task in done:
servicer.log(
f"Shutting down gRPC server with {grace_period}s grace period..."
)
await server.stop(grace=grace_period)

# Cancel any pending tasks
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
finally:
# Ensure server is fully stopped
await server.stop(grace=0)


async def run_agent(address: str, log_fd: int | None = None) -> int:
"""Run the agent servicer on the given address."""
server = create_server(address)
servicer = AgentServicer(log_fd=log_fd)

# This function just calls some methods on the server
# and register a generic handler for the bridge. It does
# not have any global side effects.
# Register the agent service
definitions.register_agent(servicer, server)

server.start()
server.wait_for_termination()
# Get shutdown grace period from environment or use default
grace_period = int(os.getenv("ISOLATE_SHUTDOWN_GRACE_PERIOD", "3600"))

await server.start()
await _handle_shutdown(server, servicer, grace_period)
return 0


def main() -> int:
async def main() -> int:
parser = ArgumentParser()
parser.add_argument("address", type=str)
parser.add_argument("--log-fd", type=int)

options = parser.parse_args()
return run_agent(options.address, log_fd=options.log_fd)
return await run_agent(options.address, log_fd=options.log_fd)


if __name__ == "__main__":
main()
import asyncio

asyncio.run(main())
112 changes: 111 additions & 1 deletion tests/test_connections.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import operator
import os
import signal
import socket
import subprocess
import sys
import threading
Expand All @@ -11,13 +14,15 @@
from typing import Any, List
from unittest.mock import Mock

import grpc
import pytest
from isolate.backends import BaseEnvironment, EnvironmentConnection
from isolate.backends.local import LocalPythonEnvironment
from isolate.backends.settings import IsolateSettings
from isolate.backends.virtualenv import VirtualPythonEnvironment
from isolate.connections import LocalPythonGRPC, PythonIPC
from isolate.connections.common import is_agent
from isolate.connections.common import is_agent, serialize_object
from isolate.connections.grpc import definitions

REPO_DIR = Path(__file__).parent.parent
assert (
Expand Down Expand Up @@ -291,3 +296,108 @@ def test_force_terminate(self):
assert (
proc.poll() is not None
), "Process should be terminated by force_terminate"

def test_agent_function_receives_signal(self):
"""Test that an agent function properly receives and handles SIGTERM."""

def find_free_port() -> str:
with socket.socket() as sock:
sock.bind(("", 0))
host, port = sock.getsockname()
return f"{host}:{port}"

# This function will be passed to a test agent subprocess. It sets an exit
# code can be used to confirm it received a signal
def fn_with_sigterm_handler():
def handle_sigterm(_signum, _frame):
print(
f"SIGTERM in execute_function received by PID {os.getpid()}",
flush=True,
)
sys.exit(42) # Use specific exit code to verify handler was called

signal.signal(signal.SIGTERM, handle_sigterm)

# Keep running until SIGTERM is received
for _ in range(10): # Max 10 seconds
time.sleep(0.1)

# If we exit the loop without receiving SIGTERM, exit with different code
sys.exit(99)

# Start agent process directly with short grace period for testing
address = find_free_port()
agent_proc = subprocess.Popen(
[
sys.executable,
"-m",
"isolate.connections.grpc.agent",
address,
],
cwd=REPO_DIR,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
env={**os.environ, "ISOLATE_SHUTDOWN_GRACE_PERIOD": "1"},
)
time.sleep(0.5) # Give agent time to start

try:
assert agent_proc.poll() is None, "Agent should be running"
# start a client to the agent service
with grpc.secure_channel(
address, grpc.local_channel_credentials()
) as channel:
stub = definitions.AgentStub(channel)

method = "cloudpickle"
function = definitions.SerializedObject(
method=method,
definition=serialize_object(method, fn_with_sigterm_handler),
was_it_raised=False,
stringized_traceback=None,
)
function_call = definitions.FunctionCall(function=function)

def run_function():
for _ in stub.Run(function_call):
pass # Consume responses until agent exits

# Start client receiver function in background thread
run_thread = threading.Thread(target=run_function)
run_thread.daemon = True
run_thread.start()

# Give the function time to set up signal handler
time.sleep(0.5)

# Send SIGTERM to the agent process using the PID we have
os.kill(agent_proc.pid, signal.SIGTERM)

# Give some time for signal handling and server shutdown
time.sleep(2.0)

# Check if agent shut down gracefully
if agent_proc.poll() is None:
# If still running, kill it and read output
agent_proc.kill()
output = agent_proc.stdout.read() if agent_proc.stdout else ""
assert (
False
), f"Agent should have shut down gracefully, output: {output}"
else:
# Agent shut down gracefully, read the output
output = agent_proc.stdout.read() if agent_proc.stdout else ""

# Check that both SIGTERM messages were received
assert (
"SIGTERM in execute_function" in output
), f"Expected user function SIGTERM message in output, got: {output}"

assert (
"Received SIGTERM, initiating graceful shutdown" in output
), f"Expected agent SIGTERM message in output, got: {output}"

finally:
if agent_proc.poll() is None:
agent_proc.kill() # Force kill if still running