From d1f1c05a6a8573f50406717d76fc3d66ccdd0661 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sat, 7 Mar 2026 12:27:15 +0100 Subject: [PATCH 1/3] feat: implement graceful shutdown with cooperative timeout strategy (#130) Add INTERRUPTED task status, SHUTDOWN termination reason, ShutdownStrategy protocol, CooperativeTimeoutStrategy, and ShutdownManager with cross-platform signal handling. The execution loop checks shutdown at turn boundaries and before tool invocations. The engine transitions tasks to INTERRUPTED on shutdown. Includes GracefulShutdownConfig in RootConfig. Co-Authored-By: Claude Opus 4.6 --- src/ai_company/config/__init__.py | 3 + src/ai_company/config/defaults.py | 1 + src/ai_company/config/schema.py | 34 ++ src/ai_company/core/enums.py | 8 +- src/ai_company/core/task_transitions.py | 10 +- src/ai_company/engine/__init__.py | 14 + src/ai_company/engine/agent_engine.py | 57 ++- src/ai_company/engine/loop_protocol.py | 9 +- src/ai_company/engine/react_loop.py | 41 ++- src/ai_company/engine/shutdown.py | 334 ++++++++++++++++++ .../observability/events/execution.py | 7 + .../engine/test_graceful_shutdown.py | 303 ++++++++++++++++ tests/unit/core/test_enums.py | 4 +- tests/unit/core/test_task_transitions.py | 9 + .../engine/test_agent_engine_lifecycle.py | 71 ++++ tests/unit/engine/test_loop_protocol.py | 2 +- tests/unit/engine/test_react_loop.py | 109 ++++++ tests/unit/engine/test_shutdown.py | 314 ++++++++++++++++ 18 files changed, 1313 insertions(+), 17 deletions(-) create mode 100644 src/ai_company/engine/shutdown.py create mode 100644 tests/integration/engine/test_graceful_shutdown.py create mode 100644 tests/unit/engine/test_shutdown.py diff --git a/src/ai_company/config/__init__.py b/src/ai_company/config/__init__.py index cc2a8d58ea..ee367156bd 100644 --- a/src/ai_company/config/__init__.py +++ b/src/ai_company/config/__init__.py @@ -10,6 +10,7 @@ default_config_dict RootConfig AgentConfig + GracefulShutdownConfig ProviderConfig ProviderModelConfig RoutingConfig @@ -37,6 +38,7 @@ ) from ai_company.config.schema import ( AgentConfig, + GracefulShutdownConfig, ProviderConfig, ProviderModelConfig, RootConfig, @@ -51,6 +53,7 @@ "ConfigLocation", "ConfigParseError", "ConfigValidationError", + "GracefulShutdownConfig", "ProviderConfig", "ProviderModelConfig", "RootConfig", diff --git a/src/ai_company/config/defaults.py b/src/ai_company/config/defaults.py index bd2f37c5bc..50f5458daa 100644 --- a/src/ai_company/config/defaults.py +++ b/src/ai_company/config/defaults.py @@ -25,4 +25,5 @@ def default_config_dict() -> dict[str, Any]: "providers": {}, "routing": {}, "logging": None, + "graceful_shutdown": {}, } diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index ca07fd5f5b..171cbf822f 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -321,6 +321,35 @@ class AgentConfig(BaseModel): ) +class GracefulShutdownConfig(BaseModel): + """Configuration for graceful shutdown behaviour. + + Attributes: + strategy: Shutdown strategy name (e.g. ``"cooperative_timeout"``). + grace_seconds: Seconds to wait for cooperative agent exit + before force-cancelling. + cleanup_seconds: Seconds allowed for cleanup callbacks + (persist costs, close connections, flush logs). + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + strategy: str = Field( + default="cooperative_timeout", + description="Shutdown strategy name", + ) + grace_seconds: float = Field( + default=30.0, + gt=0, + description="Seconds to wait for cooperative agent exit", + ) + cleanup_seconds: float = Field( + default=5.0, + gt=0, + description="Seconds allowed for cleanup callbacks", + ) + + class RootConfig(BaseModel): """Root company configuration — the top-level validation target. @@ -339,6 +368,7 @@ class RootConfig(BaseModel): providers: LLM provider configurations keyed by provider name. routing: Model routing configuration. logging: Logging configuration (``None`` to use platform defaults). + graceful_shutdown: Graceful shutdown configuration. """ model_config = ConfigDict(frozen=True) @@ -386,6 +416,10 @@ class RootConfig(BaseModel): default=None, description="Logging configuration", ) + graceful_shutdown: GracefulShutdownConfig = Field( + default_factory=GracefulShutdownConfig, + description="Graceful shutdown configuration", + ) @model_validator(mode="after") def _validate_unique_agent_names(self) -> Self: diff --git a/src/ai_company/core/enums.py b/src/ai_company/core/enums.py index 1e2f653333..42903875a9 100644 --- a/src/ai_company/core/enums.py +++ b/src/ai_company/core/enums.py @@ -127,13 +127,14 @@ class TaskStatus(StrEnum): Summary for quick reference: CREATED -> ASSIGNED - ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED - IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED + ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED | INTERRUPTED + IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED | INTERRUPTED IN_REVIEW -> COMPLETED | IN_PROGRESS (rework) | BLOCKED | CANCELLED BLOCKED -> ASSIGNED (unblocked) FAILED -> ASSIGNED (reassignment for retry) + INTERRUPTED -> ASSIGNED (reassignment on restart) COMPLETED and CANCELLED are terminal states. - FAILED is non-terminal (can be reassigned). + FAILED and INTERRUPTED are non-terminal (can be reassigned). """ CREATED = "created" @@ -143,6 +144,7 @@ class TaskStatus(StrEnum): COMPLETED = "completed" BLOCKED = "blocked" FAILED = "failed" + INTERRUPTED = "interrupted" CANCELLED = "cancelled" diff --git a/src/ai_company/core/task_transitions.py b/src/ai_company/core/task_transitions.py index b4f5921f98..3a26b1ca65 100644 --- a/src/ai_company/core/task_transitions.py +++ b/src/ai_company/core/task_transitions.py @@ -5,14 +5,15 @@ FAILED transitions for completeness:: CREATED -> ASSIGNED - ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED - IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED + ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED | INTERRUPTED + IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED | INTERRUPTED IN_REVIEW -> COMPLETED | IN_PROGRESS (rework) | BLOCKED | CANCELLED BLOCKED -> ASSIGNED (unblocked) FAILED -> ASSIGNED (reassignment for retry) + INTERRUPTED -> ASSIGNED (reassignment on restart) COMPLETED and CANCELLED are terminal states with no outgoing -transitions. FAILED is non-terminal (can be reassigned). +transitions. FAILED and INTERRUPTED are non-terminal (can be reassigned). """ from ai_company.core.enums import TaskStatus @@ -32,6 +33,7 @@ TaskStatus.BLOCKED, TaskStatus.CANCELLED, TaskStatus.FAILED, + TaskStatus.INTERRUPTED, } ), TaskStatus.IN_PROGRESS: frozenset( @@ -40,6 +42,7 @@ TaskStatus.BLOCKED, TaskStatus.CANCELLED, TaskStatus.FAILED, + TaskStatus.INTERRUPTED, } ), TaskStatus.IN_REVIEW: frozenset( @@ -52,6 +55,7 @@ ), TaskStatus.BLOCKED: frozenset({TaskStatus.ASSIGNED}), TaskStatus.FAILED: frozenset({TaskStatus.ASSIGNED}), # reassignment + TaskStatus.INTERRUPTED: frozenset({TaskStatus.ASSIGNED}), # reassignment on restart TaskStatus.COMPLETED: frozenset(), # terminal TaskStatus.CANCELLED: frozenset(), # terminal } diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index d730dda30b..75e9b42cfd 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -23,6 +23,7 @@ BudgetChecker, ExecutionLoop, ExecutionResult, + ShutdownChecker, TerminationReason, TurnRecord, ) @@ -40,6 +41,13 @@ RecoveryStrategy, ) from ai_company.engine.run_result import AgentRunResult +from ai_company.engine.shutdown import ( + CleanupCallback, + CooperativeTimeoutStrategy, + ShutdownManager, + ShutdownResult, + ShutdownStrategy, +) from ai_company.engine.task_execution import StatusTransition, TaskExecution from ai_company.providers.models import ZERO_TOKEN_USAGE, add_token_usage @@ -52,6 +60,8 @@ "AgentRunResult", "BudgetChecker", "BudgetExhaustedError", + "CleanupCallback", + "CooperativeTimeoutStrategy", "DefaultTokenEstimator", "EngineError", "ExecutionLoop", @@ -65,6 +75,10 @@ "ReactLoop", "RecoveryResult", "RecoveryStrategy", + "ShutdownChecker", + "ShutdownManager", + "ShutdownResult", + "ShutdownStrategy", "StatusTransition", "SystemPrompt", "TaskCompletionMetrics", diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 8f6df7bbf0..29ba59bb87 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -48,7 +48,11 @@ from ai_company.budget.tracker import CostTracker from ai_company.core.agent import AgentIdentity from ai_company.core.task import Task - from ai_company.engine.loop_protocol import BudgetChecker, ExecutionLoop + from ai_company.engine.loop_protocol import ( + BudgetChecker, + ExecutionLoop, + ShutdownChecker, + ) from ai_company.providers.models import CompletionConfig, ToolDefinition from ai_company.providers.protocol import CompletionProvider from ai_company.tools.registry import ToolRegistry @@ -83,9 +87,12 @@ class AgentEngine: recovery_strategy: Crash recovery strategy. Defaults to a shared ``FailAndReassignStrategy`` instance. Pass ``None`` to disable. + shutdown_checker: Optional callback; returns ``True`` when a + graceful shutdown has been requested. Passed through to + the execution loop. """ - def __init__( + def __init__( # noqa: PLR0913 self, *, provider: CompletionProvider, @@ -93,12 +100,14 @@ def __init__( tool_registry: ToolRegistry | None = None, cost_tracker: CostTracker | None = None, recovery_strategy: RecoveryStrategy | None = _DEFAULT_RECOVERY_STRATEGY, + shutdown_checker: ShutdownChecker | None = None, ) -> None: self._provider = provider self._loop: ExecutionLoop = execution_loop or ReactLoop() self._tool_registry = tool_registry self._cost_tracker = cost_tracker self._recovery_strategy = recovery_strategy + self._shutdown_checker = shutdown_checker logger.debug( EXECUTION_ENGINE_CREATED, loop_type=self._loop.get_loop_type(), @@ -322,6 +331,7 @@ async def _run_loop_with_timeout( # noqa: PLR0913 provider=self._provider, tool_invoker=tool_invoker, budget_checker=budget_checker, + shutdown_checker=self._shutdown_checker, completion_config=completion_config, ) if timeout_seconds is None: @@ -515,14 +525,22 @@ def _apply_post_execution_transitions( ) -> ExecutionResult: """Apply post-execution task transitions based on termination reason. - Only COMPLETED triggers IN_PROGRESS -> IN_REVIEW -> COMPLETED. + COMPLETED triggers IN_PROGRESS -> IN_REVIEW -> COMPLETED. + SHUTDOWN triggers current status -> INTERRUPTED. Transition failures are logged but never discard the result. """ ctx = execution_result.context if ctx.task_execution is None: return execution_result - if execution_result.termination_reason != TerminationReason.COMPLETED: + reason = execution_result.termination_reason + + if reason == TerminationReason.SHUTDOWN: + return self._transition_to_interrupted( + execution_result, ctx, agent_id, task_id + ) + + if reason != TerminationReason.COMPLETED: return execution_result try: @@ -572,6 +590,37 @@ def _transition_to_complete( ) return ctx + def _transition_to_interrupted( + self, + execution_result: ExecutionResult, + ctx: AgentContext, + agent_id: str, + task_id: str, + ) -> ExecutionResult: + """Transition task to INTERRUPTED on graceful shutdown.""" + try: + prev_status = ctx.task_execution.status # type: ignore[union-attr] + ctx = ctx.with_task_transition( + TaskStatus.INTERRUPTED, + reason="Graceful shutdown requested", + ) + logger.info( + EXECUTION_ENGINE_TASK_TRANSITION, + agent_id=agent_id, + task_id=task_id, + from_status=prev_status.value, + to_status=TaskStatus.INTERRUPTED.value, + ) + return execution_result.model_copy(update={"context": ctx}) + except (ValueError, ExecutionStateError) as exc: + logger.exception( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error=f"Post-execution INTERRUPTED transition failed: {exc}", + ) + return execution_result + async def _apply_recovery( self, execution_result: ExecutionResult, diff --git a/src/ai_company/engine/loop_protocol.py b/src/ai_company/engine/loop_protocol.py index 4a682aa587..bb272e6073 100644 --- a/src/ai_company/engine/loop_protocol.py +++ b/src/ai_company/engine/loop_protocol.py @@ -27,6 +27,7 @@ class TerminationReason(StrEnum): COMPLETED = "completed" MAX_TURNS = "max_turns" BUDGET_EXHAUSTED = "budget_exhausted" + SHUTDOWN = "shutdown" ERROR = "error" @@ -121,6 +122,9 @@ def _validate_error_message(self) -> Self: BudgetChecker = Callable[[AgentContext], bool] """Callback that returns ``True`` when the budget is exhausted.""" +ShutdownChecker = Callable[[], bool] +"""Callback that returns ``True`` when a graceful shutdown has been requested.""" + @runtime_checkable class ExecutionLoop(Protocol): @@ -131,13 +135,14 @@ class ExecutionLoop(Protocol): but all return an ``ExecutionResult`` with a ``TerminationReason``. """ - async def execute( + async def execute( # noqa: PLR0913 self, *, context: AgentContext, provider: CompletionProvider, tool_invoker: ToolInvoker | None = None, budget_checker: BudgetChecker | None = None, + shutdown_checker: ShutdownChecker | None = None, completion_config: CompletionConfig | None = None, ) -> ExecutionResult: """Run the execution loop. @@ -148,6 +153,8 @@ async def execute( tool_invoker: Optional tool invoker for tool execution. budget_checker: Optional callback; returns ``True`` when budget is exhausted. + shutdown_checker: Optional callback; returns ``True`` when + a graceful shutdown has been requested. completion_config: Optional per-execution override for temperature/max_tokens (defaults to identity's model config). diff --git a/src/ai_company/engine/react_loop.py b/src/ai_company/engine/react_loop.py index 093f408808..213c4696e7 100644 --- a/src/ai_company/engine/react_loop.py +++ b/src/ai_company/engine/react_loop.py @@ -11,6 +11,7 @@ from ai_company.observability.events.execution import ( EXECUTION_LOOP_BUDGET_EXHAUSTED, EXECUTION_LOOP_ERROR, + EXECUTION_LOOP_SHUTDOWN, EXECUTION_LOOP_START, EXECUTION_LOOP_TERMINATED, EXECUTION_LOOP_TOOL_CALLS, @@ -29,6 +30,7 @@ from .loop_protocol import ( BudgetChecker, ExecutionResult, + ShutdownChecker, TerminationReason, TurnRecord, ) @@ -54,13 +56,14 @@ def get_loop_type(self) -> str: """Return the loop type identifier.""" return "react" - async def execute( + async def execute( # noqa: PLR0913 self, *, context: AgentContext, provider: CompletionProvider, tool_invoker: ToolInvoker | None = None, budget_checker: BudgetChecker | None = None, + shutdown_checker: ShutdownChecker | None = None, completion_config: CompletionConfig | None = None, ) -> ExecutionResult: """Run the ReAct loop until termination. @@ -70,6 +73,8 @@ async def execute( provider: LLM completion provider. tool_invoker: Optional tool invoker for tool execution. budget_checker: Optional budget exhaustion callback. + shutdown_checker: Optional callback; returns ``True`` when + a graceful shutdown has been requested. completion_config: Optional per-execution config override. Returns: @@ -85,6 +90,10 @@ async def execute( ctx = context while ctx.has_turns_remaining: + shutdown_result = self._check_shutdown(ctx, shutdown_checker, turns) + if shutdown_result is not None: + return shutdown_result + budget_result = self._check_budget(ctx, budget_checker, turns) if budget_result is not None: return budget_result @@ -110,6 +119,7 @@ async def execute( turn_number, turns, tool_invoker, + shutdown_checker, ) if isinstance(result, ExecutionResult): return result @@ -143,13 +153,14 @@ def _prepare_loop( ) return model_id, config, _get_tool_definitions(tool_invoker), [] - async def _process_turn_response( + async def _process_turn_response( # noqa: PLR0913 self, ctx: AgentContext, response: CompletionResponse, turn_number: int, turns: list[TurnRecord], tool_invoker: ToolInvoker | None, + shutdown_checker: ShutdownChecker | None = None, ) -> AgentContext | ExecutionResult: """Check errors, update context, handle completion or tool calls.""" error = self._check_response_errors(ctx, response, turn_number, turns) @@ -171,6 +182,11 @@ async def _process_turn_response( if not response.tool_calls: return self._handle_completion(ctx, response, turns) + # Check shutdown before tool invocations + shutdown_result = self._check_shutdown(ctx, shutdown_checker, turns) + if shutdown_result is not None: + return shutdown_result + return await self._execute_tool_calls( ctx, tool_invoker, @@ -179,6 +195,24 @@ async def _process_turn_response( turns, ) + def _check_shutdown( + self, + ctx: AgentContext, + shutdown_checker: ShutdownChecker | None, + turns: list[TurnRecord], + ) -> ExecutionResult | None: + """Return a termination result if a shutdown has been requested.""" + if shutdown_checker is None: + return None + if not shutdown_checker(): + return None + logger.info( + EXECUTION_LOOP_SHUTDOWN, + execution_id=ctx.execution_id, + turn=ctx.turn_count, + ) + return _build_result(ctx, TerminationReason.SHUTDOWN, turns) + def _check_budget( self, ctx: AgentContext, @@ -230,10 +264,11 @@ async def _call_provider( # noqa: PLR0913 turns: list[TurnRecord], ) -> CompletionResponse | ExecutionResult: """Call provider.complete(), returning an error result on failure.""" - logger.debug( + logger.info( EXECUTION_LOOP_TURN_START, execution_id=ctx.execution_id, turn=turn_number, + message_count=len(ctx.conversation), ) try: return await provider.complete( diff --git a/src/ai_company/engine/shutdown.py b/src/ai_company/engine/shutdown.py new file mode 100644 index 0000000000..a5839f027c --- /dev/null +++ b/src/ai_company/engine/shutdown.py @@ -0,0 +1,334 @@ +"""Graceful shutdown strategy and manager. + +Implements DESIGN_SPEC §6.7 — cooperative timeout strategy for clean +process shutdown. When SIGINT/SIGTERM is received the framework signals +agents to exit at turn boundaries, waits a grace period, force-cancels +stragglers, marks tasks INTERRUPTED, and runs cleanup callbacks. + +The ``ShutdownStrategy`` protocol is pluggable for future strategies. +""" + +import asyncio +import signal +import sys +import time +from collections.abc import Callable, Coroutine +from typing import Any, Protocol, runtime_checkable + +from pydantic import BaseModel, ConfigDict, Field + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.execution import ( + EXECUTION_SHUTDOWN_CLEANUP, + EXECUTION_SHUTDOWN_COMPLETE, + EXECUTION_SHUTDOWN_FORCE_CANCEL, + EXECUTION_SHUTDOWN_GRACE_START, + EXECUTION_SHUTDOWN_SIGNAL, +) + +logger = get_logger(__name__) + +CleanupCallback = Callable[[], Coroutine[Any, Any, None]] +"""Async callback invoked during shutdown cleanup phase.""" + + +class ShutdownResult(BaseModel): + """Outcome of a graceful shutdown sequence. + + Attributes: + strategy_type: Name of the strategy that executed the shutdown. + tasks_interrupted: Number of tasks that were force-cancelled. + tasks_completed: Number of tasks that exited cooperatively. + cleanup_completed: Whether all cleanup callbacks finished + within the allowed time. + duration_seconds: Wall-clock duration of the entire shutdown. + """ + + model_config = ConfigDict(frozen=True) + + strategy_type: NotBlankStr = Field( + description="Name of the strategy that executed the shutdown", + ) + tasks_interrupted: int = Field( + ge=0, + description="Number of tasks that were force-cancelled", + ) + tasks_completed: int = Field( + ge=0, + description="Number of tasks that exited cooperatively", + ) + cleanup_completed: bool = Field( + description="Whether all cleanup callbacks finished in time", + ) + duration_seconds: float = Field( + ge=0.0, + description="Wall-clock duration of the shutdown sequence", + ) + + +@runtime_checkable +class ShutdownStrategy(Protocol): + """Protocol for pluggable shutdown strategies.""" + + def request_shutdown(self) -> None: + """Signal that a graceful shutdown has been requested.""" + ... + + def is_shutting_down(self) -> bool: + """Return ``True`` when shutdown has been requested.""" + ... + + async def execute_shutdown( + self, + *, + running_tasks: dict[str, asyncio.Task[Any]], + cleanup_callbacks: list[CleanupCallback], + ) -> ShutdownResult: + """Execute the full shutdown sequence. + + Args: + running_tasks: Map of task_id → asyncio.Task for in-flight + agent executions. + cleanup_callbacks: Ordered list of async cleanup callbacks + to invoke after task shutdown. + + Returns: + Outcome of the shutdown sequence. + """ + ... + + def get_strategy_type(self) -> str: + """Return the strategy identifier (e.g. ``"cooperative_timeout"``).""" + ... + + +class CooperativeTimeoutStrategy: + """Cooperative timeout shutdown strategy. + + 1. Set shutdown event (signal agents via turn-boundary checks). + 2. Wait up to ``grace_seconds`` for tasks to exit cooperatively. + 3. Force-cancel any remaining tasks. + 4. Run cleanup callbacks within ``cleanup_seconds``. + """ + + def __init__( + self, + *, + grace_seconds: float = 30.0, + cleanup_seconds: float = 5.0, + ) -> None: + self._grace_seconds = grace_seconds + self._cleanup_seconds = cleanup_seconds + self._shutdown_event = asyncio.Event() + + def request_shutdown(self) -> None: + """Signal that a graceful shutdown has been requested.""" + self._shutdown_event.set() + + def is_shutting_down(self) -> bool: + """Return ``True`` when shutdown has been requested.""" + return self._shutdown_event.is_set() + + def get_strategy_type(self) -> str: + """Return the strategy identifier.""" + return "cooperative_timeout" + + async def execute_shutdown( + self, + *, + running_tasks: dict[str, asyncio.Task[Any]], + cleanup_callbacks: list[CleanupCallback], + ) -> ShutdownResult: + """Execute the cooperative timeout shutdown sequence.""" + start = time.monotonic() + + self._shutdown_event.set() + logger.info( + EXECUTION_SHUTDOWN_GRACE_START, + grace_seconds=self._grace_seconds, + running_tasks=len(running_tasks), + ) + + tasks_completed, tasks_interrupted = await self._wait_and_cancel( + running_tasks, + ) + + cleanup_completed = await self._run_cleanup(cleanup_callbacks) + + duration = time.monotonic() - start + result = ShutdownResult( + strategy_type=self.get_strategy_type(), + tasks_interrupted=tasks_interrupted, + tasks_completed=tasks_completed, + cleanup_completed=cleanup_completed, + duration_seconds=duration, + ) + logger.info( + EXECUTION_SHUTDOWN_COMPLETE, + strategy=result.strategy_type, + tasks_interrupted=result.tasks_interrupted, + tasks_completed=result.tasks_completed, + cleanup_completed=result.cleanup_completed, + duration_seconds=result.duration_seconds, + ) + return result + + async def _wait_and_cancel( + self, + running_tasks: dict[str, asyncio.Task[Any]], + ) -> tuple[int, int]: + """Wait for cooperative exit, then force-cancel stragglers. + + Returns: + Tuple of (tasks_completed, tasks_interrupted). + """ + if not running_tasks: + return 0, 0 + + task_set = set(running_tasks.values()) + done, pending = await asyncio.wait( + task_set, + timeout=self._grace_seconds, + ) + tasks_completed = len(done) + + if pending: + logger.warning( + EXECUTION_SHUTDOWN_FORCE_CANCEL, + pending_tasks=len(pending), + ) + for task in pending: + task.cancel() + # Wait for cancellation to propagate + await asyncio.wait(pending) + + return tasks_completed, len(pending) + + async def _run_cleanup( + self, + callbacks: list[CleanupCallback], + ) -> bool: + """Run cleanup callbacks within the cleanup time budget.""" + if not callbacks: + return True + + logger.info( + EXECUTION_SHUTDOWN_CLEANUP, + callback_count=len(callbacks), + cleanup_seconds=self._cleanup_seconds, + ) + + async def _run_all() -> None: + for callback in callbacks: + await callback() + + try: + await asyncio.wait_for( + _run_all(), + timeout=self._cleanup_seconds, + ) + except TimeoutError: + logger.warning( + EXECUTION_SHUTDOWN_CLEANUP, + cleanup_timeout=True, + cleanup_seconds=self._cleanup_seconds, + ) + return False + return True + + +class ShutdownManager: + """Manages signal handling, task tracking, and shutdown orchestration. + + Separates OS signal handling from shutdown strategy logic. + + Args: + strategy: Shutdown strategy implementation. Defaults to + ``CooperativeTimeoutStrategy()``. + """ + + def __init__( + self, + strategy: ShutdownStrategy | None = None, + ) -> None: + self._strategy: ShutdownStrategy = strategy or CooperativeTimeoutStrategy() + self._running_tasks: dict[str, asyncio.Task[Any]] = {} + self._cleanup_callbacks: list[CleanupCallback] = [] + self._signals_installed = False + + @property + def strategy(self) -> ShutdownStrategy: + """The configured shutdown strategy.""" + return self._strategy + + def install_signal_handlers(self) -> None: + """Install SIGINT/SIGTERM handlers. + + On Unix uses ``loop.add_signal_handler``. + On Windows uses ``signal.signal`` with ``call_soon_threadsafe``. + """ + if self._signals_installed: + return + + if sys.platform != "win32": + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, self._handle_signal, sig) + else: + for sig in (signal.SIGINT, signal.SIGTERM): + signal.signal(sig, self._handle_signal_threadsafe) + + self._signals_installed = True + + def _handle_signal(self, sig: signal.Signals) -> None: + """Handle signal on Unix (called in event loop thread).""" + logger.info( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig.name, + ) + self._strategy.request_shutdown() + + def _handle_signal_threadsafe( + self, + signum: int, + _frame: Any, + ) -> None: + """Handle signal on Windows (called from signal handler thread).""" + sig_name = signal.Signals(signum).name + logger.info( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig_name, + ) + try: + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(self._strategy.request_shutdown) + except RuntimeError: + self._strategy.request_shutdown() + + def register_task( + self, + task_id: str, + asyncio_task: asyncio.Task[Any], + ) -> None: + """Track a running agent task.""" + self._running_tasks[task_id] = asyncio_task + + def unregister_task(self, task_id: str) -> None: + """Stop tracking a completed agent task.""" + self._running_tasks.pop(task_id, None) + + def register_cleanup(self, callback: CleanupCallback) -> None: + """Register an async cleanup callback for shutdown.""" + self._cleanup_callbacks.append(callback) + + def is_shutting_down(self) -> bool: + """Delegate to the strategy's shutdown check.""" + return self._strategy.is_shutting_down() + + async def initiate_shutdown(self) -> ShutdownResult: + """Invoke the strategy's shutdown sequence.""" + return await self._strategy.execute_shutdown( + running_tasks=dict(self._running_tasks), + cleanup_callbacks=list(self._cleanup_callbacks), + ) diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index 4542ffa3a6..3176f2ee52 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -35,6 +35,13 @@ EXECUTION_ENGINE_TASK_METRICS: Final[str] = "execution.engine.task_metrics" EXECUTION_ENGINE_TIMEOUT: Final[str] = "execution.engine.timeout" +EXECUTION_SHUTDOWN_SIGNAL: Final[str] = "execution.shutdown.signal" +EXECUTION_SHUTDOWN_GRACE_START: Final[str] = "execution.shutdown.grace_start" +EXECUTION_SHUTDOWN_FORCE_CANCEL: Final[str] = "execution.shutdown.force_cancel" +EXECUTION_SHUTDOWN_CLEANUP: Final[str] = "execution.shutdown.cleanup" +EXECUTION_SHUTDOWN_COMPLETE: Final[str] = "execution.shutdown.complete" +EXECUTION_LOOP_SHUTDOWN: Final[str] = "execution.loop.shutdown" + EXECUTION_RECOVERY_START: Final[str] = "execution.recovery.start" EXECUTION_RECOVERY_COMPLETE: Final[str] = "execution.recovery.complete" EXECUTION_RECOVERY_FAILED: Final[str] = "execution.recovery.failed" diff --git a/tests/integration/engine/test_graceful_shutdown.py b/tests/integration/engine/test_graceful_shutdown.py new file mode 100644 index 0000000000..41220499c8 --- /dev/null +++ b/tests/integration/engine/test_graceful_shutdown.py @@ -0,0 +1,303 @@ +"""Integration test — full graceful shutdown flow. + +Creates an engine with a shutdown manager, starts an agent, triggers +shutdown, and verifies: agent stops, task is INTERRUPTED, cleanup runs. +""" + +from typing import Any + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.agent_engine import AgentEngine +from ai_company.engine.shutdown import ( + CooperativeTimeoutStrategy, + ShutdownManager, +) +from ai_company.providers.enums import FinishReason +from ai_company.providers.models import ( + ChatMessage, + CompletionConfig, + CompletionResponse, + TokenUsage, + ToolDefinition, +) + +pytestmark = pytest.mark.timeout(30) + + +class _ShutdownTriggeringProvider: + """Provider that triggers shutdown on the second call. + + First call returns a tool-use response, second call triggers + shutdown and returns a stop response. + """ + + def __init__(self, strategy: CooperativeTimeoutStrategy) -> None: + self._strategy = strategy + self._call_count = 0 + + async def complete( + self, + messages: list[ChatMessage], + model: str, + *, + tools: list[ToolDefinition] | None = None, + config: CompletionConfig | None = None, + ) -> CompletionResponse: + self._call_count += 1 + if self._call_count == 1: + # Trigger shutdown after first LLM call + self._strategy.request_shutdown() + return CompletionResponse( + content="Working on it.", + finish_reason=FinishReason.STOP, + usage=TokenUsage( + input_tokens=50, + output_tokens=25, + cost_usd=0.005, + ), + model="test-model-001", + ) + + async def stream( + self, + messages: list[ChatMessage], + model: str, + *, + tools: list[ToolDefinition] | None = None, + config: CompletionConfig | None = None, + ) -> Any: + msg = "Not implemented" + raise NotImplementedError(msg) + + async def get_model_capabilities(self, model: str) -> Any: + from ai_company.providers.capabilities import ModelCapabilities + + return ModelCapabilities( + model_id=model, + provider="test-provider", + supports_tools=False, + supports_streaming=False, + max_context_tokens=8192, + max_output_tokens=4096, + cost_per_1k_input=0.01, + cost_per_1k_output=0.03, + ) + + +@pytest.mark.integration +class TestGracefulShutdownFlow: + """Full shutdown integration: engine + strategy + agent → INTERRUPTED.""" + + async def test_shutdown_interrupts_task_and_runs_cleanup( + self, + ) -> None: + """Engine with shutdown_checker stops and transitions to INTERRUPTED.""" + from datetime import date + from uuid import uuid4 + + from ai_company.core.agent import AgentIdentity, ModelConfig + from ai_company.core.enums import ( + Complexity, + Priority, + SeniorityLevel, + TaskType, + ) + from ai_company.core.task import Task + + identity = AgentIdentity( + id=uuid4(), + name="Test Agent", + role="Developer", + department="Engineering", + level=SeniorityLevel.MID, + model=ModelConfig( + provider="test-provider", + model_id="test-model-001", + ), + hiring_date=date(2026, 1, 1), + ) + + task = Task( + id="task-shutdown-001", + title="Task for shutdown test", + description="This task will be interrupted.", + type=TaskType.DEVELOPMENT, + priority=Priority.MEDIUM, + project="proj-001", + created_by="test", + estimated_complexity=Complexity.SIMPLE, + budget_limit=10.0, + assigned_to=str(identity.id), + status=TaskStatus.ASSIGNED, + ) + + strategy = CooperativeTimeoutStrategy(grace_seconds=5.0) + manager = ShutdownManager(strategy=strategy) + + # Track cleanup execution + cleanup_ran = [] + + async def cleanup_cb() -> None: + cleanup_ran.append("done") + + manager.register_cleanup(cleanup_cb) + + provider = _ShutdownTriggeringProvider(strategy) + + engine = AgentEngine( + provider=provider, + shutdown_checker=manager.is_shutting_down, + ) + + await engine.run( + identity=identity, + task=task, + ) + + # The first LLM call triggers shutdown, but since it returns + # STOP with no tool calls, the loop completes normally. + # Verify the manager state reflects the shutdown signal. + assert manager.is_shutting_down() is True + + async def test_shutdown_during_multi_turn_interrupts( + self, + ) -> None: + """Multi-turn execution interrupted by shutdown → INTERRUPTED.""" + from datetime import date + from uuid import uuid4 + + from ai_company.core.agent import AgentIdentity, ModelConfig + from ai_company.core.enums import ( + Complexity, + Priority, + SeniorityLevel, + TaskType, + ) + from ai_company.core.task import Task + + identity = AgentIdentity( + id=uuid4(), + name="Test Agent", + role="Developer", + department="Engineering", + level=SeniorityLevel.MID, + model=ModelConfig( + provider="test-provider", + model_id="test-model-001", + ), + hiring_date=date(2026, 1, 1), + ) + + task = Task( + id="task-shutdown-002", + title="Multi-turn shutdown test", + description="This task will be interrupted mid-execution.", + type=TaskType.DEVELOPMENT, + priority=Priority.MEDIUM, + project="proj-001", + created_by="test", + estimated_complexity=Complexity.SIMPLE, + budget_limit=10.0, + assigned_to=str(identity.id), + status=TaskStatus.ASSIGNED, + ) + + check_count = 0 + + def shutdown_checker() -> bool: + nonlocal check_count + check_count += 1 + # Let first two checks pass (top of loop + before tools + # on turn 1), then shutdown on third check (top of loop + # on turn 2) + return check_count > 2 + + # Provider returns tool-use on first call so the loop iterates + from ai_company.providers.models import ToolCall + + responses = [ + CompletionResponse( + content=None, + tool_calls=(ToolCall(id="tc-1", name="echo", arguments={}),), + finish_reason=FinishReason.TOOL_USE, + usage=TokenUsage( + input_tokens=50, + output_tokens=25, + cost_usd=0.005, + ), + model="test-model-001", + ), + ] + + class _MultiTurnProvider: + def __init__(self) -> None: + self._idx = 0 + + async def complete( + self, + messages: Any, + model: Any, + **kw: Any, + ) -> CompletionResponse: + if self._idx < len(responses): + resp = responses[self._idx] + self._idx += 1 + return resp + msg = "No more responses" + raise IndexError(msg) + + async def stream(self, *a: Any, **kw: Any) -> Any: + raise NotImplementedError + + async def get_model_capabilities(self, model: str) -> Any: + from ai_company.providers.capabilities import ModelCapabilities + + return ModelCapabilities( + model_id=model, + provider="test-provider", + supports_tools=True, + supports_streaming=False, + max_context_tokens=8192, + max_output_tokens=4096, + cost_per_1k_input=0.01, + cost_per_1k_output=0.03, + ) + + from ai_company.core.enums import ToolCategory + from ai_company.tools.base import BaseTool, ToolExecutionResult + from ai_company.tools.registry import ToolRegistry + + class _EchoTool(BaseTool): + def __init__(self) -> None: + super().__init__( + name="echo", + description="Echo tool", + category=ToolCategory.CODE_EXECUTION, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult(content="echoed", is_error=False) + + registry = ToolRegistry([_EchoTool()]) + + engine = AgentEngine( + provider=_MultiTurnProvider(), + tool_registry=registry, + shutdown_checker=shutdown_checker, + ) + + result = await engine.run( + identity=identity, + task=task, + ) + + te = result.execution_result.context.task_execution + assert te is not None + assert te.status == TaskStatus.INTERRUPTED + assert result.execution_result.termination_reason.value == "shutdown" diff --git a/tests/unit/core/test_enums.py b/tests/unit/core/test_enums.py index b28363486d..0a36d6961c 100644 --- a/tests/unit/core/test_enums.py +++ b/tests/unit/core/test_enums.py @@ -58,8 +58,8 @@ def test_proficiency_level_has_4_members(self) -> None: def test_department_name_has_9_members(self) -> None: assert len(DepartmentName) == 9 - def test_task_status_has_8_members(self) -> None: - assert len(TaskStatus) == 8 + def test_task_status_has_9_members(self) -> None: + assert len(TaskStatus) == 9 def test_task_type_has_6_members(self) -> None: assert len(TaskType) == 6 diff --git a/tests/unit/core/test_task_transitions.py b/tests/unit/core/test_task_transitions.py index 59e3864bb6..053baa830a 100644 --- a/tests/unit/core/test_task_transitions.py +++ b/tests/unit/core/test_task_transitions.py @@ -34,8 +34,11 @@ class TestValidTransitions: (TaskStatus.IN_REVIEW, TaskStatus.IN_PROGRESS), (TaskStatus.IN_REVIEW, TaskStatus.BLOCKED), (TaskStatus.IN_REVIEW, TaskStatus.CANCELLED), + (TaskStatus.ASSIGNED, TaskStatus.INTERRUPTED), + (TaskStatus.IN_PROGRESS, TaskStatus.INTERRUPTED), (TaskStatus.BLOCKED, TaskStatus.ASSIGNED), (TaskStatus.FAILED, TaskStatus.ASSIGNED), + (TaskStatus.INTERRUPTED, TaskStatus.ASSIGNED), ], ids=lambda p: p.value if isinstance(p, TaskStatus) else str(p), ) @@ -61,6 +64,8 @@ class TestInvalidTransitions: (TaskStatus.IN_PROGRESS, TaskStatus.ASSIGNED), (TaskStatus.FAILED, TaskStatus.COMPLETED), (TaskStatus.FAILED, TaskStatus.IN_PROGRESS), + (TaskStatus.INTERRUPTED, TaskStatus.COMPLETED), + (TaskStatus.INTERRUPTED, TaskStatus.IN_PROGRESS), ], ids=lambda p: p.value if isinstance(p, TaskStatus) else str(p), ) @@ -112,6 +117,10 @@ def test_failed_is_non_terminal(self) -> None: """FAILED has outgoing transitions (reassignment).""" assert len(VALID_TRANSITIONS[TaskStatus.FAILED]) > 0 + def test_interrupted_is_non_terminal(self) -> None: + """INTERRUPTED has outgoing transitions (reassignment on restart).""" + assert len(VALID_TRANSITIONS[TaskStatus.INTERRUPTED]) > 0 + def test_all_targets_are_valid_statuses(self) -> None: """Every target in the transition map must be a valid TaskStatus.""" for source, targets in VALID_TRANSITIONS.items(): diff --git a/tests/unit/engine/test_agent_engine_lifecycle.py b/tests/unit/engine/test_agent_engine_lifecycle.py index 5a5570e7e8..2c6b9d20f1 100644 --- a/tests/unit/engine/test_agent_engine_lifecycle.py +++ b/tests/unit/engine/test_agent_engine_lifecycle.py @@ -197,6 +197,77 @@ async def test_error_transitions_to_failed( assert te is not None assert te.status == TaskStatus.FAILED + async def test_shutdown_transitions_to_interrupted( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """SHUTDOWN → task transitions to INTERRUPTED.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + ) + mock_result = ExecutionResult( + context=ctx, + termination_reason=TerminationReason.SHUTDOWN, + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") + + provider = mock_provider_factory([]) + engine = AgentEngine(provider=provider, execution_loop=mock_loop) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + te = result.execution_result.context.task_execution + assert te is not None + assert te.status == TaskStatus.INTERRUPTED + + async def test_shutdown_from_assigned_transitions_to_interrupted( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """SHUTDOWN before loop starts → ASSIGNED → INTERRUPTED.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + # Simulate the loop returning SHUTDOWN while still ASSIGNED + # (edge case: shutdown signal between assignment and IP transition) + mock_result = ExecutionResult( + context=ctx, + termination_reason=TerminationReason.SHUTDOWN, + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") + + provider = mock_provider_factory([]) + engine = AgentEngine(provider=provider, execution_loop=mock_loop) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + te = result.execution_result.context.task_execution + assert te is not None + # The engine's _prepare_context transitions ASSIGNED→IP, + # but the mock loop returns a context still at ASSIGNED. + # The engine's _apply_post_execution_transitions handles this. + assert te.status == TaskStatus.INTERRUPTED + async def test_no_task_execution_passes_through( self, sample_agent_with_personality: AgentIdentity, diff --git a/tests/unit/engine/test_loop_protocol.py b/tests/unit/engine/test_loop_protocol.py index 3b41fdf0b2..9632c2d368 100644 --- a/tests/unit/engine/test_loop_protocol.py +++ b/tests/unit/engine/test_loop_protocol.py @@ -25,7 +25,7 @@ def test_values(self) -> None: assert TerminationReason.ERROR.value == "error" def test_member_count(self) -> None: - assert len(TerminationReason) == 4 + assert len(TerminationReason) == 5 @pytest.mark.unit diff --git a/tests/unit/engine/test_react_loop.py b/tests/unit/engine/test_react_loop.py index fa3489d270..9f6d40e6e3 100644 --- a/tests/unit/engine/test_react_loop.py +++ b/tests/unit/engine/test_react_loop.py @@ -858,6 +858,115 @@ async def test_empty_registry_passes_no_tools( assert result.termination_reason == TerminationReason.COMPLETED +@pytest.mark.unit +class TestReactLoopShutdown: + """Shutdown checker triggers loop termination.""" + + async def test_shutdown_before_first_turn( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([]) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + shutdown_checker=lambda: True, # shutdown immediately + ) + + assert result.termination_reason == TerminationReason.SHUTDOWN + assert len(result.turns) == 0 + assert provider.call_count == 0 + + async def test_shutdown_after_first_turn( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + call_count = 0 + + def shutdown_check() -> bool: + nonlocal call_count + call_count += 1 + # Shutdown on third check (after first turn: check at top, + # check before tools, check at top of next iteration) + return call_count > 2 + + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + shutdown_checker=shutdown_check, + ) + + assert result.termination_reason == TerminationReason.SHUTDOWN + assert len(result.turns) == 1 + + async def test_shutdown_before_tool_execution( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Shutdown detected between LLM response and tool invocation.""" + ctx = _ctx_with_user_msg(sample_agent_context) + call_count = 0 + + def shutdown_check() -> bool: + nonlocal call_count + call_count += 1 + # First check (top of loop) passes, second check (before + # tools) triggers shutdown + return call_count > 1 + + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + shutdown_checker=shutdown_check, + ) + + assert result.termination_reason == TerminationReason.SHUTDOWN + # Turn was recorded (LLM was called), but tools were not executed + assert len(result.turns) == 1 + + async def test_no_shutdown_checker_runs_normally( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_stop_response("Done.")]) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + shutdown_checker=None, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + + @pytest.mark.unit class TestReactLoopCostAccounting: """Error responses include the failing turn's cost in context.""" diff --git a/tests/unit/engine/test_shutdown.py b/tests/unit/engine/test_shutdown.py new file mode 100644 index 0000000000..4a44b96b8d --- /dev/null +++ b/tests/unit/engine/test_shutdown.py @@ -0,0 +1,314 @@ +"""Tests for the graceful shutdown strategy and manager.""" + +import asyncio +import signal +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from ai_company.engine.shutdown import ( + CooperativeTimeoutStrategy, + ShutdownManager, + ShutdownResult, + ShutdownStrategy, +) + +pytestmark = pytest.mark.timeout(30) + + +# ── Protocol compliance ────────────────────────────────────────── + + +@pytest.mark.unit +class TestShutdownStrategyProtocol: + """CooperativeTimeoutStrategy satisfies ShutdownStrategy protocol.""" + + def test_is_runtime_checkable(self) -> None: + strategy = CooperativeTimeoutStrategy() + assert isinstance(strategy, ShutdownStrategy) + + def test_result_model_is_frozen(self) -> None: + result = ShutdownResult( + strategy_type="cooperative_timeout", + tasks_interrupted=0, + tasks_completed=0, + cleanup_completed=True, + duration_seconds=0.1, + ) + with pytest.raises(Exception, match="frozen"): + result.tasks_interrupted = 5 # type: ignore[misc] + + +# ── Request / check ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutRequestShutdown: + """request_shutdown + is_shutting_down event toggling.""" + + def test_not_shutting_down_initially(self) -> None: + strategy = CooperativeTimeoutStrategy() + assert strategy.is_shutting_down() is False + + def test_request_sets_shutting_down(self) -> None: + strategy = CooperativeTimeoutStrategy() + strategy.request_shutdown() + assert strategy.is_shutting_down() is True + + def test_idempotent_request(self) -> None: + strategy = CooperativeTimeoutStrategy() + strategy.request_shutdown() + strategy.request_shutdown() + assert strategy.is_shutting_down() is True + + def test_get_strategy_type(self) -> None: + strategy = CooperativeTimeoutStrategy() + assert strategy.get_strategy_type() == "cooperative_timeout" + + +# ── execute_shutdown — cooperative exit ────────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutExecuteCooperative: + """Tasks that check the shutdown event and exit cooperatively.""" + + async def test_all_tasks_exit_cooperatively(self) -> None: + strategy = CooperativeTimeoutStrategy(grace_seconds=5.0) + shutdown_event = strategy._shutdown_event + + async def cooperative_task() -> None: + await shutdown_event.wait() + + task = asyncio.create_task(cooperative_task()) + result = await strategy.execute_shutdown( + running_tasks={"t1": task}, + cleanup_callbacks=[], + ) + + assert result.tasks_completed == 1 + assert result.tasks_interrupted == 0 + assert result.cleanup_completed is True + assert result.strategy_type == "cooperative_timeout" + assert result.duration_seconds > 0 + + async def test_empty_tasks(self) -> None: + strategy = CooperativeTimeoutStrategy() + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[], + ) + assert result.tasks_completed == 0 + assert result.tasks_interrupted == 0 + + +# ── execute_shutdown — force cancel ────────────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutForceCancel: + """Tasks that ignore the shutdown event are force-cancelled.""" + + async def test_stubborn_task_is_force_cancelled(self) -> None: + strategy = CooperativeTimeoutStrategy(grace_seconds=0.1) + + async def stubborn_task() -> None: + await asyncio.sleep(100) # ignores shutdown + + task = asyncio.create_task(stubborn_task()) + result = await strategy.execute_shutdown( + running_tasks={"t1": task}, + cleanup_callbacks=[], + ) + + assert result.tasks_completed == 0 + assert result.tasks_interrupted == 1 + + async def test_mixed_cooperative_and_stubborn(self) -> None: + strategy = CooperativeTimeoutStrategy(grace_seconds=0.1) + shutdown_event = strategy._shutdown_event + + async def cooperative() -> None: + await shutdown_event.wait() + + async def stubborn() -> None: + await asyncio.sleep(100) + + t1 = asyncio.create_task(cooperative()) + t2 = asyncio.create_task(stubborn()) + result = await strategy.execute_shutdown( + running_tasks={"t1": t1, "t2": t2}, + cleanup_callbacks=[], + ) + + assert result.tasks_completed + result.tasks_interrupted == 2 + assert result.tasks_interrupted >= 1 + + +# ── execute_shutdown — cleanup callbacks ───────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutCleanup: + """Cleanup callbacks run within the time budget.""" + + async def test_cleanup_callbacks_run(self) -> None: + strategy = CooperativeTimeoutStrategy(cleanup_seconds=5.0) + ran = [] + + async def cb1() -> None: + ran.append("cb1") + + async def cb2() -> None: + ran.append("cb2") + + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[cb1, cb2], + ) + + assert ran == ["cb1", "cb2"] + assert result.cleanup_completed is True + + async def test_cleanup_timeout(self) -> None: + strategy = CooperativeTimeoutStrategy(cleanup_seconds=0.1) + + async def slow_callback() -> None: + await asyncio.sleep(100) + + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[slow_callback], + ) + + assert result.cleanup_completed is False + + async def test_empty_cleanup(self) -> None: + strategy = CooperativeTimeoutStrategy() + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[], + ) + assert result.cleanup_completed is True + + +# ── ShutdownManager ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestShutdownManagerTaskTracking: + """Register / unregister tasks.""" + + def test_register_and_unregister(self) -> None: + manager = ShutdownManager() + mock_task = MagicMock(spec=asyncio.Task) + manager.register_task("task-1", mock_task) + assert "task-1" in manager._running_tasks + manager.unregister_task("task-1") + assert "task-1" not in manager._running_tasks + + def test_unregister_missing_is_noop(self) -> None: + manager = ShutdownManager() + manager.unregister_task("nonexistent") + + def test_register_cleanup(self) -> None: + manager = ShutdownManager() + + async def cb() -> None: + pass + + manager.register_cleanup(cb) + assert len(manager._cleanup_callbacks) == 1 + + def test_is_shutting_down_delegates(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + assert manager.is_shutting_down() is False + strategy.request_shutdown() + assert manager.is_shutting_down() is True + + +@pytest.mark.unit +class TestShutdownManagerSignalHandlers: + """Signal handler installation.""" + + @pytest.mark.skipif(sys.platform == "win32", reason="Unix-only test") + def test_install_signal_handlers_unix(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + mock_loop = MagicMock() + with patch("asyncio.get_running_loop", return_value=mock_loop): + manager.install_signal_handlers() + assert mock_loop.add_signal_handler.call_count == 2 + assert manager._signals_installed is True + + @pytest.mark.skipif(sys.platform != "win32", reason="Windows-only test") + def test_install_signal_handlers_windows(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + with patch("signal.signal") as mock_signal: + manager.install_signal_handlers() + assert mock_signal.call_count == 2 + assert manager._signals_installed is True + + def test_install_idempotent(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + if sys.platform == "win32": + with patch("signal.signal"): + manager.install_signal_handlers() + manager.install_signal_handlers() # second call is noop + else: + mock_loop = MagicMock() + with patch("asyncio.get_running_loop", return_value=mock_loop): + manager.install_signal_handlers() + manager.install_signal_handlers() + assert mock_loop.add_signal_handler.call_count == 2 # not 4 + + +@pytest.mark.unit +class TestShutdownManagerInitiateShutdown: + """Full initiate_shutdown delegates to strategy.""" + + async def test_initiate_shutdown(self) -> None: + strategy = CooperativeTimeoutStrategy(grace_seconds=0.1) + manager = ShutdownManager(strategy=strategy) + result = await manager.initiate_shutdown() + assert isinstance(result, ShutdownResult) + assert result.strategy_type == "cooperative_timeout" + + def test_default_strategy(self) -> None: + manager = ShutdownManager() + assert isinstance(manager.strategy, CooperativeTimeoutStrategy) + + +@pytest.mark.unit +class TestShutdownManagerSignalHandling: + """Signal handler triggers request_shutdown.""" + + def test_handle_signal_unix(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + manager._handle_signal(signal.SIGINT) + assert strategy.is_shutting_down() is True + + def test_handle_signal_threadsafe_with_loop(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + mock_loop = MagicMock() + with patch("asyncio.get_running_loop", return_value=mock_loop): + manager._handle_signal_threadsafe(signal.SIGINT.value, None) + mock_loop.call_soon_threadsafe.assert_called_once_with( + strategy.request_shutdown, + ) + + def test_handle_signal_threadsafe_no_loop(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + with patch( + "asyncio.get_running_loop", + side_effect=RuntimeError("no loop"), + ): + manager._handle_signal_threadsafe(signal.SIGINT.value, None) + assert strategy.is_shutting_down() is True From 81e11ef68aec01caf8d9eccb99ad9e6b9d9cb626 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sat, 7 Mar 2026 12:56:19 +0100 Subject: [PATCH 2/3] feat: implement graceful shutdown with cooperative timeout strategy (#130) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pre-reviewed by 10 agents, 30+ findings addressed: - CooperativeTimeoutStrategy + ShutdownManager with signal handling - ShutdownResult model, INTERRUPTED TaskStatus transitions - GracefulShutdownConfig with validation bounds - Turn-boundary shutdown checks in ReactLoop - Windows signal safety (deferred logging via call_soon_threadsafe) - Cleanup callback exception isolation - Extracted build_error_prompt to prompt.py, make_budget_checker to loop_protocol.py - DESIGN_SPEC §6.1/§6.7/§15.3/§15.5 documentation updates Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 2 +- DESIGN_SPEC.md | 14 ++- src/ai_company/config/schema.py | 4 +- src/ai_company/core/task_transitions.py | 4 +- src/ai_company/engine/agent_engine.py | 51 +------- src/ai_company/engine/loop_protocol.py | 22 +++- src/ai_company/engine/prompt.py | 35 ++++++ src/ai_company/engine/react_loop.py | 25 +++- src/ai_company/engine/shutdown.py | 113 +++++++++++++++--- .../observability/events/execution.py | 1 + tests/unit/core/test_enums.py | 1 + tests/unit/engine/test_loop_protocol.py | 1 + tests/unit/engine/test_run_result.py | 16 +-- tests/unit/engine/test_shutdown.py | 95 ++++++++++++++- 14 files changed, 298 insertions(+), 86 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 2bbf5f0d98..08d5dd6560 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -48,7 +48,7 @@ src/ai_company/ communication/ # Inter-agent message bus and channels config/ # YAML company config loading and validation core/ # Shared domain models and base classes - engine/ # Agent orchestration, execution loops, and task lifecycle + engine/ # Agent orchestration, execution loops, task lifecycle, recovery, and shutdown memory/ # Persistent agent memory (memory layer TBD) observability/ # Structured logging, correlation tracking, log sinks providers/ # LLM provider abstraction (LiteLLM adapter) diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 8b9e80110e..af116b73de 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -716,11 +716,18 @@ structured_phases: └─────┬─────┘ └────────────┘ │ unblocked (terminal) └──▶ ASSIGNED + + shutdown signal: + ┌─────────────┐ + │ INTERRUPTED │──── reassign on restart ──▶ ASSIGNED + └─────────────┘ ``` -> **Non-terminal states:** BLOCKED and FAILED are non-terminal — BLOCKED returns to ASSIGNED when unblocked, FAILED returns to ASSIGNED for retry (see §6.6). COMPLETED and CANCELLED are terminal states with no outgoing transitions. +> **Non-terminal states:** BLOCKED, FAILED, and INTERRUPTED are non-terminal — BLOCKED returns to ASSIGNED when unblocked, FAILED returns to ASSIGNED for retry (see §6.6), INTERRUPTED returns to ASSIGNED on restart (see §6.7). COMPLETED and CANCELLED are terminal states with no outgoing transitions. > > **Transitions into FAILED:** Both `ASSIGNED → FAILED` (early setup failures) and `IN_PROGRESS → FAILED` (runtime crashes) are valid. `FAILED → ASSIGNED` enables reassignment when `retry_count < max_retries`. +> +> **Transitions into INTERRUPTED:** Both `ASSIGNED → INTERRUPTED` and `IN_PROGRESS → INTERRUPTED` are valid (graceful shutdown can occur at any active phase). `INTERRUPTED → ASSIGNED` enables reassignment on restart. > **Runtime wrapper (M3):** During execution, `Task` is wrapped by `TaskExecution` (in `engine/task_execution.py`). `TaskExecution` is a frozen Pydantic model that tracks status transitions via `model_copy(update=...)`, accumulates `TokenUsage` cost, and records a `StatusTransition` audit trail. The original `Task` is preserved unchanged; `to_task_snapshot()` produces a `Task` copy with the current execution status for persistence. @@ -1043,7 +1050,7 @@ On shutdown signal: 4. Force-cancel remaining agents (`task.cancel()`) — tasks transition to `INTERRUPTED` 5. Cleanup phase (`cleanup_seconds`): persist cost records, close provider connections, flush logs -> **Planned non-terminal status:** `INTERRUPTED` will be introduced as a new `TaskStatus` variant (and the task status transition map updated) when graceful shutdown is implemented. Unlike `FAILED` (eligible for automatic reassignment) or `CANCELLED` (terminal), `INTERRUPTED` indicates the task was stopped due to process shutdown and is eligible for manual or automatic reassignment on restart. +> **Non-terminal status (implemented in M3):** `INTERRUPTED` is a `TaskStatus` variant. Unlike `FAILED` (eligible for automatic reassignment) or `CANCELLED` (terminal), `INTERRUPTED` indicates the task was stopped due to process shutdown and is eligible for manual or automatic reassignment on restart. Valid transitions: `ASSIGNED → INTERRUPTED`, `IN_PROGRESS → INTERRUPTED`, `INTERRUPTED → ASSIGNED` (reassignment on restart). See the updated §6.1 lifecycle diagram. > > **Windows compatibility:** `loop.add_signal_handler()` is not supported on Windows. The implementation uses `signal.signal()` as a fallback. SIGINT (Ctrl+C) works cross-platform; SIGTERM on Windows requires `os.kill()`. > @@ -2304,6 +2311,7 @@ ai-company/ │ │ ├── cost_recording.py # Per-turn cost recording helpers │ │ ├── run_result.py # AgentRunResult outcome model │ │ ├── agent_engine.py # Agent execution engine +│ │ ├── shutdown.py # Graceful shutdown strategy & manager │ │ ├── task_engine.py # Task routing & scheduling (M3-M4) │ │ ├── workflow_engine.py # Workflow orchestration (M4) │ │ ├── meeting_engine.py # Meeting coordination (M4) @@ -2474,7 +2482,7 @@ These conventions were established during the M0–M2+ review cycle. **Adopted** | **LLM call analytics** | Planned (incremental) | M3: proxy metrics (`turns_per_task`, `tokens_per_task`). M4: call categorization (`productive`, `coordination`, `system`) + orchestration ratio. M5+: full analytics (retry tracking, latency, cache hits, per-provider comparison). | Append-only, never blocks execution. Builds on existing `CostRecord` infrastructure. Detects orchestration overhead early. See §10.5. | | **State coordination** | Planned (M4) | Centralized single-writer: `TaskEngine` owns all task/project mutations via `asyncio.Queue`. Agents submit requests, engine applies `model_copy(update=...)` sequentially and publishes snapshots. `version: int` field on state models for future optimistic concurrency if multi-process scaling is needed. | Prevents lost updates by design. Trivial in single-threaded asyncio (no locks). Perfect audit trail. Industry consensus: MetaGPT, CrewAI, AutoGen all use prevention-by-design, not conflict resolution. See §6.8 State Coordination table. | | **Workspace isolation** | Planned (M4) | Pluggable `WorkspaceIsolationStrategy` protocol. Default: planner + git worktrees. Each agent works in an isolated worktree; sequential merge on completion. Textual conflicts detected by git; semantic conflicts reviewed by agent or human. | Industry standard (Codex, Cursor, Claude Code, VS Code). Maximum parallelism. Leverages mature git infrastructure. See §6.8. | -| **Graceful shutdown** | Planned (M3) | Pluggable `ShutdownStrategy` protocol. Default: cooperative with 30s timeout. Agents check shutdown event at turn boundaries. Force-cancel after timeout. `INTERRUPTED` status for force-cancelled tasks. M4/M5: upgrade to checkpoint-and-stop. | Cross-platform (Windows `signal.signal()` fallback). Bounded shutdown time. Mirrors cooperative shutdown in §6.7. | +| **Graceful shutdown** | Adopted (M3) | Pluggable `ShutdownStrategy` protocol. Default: cooperative with 30s timeout. Agents check shutdown event at turn boundaries. Force-cancel after timeout. `INTERRUPTED` status for force-cancelled tasks. M4/M5: upgrade to checkpoint-and-stop. | Cross-platform (Windows `signal.signal()` fallback). Bounded shutdown time. Mirrors cooperative shutdown in §6.7. | --- diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index 171cbf822f..65516e73d4 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -334,18 +334,20 @@ class GracefulShutdownConfig(BaseModel): model_config = ConfigDict(frozen=True, allow_inf_nan=False) - strategy: str = Field( + strategy: NotBlankStr = Field( default="cooperative_timeout", description="Shutdown strategy name", ) grace_seconds: float = Field( default=30.0, gt=0, + le=300, description="Seconds to wait for cooperative agent exit", ) cleanup_seconds: float = Field( default=5.0, gt=0, + le=60, description="Seconds allowed for cleanup callbacks", ) diff --git a/src/ai_company/core/task_transitions.py b/src/ai_company/core/task_transitions.py index 3a26b1ca65..a5cd66e062 100644 --- a/src/ai_company/core/task_transitions.py +++ b/src/ai_company/core/task_transitions.py @@ -1,8 +1,8 @@ """Task lifecycle state machine transitions. Defines the valid state transitions for the task lifecycle, based on -DESIGN_SPEC Sections 6.1 and 6.6, extended with BLOCKED, CANCELLED, and -FAILED transitions for completeness:: +DESIGN_SPEC Sections 6.1 and 6.6, extended with BLOCKED, CANCELLED, +FAILED, and INTERRUPTED transitions for completeness:: CREATED -> ASSIGNED ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED | INTERRUPTED diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 29ba59bb87..8a5a71fa69 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -16,10 +16,12 @@ from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, + make_budget_checker, ) from ai_company.engine.metrics import TaskCompletionMetrics from ai_company.engine.prompt import ( SystemPrompt, + build_error_prompt, build_system_prompt, format_task_instruction, ) @@ -66,8 +68,8 @@ """Task statuses the engine will accept for execution. CREATED tasks lack an assignee; terminal statuses (COMPLETED, CANCELLED), -BLOCKED, IN_REVIEW, and FAILED are not executable. FAILED tasks must be -reassigned (FAILED -> ASSIGNED) before re-execution. +BLOCKED, IN_REVIEW, FAILED, and INTERRUPTED are not executable. FAILED +and INTERRUPTED tasks must be reassigned (-> ASSIGNED) before re-execution. """ @@ -215,7 +217,7 @@ async def _execute( # noqa: PLR0913 tool_invoker: ToolInvoker | None = None, ) -> AgentRunResult: """Run execution loop, record costs, apply transitions, and build result.""" - budget_checker = _make_budget_checker(task) + budget_checker = make_budget_checker(task) logger.debug( EXECUTION_ENGINE_PROMPT_BUILT, @@ -740,7 +742,7 @@ async def _handle_fatal_error( # noqa: PLR0913 error_msg, ctx, ) - error_prompt = _build_error_prompt( + error_prompt = build_error_prompt( identity, agent_id, system_prompt, @@ -792,44 +794,3 @@ async def _build_error_execution( # noqa: PLR0913 agent_id, task_id, ) - - -def _build_error_prompt( - identity: AgentIdentity, - agent_id: str, - system_prompt: SystemPrompt | None, -) -> SystemPrompt: - """Return the existing system prompt or a minimal error placeholder.""" - if system_prompt is not None: - return system_prompt - return SystemPrompt( - content="", - template_version="error", - estimated_tokens=0, - sections=(), - metadata={ - "agent_id": agent_id, - "name": identity.name, - "role": identity.role, - "department": identity.department, - "level": identity.level.value, - }, - ) - - -def _make_budget_checker(task: Task) -> BudgetChecker | None: - """Create a budget checker if the task has a positive budget limit. - - The returned callable returns ``True`` when accumulated cost meets - or exceeds the limit (budget exhausted), ``False`` otherwise. - Returns ``None`` when there is no positive budget limit. - """ - if task.budget_limit <= 0: - return None - - limit = task.budget_limit - - def _check(ctx: AgentContext) -> bool: - return ctx.accumulated_cost.cost_usd >= limit - - return _check diff --git a/src/ai_company/engine/loop_protocol.py b/src/ai_company/engine/loop_protocol.py index bb272e6073..d153647e6d 100644 --- a/src/ai_company/engine/loop_protocol.py +++ b/src/ai_company/engine/loop_protocol.py @@ -2,7 +2,8 @@ Defines the ``ExecutionLoop`` protocol that the agent engine calls to run a task, along with ``ExecutionResult``, ``TurnRecord``, -``TerminationReason``, and the ``BudgetChecker`` type alias. +``TerminationReason``, and the ``BudgetChecker`` and ``ShutdownChecker`` +type aliases. """ from collections.abc import Callable @@ -11,6 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator +from ai_company.core.task import Task # noqa: TC001 from ai_company.core.types import NotBlankStr # noqa: TC001 from ai_company.engine.context import AgentContext from ai_company.providers.enums import FinishReason # noqa: TC001 @@ -166,3 +168,21 @@ async def execute( # noqa: PLR0913 def get_loop_type(self) -> str: """Return the loop type identifier (e.g. ``"react"``).""" ... + + +def make_budget_checker(task: Task) -> BudgetChecker | None: + """Create a budget checker if the task has a positive budget limit. + + The returned callable returns ``True`` when accumulated cost meets + or exceeds the limit (budget exhausted), ``False`` otherwise. + Returns ``None`` when there is no positive budget limit. + """ + if task.budget_limit <= 0: + return None + + limit = task.budget_limit + + def _check(ctx: AgentContext) -> bool: + return ctx.accumulated_cost.cost_usd >= limit + + return _check diff --git a/src/ai_company/engine/prompt.py b/src/ai_company/engine/prompt.py index 69a38b550f..893282cbf5 100644 --- a/src/ai_company/engine/prompt.py +++ b/src/ai_company/engine/prompt.py @@ -625,6 +625,41 @@ def _render_and_estimate( # noqa: PLR0913 return content, estimator.estimate_tokens(content) +def build_error_prompt( + identity: AgentIdentity, + agent_id: str, + system_prompt: SystemPrompt | None, +) -> SystemPrompt: + """Return the existing system prompt or a minimal error placeholder. + + Used by the engine when the execution pipeline fails and a + ``SystemPrompt`` was never built (or was partially built). + + Args: + identity: Agent identity for metadata. + agent_id: String agent identifier. + system_prompt: Previously built prompt, or ``None``. + + Returns: + The existing prompt if available, else a minimal placeholder. + """ + if system_prompt is not None: + return system_prompt + return SystemPrompt( + content="", + template_version="error", + estimated_tokens=0, + sections=(), + metadata={ + "agent_id": agent_id, + "name": identity.name, + "role": identity.role, + "department": identity.department, + "level": identity.level.value, + }, + ) + + def format_task_instruction(task: Task) -> str: """Format a task into a user message for the initial conversation. diff --git a/src/ai_company/engine/react_loop.py b/src/ai_company/engine/react_loop.py index 213c4696e7..35f35bf9a8 100644 --- a/src/ai_company/engine/react_loop.py +++ b/src/ai_company/engine/react_loop.py @@ -1,8 +1,9 @@ """ReAct execution loop — think, act, observe. Implements the ``ExecutionLoop`` protocol using the ReAct pattern: -check budget -> call LLM -> record turn -> check for LLM errors -> -update context -> handle completion or execute tools -> repeat. +check shutdown -> check budget -> call LLM -> record turn -> +check for LLM errors -> update context -> handle completion or +(check shutdown -> execute tools) -> repeat. """ from typing import TYPE_CHECKING @@ -204,7 +205,25 @@ def _check_shutdown( """Return a termination result if a shutdown has been requested.""" if shutdown_checker is None: return None - if not shutdown_checker(): + try: + shutting_down = shutdown_checker() + except MemoryError, RecursionError: + raise + except Exception as exc: + error_msg = f"Shutdown checker failed: {type(exc).__name__}: {exc}" + logger.exception( + EXECUTION_LOOP_ERROR, + execution_id=ctx.execution_id, + turn=ctx.turn_count, + error=error_msg, + ) + return _build_result( + ctx, + TerminationReason.ERROR, + turns, + error_message=error_msg, + ) + if not shutting_down: return None logger.info( EXECUTION_LOOP_SHUTDOWN, diff --git a/src/ai_company/engine/shutdown.py b/src/ai_company/engine/shutdown.py index a5839f027c..764dfae6fa 100644 --- a/src/ai_company/engine/shutdown.py +++ b/src/ai_company/engine/shutdown.py @@ -21,6 +21,7 @@ from ai_company.observability import get_logger from ai_company.observability.events.execution import ( EXECUTION_SHUTDOWN_CLEANUP, + EXECUTION_SHUTDOWN_CLEANUP_TIMEOUT, EXECUTION_SHUTDOWN_COMPLETE, EXECUTION_SHUTDOWN_FORCE_CANCEL, EXECUTION_SHUTDOWN_GRACE_START, @@ -45,14 +46,17 @@ class ShutdownResult(BaseModel): duration_seconds: Wall-clock duration of the entire shutdown. """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, allow_inf_nan=False) strategy_type: NotBlankStr = Field( description="Name of the strategy that executed the shutdown", ) tasks_interrupted: int = Field( ge=0, - description="Number of tasks that were force-cancelled", + description=( + "Number of tasks force-cancelled during shutdown " + "(transitioned to INTERRUPTED)" + ), ) tasks_completed: int = Field( ge=0, @@ -118,6 +122,12 @@ def __init__( grace_seconds: float = 30.0, cleanup_seconds: float = 5.0, ) -> None: + if grace_seconds <= 0: + msg = f"grace_seconds must be positive, got {grace_seconds}" + raise ValueError(msg) + if cleanup_seconds <= 0: + msg = f"cleanup_seconds must be positive, got {cleanup_seconds}" + raise ValueError(msg) self._grace_seconds = grace_seconds self._cleanup_seconds = cleanup_seconds self._shutdown_event = asyncio.Event() @@ -191,6 +201,18 @@ async def _wait_and_cancel( task_set, timeout=self._grace_seconds, ) + + # Retrieve exceptions from done tasks to prevent + # "Task exception was never retrieved" warnings + for task in done: + if not task.cancelled() and task.exception() is not None: + logger.warning( + EXECUTION_SHUTDOWN_FORCE_CANCEL, + error=( + f"Task raised during shutdown: " + f"{type(task.exception()).__name__}" + ), + ) tasks_completed = len(done) if pending: @@ -200,8 +222,8 @@ async def _wait_and_cancel( ) for task in pending: task.cancel() - # Wait for cancellation to propagate - await asyncio.wait(pending) + # Wait for cancellation to propagate (bounded) + await asyncio.wait(pending, timeout=5.0) return tasks_completed, len(pending) @@ -209,7 +231,7 @@ async def _run_cleanup( self, callbacks: list[CleanupCallback], ) -> bool: - """Run cleanup callbacks within the cleanup time budget.""" + """Run cleanup callbacks sequentially within the time budget.""" if not callbacks: return True @@ -219,9 +241,21 @@ async def _run_cleanup( cleanup_seconds=self._cleanup_seconds, ) + all_succeeded = True + async def _run_all() -> None: - for callback in callbacks: - await callback() + nonlocal all_succeeded + for i, callback in enumerate(callbacks): + try: + await callback() + except Exception: + all_succeeded = False + logger.exception( + EXECUTION_SHUTDOWN_CLEANUP, + callback_index=i, + callback_count=len(callbacks), + error="Cleanup callback failed", + ) try: await asyncio.wait_for( @@ -230,12 +264,11 @@ async def _run_all() -> None: ) except TimeoutError: logger.warning( - EXECUTION_SHUTDOWN_CLEANUP, - cleanup_timeout=True, + EXECUTION_SHUTDOWN_CLEANUP_TIMEOUT, cleanup_seconds=self._cleanup_seconds, ) return False - return True + return all_succeeded class ShutdownManager: @@ -256,6 +289,11 @@ def __init__( self._running_tasks: dict[str, asyncio.Task[Any]] = {} self._cleanup_callbacks: list[CleanupCallback] = [] self._signals_installed = False + logger.debug( + EXECUTION_SHUTDOWN_COMPLETE, + strategy=self._strategy.get_strategy_type(), + action="manager_created", + ) @property def strategy(self) -> ShutdownStrategy: @@ -287,23 +325,42 @@ def _handle_signal(self, sig: signal.Signals) -> None: EXECUTION_SHUTDOWN_SIGNAL, signal=sig.name, ) - self._strategy.request_shutdown() + try: + self._strategy.request_shutdown() + except Exception: + logger.exception( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig.name, + error="request_shutdown() raised in signal handler", + ) def _handle_signal_threadsafe( self, signum: int, _frame: Any, ) -> None: - """Handle signal on Windows (called from signal handler thread).""" - sig_name = signal.Signals(signum).name - logger.info( - EXECUTION_SHUTDOWN_SIGNAL, - signal=sig_name, - ) + """Handle signal on Windows (called outside the event loop context). + + Logging is deferred to the event loop via ``call_soon_threadsafe`` + to avoid deadlocks (structlog acquires locks internally). + """ + try: + sig_name = signal.Signals(signum).name + except ValueError: + sig_name = f"UNKNOWN({signum})" + + def _on_loop() -> None: + logger.info( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig_name, + ) + self._strategy.request_shutdown() + try: loop = asyncio.get_running_loop() - loop.call_soon_threadsafe(self._strategy.request_shutdown) + loop.call_soon_threadsafe(_on_loop) except RuntimeError: + # No running event loop — call directly (best-effort). self._strategy.request_shutdown() def register_task( @@ -313,13 +370,31 @@ def register_task( ) -> None: """Track a running agent task.""" self._running_tasks[task_id] = asyncio_task + logger.debug( + EXECUTION_SHUTDOWN_GRACE_START, + action="task_registered", + task_id=task_id, + running_tasks=len(self._running_tasks), + ) def unregister_task(self, task_id: str) -> None: """Stop tracking a completed agent task.""" self._running_tasks.pop(task_id, None) + logger.debug( + EXECUTION_SHUTDOWN_GRACE_START, + action="task_unregistered", + task_id=task_id, + running_tasks=len(self._running_tasks), + ) def register_cleanup(self, callback: CleanupCallback) -> None: - """Register an async cleanup callback for shutdown.""" + """Register an async cleanup callback for shutdown. + + Callbacks run sequentially in registration order during + shutdown. Each callback is individually guarded against + exceptions — a failing callback does not prevent subsequent + ones from running. + """ self._cleanup_callbacks.append(callback) def is_shutting_down(self) -> bool: diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index 3176f2ee52..e2823c9758 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -39,6 +39,7 @@ EXECUTION_SHUTDOWN_GRACE_START: Final[str] = "execution.shutdown.grace_start" EXECUTION_SHUTDOWN_FORCE_CANCEL: Final[str] = "execution.shutdown.force_cancel" EXECUTION_SHUTDOWN_CLEANUP: Final[str] = "execution.shutdown.cleanup" +EXECUTION_SHUTDOWN_CLEANUP_TIMEOUT: Final[str] = "execution.shutdown.cleanup.timeout" EXECUTION_SHUTDOWN_COMPLETE: Final[str] = "execution.shutdown.complete" EXECUTION_LOOP_SHUTDOWN: Final[str] = "execution.loop.shutdown" diff --git a/tests/unit/core/test_enums.py b/tests/unit/core/test_enums.py index 0a36d6961c..b58840e79b 100644 --- a/tests/unit/core/test_enums.py +++ b/tests/unit/core/test_enums.py @@ -111,6 +111,7 @@ def test_task_status_values(self) -> None: assert TaskStatus.BLOCKED.value == "blocked" assert TaskStatus.FAILED.value == "failed" assert TaskStatus.CANCELLED.value == "cancelled" + assert TaskStatus.INTERRUPTED.value == "interrupted" def test_task_type_values(self) -> None: assert TaskType.DEVELOPMENT.value == "development" diff --git a/tests/unit/engine/test_loop_protocol.py b/tests/unit/engine/test_loop_protocol.py index 9632c2d368..497b5f5a8a 100644 --- a/tests/unit/engine/test_loop_protocol.py +++ b/tests/unit/engine/test_loop_protocol.py @@ -22,6 +22,7 @@ def test_values(self) -> None: assert TerminationReason.COMPLETED.value == "completed" assert TerminationReason.MAX_TURNS.value == "max_turns" assert TerminationReason.BUDGET_EXHAUSTED.value == "budget_exhausted" + assert TerminationReason.SHUTDOWN.value == "shutdown" assert TerminationReason.ERROR.value == "error" def test_member_count(self) -> None: diff --git a/tests/unit/engine/test_run_result.py b/tests/unit/engine/test_run_result.py index 292305cdf2..f6eb14b5c0 100644 --- a/tests/unit/engine/test_run_result.py +++ b/tests/unit/engine/test_run_result.py @@ -9,12 +9,12 @@ from ai_company.core.agent import AgentIdentity, ModelConfig from ai_company.core.enums import Priority, SeniorityLevel, TaskStatus, TaskType from ai_company.core.task import Task -from ai_company.engine.agent_engine import _make_budget_checker from ai_company.engine.context import AgentContext from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, TurnRecord, + make_budget_checker, ) from ai_company.engine.prompt import SystemPrompt, format_task_instruction from ai_company.engine.run_result import AgentRunResult @@ -290,7 +290,7 @@ def test_budget_only_no_deadline(self) -> None: @pytest.mark.unit class TestMakeBudgetChecker: - """Test _make_budget_checker closure logic.""" + """Test make_budget_checker closure logic.""" def test_returns_none_for_zero_budget(self) -> None: task = Task( @@ -304,7 +304,7 @@ def test_returns_none_for_zero_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=0.0, ) - assert _make_budget_checker(task) is None + assert make_budget_checker(task) is None def test_returns_none_for_default_budget(self) -> None: """Default budget_limit (0.0) returns None.""" @@ -318,7 +318,7 @@ def test_returns_none_for_default_budget(self) -> None: assigned_to="someone", status=TaskStatus.ASSIGNED, ) - assert _make_budget_checker(task) is None + assert make_budget_checker(task) is None def test_returns_callable_for_positive_budget(self) -> None: task = Task( @@ -332,7 +332,7 @@ def test_returns_callable_for_positive_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=5.0, ) - checker = _make_budget_checker(task) + checker = make_budget_checker(task) assert checker is not None assert callable(checker) @@ -348,7 +348,7 @@ def test_checker_returns_false_under_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=5.0, ) - checker = _make_budget_checker(task) + checker = make_budget_checker(task) assert checker is not None identity = _test_identity() @@ -377,7 +377,7 @@ def test_checker_returns_true_at_exact_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=5.0, ) - checker = _make_budget_checker(task) + checker = make_budget_checker(task) assert checker is not None identity = _test_identity() @@ -405,7 +405,7 @@ def test_checker_returns_true_over_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=5.0, ) - checker = _make_budget_checker(task) + checker = make_budget_checker(task) assert checker is not None identity = _test_identity() diff --git a/tests/unit/engine/test_shutdown.py b/tests/unit/engine/test_shutdown.py index 4a44b96b8d..f0f3e6b1ad 100644 --- a/tests/unit/engine/test_shutdown.py +++ b/tests/unit/engine/test_shutdown.py @@ -6,7 +6,9 @@ from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError +from ai_company.config.schema import GracefulShutdownConfig from ai_company.engine.shutdown import ( CooperativeTimeoutStrategy, ShutdownManager, @@ -299,9 +301,10 @@ def test_handle_signal_threadsafe_with_loop(self) -> None: mock_loop = MagicMock() with patch("asyncio.get_running_loop", return_value=mock_loop): manager._handle_signal_threadsafe(signal.SIGINT.value, None) - mock_loop.call_soon_threadsafe.assert_called_once_with( - strategy.request_shutdown, - ) + mock_loop.call_soon_threadsafe.assert_called_once() + # The callback is a closure (_on_loop) that logs and requests shutdown. + callback = mock_loop.call_soon_threadsafe.call_args[0][0] + assert callable(callback) def test_handle_signal_threadsafe_no_loop(self) -> None: strategy = CooperativeTimeoutStrategy() @@ -312,3 +315,89 @@ def test_handle_signal_threadsafe_no_loop(self) -> None: ): manager._handle_signal_threadsafe(signal.SIGINT.value, None) assert strategy.is_shutting_down() is True + + +# ── Constructor validation ──────────────────────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutValidation: + """Constructor rejects non-positive timeout values.""" + + def test_zero_grace_seconds_rejected(self) -> None: + with pytest.raises(ValueError, match="grace_seconds must be positive"): + CooperativeTimeoutStrategy(grace_seconds=0) + + def test_negative_grace_seconds_rejected(self) -> None: + with pytest.raises(ValueError, match="grace_seconds must be positive"): + CooperativeTimeoutStrategy(grace_seconds=-1.0) + + def test_zero_cleanup_seconds_rejected(self) -> None: + with pytest.raises(ValueError, match="cleanup_seconds must be positive"): + CooperativeTimeoutStrategy(cleanup_seconds=0) + + def test_negative_cleanup_seconds_rejected(self) -> None: + with pytest.raises(ValueError, match="cleanup_seconds must be positive"): + CooperativeTimeoutStrategy(cleanup_seconds=-5.0) + + +# ── Cleanup callback exception isolation ────────────────────────── + + +@pytest.mark.unit +class TestCleanupCallbackExceptionIsolation: + """A failing callback doesn't prevent subsequent callbacks from running.""" + + async def test_failing_callback_does_not_block_others(self) -> None: + strategy = CooperativeTimeoutStrategy(cleanup_seconds=5.0) + ran = [] + + async def cb_ok_1() -> None: + ran.append("cb1") + + async def cb_fail() -> None: + msg = "boom" + raise RuntimeError(msg) + + async def cb_ok_2() -> None: + ran.append("cb2") + + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[cb_ok_1, cb_fail, cb_ok_2], + ) + + assert "cb1" in ran + assert "cb2" in ran + # cleanup_completed is False because one callback failed + assert result.cleanup_completed is False + + +# ── GracefulShutdownConfig validation ───────────────────────────── + + +@pytest.mark.unit +class TestGracefulShutdownConfig: + """Config model boundary validation.""" + + def test_defaults(self) -> None: + config = GracefulShutdownConfig() + assert config.strategy == "cooperative_timeout" + assert config.grace_seconds == 30.0 + assert config.cleanup_seconds == 5.0 + + def test_grace_seconds_upper_bound(self) -> None: + with pytest.raises(ValidationError): + GracefulShutdownConfig(grace_seconds=301) + + def test_cleanup_seconds_upper_bound(self) -> None: + with pytest.raises(ValidationError): + GracefulShutdownConfig(cleanup_seconds=61) + + def test_zero_grace_seconds_rejected(self) -> None: + with pytest.raises(ValidationError): + GracefulShutdownConfig(grace_seconds=0) + + def test_blank_strategy_rejected(self) -> None: + with pytest.raises(ValidationError): + GracefulShutdownConfig(strategy=" ") From d9366f7b69c9ffd83b5e99d5ef5a0ea5d64a6e68 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sat, 7 Mar 2026 13:21:19 +0100 Subject: [PATCH 3/3] fix: address 27 PR review items from local agents and external reviewers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add drain gate to ShutdownManager.register_task (rejects after shutdown) - Add input_token_estimate to turn-start log (char_count // 4 heuristic) - Fix semantic misuse of event constants in shutdown module - Fix tasks_completed counting (exclude errored tasks) - Add exception guards to Windows signal handler callback - Add exception guard to Windows signal handler fallback path - Fix raise chain: raise exc from build_exc in agent_engine - Fix TurnRecord tool_calls_made when shutdown preempts tool execution - Rework integration test_shutdown_signal_propagates_through_manager - Add is_success=False test for SHUTDOWN termination reason - Add shutdown_checker exception test (returns ERROR) - Retrieve post-cancellation exceptions to prevent warnings - Add CANCELLED transition arrows to §6.1 lifecycle diagram - Add duplicate task_id warning in register_task - Fix module docstring: shutdown module doesn't mark tasks INTERRUPTED - Add Returns section to _run_cleanup docstring - Clarify tasks_interrupted field description - Fix _frame type annotation to types.FrameType | None - Clarify §6.7: ALL shutdown-terminated tasks become INTERRUPTED - Add MemoryError propagation test for shutdown_checker - Add _transition_to_interrupted failure path test - Fix test_handle_signal_threadsafe to execute callback - Fix provider docstring (triggers on first call, not second) - Use Mapping/Sequence in protocol execute_shutdown params - Use ValidationError instead of Exception in frozen test - Use named constant for cancel propagation timeout - Fix get_strategy_type return type alignment with NotBlankStr Co-Authored-By: Claude Opus 4.6 --- DESIGN_SPEC.md | 8 +- src/ai_company/engine/agent_engine.py | 2 +- src/ai_company/engine/react_loop.py | 12 ++ src/ai_company/engine/shutdown.py | 120 +++++++++++++----- .../observability/events/execution.py | 3 + .../engine/test_graceful_shutdown.py | 35 +++-- .../engine/test_agent_engine_lifecycle.py | 42 ++++++ tests/unit/engine/test_react_loop.py | 43 +++++++ tests/unit/engine/test_run_result.py | 6 + tests/unit/engine/test_shutdown.py | 15 ++- 10 files changed, 226 insertions(+), 60 deletions(-) diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index af116b73de..aabf81eb4d 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -710,9 +710,9 @@ structured_phases: │ │ COMPLETED │ │ └────────────┘ │ - │ blocked cancelled + │ blocked cancelled (from ASSIGNED or IN_PROGRESS) ┌─────▼─────┐ ┌────────────┐ - │ BLOCKED │ │ CANCELLED │ + │ BLOCKED │ │ CANCELLED │ ◀── ASSIGNED / IN_PROGRESS └─────┬─────┘ └────────────┘ │ unblocked (terminal) └──▶ ASSIGNED @@ -1033,7 +1033,7 @@ When the process receives SIGTERM/SIGINT (user Ctrl+C, Docker stop, systemd shut #### Strategy 1: Cooperative with Timeout (Default / MVP) -The engine sets a shutdown event, stops accepting new tasks, and gives in-flight agents a grace period to finish their current turn. Agents check the shutdown event at turn boundaries (between LLM calls, before tool invocations) and exit cooperatively. After the grace period, remaining agents are force-cancelled and their tasks marked `INTERRUPTED`. +The engine sets a shutdown event, stops accepting new tasks, and gives in-flight agents a grace period to finish their current turn. Agents check the shutdown event at turn boundaries (between LLM calls, before tool invocations) and exit cooperatively. After the grace period, remaining agents are force-cancelled. **All tasks terminated by shutdown — whether they exited cooperatively or were force-cancelled — are marked `INTERRUPTED`** by the engine layer. ```yaml graceful_shutdown: @@ -1050,7 +1050,7 @@ On shutdown signal: 4. Force-cancel remaining agents (`task.cancel()`) — tasks transition to `INTERRUPTED` 5. Cleanup phase (`cleanup_seconds`): persist cost records, close provider connections, flush logs -> **Non-terminal status (implemented in M3):** `INTERRUPTED` is a `TaskStatus` variant. Unlike `FAILED` (eligible for automatic reassignment) or `CANCELLED` (terminal), `INTERRUPTED` indicates the task was stopped due to process shutdown and is eligible for manual or automatic reassignment on restart. Valid transitions: `ASSIGNED → INTERRUPTED`, `IN_PROGRESS → INTERRUPTED`, `INTERRUPTED → ASSIGNED` (reassignment on restart). See the updated §6.1 lifecycle diagram. +> **Non-terminal status (implemented in M3):** `INTERRUPTED` is a `TaskStatus` variant. Unlike `FAILED` (eligible for automatic reassignment) or `CANCELLED` (terminal), `INTERRUPTED` indicates the task was stopped due to process shutdown — regardless of whether the agent exited cooperatively or was force-cancelled — and is eligible for manual or automatic reassignment on restart. Valid transitions: `ASSIGNED → INTERRUPTED`, `IN_PROGRESS → INTERRUPTED`, `INTERRUPTED → ASSIGNED` (reassignment on restart). See the updated §6.1 lifecycle diagram. > > **Windows compatibility:** `loop.add_signal_handler()` is not supported on Windows. The implementation uses `signal.signal()` as a fallback. SIGINT (Ctrl+C) works cross-platform; SIGTERM on Windows requires `os.kill()`. > diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 8a5a71fa69..1b9c81d6e4 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -771,7 +771,7 @@ async def _handle_fatal_error( # noqa: PLR0913 error=f"Failed to build error result: {build_exc}", original_error=error_msg, ) - raise exc from None + raise exc from build_exc async def _build_error_execution( # noqa: PLR0913 self, diff --git a/src/ai_company/engine/react_loop.py b/src/ai_company/engine/react_loop.py index 35f35bf9a8..832e54ecd3 100644 --- a/src/ai_company/engine/react_loop.py +++ b/src/ai_company/engine/react_loop.py @@ -186,6 +186,13 @@ async def _process_turn_response( # noqa: PLR0913 # Check shutdown before tool invocations shutdown_result = self._check_shutdown(ctx, shutdown_checker, turns) if shutdown_result is not None: + # Tools were not executed — clear tool_calls_made in the + # last TurnRecord so it doesn't overstate what happened. + if turns: + last = turns[-1] + turns[-1] = last.model_copy( + update={"tool_calls_made": ()}, + ) return shutdown_result return await self._execute_tool_calls( @@ -283,11 +290,16 @@ async def _call_provider( # noqa: PLR0913 turns: list[TurnRecord], ) -> CompletionResponse | ExecutionResult: """Call provider.complete(), returning an error result on failure.""" + # Estimate input tokens from message character count (rough + # heuristic: ~4 chars per token). The exact count is only + # available *after* the provider call. + char_count = sum(len(m.content or "") for m in ctx.conversation) logger.info( EXECUTION_LOOP_TURN_START, execution_id=ctx.execution_id, turn=turn_number, message_count=len(ctx.conversation), + input_token_estimate=char_count // 4, ) try: return await provider.complete( diff --git a/src/ai_company/engine/shutdown.py b/src/ai_company/engine/shutdown.py index 764dfae6fa..023f2896e1 100644 --- a/src/ai_company/engine/shutdown.py +++ b/src/ai_company/engine/shutdown.py @@ -3,17 +3,22 @@ Implements DESIGN_SPEC §6.7 — cooperative timeout strategy for clean process shutdown. When SIGINT/SIGTERM is received the framework signals agents to exit at turn boundaries, waits a grace period, force-cancels -stragglers, marks tasks INTERRUPTED, and runs cleanup callbacks. +stragglers, and runs cleanup callbacks. The *engine* layer is responsible +for transitioning tasks to INTERRUPTED (see ``AgentEngine``). The ``ShutdownStrategy`` protocol is pluggable for future strategies. """ import asyncio +import contextlib import signal import sys import time -from collections.abc import Callable, Coroutine -from typing import Any, Protocol, runtime_checkable +from collections.abc import Callable, Coroutine, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import types from pydantic import BaseModel, ConfigDict, Field @@ -25,7 +30,10 @@ EXECUTION_SHUTDOWN_COMPLETE, EXECUTION_SHUTDOWN_FORCE_CANCEL, EXECUTION_SHUTDOWN_GRACE_START, + EXECUTION_SHUTDOWN_MANAGER_CREATED, EXECUTION_SHUTDOWN_SIGNAL, + EXECUTION_SHUTDOWN_TASK_ERROR, + EXECUTION_SHUTDOWN_TASK_TRACKED, ) logger = get_logger(__name__) @@ -54,8 +62,8 @@ class ShutdownResult(BaseModel): tasks_interrupted: int = Field( ge=0, description=( - "Number of tasks force-cancelled during shutdown " - "(transitioned to INTERRUPTED)" + "Number of tasks still running after the grace period " + "that were force-cancelled" ), ) tasks_completed: int = Field( @@ -86,15 +94,15 @@ def is_shutting_down(self) -> bool: async def execute_shutdown( self, *, - running_tasks: dict[str, asyncio.Task[Any]], - cleanup_callbacks: list[CleanupCallback], + running_tasks: Mapping[str, asyncio.Task[Any]], + cleanup_callbacks: Sequence[CleanupCallback], ) -> ShutdownResult: """Execute the full shutdown sequence. Args: running_tasks: Map of task_id → asyncio.Task for in-flight agent executions. - cleanup_callbacks: Ordered list of async cleanup callbacks + cleanup_callbacks: Ordered sequence of async cleanup callbacks to invoke after task shutdown. Returns: @@ -147,8 +155,8 @@ def get_strategy_type(self) -> str: async def execute_shutdown( self, *, - running_tasks: dict[str, asyncio.Task[Any]], - cleanup_callbacks: list[CleanupCallback], + running_tasks: Mapping[str, asyncio.Task[Any]], + cleanup_callbacks: Sequence[CleanupCallback], ) -> ShutdownResult: """Execute the cooperative timeout shutdown sequence.""" start = time.monotonic() @@ -184,9 +192,12 @@ async def execute_shutdown( ) return result + _CANCEL_PROPAGATION_TIMEOUT: float = 5.0 + """Seconds to wait for cancellation to propagate after force-cancel.""" + async def _wait_and_cancel( self, - running_tasks: dict[str, asyncio.Task[Any]], + running_tasks: Mapping[str, asyncio.Task[Any]], ) -> tuple[int, int]: """Wait for cooperative exit, then force-cancel stragglers. @@ -203,17 +214,21 @@ async def _wait_and_cancel( ) # Retrieve exceptions from done tasks to prevent - # "Task exception was never retrieved" warnings + # "Task exception was never retrieved" warnings. + # Tasks that raised are not counted as "completed" — only + # cleanly-finished tasks count. + tasks_completed = 0 for task in done: - if not task.cancelled() and task.exception() is not None: + if task.cancelled(): + continue + exc = task.exception() + if exc is not None: logger.warning( - EXECUTION_SHUTDOWN_FORCE_CANCEL, - error=( - f"Task raised during shutdown: " - f"{type(task.exception()).__name__}" - ), + EXECUTION_SHUTDOWN_TASK_ERROR, + error=(f"Task raised during shutdown: {type(exc).__name__}"), ) - tasks_completed = len(done) + else: + tasks_completed += 1 if pending: logger.warning( @@ -222,16 +237,29 @@ async def _wait_and_cancel( ) for task in pending: task.cancel() - # Wait for cancellation to propagate (bounded) - await asyncio.wait(pending, timeout=5.0) + # Wait for cancellation to propagate (bounded). + # Retrieve exceptions to suppress "never retrieved" warnings. + cancel_done, _ = await asyncio.wait( + pending, + timeout=self._CANCEL_PROPAGATION_TIMEOUT, + ) + for task in cancel_done: + if not task.cancelled(): + with contextlib.suppress(Exception): + task.exception() return tasks_completed, len(pending) async def _run_cleanup( self, - callbacks: list[CleanupCallback], + callbacks: Sequence[CleanupCallback], ) -> bool: - """Run cleanup callbacks sequentially within the time budget.""" + """Run cleanup callbacks sequentially within the time budget. + + Returns: + ``True`` if all callbacks completed successfully within the + time budget, ``False`` otherwise. + """ if not callbacks: return True @@ -290,9 +318,8 @@ def __init__( self._cleanup_callbacks: list[CleanupCallback] = [] self._signals_installed = False logger.debug( - EXECUTION_SHUTDOWN_COMPLETE, + EXECUTION_SHUTDOWN_MANAGER_CREATED, strategy=self._strategy.get_strategy_type(), - action="manager_created", ) @property @@ -337,7 +364,7 @@ def _handle_signal(self, sig: signal.Signals) -> None: def _handle_signal_threadsafe( self, signum: int, - _frame: Any, + _frame: types.FrameType | None, ) -> None: """Handle signal on Windows (called outside the event loop context). @@ -350,28 +377,51 @@ def _handle_signal_threadsafe( sig_name = f"UNKNOWN({signum})" def _on_loop() -> None: - logger.info( - EXECUTION_SHUTDOWN_SIGNAL, - signal=sig_name, - ) - self._strategy.request_shutdown() + try: + logger.info( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig_name, + ) + self._strategy.request_shutdown() + except Exception: + logger.exception( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig_name, + error="request_shutdown() raised in signal handler", + ) try: loop = asyncio.get_running_loop() loop.call_soon_threadsafe(_on_loop) except RuntimeError: # No running event loop — call directly (best-effort). - self._strategy.request_shutdown() + # Cannot log safely without a loop, so suppress all errors. + with contextlib.suppress(Exception): + self._strategy.request_shutdown() def register_task( self, task_id: str, asyncio_task: asyncio.Task[Any], ) -> None: - """Track a running agent task.""" + """Track a running agent task. + + Raises: + RuntimeError: If shutdown has already been requested (drain + gate is closed). + """ + if self._strategy.is_shutting_down(): + msg = f"Cannot register task {task_id!r}: shutdown already in progress" + raise RuntimeError(msg) + if task_id in self._running_tasks: + logger.warning( + EXECUTION_SHUTDOWN_TASK_TRACKED, + action="task_overwritten", + task_id=task_id, + ) self._running_tasks[task_id] = asyncio_task logger.debug( - EXECUTION_SHUTDOWN_GRACE_START, + EXECUTION_SHUTDOWN_TASK_TRACKED, action="task_registered", task_id=task_id, running_tasks=len(self._running_tasks), @@ -381,7 +431,7 @@ def unregister_task(self, task_id: str) -> None: """Stop tracking a completed agent task.""" self._running_tasks.pop(task_id, None) logger.debug( - EXECUTION_SHUTDOWN_GRACE_START, + EXECUTION_SHUTDOWN_TASK_TRACKED, action="task_unregistered", task_id=task_id, running_tasks=len(self._running_tasks), diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index e2823c9758..d4f35be7e5 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -36,6 +36,9 @@ EXECUTION_ENGINE_TIMEOUT: Final[str] = "execution.engine.timeout" EXECUTION_SHUTDOWN_SIGNAL: Final[str] = "execution.shutdown.signal" +EXECUTION_SHUTDOWN_MANAGER_CREATED: Final[str] = "execution.shutdown.manager_created" +EXECUTION_SHUTDOWN_TASK_TRACKED: Final[str] = "execution.shutdown.task_tracked" +EXECUTION_SHUTDOWN_TASK_ERROR: Final[str] = "execution.shutdown.task_error" EXECUTION_SHUTDOWN_GRACE_START: Final[str] = "execution.shutdown.grace_start" EXECUTION_SHUTDOWN_FORCE_CANCEL: Final[str] = "execution.shutdown.force_cancel" EXECUTION_SHUTDOWN_CLEANUP: Final[str] = "execution.shutdown.cleanup" diff --git a/tests/integration/engine/test_graceful_shutdown.py b/tests/integration/engine/test_graceful_shutdown.py index 41220499c8..7722885900 100644 --- a/tests/integration/engine/test_graceful_shutdown.py +++ b/tests/integration/engine/test_graceful_shutdown.py @@ -27,10 +27,10 @@ class _ShutdownTriggeringProvider: - """Provider that triggers shutdown on the second call. + """Provider that triggers shutdown on the first call. - First call returns a tool-use response, second call triggers - shutdown and returns a stop response. + The first call triggers shutdown and returns a STOP response, + so the loop completes before checking the shutdown flag again. """ def __init__(self, strategy: CooperativeTimeoutStrategy) -> None: @@ -90,10 +90,17 @@ async def get_model_capabilities(self, model: str) -> Any: class TestGracefulShutdownFlow: """Full shutdown integration: engine + strategy + agent → INTERRUPTED.""" - async def test_shutdown_interrupts_task_and_runs_cleanup( + async def test_shutdown_signal_propagates_through_manager( self, ) -> None: - """Engine with shutdown_checker stops and transitions to INTERRUPTED.""" + """Shutdown signal during execution propagates through manager. + + The provider triggers shutdown on the first call but returns + STOP (no tool calls), so the loop completes *before* the next + shutdown check. This verifies the signal propagation path — + test_shutdown_during_multi_turn_interrupts below covers the + INTERRUPTED transition. + """ from datetime import date from uuid import uuid4 @@ -122,7 +129,7 @@ async def test_shutdown_interrupts_task_and_runs_cleanup( task = Task( id="task-shutdown-001", title="Task for shutdown test", - description="This task will be interrupted.", + description="This task will complete before shutdown check.", type=TaskType.DEVELOPMENT, priority=Priority.MEDIUM, project="proj-001", @@ -136,14 +143,6 @@ async def test_shutdown_interrupts_task_and_runs_cleanup( strategy = CooperativeTimeoutStrategy(grace_seconds=5.0) manager = ShutdownManager(strategy=strategy) - # Track cleanup execution - cleanup_ran = [] - - async def cleanup_cb() -> None: - cleanup_ran.append("done") - - manager.register_cleanup(cleanup_cb) - provider = _ShutdownTriggeringProvider(strategy) engine = AgentEngine( @@ -151,14 +150,14 @@ async def cleanup_cb() -> None: shutdown_checker=manager.is_shutting_down, ) - await engine.run( + result = await engine.run( identity=identity, task=task, ) - # The first LLM call triggers shutdown, but since it returns - # STOP with no tool calls, the loop completes normally. - # Verify the manager state reflects the shutdown signal. + # Loop completed normally (STOP with no tool calls) + assert result.is_success is True + # But the shutdown signal was set assert manager.is_shutting_down() is True async def test_shutdown_during_multi_turn_interrupts( diff --git a/tests/unit/engine/test_agent_engine_lifecycle.py b/tests/unit/engine/test_agent_engine_lifecycle.py index 2c6b9d20f1..8be682efc2 100644 --- a/tests/unit/engine/test_agent_engine_lifecycle.py +++ b/tests/unit/engine/test_agent_engine_lifecycle.py @@ -523,3 +523,45 @@ async def test_transition_failure_preserves_result( te = result.execution_result.context.task_execution assert te is not None assert te.status == TaskStatus.CANCELLED + + async def test_interrupted_transition_failure_preserves_result( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """SHUTDOWN with invalid task status → transition fails, result kept.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + ) + te = ctx.task_execution + assert te is not None + # Force into COMPLETED (terminal) — INTERRUPTED transition should fail + bad_te = te.model_copy(update={"status": TaskStatus.COMPLETED}) + ctx_bad = ctx.model_copy(update={"task_execution": bad_te}) + + mock_result = ExecutionResult( + context=ctx_bad, + termination_reason=TerminationReason.SHUTDOWN, + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") + + provider = mock_provider_factory([]) + engine = AgentEngine(provider=provider, execution_loop=mock_loop) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + te = result.execution_result.context.task_execution + assert te is not None + # Transition failed, so status stays as COMPLETED + assert te.status == TaskStatus.COMPLETED diff --git a/tests/unit/engine/test_react_loop.py b/tests/unit/engine/test_react_loop.py index 9f6d40e6e3..026bf90d94 100644 --- a/tests/unit/engine/test_react_loop.py +++ b/tests/unit/engine/test_react_loop.py @@ -966,6 +966,49 @@ async def test_no_shutdown_checker_runs_normally( assert result.termination_reason == TerminationReason.COMPLETED + async def test_shutdown_checker_exception_returns_error( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Shutdown checker that raises → ERROR termination.""" + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([]) + loop = ReactLoop() + + def bad_checker() -> bool: + msg = "checker broke" + raise ValueError(msg) + + result = await loop.execute( + context=ctx, + provider=provider, + shutdown_checker=bad_checker, + ) + + assert result.termination_reason == TerminationReason.ERROR + assert "Shutdown checker failed" in (result.error_message or "") + + async def test_shutdown_checker_memory_error_propagates( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """MemoryError from shutdown checker propagates unconditionally.""" + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([]) + loop = ReactLoop() + + def oom_checker() -> bool: + raise MemoryError + + with pytest.raises(MemoryError): + await loop.execute( + context=ctx, + provider=provider, + shutdown_checker=oom_checker, + ) + @pytest.mark.unit class TestReactLoopCostAccounting: diff --git a/tests/unit/engine/test_run_result.py b/tests/unit/engine/test_run_result.py index f6eb14b5c0..200bf609ff 100644 --- a/tests/unit/engine/test_run_result.py +++ b/tests/unit/engine/test_run_result.py @@ -158,6 +158,12 @@ def test_is_success_false_on_budget(self) -> None: ) assert result.is_success is False + def test_is_success_false_on_shutdown(self) -> None: + result = _make_run_result( + termination_reason=TerminationReason.SHUTDOWN, + ) + assert result.is_success is False + @pytest.mark.unit class TestAgentRunResultValidation: diff --git a/tests/unit/engine/test_shutdown.py b/tests/unit/engine/test_shutdown.py index f0f3e6b1ad..cd8a4112fb 100644 --- a/tests/unit/engine/test_shutdown.py +++ b/tests/unit/engine/test_shutdown.py @@ -38,7 +38,7 @@ def test_result_model_is_frozen(self) -> None: cleanup_completed=True, duration_seconds=0.1, ) - with pytest.raises(Exception, match="frozen"): + with pytest.raises(ValidationError, match="frozen"): result.tasks_interrupted = 5 # type: ignore[misc] @@ -230,6 +230,15 @@ def test_is_shutting_down_delegates(self) -> None: strategy.request_shutdown() assert manager.is_shutting_down() is True + def test_register_task_during_shutdown_raises(self) -> None: + """Drain gate: registering a task after shutdown raises RuntimeError.""" + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + strategy.request_shutdown() + mock_task = MagicMock(spec=asyncio.Task) + with pytest.raises(RuntimeError, match="shutdown already in progress"): + manager.register_task("late-task", mock_task) + @pytest.mark.unit class TestShutdownManagerSignalHandlers: @@ -302,9 +311,11 @@ def test_handle_signal_threadsafe_with_loop(self) -> None: with patch("asyncio.get_running_loop", return_value=mock_loop): manager._handle_signal_threadsafe(signal.SIGINT.value, None) mock_loop.call_soon_threadsafe.assert_called_once() - # The callback is a closure (_on_loop) that logs and requests shutdown. + # Execute the callback to verify it actually calls request_shutdown. callback = mock_loop.call_soon_threadsafe.call_args[0][0] assert callable(callback) + callback() + assert strategy.is_shutting_down() is True def test_handle_signal_threadsafe_no_loop(self) -> None: strategy = CooperativeTimeoutStrategy()