diff --git a/CLAUDE.md b/CLAUDE.md index 2bbf5f0d98..08d5dd6560 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -48,7 +48,7 @@ src/ai_company/ communication/ # Inter-agent message bus and channels config/ # YAML company config loading and validation core/ # Shared domain models and base classes - engine/ # Agent orchestration, execution loops, and task lifecycle + engine/ # Agent orchestration, execution loops, task lifecycle, recovery, and shutdown memory/ # Persistent agent memory (memory layer TBD) observability/ # Structured logging, correlation tracking, log sinks providers/ # LLM provider abstraction (LiteLLM adapter) diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 8b9e80110e..aabf81eb4d 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -710,17 +710,24 @@ structured_phases: │ │ COMPLETED │ │ └────────────┘ │ - │ blocked cancelled + │ blocked cancelled (from ASSIGNED or IN_PROGRESS) ┌─────▼─────┐ ┌────────────┐ - │ BLOCKED │ │ CANCELLED │ + │ BLOCKED │ │ CANCELLED │ ◀── ASSIGNED / IN_PROGRESS └─────┬─────┘ └────────────┘ │ unblocked (terminal) └──▶ ASSIGNED + + shutdown signal: + ┌─────────────┐ + │ INTERRUPTED │──── reassign on restart ──▶ ASSIGNED + └─────────────┘ ``` -> **Non-terminal states:** BLOCKED and FAILED are non-terminal — BLOCKED returns to ASSIGNED when unblocked, FAILED returns to ASSIGNED for retry (see §6.6). COMPLETED and CANCELLED are terminal states with no outgoing transitions. +> **Non-terminal states:** BLOCKED, FAILED, and INTERRUPTED are non-terminal — BLOCKED returns to ASSIGNED when unblocked, FAILED returns to ASSIGNED for retry (see §6.6), INTERRUPTED returns to ASSIGNED on restart (see §6.7). COMPLETED and CANCELLED are terminal states with no outgoing transitions. > > **Transitions into FAILED:** Both `ASSIGNED → FAILED` (early setup failures) and `IN_PROGRESS → FAILED` (runtime crashes) are valid. `FAILED → ASSIGNED` enables reassignment when `retry_count < max_retries`. +> +> **Transitions into INTERRUPTED:** Both `ASSIGNED → INTERRUPTED` and `IN_PROGRESS → INTERRUPTED` are valid (graceful shutdown can occur at any active phase). `INTERRUPTED → ASSIGNED` enables reassignment on restart. > **Runtime wrapper (M3):** During execution, `Task` is wrapped by `TaskExecution` (in `engine/task_execution.py`). `TaskExecution` is a frozen Pydantic model that tracks status transitions via `model_copy(update=...)`, accumulates `TokenUsage` cost, and records a `StatusTransition` audit trail. The original `Task` is preserved unchanged; `to_task_snapshot()` produces a `Task` copy with the current execution status for persistence. @@ -1026,7 +1033,7 @@ When the process receives SIGTERM/SIGINT (user Ctrl+C, Docker stop, systemd shut #### Strategy 1: Cooperative with Timeout (Default / MVP) -The engine sets a shutdown event, stops accepting new tasks, and gives in-flight agents a grace period to finish their current turn. Agents check the shutdown event at turn boundaries (between LLM calls, before tool invocations) and exit cooperatively. After the grace period, remaining agents are force-cancelled and their tasks marked `INTERRUPTED`. +The engine sets a shutdown event, stops accepting new tasks, and gives in-flight agents a grace period to finish their current turn. Agents check the shutdown event at turn boundaries (between LLM calls, before tool invocations) and exit cooperatively. After the grace period, remaining agents are force-cancelled. **All tasks terminated by shutdown — whether they exited cooperatively or were force-cancelled — are marked `INTERRUPTED`** by the engine layer. ```yaml graceful_shutdown: @@ -1043,7 +1050,7 @@ On shutdown signal: 4. Force-cancel remaining agents (`task.cancel()`) — tasks transition to `INTERRUPTED` 5. Cleanup phase (`cleanup_seconds`): persist cost records, close provider connections, flush logs -> **Planned non-terminal status:** `INTERRUPTED` will be introduced as a new `TaskStatus` variant (and the task status transition map updated) when graceful shutdown is implemented. Unlike `FAILED` (eligible for automatic reassignment) or `CANCELLED` (terminal), `INTERRUPTED` indicates the task was stopped due to process shutdown and is eligible for manual or automatic reassignment on restart. +> **Non-terminal status (implemented in M3):** `INTERRUPTED` is a `TaskStatus` variant. Unlike `FAILED` (eligible for automatic reassignment) or `CANCELLED` (terminal), `INTERRUPTED` indicates the task was stopped due to process shutdown — regardless of whether the agent exited cooperatively or was force-cancelled — and is eligible for manual or automatic reassignment on restart. Valid transitions: `ASSIGNED → INTERRUPTED`, `IN_PROGRESS → INTERRUPTED`, `INTERRUPTED → ASSIGNED` (reassignment on restart). See the updated §6.1 lifecycle diagram. > > **Windows compatibility:** `loop.add_signal_handler()` is not supported on Windows. The implementation uses `signal.signal()` as a fallback. SIGINT (Ctrl+C) works cross-platform; SIGTERM on Windows requires `os.kill()`. > @@ -2304,6 +2311,7 @@ ai-company/ │ │ ├── cost_recording.py # Per-turn cost recording helpers │ │ ├── run_result.py # AgentRunResult outcome model │ │ ├── agent_engine.py # Agent execution engine +│ │ ├── shutdown.py # Graceful shutdown strategy & manager │ │ ├── task_engine.py # Task routing & scheduling (M3-M4) │ │ ├── workflow_engine.py # Workflow orchestration (M4) │ │ ├── meeting_engine.py # Meeting coordination (M4) @@ -2474,7 +2482,7 @@ These conventions were established during the M0–M2+ review cycle. **Adopted** | **LLM call analytics** | Planned (incremental) | M3: proxy metrics (`turns_per_task`, `tokens_per_task`). M4: call categorization (`productive`, `coordination`, `system`) + orchestration ratio. M5+: full analytics (retry tracking, latency, cache hits, per-provider comparison). | Append-only, never blocks execution. Builds on existing `CostRecord` infrastructure. Detects orchestration overhead early. See §10.5. | | **State coordination** | Planned (M4) | Centralized single-writer: `TaskEngine` owns all task/project mutations via `asyncio.Queue`. Agents submit requests, engine applies `model_copy(update=...)` sequentially and publishes snapshots. `version: int` field on state models for future optimistic concurrency if multi-process scaling is needed. | Prevents lost updates by design. Trivial in single-threaded asyncio (no locks). Perfect audit trail. Industry consensus: MetaGPT, CrewAI, AutoGen all use prevention-by-design, not conflict resolution. See §6.8 State Coordination table. | | **Workspace isolation** | Planned (M4) | Pluggable `WorkspaceIsolationStrategy` protocol. Default: planner + git worktrees. Each agent works in an isolated worktree; sequential merge on completion. Textual conflicts detected by git; semantic conflicts reviewed by agent or human. | Industry standard (Codex, Cursor, Claude Code, VS Code). Maximum parallelism. Leverages mature git infrastructure. See §6.8. | -| **Graceful shutdown** | Planned (M3) | Pluggable `ShutdownStrategy` protocol. Default: cooperative with 30s timeout. Agents check shutdown event at turn boundaries. Force-cancel after timeout. `INTERRUPTED` status for force-cancelled tasks. M4/M5: upgrade to checkpoint-and-stop. | Cross-platform (Windows `signal.signal()` fallback). Bounded shutdown time. Mirrors cooperative shutdown in §6.7. | +| **Graceful shutdown** | Adopted (M3) | Pluggable `ShutdownStrategy` protocol. Default: cooperative with 30s timeout. Agents check shutdown event at turn boundaries. Force-cancel after timeout. `INTERRUPTED` status for force-cancelled tasks. M4/M5: upgrade to checkpoint-and-stop. | Cross-platform (Windows `signal.signal()` fallback). Bounded shutdown time. Mirrors cooperative shutdown in §6.7. | --- diff --git a/src/ai_company/config/__init__.py b/src/ai_company/config/__init__.py index cc2a8d58ea..ee367156bd 100644 --- a/src/ai_company/config/__init__.py +++ b/src/ai_company/config/__init__.py @@ -10,6 +10,7 @@ default_config_dict RootConfig AgentConfig + GracefulShutdownConfig ProviderConfig ProviderModelConfig RoutingConfig @@ -37,6 +38,7 @@ ) from ai_company.config.schema import ( AgentConfig, + GracefulShutdownConfig, ProviderConfig, ProviderModelConfig, RootConfig, @@ -51,6 +53,7 @@ "ConfigLocation", "ConfigParseError", "ConfigValidationError", + "GracefulShutdownConfig", "ProviderConfig", "ProviderModelConfig", "RootConfig", diff --git a/src/ai_company/config/defaults.py b/src/ai_company/config/defaults.py index bd2f37c5bc..50f5458daa 100644 --- a/src/ai_company/config/defaults.py +++ b/src/ai_company/config/defaults.py @@ -25,4 +25,5 @@ def default_config_dict() -> dict[str, Any]: "providers": {}, "routing": {}, "logging": None, + "graceful_shutdown": {}, } diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index ca07fd5f5b..65516e73d4 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -321,6 +321,37 @@ class AgentConfig(BaseModel): ) +class GracefulShutdownConfig(BaseModel): + """Configuration for graceful shutdown behaviour. + + Attributes: + strategy: Shutdown strategy name (e.g. ``"cooperative_timeout"``). + grace_seconds: Seconds to wait for cooperative agent exit + before force-cancelling. + cleanup_seconds: Seconds allowed for cleanup callbacks + (persist costs, close connections, flush logs). + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + strategy: NotBlankStr = Field( + default="cooperative_timeout", + description="Shutdown strategy name", + ) + grace_seconds: float = Field( + default=30.0, + gt=0, + le=300, + description="Seconds to wait for cooperative agent exit", + ) + cleanup_seconds: float = Field( + default=5.0, + gt=0, + le=60, + description="Seconds allowed for cleanup callbacks", + ) + + class RootConfig(BaseModel): """Root company configuration — the top-level validation target. @@ -339,6 +370,7 @@ class RootConfig(BaseModel): providers: LLM provider configurations keyed by provider name. routing: Model routing configuration. logging: Logging configuration (``None`` to use platform defaults). + graceful_shutdown: Graceful shutdown configuration. """ model_config = ConfigDict(frozen=True) @@ -386,6 +418,10 @@ class RootConfig(BaseModel): default=None, description="Logging configuration", ) + graceful_shutdown: GracefulShutdownConfig = Field( + default_factory=GracefulShutdownConfig, + description="Graceful shutdown configuration", + ) @model_validator(mode="after") def _validate_unique_agent_names(self) -> Self: diff --git a/src/ai_company/core/enums.py b/src/ai_company/core/enums.py index 1e2f653333..42903875a9 100644 --- a/src/ai_company/core/enums.py +++ b/src/ai_company/core/enums.py @@ -127,13 +127,14 @@ class TaskStatus(StrEnum): Summary for quick reference: CREATED -> ASSIGNED - ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED - IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED + ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED | INTERRUPTED + IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED | INTERRUPTED IN_REVIEW -> COMPLETED | IN_PROGRESS (rework) | BLOCKED | CANCELLED BLOCKED -> ASSIGNED (unblocked) FAILED -> ASSIGNED (reassignment for retry) + INTERRUPTED -> ASSIGNED (reassignment on restart) COMPLETED and CANCELLED are terminal states. - FAILED is non-terminal (can be reassigned). + FAILED and INTERRUPTED are non-terminal (can be reassigned). """ CREATED = "created" @@ -143,6 +144,7 @@ class TaskStatus(StrEnum): COMPLETED = "completed" BLOCKED = "blocked" FAILED = "failed" + INTERRUPTED = "interrupted" CANCELLED = "cancelled" diff --git a/src/ai_company/core/task_transitions.py b/src/ai_company/core/task_transitions.py index b4f5921f98..a5cd66e062 100644 --- a/src/ai_company/core/task_transitions.py +++ b/src/ai_company/core/task_transitions.py @@ -1,18 +1,19 @@ """Task lifecycle state machine transitions. Defines the valid state transitions for the task lifecycle, based on -DESIGN_SPEC Sections 6.1 and 6.6, extended with BLOCKED, CANCELLED, and -FAILED transitions for completeness:: +DESIGN_SPEC Sections 6.1 and 6.6, extended with BLOCKED, CANCELLED, +FAILED, and INTERRUPTED transitions for completeness:: CREATED -> ASSIGNED - ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED - IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED + ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED | INTERRUPTED + IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED | INTERRUPTED IN_REVIEW -> COMPLETED | IN_PROGRESS (rework) | BLOCKED | CANCELLED BLOCKED -> ASSIGNED (unblocked) FAILED -> ASSIGNED (reassignment for retry) + INTERRUPTED -> ASSIGNED (reassignment on restart) COMPLETED and CANCELLED are terminal states with no outgoing -transitions. FAILED is non-terminal (can be reassigned). +transitions. FAILED and INTERRUPTED are non-terminal (can be reassigned). """ from ai_company.core.enums import TaskStatus @@ -32,6 +33,7 @@ TaskStatus.BLOCKED, TaskStatus.CANCELLED, TaskStatus.FAILED, + TaskStatus.INTERRUPTED, } ), TaskStatus.IN_PROGRESS: frozenset( @@ -40,6 +42,7 @@ TaskStatus.BLOCKED, TaskStatus.CANCELLED, TaskStatus.FAILED, + TaskStatus.INTERRUPTED, } ), TaskStatus.IN_REVIEW: frozenset( @@ -52,6 +55,7 @@ ), TaskStatus.BLOCKED: frozenset({TaskStatus.ASSIGNED}), TaskStatus.FAILED: frozenset({TaskStatus.ASSIGNED}), # reassignment + TaskStatus.INTERRUPTED: frozenset({TaskStatus.ASSIGNED}), # reassignment on restart TaskStatus.COMPLETED: frozenset(), # terminal TaskStatus.CANCELLED: frozenset(), # terminal } diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index d730dda30b..75e9b42cfd 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -23,6 +23,7 @@ BudgetChecker, ExecutionLoop, ExecutionResult, + ShutdownChecker, TerminationReason, TurnRecord, ) @@ -40,6 +41,13 @@ RecoveryStrategy, ) from ai_company.engine.run_result import AgentRunResult +from ai_company.engine.shutdown import ( + CleanupCallback, + CooperativeTimeoutStrategy, + ShutdownManager, + ShutdownResult, + ShutdownStrategy, +) from ai_company.engine.task_execution import StatusTransition, TaskExecution from ai_company.providers.models import ZERO_TOKEN_USAGE, add_token_usage @@ -52,6 +60,8 @@ "AgentRunResult", "BudgetChecker", "BudgetExhaustedError", + "CleanupCallback", + "CooperativeTimeoutStrategy", "DefaultTokenEstimator", "EngineError", "ExecutionLoop", @@ -65,6 +75,10 @@ "ReactLoop", "RecoveryResult", "RecoveryStrategy", + "ShutdownChecker", + "ShutdownManager", + "ShutdownResult", + "ShutdownStrategy", "StatusTransition", "SystemPrompt", "TaskCompletionMetrics", diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 8f6df7bbf0..1b9c81d6e4 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -16,10 +16,12 @@ from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, + make_budget_checker, ) from ai_company.engine.metrics import TaskCompletionMetrics from ai_company.engine.prompt import ( SystemPrompt, + build_error_prompt, build_system_prompt, format_task_instruction, ) @@ -48,7 +50,11 @@ from ai_company.budget.tracker import CostTracker from ai_company.core.agent import AgentIdentity from ai_company.core.task import Task - from ai_company.engine.loop_protocol import BudgetChecker, ExecutionLoop + from ai_company.engine.loop_protocol import ( + BudgetChecker, + ExecutionLoop, + ShutdownChecker, + ) from ai_company.providers.models import CompletionConfig, ToolDefinition from ai_company.providers.protocol import CompletionProvider from ai_company.tools.registry import ToolRegistry @@ -62,8 +68,8 @@ """Task statuses the engine will accept for execution. CREATED tasks lack an assignee; terminal statuses (COMPLETED, CANCELLED), -BLOCKED, IN_REVIEW, and FAILED are not executable. FAILED tasks must be -reassigned (FAILED -> ASSIGNED) before re-execution. +BLOCKED, IN_REVIEW, FAILED, and INTERRUPTED are not executable. FAILED +and INTERRUPTED tasks must be reassigned (-> ASSIGNED) before re-execution. """ @@ -83,9 +89,12 @@ class AgentEngine: recovery_strategy: Crash recovery strategy. Defaults to a shared ``FailAndReassignStrategy`` instance. Pass ``None`` to disable. + shutdown_checker: Optional callback; returns ``True`` when a + graceful shutdown has been requested. Passed through to + the execution loop. """ - def __init__( + def __init__( # noqa: PLR0913 self, *, provider: CompletionProvider, @@ -93,12 +102,14 @@ def __init__( tool_registry: ToolRegistry | None = None, cost_tracker: CostTracker | None = None, recovery_strategy: RecoveryStrategy | None = _DEFAULT_RECOVERY_STRATEGY, + shutdown_checker: ShutdownChecker | None = None, ) -> None: self._provider = provider self._loop: ExecutionLoop = execution_loop or ReactLoop() self._tool_registry = tool_registry self._cost_tracker = cost_tracker self._recovery_strategy = recovery_strategy + self._shutdown_checker = shutdown_checker logger.debug( EXECUTION_ENGINE_CREATED, loop_type=self._loop.get_loop_type(), @@ -206,7 +217,7 @@ async def _execute( # noqa: PLR0913 tool_invoker: ToolInvoker | None = None, ) -> AgentRunResult: """Run execution loop, record costs, apply transitions, and build result.""" - budget_checker = _make_budget_checker(task) + budget_checker = make_budget_checker(task) logger.debug( EXECUTION_ENGINE_PROMPT_BUILT, @@ -322,6 +333,7 @@ async def _run_loop_with_timeout( # noqa: PLR0913 provider=self._provider, tool_invoker=tool_invoker, budget_checker=budget_checker, + shutdown_checker=self._shutdown_checker, completion_config=completion_config, ) if timeout_seconds is None: @@ -515,14 +527,22 @@ def _apply_post_execution_transitions( ) -> ExecutionResult: """Apply post-execution task transitions based on termination reason. - Only COMPLETED triggers IN_PROGRESS -> IN_REVIEW -> COMPLETED. + COMPLETED triggers IN_PROGRESS -> IN_REVIEW -> COMPLETED. + SHUTDOWN triggers current status -> INTERRUPTED. Transition failures are logged but never discard the result. """ ctx = execution_result.context if ctx.task_execution is None: return execution_result - if execution_result.termination_reason != TerminationReason.COMPLETED: + reason = execution_result.termination_reason + + if reason == TerminationReason.SHUTDOWN: + return self._transition_to_interrupted( + execution_result, ctx, agent_id, task_id + ) + + if reason != TerminationReason.COMPLETED: return execution_result try: @@ -572,6 +592,37 @@ def _transition_to_complete( ) return ctx + def _transition_to_interrupted( + self, + execution_result: ExecutionResult, + ctx: AgentContext, + agent_id: str, + task_id: str, + ) -> ExecutionResult: + """Transition task to INTERRUPTED on graceful shutdown.""" + try: + prev_status = ctx.task_execution.status # type: ignore[union-attr] + ctx = ctx.with_task_transition( + TaskStatus.INTERRUPTED, + reason="Graceful shutdown requested", + ) + logger.info( + EXECUTION_ENGINE_TASK_TRANSITION, + agent_id=agent_id, + task_id=task_id, + from_status=prev_status.value, + to_status=TaskStatus.INTERRUPTED.value, + ) + return execution_result.model_copy(update={"context": ctx}) + except (ValueError, ExecutionStateError) as exc: + logger.exception( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error=f"Post-execution INTERRUPTED transition failed: {exc}", + ) + return execution_result + async def _apply_recovery( self, execution_result: ExecutionResult, @@ -691,7 +742,7 @@ async def _handle_fatal_error( # noqa: PLR0913 error_msg, ctx, ) - error_prompt = _build_error_prompt( + error_prompt = build_error_prompt( identity, agent_id, system_prompt, @@ -720,7 +771,7 @@ async def _handle_fatal_error( # noqa: PLR0913 error=f"Failed to build error result: {build_exc}", original_error=error_msg, ) - raise exc from None + raise exc from build_exc async def _build_error_execution( # noqa: PLR0913 self, @@ -743,44 +794,3 @@ async def _build_error_execution( # noqa: PLR0913 agent_id, task_id, ) - - -def _build_error_prompt( - identity: AgentIdentity, - agent_id: str, - system_prompt: SystemPrompt | None, -) -> SystemPrompt: - """Return the existing system prompt or a minimal error placeholder.""" - if system_prompt is not None: - return system_prompt - return SystemPrompt( - content="", - template_version="error", - estimated_tokens=0, - sections=(), - metadata={ - "agent_id": agent_id, - "name": identity.name, - "role": identity.role, - "department": identity.department, - "level": identity.level.value, - }, - ) - - -def _make_budget_checker(task: Task) -> BudgetChecker | None: - """Create a budget checker if the task has a positive budget limit. - - The returned callable returns ``True`` when accumulated cost meets - or exceeds the limit (budget exhausted), ``False`` otherwise. - Returns ``None`` when there is no positive budget limit. - """ - if task.budget_limit <= 0: - return None - - limit = task.budget_limit - - def _check(ctx: AgentContext) -> bool: - return ctx.accumulated_cost.cost_usd >= limit - - return _check diff --git a/src/ai_company/engine/loop_protocol.py b/src/ai_company/engine/loop_protocol.py index 4a682aa587..d153647e6d 100644 --- a/src/ai_company/engine/loop_protocol.py +++ b/src/ai_company/engine/loop_protocol.py @@ -2,7 +2,8 @@ Defines the ``ExecutionLoop`` protocol that the agent engine calls to run a task, along with ``ExecutionResult``, ``TurnRecord``, -``TerminationReason``, and the ``BudgetChecker`` type alias. +``TerminationReason``, and the ``BudgetChecker`` and ``ShutdownChecker`` +type aliases. """ from collections.abc import Callable @@ -11,6 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator +from ai_company.core.task import Task # noqa: TC001 from ai_company.core.types import NotBlankStr # noqa: TC001 from ai_company.engine.context import AgentContext from ai_company.providers.enums import FinishReason # noqa: TC001 @@ -27,6 +29,7 @@ class TerminationReason(StrEnum): COMPLETED = "completed" MAX_TURNS = "max_turns" BUDGET_EXHAUSTED = "budget_exhausted" + SHUTDOWN = "shutdown" ERROR = "error" @@ -121,6 +124,9 @@ def _validate_error_message(self) -> Self: BudgetChecker = Callable[[AgentContext], bool] """Callback that returns ``True`` when the budget is exhausted.""" +ShutdownChecker = Callable[[], bool] +"""Callback that returns ``True`` when a graceful shutdown has been requested.""" + @runtime_checkable class ExecutionLoop(Protocol): @@ -131,13 +137,14 @@ class ExecutionLoop(Protocol): but all return an ``ExecutionResult`` with a ``TerminationReason``. """ - async def execute( + async def execute( # noqa: PLR0913 self, *, context: AgentContext, provider: CompletionProvider, tool_invoker: ToolInvoker | None = None, budget_checker: BudgetChecker | None = None, + shutdown_checker: ShutdownChecker | None = None, completion_config: CompletionConfig | None = None, ) -> ExecutionResult: """Run the execution loop. @@ -148,6 +155,8 @@ async def execute( tool_invoker: Optional tool invoker for tool execution. budget_checker: Optional callback; returns ``True`` when budget is exhausted. + shutdown_checker: Optional callback; returns ``True`` when + a graceful shutdown has been requested. completion_config: Optional per-execution override for temperature/max_tokens (defaults to identity's model config). @@ -159,3 +168,21 @@ async def execute( def get_loop_type(self) -> str: """Return the loop type identifier (e.g. ``"react"``).""" ... + + +def make_budget_checker(task: Task) -> BudgetChecker | None: + """Create a budget checker if the task has a positive budget limit. + + The returned callable returns ``True`` when accumulated cost meets + or exceeds the limit (budget exhausted), ``False`` otherwise. + Returns ``None`` when there is no positive budget limit. + """ + if task.budget_limit <= 0: + return None + + limit = task.budget_limit + + def _check(ctx: AgentContext) -> bool: + return ctx.accumulated_cost.cost_usd >= limit + + return _check diff --git a/src/ai_company/engine/prompt.py b/src/ai_company/engine/prompt.py index 69a38b550f..893282cbf5 100644 --- a/src/ai_company/engine/prompt.py +++ b/src/ai_company/engine/prompt.py @@ -625,6 +625,41 @@ def _render_and_estimate( # noqa: PLR0913 return content, estimator.estimate_tokens(content) +def build_error_prompt( + identity: AgentIdentity, + agent_id: str, + system_prompt: SystemPrompt | None, +) -> SystemPrompt: + """Return the existing system prompt or a minimal error placeholder. + + Used by the engine when the execution pipeline fails and a + ``SystemPrompt`` was never built (or was partially built). + + Args: + identity: Agent identity for metadata. + agent_id: String agent identifier. + system_prompt: Previously built prompt, or ``None``. + + Returns: + The existing prompt if available, else a minimal placeholder. + """ + if system_prompt is not None: + return system_prompt + return SystemPrompt( + content="", + template_version="error", + estimated_tokens=0, + sections=(), + metadata={ + "agent_id": agent_id, + "name": identity.name, + "role": identity.role, + "department": identity.department, + "level": identity.level.value, + }, + ) + + def format_task_instruction(task: Task) -> str: """Format a task into a user message for the initial conversation. diff --git a/src/ai_company/engine/react_loop.py b/src/ai_company/engine/react_loop.py index 093f408808..832e54ecd3 100644 --- a/src/ai_company/engine/react_loop.py +++ b/src/ai_company/engine/react_loop.py @@ -1,8 +1,9 @@ """ReAct execution loop — think, act, observe. Implements the ``ExecutionLoop`` protocol using the ReAct pattern: -check budget -> call LLM -> record turn -> check for LLM errors -> -update context -> handle completion or execute tools -> repeat. +check shutdown -> check budget -> call LLM -> record turn -> +check for LLM errors -> update context -> handle completion or +(check shutdown -> execute tools) -> repeat. """ from typing import TYPE_CHECKING @@ -11,6 +12,7 @@ from ai_company.observability.events.execution import ( EXECUTION_LOOP_BUDGET_EXHAUSTED, EXECUTION_LOOP_ERROR, + EXECUTION_LOOP_SHUTDOWN, EXECUTION_LOOP_START, EXECUTION_LOOP_TERMINATED, EXECUTION_LOOP_TOOL_CALLS, @@ -29,6 +31,7 @@ from .loop_protocol import ( BudgetChecker, ExecutionResult, + ShutdownChecker, TerminationReason, TurnRecord, ) @@ -54,13 +57,14 @@ def get_loop_type(self) -> str: """Return the loop type identifier.""" return "react" - async def execute( + async def execute( # noqa: PLR0913 self, *, context: AgentContext, provider: CompletionProvider, tool_invoker: ToolInvoker | None = None, budget_checker: BudgetChecker | None = None, + shutdown_checker: ShutdownChecker | None = None, completion_config: CompletionConfig | None = None, ) -> ExecutionResult: """Run the ReAct loop until termination. @@ -70,6 +74,8 @@ async def execute( provider: LLM completion provider. tool_invoker: Optional tool invoker for tool execution. budget_checker: Optional budget exhaustion callback. + shutdown_checker: Optional callback; returns ``True`` when + a graceful shutdown has been requested. completion_config: Optional per-execution config override. Returns: @@ -85,6 +91,10 @@ async def execute( ctx = context while ctx.has_turns_remaining: + shutdown_result = self._check_shutdown(ctx, shutdown_checker, turns) + if shutdown_result is not None: + return shutdown_result + budget_result = self._check_budget(ctx, budget_checker, turns) if budget_result is not None: return budget_result @@ -110,6 +120,7 @@ async def execute( turn_number, turns, tool_invoker, + shutdown_checker, ) if isinstance(result, ExecutionResult): return result @@ -143,13 +154,14 @@ def _prepare_loop( ) return model_id, config, _get_tool_definitions(tool_invoker), [] - async def _process_turn_response( + async def _process_turn_response( # noqa: PLR0913 self, ctx: AgentContext, response: CompletionResponse, turn_number: int, turns: list[TurnRecord], tool_invoker: ToolInvoker | None, + shutdown_checker: ShutdownChecker | None = None, ) -> AgentContext | ExecutionResult: """Check errors, update context, handle completion or tool calls.""" error = self._check_response_errors(ctx, response, turn_number, turns) @@ -171,6 +183,18 @@ async def _process_turn_response( if not response.tool_calls: return self._handle_completion(ctx, response, turns) + # Check shutdown before tool invocations + shutdown_result = self._check_shutdown(ctx, shutdown_checker, turns) + if shutdown_result is not None: + # Tools were not executed — clear tool_calls_made in the + # last TurnRecord so it doesn't overstate what happened. + if turns: + last = turns[-1] + turns[-1] = last.model_copy( + update={"tool_calls_made": ()}, + ) + return shutdown_result + return await self._execute_tool_calls( ctx, tool_invoker, @@ -179,6 +203,42 @@ async def _process_turn_response( turns, ) + def _check_shutdown( + self, + ctx: AgentContext, + shutdown_checker: ShutdownChecker | None, + turns: list[TurnRecord], + ) -> ExecutionResult | None: + """Return a termination result if a shutdown has been requested.""" + if shutdown_checker is None: + return None + try: + shutting_down = shutdown_checker() + except MemoryError, RecursionError: + raise + except Exception as exc: + error_msg = f"Shutdown checker failed: {type(exc).__name__}: {exc}" + logger.exception( + EXECUTION_LOOP_ERROR, + execution_id=ctx.execution_id, + turn=ctx.turn_count, + error=error_msg, + ) + return _build_result( + ctx, + TerminationReason.ERROR, + turns, + error_message=error_msg, + ) + if not shutting_down: + return None + logger.info( + EXECUTION_LOOP_SHUTDOWN, + execution_id=ctx.execution_id, + turn=ctx.turn_count, + ) + return _build_result(ctx, TerminationReason.SHUTDOWN, turns) + def _check_budget( self, ctx: AgentContext, @@ -230,10 +290,16 @@ async def _call_provider( # noqa: PLR0913 turns: list[TurnRecord], ) -> CompletionResponse | ExecutionResult: """Call provider.complete(), returning an error result on failure.""" - logger.debug( + # Estimate input tokens from message character count (rough + # heuristic: ~4 chars per token). The exact count is only + # available *after* the provider call. + char_count = sum(len(m.content or "") for m in ctx.conversation) + logger.info( EXECUTION_LOOP_TURN_START, execution_id=ctx.execution_id, turn=turn_number, + message_count=len(ctx.conversation), + input_token_estimate=char_count // 4, ) try: return await provider.complete( diff --git a/src/ai_company/engine/shutdown.py b/src/ai_company/engine/shutdown.py new file mode 100644 index 0000000000..023f2896e1 --- /dev/null +++ b/src/ai_company/engine/shutdown.py @@ -0,0 +1,459 @@ +"""Graceful shutdown strategy and manager. + +Implements DESIGN_SPEC §6.7 — cooperative timeout strategy for clean +process shutdown. When SIGINT/SIGTERM is received the framework signals +agents to exit at turn boundaries, waits a grace period, force-cancels +stragglers, and runs cleanup callbacks. The *engine* layer is responsible +for transitioning tasks to INTERRUPTED (see ``AgentEngine``). + +The ``ShutdownStrategy`` protocol is pluggable for future strategies. +""" + +import asyncio +import contextlib +import signal +import sys +import time +from collections.abc import Callable, Coroutine, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +if TYPE_CHECKING: + import types + +from pydantic import BaseModel, ConfigDict, Field + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.execution import ( + EXECUTION_SHUTDOWN_CLEANUP, + EXECUTION_SHUTDOWN_CLEANUP_TIMEOUT, + EXECUTION_SHUTDOWN_COMPLETE, + EXECUTION_SHUTDOWN_FORCE_CANCEL, + EXECUTION_SHUTDOWN_GRACE_START, + EXECUTION_SHUTDOWN_MANAGER_CREATED, + EXECUTION_SHUTDOWN_SIGNAL, + EXECUTION_SHUTDOWN_TASK_ERROR, + EXECUTION_SHUTDOWN_TASK_TRACKED, +) + +logger = get_logger(__name__) + +CleanupCallback = Callable[[], Coroutine[Any, Any, None]] +"""Async callback invoked during shutdown cleanup phase.""" + + +class ShutdownResult(BaseModel): + """Outcome of a graceful shutdown sequence. + + Attributes: + strategy_type: Name of the strategy that executed the shutdown. + tasks_interrupted: Number of tasks that were force-cancelled. + tasks_completed: Number of tasks that exited cooperatively. + cleanup_completed: Whether all cleanup callbacks finished + within the allowed time. + duration_seconds: Wall-clock duration of the entire shutdown. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + strategy_type: NotBlankStr = Field( + description="Name of the strategy that executed the shutdown", + ) + tasks_interrupted: int = Field( + ge=0, + description=( + "Number of tasks still running after the grace period " + "that were force-cancelled" + ), + ) + tasks_completed: int = Field( + ge=0, + description="Number of tasks that exited cooperatively", + ) + cleanup_completed: bool = Field( + description="Whether all cleanup callbacks finished in time", + ) + duration_seconds: float = Field( + ge=0.0, + description="Wall-clock duration of the shutdown sequence", + ) + + +@runtime_checkable +class ShutdownStrategy(Protocol): + """Protocol for pluggable shutdown strategies.""" + + def request_shutdown(self) -> None: + """Signal that a graceful shutdown has been requested.""" + ... + + def is_shutting_down(self) -> bool: + """Return ``True`` when shutdown has been requested.""" + ... + + async def execute_shutdown( + self, + *, + running_tasks: Mapping[str, asyncio.Task[Any]], + cleanup_callbacks: Sequence[CleanupCallback], + ) -> ShutdownResult: + """Execute the full shutdown sequence. + + Args: + running_tasks: Map of task_id → asyncio.Task for in-flight + agent executions. + cleanup_callbacks: Ordered sequence of async cleanup callbacks + to invoke after task shutdown. + + Returns: + Outcome of the shutdown sequence. + """ + ... + + def get_strategy_type(self) -> str: + """Return the strategy identifier (e.g. ``"cooperative_timeout"``).""" + ... + + +class CooperativeTimeoutStrategy: + """Cooperative timeout shutdown strategy. + + 1. Set shutdown event (signal agents via turn-boundary checks). + 2. Wait up to ``grace_seconds`` for tasks to exit cooperatively. + 3. Force-cancel any remaining tasks. + 4. Run cleanup callbacks within ``cleanup_seconds``. + """ + + def __init__( + self, + *, + grace_seconds: float = 30.0, + cleanup_seconds: float = 5.0, + ) -> None: + if grace_seconds <= 0: + msg = f"grace_seconds must be positive, got {grace_seconds}" + raise ValueError(msg) + if cleanup_seconds <= 0: + msg = f"cleanup_seconds must be positive, got {cleanup_seconds}" + raise ValueError(msg) + self._grace_seconds = grace_seconds + self._cleanup_seconds = cleanup_seconds + self._shutdown_event = asyncio.Event() + + def request_shutdown(self) -> None: + """Signal that a graceful shutdown has been requested.""" + self._shutdown_event.set() + + def is_shutting_down(self) -> bool: + """Return ``True`` when shutdown has been requested.""" + return self._shutdown_event.is_set() + + def get_strategy_type(self) -> str: + """Return the strategy identifier.""" + return "cooperative_timeout" + + async def execute_shutdown( + self, + *, + running_tasks: Mapping[str, asyncio.Task[Any]], + cleanup_callbacks: Sequence[CleanupCallback], + ) -> ShutdownResult: + """Execute the cooperative timeout shutdown sequence.""" + start = time.monotonic() + + self._shutdown_event.set() + logger.info( + EXECUTION_SHUTDOWN_GRACE_START, + grace_seconds=self._grace_seconds, + running_tasks=len(running_tasks), + ) + + tasks_completed, tasks_interrupted = await self._wait_and_cancel( + running_tasks, + ) + + cleanup_completed = await self._run_cleanup(cleanup_callbacks) + + duration = time.monotonic() - start + result = ShutdownResult( + strategy_type=self.get_strategy_type(), + tasks_interrupted=tasks_interrupted, + tasks_completed=tasks_completed, + cleanup_completed=cleanup_completed, + duration_seconds=duration, + ) + logger.info( + EXECUTION_SHUTDOWN_COMPLETE, + strategy=result.strategy_type, + tasks_interrupted=result.tasks_interrupted, + tasks_completed=result.tasks_completed, + cleanup_completed=result.cleanup_completed, + duration_seconds=result.duration_seconds, + ) + return result + + _CANCEL_PROPAGATION_TIMEOUT: float = 5.0 + """Seconds to wait for cancellation to propagate after force-cancel.""" + + async def _wait_and_cancel( + self, + running_tasks: Mapping[str, asyncio.Task[Any]], + ) -> tuple[int, int]: + """Wait for cooperative exit, then force-cancel stragglers. + + Returns: + Tuple of (tasks_completed, tasks_interrupted). + """ + if not running_tasks: + return 0, 0 + + task_set = set(running_tasks.values()) + done, pending = await asyncio.wait( + task_set, + timeout=self._grace_seconds, + ) + + # Retrieve exceptions from done tasks to prevent + # "Task exception was never retrieved" warnings. + # Tasks that raised are not counted as "completed" — only + # cleanly-finished tasks count. + tasks_completed = 0 + for task in done: + if task.cancelled(): + continue + exc = task.exception() + if exc is not None: + logger.warning( + EXECUTION_SHUTDOWN_TASK_ERROR, + error=(f"Task raised during shutdown: {type(exc).__name__}"), + ) + else: + tasks_completed += 1 + + if pending: + logger.warning( + EXECUTION_SHUTDOWN_FORCE_CANCEL, + pending_tasks=len(pending), + ) + for task in pending: + task.cancel() + # Wait for cancellation to propagate (bounded). + # Retrieve exceptions to suppress "never retrieved" warnings. + cancel_done, _ = await asyncio.wait( + pending, + timeout=self._CANCEL_PROPAGATION_TIMEOUT, + ) + for task in cancel_done: + if not task.cancelled(): + with contextlib.suppress(Exception): + task.exception() + + return tasks_completed, len(pending) + + async def _run_cleanup( + self, + callbacks: Sequence[CleanupCallback], + ) -> bool: + """Run cleanup callbacks sequentially within the time budget. + + Returns: + ``True`` if all callbacks completed successfully within the + time budget, ``False`` otherwise. + """ + if not callbacks: + return True + + logger.info( + EXECUTION_SHUTDOWN_CLEANUP, + callback_count=len(callbacks), + cleanup_seconds=self._cleanup_seconds, + ) + + all_succeeded = True + + async def _run_all() -> None: + nonlocal all_succeeded + for i, callback in enumerate(callbacks): + try: + await callback() + except Exception: + all_succeeded = False + logger.exception( + EXECUTION_SHUTDOWN_CLEANUP, + callback_index=i, + callback_count=len(callbacks), + error="Cleanup callback failed", + ) + + try: + await asyncio.wait_for( + _run_all(), + timeout=self._cleanup_seconds, + ) + except TimeoutError: + logger.warning( + EXECUTION_SHUTDOWN_CLEANUP_TIMEOUT, + cleanup_seconds=self._cleanup_seconds, + ) + return False + return all_succeeded + + +class ShutdownManager: + """Manages signal handling, task tracking, and shutdown orchestration. + + Separates OS signal handling from shutdown strategy logic. + + Args: + strategy: Shutdown strategy implementation. Defaults to + ``CooperativeTimeoutStrategy()``. + """ + + def __init__( + self, + strategy: ShutdownStrategy | None = None, + ) -> None: + self._strategy: ShutdownStrategy = strategy or CooperativeTimeoutStrategy() + self._running_tasks: dict[str, asyncio.Task[Any]] = {} + self._cleanup_callbacks: list[CleanupCallback] = [] + self._signals_installed = False + logger.debug( + EXECUTION_SHUTDOWN_MANAGER_CREATED, + strategy=self._strategy.get_strategy_type(), + ) + + @property + def strategy(self) -> ShutdownStrategy: + """The configured shutdown strategy.""" + return self._strategy + + def install_signal_handlers(self) -> None: + """Install SIGINT/SIGTERM handlers. + + On Unix uses ``loop.add_signal_handler``. + On Windows uses ``signal.signal`` with ``call_soon_threadsafe``. + """ + if self._signals_installed: + return + + if sys.platform != "win32": + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, self._handle_signal, sig) + else: + for sig in (signal.SIGINT, signal.SIGTERM): + signal.signal(sig, self._handle_signal_threadsafe) + + self._signals_installed = True + + def _handle_signal(self, sig: signal.Signals) -> None: + """Handle signal on Unix (called in event loop thread).""" + logger.info( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig.name, + ) + try: + self._strategy.request_shutdown() + except Exception: + logger.exception( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig.name, + error="request_shutdown() raised in signal handler", + ) + + def _handle_signal_threadsafe( + self, + signum: int, + _frame: types.FrameType | None, + ) -> None: + """Handle signal on Windows (called outside the event loop context). + + Logging is deferred to the event loop via ``call_soon_threadsafe`` + to avoid deadlocks (structlog acquires locks internally). + """ + try: + sig_name = signal.Signals(signum).name + except ValueError: + sig_name = f"UNKNOWN({signum})" + + def _on_loop() -> None: + try: + logger.info( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig_name, + ) + self._strategy.request_shutdown() + except Exception: + logger.exception( + EXECUTION_SHUTDOWN_SIGNAL, + signal=sig_name, + error="request_shutdown() raised in signal handler", + ) + + try: + loop = asyncio.get_running_loop() + loop.call_soon_threadsafe(_on_loop) + except RuntimeError: + # No running event loop — call directly (best-effort). + # Cannot log safely without a loop, so suppress all errors. + with contextlib.suppress(Exception): + self._strategy.request_shutdown() + + def register_task( + self, + task_id: str, + asyncio_task: asyncio.Task[Any], + ) -> None: + """Track a running agent task. + + Raises: + RuntimeError: If shutdown has already been requested (drain + gate is closed). + """ + if self._strategy.is_shutting_down(): + msg = f"Cannot register task {task_id!r}: shutdown already in progress" + raise RuntimeError(msg) + if task_id in self._running_tasks: + logger.warning( + EXECUTION_SHUTDOWN_TASK_TRACKED, + action="task_overwritten", + task_id=task_id, + ) + self._running_tasks[task_id] = asyncio_task + logger.debug( + EXECUTION_SHUTDOWN_TASK_TRACKED, + action="task_registered", + task_id=task_id, + running_tasks=len(self._running_tasks), + ) + + def unregister_task(self, task_id: str) -> None: + """Stop tracking a completed agent task.""" + self._running_tasks.pop(task_id, None) + logger.debug( + EXECUTION_SHUTDOWN_TASK_TRACKED, + action="task_unregistered", + task_id=task_id, + running_tasks=len(self._running_tasks), + ) + + def register_cleanup(self, callback: CleanupCallback) -> None: + """Register an async cleanup callback for shutdown. + + Callbacks run sequentially in registration order during + shutdown. Each callback is individually guarded against + exceptions — a failing callback does not prevent subsequent + ones from running. + """ + self._cleanup_callbacks.append(callback) + + def is_shutting_down(self) -> bool: + """Delegate to the strategy's shutdown check.""" + return self._strategy.is_shutting_down() + + async def initiate_shutdown(self) -> ShutdownResult: + """Invoke the strategy's shutdown sequence.""" + return await self._strategy.execute_shutdown( + running_tasks=dict(self._running_tasks), + cleanup_callbacks=list(self._cleanup_callbacks), + ) diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index 4542ffa3a6..d4f35be7e5 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -35,6 +35,17 @@ EXECUTION_ENGINE_TASK_METRICS: Final[str] = "execution.engine.task_metrics" EXECUTION_ENGINE_TIMEOUT: Final[str] = "execution.engine.timeout" +EXECUTION_SHUTDOWN_SIGNAL: Final[str] = "execution.shutdown.signal" +EXECUTION_SHUTDOWN_MANAGER_CREATED: Final[str] = "execution.shutdown.manager_created" +EXECUTION_SHUTDOWN_TASK_TRACKED: Final[str] = "execution.shutdown.task_tracked" +EXECUTION_SHUTDOWN_TASK_ERROR: Final[str] = "execution.shutdown.task_error" +EXECUTION_SHUTDOWN_GRACE_START: Final[str] = "execution.shutdown.grace_start" +EXECUTION_SHUTDOWN_FORCE_CANCEL: Final[str] = "execution.shutdown.force_cancel" +EXECUTION_SHUTDOWN_CLEANUP: Final[str] = "execution.shutdown.cleanup" +EXECUTION_SHUTDOWN_CLEANUP_TIMEOUT: Final[str] = "execution.shutdown.cleanup.timeout" +EXECUTION_SHUTDOWN_COMPLETE: Final[str] = "execution.shutdown.complete" +EXECUTION_LOOP_SHUTDOWN: Final[str] = "execution.loop.shutdown" + EXECUTION_RECOVERY_START: Final[str] = "execution.recovery.start" EXECUTION_RECOVERY_COMPLETE: Final[str] = "execution.recovery.complete" EXECUTION_RECOVERY_FAILED: Final[str] = "execution.recovery.failed" diff --git a/tests/integration/engine/test_graceful_shutdown.py b/tests/integration/engine/test_graceful_shutdown.py new file mode 100644 index 0000000000..7722885900 --- /dev/null +++ b/tests/integration/engine/test_graceful_shutdown.py @@ -0,0 +1,302 @@ +"""Integration test — full graceful shutdown flow. + +Creates an engine with a shutdown manager, starts an agent, triggers +shutdown, and verifies: agent stops, task is INTERRUPTED, cleanup runs. +""" + +from typing import Any + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.agent_engine import AgentEngine +from ai_company.engine.shutdown import ( + CooperativeTimeoutStrategy, + ShutdownManager, +) +from ai_company.providers.enums import FinishReason +from ai_company.providers.models import ( + ChatMessage, + CompletionConfig, + CompletionResponse, + TokenUsage, + ToolDefinition, +) + +pytestmark = pytest.mark.timeout(30) + + +class _ShutdownTriggeringProvider: + """Provider that triggers shutdown on the first call. + + The first call triggers shutdown and returns a STOP response, + so the loop completes before checking the shutdown flag again. + """ + + def __init__(self, strategy: CooperativeTimeoutStrategy) -> None: + self._strategy = strategy + self._call_count = 0 + + async def complete( + self, + messages: list[ChatMessage], + model: str, + *, + tools: list[ToolDefinition] | None = None, + config: CompletionConfig | None = None, + ) -> CompletionResponse: + self._call_count += 1 + if self._call_count == 1: + # Trigger shutdown after first LLM call + self._strategy.request_shutdown() + return CompletionResponse( + content="Working on it.", + finish_reason=FinishReason.STOP, + usage=TokenUsage( + input_tokens=50, + output_tokens=25, + cost_usd=0.005, + ), + model="test-model-001", + ) + + async def stream( + self, + messages: list[ChatMessage], + model: str, + *, + tools: list[ToolDefinition] | None = None, + config: CompletionConfig | None = None, + ) -> Any: + msg = "Not implemented" + raise NotImplementedError(msg) + + async def get_model_capabilities(self, model: str) -> Any: + from ai_company.providers.capabilities import ModelCapabilities + + return ModelCapabilities( + model_id=model, + provider="test-provider", + supports_tools=False, + supports_streaming=False, + max_context_tokens=8192, + max_output_tokens=4096, + cost_per_1k_input=0.01, + cost_per_1k_output=0.03, + ) + + +@pytest.mark.integration +class TestGracefulShutdownFlow: + """Full shutdown integration: engine + strategy + agent → INTERRUPTED.""" + + async def test_shutdown_signal_propagates_through_manager( + self, + ) -> None: + """Shutdown signal during execution propagates through manager. + + The provider triggers shutdown on the first call but returns + STOP (no tool calls), so the loop completes *before* the next + shutdown check. This verifies the signal propagation path — + test_shutdown_during_multi_turn_interrupts below covers the + INTERRUPTED transition. + """ + from datetime import date + from uuid import uuid4 + + from ai_company.core.agent import AgentIdentity, ModelConfig + from ai_company.core.enums import ( + Complexity, + Priority, + SeniorityLevel, + TaskType, + ) + from ai_company.core.task import Task + + identity = AgentIdentity( + id=uuid4(), + name="Test Agent", + role="Developer", + department="Engineering", + level=SeniorityLevel.MID, + model=ModelConfig( + provider="test-provider", + model_id="test-model-001", + ), + hiring_date=date(2026, 1, 1), + ) + + task = Task( + id="task-shutdown-001", + title="Task for shutdown test", + description="This task will complete before shutdown check.", + type=TaskType.DEVELOPMENT, + priority=Priority.MEDIUM, + project="proj-001", + created_by="test", + estimated_complexity=Complexity.SIMPLE, + budget_limit=10.0, + assigned_to=str(identity.id), + status=TaskStatus.ASSIGNED, + ) + + strategy = CooperativeTimeoutStrategy(grace_seconds=5.0) + manager = ShutdownManager(strategy=strategy) + + provider = _ShutdownTriggeringProvider(strategy) + + engine = AgentEngine( + provider=provider, + shutdown_checker=manager.is_shutting_down, + ) + + result = await engine.run( + identity=identity, + task=task, + ) + + # Loop completed normally (STOP with no tool calls) + assert result.is_success is True + # But the shutdown signal was set + assert manager.is_shutting_down() is True + + async def test_shutdown_during_multi_turn_interrupts( + self, + ) -> None: + """Multi-turn execution interrupted by shutdown → INTERRUPTED.""" + from datetime import date + from uuid import uuid4 + + from ai_company.core.agent import AgentIdentity, ModelConfig + from ai_company.core.enums import ( + Complexity, + Priority, + SeniorityLevel, + TaskType, + ) + from ai_company.core.task import Task + + identity = AgentIdentity( + id=uuid4(), + name="Test Agent", + role="Developer", + department="Engineering", + level=SeniorityLevel.MID, + model=ModelConfig( + provider="test-provider", + model_id="test-model-001", + ), + hiring_date=date(2026, 1, 1), + ) + + task = Task( + id="task-shutdown-002", + title="Multi-turn shutdown test", + description="This task will be interrupted mid-execution.", + type=TaskType.DEVELOPMENT, + priority=Priority.MEDIUM, + project="proj-001", + created_by="test", + estimated_complexity=Complexity.SIMPLE, + budget_limit=10.0, + assigned_to=str(identity.id), + status=TaskStatus.ASSIGNED, + ) + + check_count = 0 + + def shutdown_checker() -> bool: + nonlocal check_count + check_count += 1 + # Let first two checks pass (top of loop + before tools + # on turn 1), then shutdown on third check (top of loop + # on turn 2) + return check_count > 2 + + # Provider returns tool-use on first call so the loop iterates + from ai_company.providers.models import ToolCall + + responses = [ + CompletionResponse( + content=None, + tool_calls=(ToolCall(id="tc-1", name="echo", arguments={}),), + finish_reason=FinishReason.TOOL_USE, + usage=TokenUsage( + input_tokens=50, + output_tokens=25, + cost_usd=0.005, + ), + model="test-model-001", + ), + ] + + class _MultiTurnProvider: + def __init__(self) -> None: + self._idx = 0 + + async def complete( + self, + messages: Any, + model: Any, + **kw: Any, + ) -> CompletionResponse: + if self._idx < len(responses): + resp = responses[self._idx] + self._idx += 1 + return resp + msg = "No more responses" + raise IndexError(msg) + + async def stream(self, *a: Any, **kw: Any) -> Any: + raise NotImplementedError + + async def get_model_capabilities(self, model: str) -> Any: + from ai_company.providers.capabilities import ModelCapabilities + + return ModelCapabilities( + model_id=model, + provider="test-provider", + supports_tools=True, + supports_streaming=False, + max_context_tokens=8192, + max_output_tokens=4096, + cost_per_1k_input=0.01, + cost_per_1k_output=0.03, + ) + + from ai_company.core.enums import ToolCategory + from ai_company.tools.base import BaseTool, ToolExecutionResult + from ai_company.tools.registry import ToolRegistry + + class _EchoTool(BaseTool): + def __init__(self) -> None: + super().__init__( + name="echo", + description="Echo tool", + category=ToolCategory.CODE_EXECUTION, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult(content="echoed", is_error=False) + + registry = ToolRegistry([_EchoTool()]) + + engine = AgentEngine( + provider=_MultiTurnProvider(), + tool_registry=registry, + shutdown_checker=shutdown_checker, + ) + + result = await engine.run( + identity=identity, + task=task, + ) + + te = result.execution_result.context.task_execution + assert te is not None + assert te.status == TaskStatus.INTERRUPTED + assert result.execution_result.termination_reason.value == "shutdown" diff --git a/tests/unit/core/test_enums.py b/tests/unit/core/test_enums.py index b28363486d..b58840e79b 100644 --- a/tests/unit/core/test_enums.py +++ b/tests/unit/core/test_enums.py @@ -58,8 +58,8 @@ def test_proficiency_level_has_4_members(self) -> None: def test_department_name_has_9_members(self) -> None: assert len(DepartmentName) == 9 - def test_task_status_has_8_members(self) -> None: - assert len(TaskStatus) == 8 + def test_task_status_has_9_members(self) -> None: + assert len(TaskStatus) == 9 def test_task_type_has_6_members(self) -> None: assert len(TaskType) == 6 @@ -111,6 +111,7 @@ def test_task_status_values(self) -> None: assert TaskStatus.BLOCKED.value == "blocked" assert TaskStatus.FAILED.value == "failed" assert TaskStatus.CANCELLED.value == "cancelled" + assert TaskStatus.INTERRUPTED.value == "interrupted" def test_task_type_values(self) -> None: assert TaskType.DEVELOPMENT.value == "development" diff --git a/tests/unit/core/test_task_transitions.py b/tests/unit/core/test_task_transitions.py index 59e3864bb6..053baa830a 100644 --- a/tests/unit/core/test_task_transitions.py +++ b/tests/unit/core/test_task_transitions.py @@ -34,8 +34,11 @@ class TestValidTransitions: (TaskStatus.IN_REVIEW, TaskStatus.IN_PROGRESS), (TaskStatus.IN_REVIEW, TaskStatus.BLOCKED), (TaskStatus.IN_REVIEW, TaskStatus.CANCELLED), + (TaskStatus.ASSIGNED, TaskStatus.INTERRUPTED), + (TaskStatus.IN_PROGRESS, TaskStatus.INTERRUPTED), (TaskStatus.BLOCKED, TaskStatus.ASSIGNED), (TaskStatus.FAILED, TaskStatus.ASSIGNED), + (TaskStatus.INTERRUPTED, TaskStatus.ASSIGNED), ], ids=lambda p: p.value if isinstance(p, TaskStatus) else str(p), ) @@ -61,6 +64,8 @@ class TestInvalidTransitions: (TaskStatus.IN_PROGRESS, TaskStatus.ASSIGNED), (TaskStatus.FAILED, TaskStatus.COMPLETED), (TaskStatus.FAILED, TaskStatus.IN_PROGRESS), + (TaskStatus.INTERRUPTED, TaskStatus.COMPLETED), + (TaskStatus.INTERRUPTED, TaskStatus.IN_PROGRESS), ], ids=lambda p: p.value if isinstance(p, TaskStatus) else str(p), ) @@ -112,6 +117,10 @@ def test_failed_is_non_terminal(self) -> None: """FAILED has outgoing transitions (reassignment).""" assert len(VALID_TRANSITIONS[TaskStatus.FAILED]) > 0 + def test_interrupted_is_non_terminal(self) -> None: + """INTERRUPTED has outgoing transitions (reassignment on restart).""" + assert len(VALID_TRANSITIONS[TaskStatus.INTERRUPTED]) > 0 + def test_all_targets_are_valid_statuses(self) -> None: """Every target in the transition map must be a valid TaskStatus.""" for source, targets in VALID_TRANSITIONS.items(): diff --git a/tests/unit/engine/test_agent_engine_lifecycle.py b/tests/unit/engine/test_agent_engine_lifecycle.py index 5a5570e7e8..8be682efc2 100644 --- a/tests/unit/engine/test_agent_engine_lifecycle.py +++ b/tests/unit/engine/test_agent_engine_lifecycle.py @@ -197,6 +197,77 @@ async def test_error_transitions_to_failed( assert te is not None assert te.status == TaskStatus.FAILED + async def test_shutdown_transitions_to_interrupted( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """SHUTDOWN → task transitions to INTERRUPTED.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + ) + mock_result = ExecutionResult( + context=ctx, + termination_reason=TerminationReason.SHUTDOWN, + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") + + provider = mock_provider_factory([]) + engine = AgentEngine(provider=provider, execution_loop=mock_loop) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + te = result.execution_result.context.task_execution + assert te is not None + assert te.status == TaskStatus.INTERRUPTED + + async def test_shutdown_from_assigned_transitions_to_interrupted( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """SHUTDOWN before loop starts → ASSIGNED → INTERRUPTED.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + # Simulate the loop returning SHUTDOWN while still ASSIGNED + # (edge case: shutdown signal between assignment and IP transition) + mock_result = ExecutionResult( + context=ctx, + termination_reason=TerminationReason.SHUTDOWN, + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") + + provider = mock_provider_factory([]) + engine = AgentEngine(provider=provider, execution_loop=mock_loop) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + te = result.execution_result.context.task_execution + assert te is not None + # The engine's _prepare_context transitions ASSIGNED→IP, + # but the mock loop returns a context still at ASSIGNED. + # The engine's _apply_post_execution_transitions handles this. + assert te.status == TaskStatus.INTERRUPTED + async def test_no_task_execution_passes_through( self, sample_agent_with_personality: AgentIdentity, @@ -452,3 +523,45 @@ async def test_transition_failure_preserves_result( te = result.execution_result.context.task_execution assert te is not None assert te.status == TaskStatus.CANCELLED + + async def test_interrupted_transition_failure_preserves_result( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """SHUTDOWN with invalid task status → transition fails, result kept.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + ) + te = ctx.task_execution + assert te is not None + # Force into COMPLETED (terminal) — INTERRUPTED transition should fail + bad_te = te.model_copy(update={"status": TaskStatus.COMPLETED}) + ctx_bad = ctx.model_copy(update={"task_execution": bad_te}) + + mock_result = ExecutionResult( + context=ctx_bad, + termination_reason=TerminationReason.SHUTDOWN, + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") + + provider = mock_provider_factory([]) + engine = AgentEngine(provider=provider, execution_loop=mock_loop) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + te = result.execution_result.context.task_execution + assert te is not None + # Transition failed, so status stays as COMPLETED + assert te.status == TaskStatus.COMPLETED diff --git a/tests/unit/engine/test_loop_protocol.py b/tests/unit/engine/test_loop_protocol.py index 3b41fdf0b2..497b5f5a8a 100644 --- a/tests/unit/engine/test_loop_protocol.py +++ b/tests/unit/engine/test_loop_protocol.py @@ -22,10 +22,11 @@ def test_values(self) -> None: assert TerminationReason.COMPLETED.value == "completed" assert TerminationReason.MAX_TURNS.value == "max_turns" assert TerminationReason.BUDGET_EXHAUSTED.value == "budget_exhausted" + assert TerminationReason.SHUTDOWN.value == "shutdown" assert TerminationReason.ERROR.value == "error" def test_member_count(self) -> None: - assert len(TerminationReason) == 4 + assert len(TerminationReason) == 5 @pytest.mark.unit diff --git a/tests/unit/engine/test_react_loop.py b/tests/unit/engine/test_react_loop.py index fa3489d270..026bf90d94 100644 --- a/tests/unit/engine/test_react_loop.py +++ b/tests/unit/engine/test_react_loop.py @@ -858,6 +858,158 @@ async def test_empty_registry_passes_no_tools( assert result.termination_reason == TerminationReason.COMPLETED +@pytest.mark.unit +class TestReactLoopShutdown: + """Shutdown checker triggers loop termination.""" + + async def test_shutdown_before_first_turn( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([]) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + shutdown_checker=lambda: True, # shutdown immediately + ) + + assert result.termination_reason == TerminationReason.SHUTDOWN + assert len(result.turns) == 0 + assert provider.call_count == 0 + + async def test_shutdown_after_first_turn( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + call_count = 0 + + def shutdown_check() -> bool: + nonlocal call_count + call_count += 1 + # Shutdown on third check (after first turn: check at top, + # check before tools, check at top of next iteration) + return call_count > 2 + + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + shutdown_checker=shutdown_check, + ) + + assert result.termination_reason == TerminationReason.SHUTDOWN + assert len(result.turns) == 1 + + async def test_shutdown_before_tool_execution( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Shutdown detected between LLM response and tool invocation.""" + ctx = _ctx_with_user_msg(sample_agent_context) + call_count = 0 + + def shutdown_check() -> bool: + nonlocal call_count + call_count += 1 + # First check (top of loop) passes, second check (before + # tools) triggers shutdown + return call_count > 1 + + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + shutdown_checker=shutdown_check, + ) + + assert result.termination_reason == TerminationReason.SHUTDOWN + # Turn was recorded (LLM was called), but tools were not executed + assert len(result.turns) == 1 + + async def test_no_shutdown_checker_runs_normally( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_stop_response("Done.")]) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + shutdown_checker=None, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + + async def test_shutdown_checker_exception_returns_error( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Shutdown checker that raises → ERROR termination.""" + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([]) + loop = ReactLoop() + + def bad_checker() -> bool: + msg = "checker broke" + raise ValueError(msg) + + result = await loop.execute( + context=ctx, + provider=provider, + shutdown_checker=bad_checker, + ) + + assert result.termination_reason == TerminationReason.ERROR + assert "Shutdown checker failed" in (result.error_message or "") + + async def test_shutdown_checker_memory_error_propagates( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """MemoryError from shutdown checker propagates unconditionally.""" + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([]) + loop = ReactLoop() + + def oom_checker() -> bool: + raise MemoryError + + with pytest.raises(MemoryError): + await loop.execute( + context=ctx, + provider=provider, + shutdown_checker=oom_checker, + ) + + @pytest.mark.unit class TestReactLoopCostAccounting: """Error responses include the failing turn's cost in context.""" diff --git a/tests/unit/engine/test_run_result.py b/tests/unit/engine/test_run_result.py index 292305cdf2..200bf609ff 100644 --- a/tests/unit/engine/test_run_result.py +++ b/tests/unit/engine/test_run_result.py @@ -9,12 +9,12 @@ from ai_company.core.agent import AgentIdentity, ModelConfig from ai_company.core.enums import Priority, SeniorityLevel, TaskStatus, TaskType from ai_company.core.task import Task -from ai_company.engine.agent_engine import _make_budget_checker from ai_company.engine.context import AgentContext from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, TurnRecord, + make_budget_checker, ) from ai_company.engine.prompt import SystemPrompt, format_task_instruction from ai_company.engine.run_result import AgentRunResult @@ -158,6 +158,12 @@ def test_is_success_false_on_budget(self) -> None: ) assert result.is_success is False + def test_is_success_false_on_shutdown(self) -> None: + result = _make_run_result( + termination_reason=TerminationReason.SHUTDOWN, + ) + assert result.is_success is False + @pytest.mark.unit class TestAgentRunResultValidation: @@ -290,7 +296,7 @@ def test_budget_only_no_deadline(self) -> None: @pytest.mark.unit class TestMakeBudgetChecker: - """Test _make_budget_checker closure logic.""" + """Test make_budget_checker closure logic.""" def test_returns_none_for_zero_budget(self) -> None: task = Task( @@ -304,7 +310,7 @@ def test_returns_none_for_zero_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=0.0, ) - assert _make_budget_checker(task) is None + assert make_budget_checker(task) is None def test_returns_none_for_default_budget(self) -> None: """Default budget_limit (0.0) returns None.""" @@ -318,7 +324,7 @@ def test_returns_none_for_default_budget(self) -> None: assigned_to="someone", status=TaskStatus.ASSIGNED, ) - assert _make_budget_checker(task) is None + assert make_budget_checker(task) is None def test_returns_callable_for_positive_budget(self) -> None: task = Task( @@ -332,7 +338,7 @@ def test_returns_callable_for_positive_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=5.0, ) - checker = _make_budget_checker(task) + checker = make_budget_checker(task) assert checker is not None assert callable(checker) @@ -348,7 +354,7 @@ def test_checker_returns_false_under_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=5.0, ) - checker = _make_budget_checker(task) + checker = make_budget_checker(task) assert checker is not None identity = _test_identity() @@ -377,7 +383,7 @@ def test_checker_returns_true_at_exact_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=5.0, ) - checker = _make_budget_checker(task) + checker = make_budget_checker(task) assert checker is not None identity = _test_identity() @@ -405,7 +411,7 @@ def test_checker_returns_true_over_budget(self) -> None: status=TaskStatus.ASSIGNED, budget_limit=5.0, ) - checker = _make_budget_checker(task) + checker = make_budget_checker(task) assert checker is not None identity = _test_identity() diff --git a/tests/unit/engine/test_shutdown.py b/tests/unit/engine/test_shutdown.py new file mode 100644 index 0000000000..cd8a4112fb --- /dev/null +++ b/tests/unit/engine/test_shutdown.py @@ -0,0 +1,414 @@ +"""Tests for the graceful shutdown strategy and manager.""" + +import asyncio +import signal +import sys +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from ai_company.config.schema import GracefulShutdownConfig +from ai_company.engine.shutdown import ( + CooperativeTimeoutStrategy, + ShutdownManager, + ShutdownResult, + ShutdownStrategy, +) + +pytestmark = pytest.mark.timeout(30) + + +# ── Protocol compliance ────────────────────────────────────────── + + +@pytest.mark.unit +class TestShutdownStrategyProtocol: + """CooperativeTimeoutStrategy satisfies ShutdownStrategy protocol.""" + + def test_is_runtime_checkable(self) -> None: + strategy = CooperativeTimeoutStrategy() + assert isinstance(strategy, ShutdownStrategy) + + def test_result_model_is_frozen(self) -> None: + result = ShutdownResult( + strategy_type="cooperative_timeout", + tasks_interrupted=0, + tasks_completed=0, + cleanup_completed=True, + duration_seconds=0.1, + ) + with pytest.raises(ValidationError, match="frozen"): + result.tasks_interrupted = 5 # type: ignore[misc] + + +# ── Request / check ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutRequestShutdown: + """request_shutdown + is_shutting_down event toggling.""" + + def test_not_shutting_down_initially(self) -> None: + strategy = CooperativeTimeoutStrategy() + assert strategy.is_shutting_down() is False + + def test_request_sets_shutting_down(self) -> None: + strategy = CooperativeTimeoutStrategy() + strategy.request_shutdown() + assert strategy.is_shutting_down() is True + + def test_idempotent_request(self) -> None: + strategy = CooperativeTimeoutStrategy() + strategy.request_shutdown() + strategy.request_shutdown() + assert strategy.is_shutting_down() is True + + def test_get_strategy_type(self) -> None: + strategy = CooperativeTimeoutStrategy() + assert strategy.get_strategy_type() == "cooperative_timeout" + + +# ── execute_shutdown — cooperative exit ────────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutExecuteCooperative: + """Tasks that check the shutdown event and exit cooperatively.""" + + async def test_all_tasks_exit_cooperatively(self) -> None: + strategy = CooperativeTimeoutStrategy(grace_seconds=5.0) + shutdown_event = strategy._shutdown_event + + async def cooperative_task() -> None: + await shutdown_event.wait() + + task = asyncio.create_task(cooperative_task()) + result = await strategy.execute_shutdown( + running_tasks={"t1": task}, + cleanup_callbacks=[], + ) + + assert result.tasks_completed == 1 + assert result.tasks_interrupted == 0 + assert result.cleanup_completed is True + assert result.strategy_type == "cooperative_timeout" + assert result.duration_seconds > 0 + + async def test_empty_tasks(self) -> None: + strategy = CooperativeTimeoutStrategy() + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[], + ) + assert result.tasks_completed == 0 + assert result.tasks_interrupted == 0 + + +# ── execute_shutdown — force cancel ────────────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutForceCancel: + """Tasks that ignore the shutdown event are force-cancelled.""" + + async def test_stubborn_task_is_force_cancelled(self) -> None: + strategy = CooperativeTimeoutStrategy(grace_seconds=0.1) + + async def stubborn_task() -> None: + await asyncio.sleep(100) # ignores shutdown + + task = asyncio.create_task(stubborn_task()) + result = await strategy.execute_shutdown( + running_tasks={"t1": task}, + cleanup_callbacks=[], + ) + + assert result.tasks_completed == 0 + assert result.tasks_interrupted == 1 + + async def test_mixed_cooperative_and_stubborn(self) -> None: + strategy = CooperativeTimeoutStrategy(grace_seconds=0.1) + shutdown_event = strategy._shutdown_event + + async def cooperative() -> None: + await shutdown_event.wait() + + async def stubborn() -> None: + await asyncio.sleep(100) + + t1 = asyncio.create_task(cooperative()) + t2 = asyncio.create_task(stubborn()) + result = await strategy.execute_shutdown( + running_tasks={"t1": t1, "t2": t2}, + cleanup_callbacks=[], + ) + + assert result.tasks_completed + result.tasks_interrupted == 2 + assert result.tasks_interrupted >= 1 + + +# ── execute_shutdown — cleanup callbacks ───────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutCleanup: + """Cleanup callbacks run within the time budget.""" + + async def test_cleanup_callbacks_run(self) -> None: + strategy = CooperativeTimeoutStrategy(cleanup_seconds=5.0) + ran = [] + + async def cb1() -> None: + ran.append("cb1") + + async def cb2() -> None: + ran.append("cb2") + + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[cb1, cb2], + ) + + assert ran == ["cb1", "cb2"] + assert result.cleanup_completed is True + + async def test_cleanup_timeout(self) -> None: + strategy = CooperativeTimeoutStrategy(cleanup_seconds=0.1) + + async def slow_callback() -> None: + await asyncio.sleep(100) + + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[slow_callback], + ) + + assert result.cleanup_completed is False + + async def test_empty_cleanup(self) -> None: + strategy = CooperativeTimeoutStrategy() + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[], + ) + assert result.cleanup_completed is True + + +# ── ShutdownManager ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestShutdownManagerTaskTracking: + """Register / unregister tasks.""" + + def test_register_and_unregister(self) -> None: + manager = ShutdownManager() + mock_task = MagicMock(spec=asyncio.Task) + manager.register_task("task-1", mock_task) + assert "task-1" in manager._running_tasks + manager.unregister_task("task-1") + assert "task-1" not in manager._running_tasks + + def test_unregister_missing_is_noop(self) -> None: + manager = ShutdownManager() + manager.unregister_task("nonexistent") + + def test_register_cleanup(self) -> None: + manager = ShutdownManager() + + async def cb() -> None: + pass + + manager.register_cleanup(cb) + assert len(manager._cleanup_callbacks) == 1 + + def test_is_shutting_down_delegates(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + assert manager.is_shutting_down() is False + strategy.request_shutdown() + assert manager.is_shutting_down() is True + + def test_register_task_during_shutdown_raises(self) -> None: + """Drain gate: registering a task after shutdown raises RuntimeError.""" + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + strategy.request_shutdown() + mock_task = MagicMock(spec=asyncio.Task) + with pytest.raises(RuntimeError, match="shutdown already in progress"): + manager.register_task("late-task", mock_task) + + +@pytest.mark.unit +class TestShutdownManagerSignalHandlers: + """Signal handler installation.""" + + @pytest.mark.skipif(sys.platform == "win32", reason="Unix-only test") + def test_install_signal_handlers_unix(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + mock_loop = MagicMock() + with patch("asyncio.get_running_loop", return_value=mock_loop): + manager.install_signal_handlers() + assert mock_loop.add_signal_handler.call_count == 2 + assert manager._signals_installed is True + + @pytest.mark.skipif(sys.platform != "win32", reason="Windows-only test") + def test_install_signal_handlers_windows(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + with patch("signal.signal") as mock_signal: + manager.install_signal_handlers() + assert mock_signal.call_count == 2 + assert manager._signals_installed is True + + def test_install_idempotent(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + if sys.platform == "win32": + with patch("signal.signal"): + manager.install_signal_handlers() + manager.install_signal_handlers() # second call is noop + else: + mock_loop = MagicMock() + with patch("asyncio.get_running_loop", return_value=mock_loop): + manager.install_signal_handlers() + manager.install_signal_handlers() + assert mock_loop.add_signal_handler.call_count == 2 # not 4 + + +@pytest.mark.unit +class TestShutdownManagerInitiateShutdown: + """Full initiate_shutdown delegates to strategy.""" + + async def test_initiate_shutdown(self) -> None: + strategy = CooperativeTimeoutStrategy(grace_seconds=0.1) + manager = ShutdownManager(strategy=strategy) + result = await manager.initiate_shutdown() + assert isinstance(result, ShutdownResult) + assert result.strategy_type == "cooperative_timeout" + + def test_default_strategy(self) -> None: + manager = ShutdownManager() + assert isinstance(manager.strategy, CooperativeTimeoutStrategy) + + +@pytest.mark.unit +class TestShutdownManagerSignalHandling: + """Signal handler triggers request_shutdown.""" + + def test_handle_signal_unix(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + manager._handle_signal(signal.SIGINT) + assert strategy.is_shutting_down() is True + + def test_handle_signal_threadsafe_with_loop(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + mock_loop = MagicMock() + with patch("asyncio.get_running_loop", return_value=mock_loop): + manager._handle_signal_threadsafe(signal.SIGINT.value, None) + mock_loop.call_soon_threadsafe.assert_called_once() + # Execute the callback to verify it actually calls request_shutdown. + callback = mock_loop.call_soon_threadsafe.call_args[0][0] + assert callable(callback) + callback() + assert strategy.is_shutting_down() is True + + def test_handle_signal_threadsafe_no_loop(self) -> None: + strategy = CooperativeTimeoutStrategy() + manager = ShutdownManager(strategy=strategy) + with patch( + "asyncio.get_running_loop", + side_effect=RuntimeError("no loop"), + ): + manager._handle_signal_threadsafe(signal.SIGINT.value, None) + assert strategy.is_shutting_down() is True + + +# ── Constructor validation ──────────────────────────────────────── + + +@pytest.mark.unit +class TestCooperativeTimeoutValidation: + """Constructor rejects non-positive timeout values.""" + + def test_zero_grace_seconds_rejected(self) -> None: + with pytest.raises(ValueError, match="grace_seconds must be positive"): + CooperativeTimeoutStrategy(grace_seconds=0) + + def test_negative_grace_seconds_rejected(self) -> None: + with pytest.raises(ValueError, match="grace_seconds must be positive"): + CooperativeTimeoutStrategy(grace_seconds=-1.0) + + def test_zero_cleanup_seconds_rejected(self) -> None: + with pytest.raises(ValueError, match="cleanup_seconds must be positive"): + CooperativeTimeoutStrategy(cleanup_seconds=0) + + def test_negative_cleanup_seconds_rejected(self) -> None: + with pytest.raises(ValueError, match="cleanup_seconds must be positive"): + CooperativeTimeoutStrategy(cleanup_seconds=-5.0) + + +# ── Cleanup callback exception isolation ────────────────────────── + + +@pytest.mark.unit +class TestCleanupCallbackExceptionIsolation: + """A failing callback doesn't prevent subsequent callbacks from running.""" + + async def test_failing_callback_does_not_block_others(self) -> None: + strategy = CooperativeTimeoutStrategy(cleanup_seconds=5.0) + ran = [] + + async def cb_ok_1() -> None: + ran.append("cb1") + + async def cb_fail() -> None: + msg = "boom" + raise RuntimeError(msg) + + async def cb_ok_2() -> None: + ran.append("cb2") + + result = await strategy.execute_shutdown( + running_tasks={}, + cleanup_callbacks=[cb_ok_1, cb_fail, cb_ok_2], + ) + + assert "cb1" in ran + assert "cb2" in ran + # cleanup_completed is False because one callback failed + assert result.cleanup_completed is False + + +# ── GracefulShutdownConfig validation ───────────────────────────── + + +@pytest.mark.unit +class TestGracefulShutdownConfig: + """Config model boundary validation.""" + + def test_defaults(self) -> None: + config = GracefulShutdownConfig() + assert config.strategy == "cooperative_timeout" + assert config.grace_seconds == 30.0 + assert config.cleanup_seconds == 5.0 + + def test_grace_seconds_upper_bound(self) -> None: + with pytest.raises(ValidationError): + GracefulShutdownConfig(grace_seconds=301) + + def test_cleanup_seconds_upper_bound(self) -> None: + with pytest.raises(ValidationError): + GracefulShutdownConfig(cleanup_seconds=61) + + def test_zero_grace_seconds_rejected(self) -> None: + with pytest.raises(ValidationError): + GracefulShutdownConfig(grace_seconds=0) + + def test_blank_strategy_rejected(self) -> None: + with pytest.raises(ValidationError): + GracefulShutdownConfig(strategy=" ")