diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 7bdfa34e85..18354db8ed 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -693,29 +693,35 @@ structured_phases: │ CREATED │ └─────┬─────┘ │ assignment - ┌─────▼─────┐ - ┌──────│ ASSIGNED │ - │ └─────┬─────┘ - │ │ agent starts + ┌─────▼─────┐ ┌──────────┐ + ┌──────│ ASSIGNED │──────────▶│ FAILED │ + │ └─────┬─────┘◀───┐ └────┬─────┘ + │ │ starts │ reassign │ + │ ┌─────▼─────┐ │ ┌────▼─────┐ + │ │IN_PROGRESS │───┼─────▶│ (retry) │ + │ └─────┬─────┘ │ └──────────┘ + │ │ ◀── (rework) + │ │ agent done │ ┌─────▼─────┐ - │ │IN_PROGRESS │◀──── (rework) - │ └─────┬─────┘ │ - │ │ agent done │ - │ ┌─────▼─────┐ │ - │ │ IN_REVIEW │───────┘ + │ │ IN_REVIEW │ │ └─────┬─────┘ │ │ approved │ ┌─────▼─────┐ │ │ COMPLETED │ │ └────────────┘ │ - │ blocked / cancelled - ┌─────▼─────┐ - │ BLOCKED / │ - │ CANCELLED │ - └────────────┘ + │ blocked cancelled + ┌─────▼─────┐ ┌────────────┐ + │ BLOCKED │ │ CANCELLED │ + └─────┬─────┘ └────────────┘ + │ unblocked (terminal) + └──▶ ASSIGNED ``` +> **Non-terminal states:** BLOCKED and FAILED are non-terminal — BLOCKED returns to ASSIGNED when unblocked, FAILED returns to ASSIGNED for retry (see §6.6). COMPLETED and CANCELLED are terminal states with no outgoing transitions. +> +> **Transitions into FAILED:** Both `ASSIGNED → FAILED` (early setup failures) and `IN_PROGRESS → FAILED` (runtime crashes) are valid. `FAILED → ASSIGNED` enables reassignment when `retry_count < max_retries`. + > **Runtime wrapper (M3):** During execution, `Task` is wrapped by `TaskExecution` (in `engine/task_execution.py`). `TaskExecution` is a frozen Pydantic model that tracks status transitions via `model_copy(update=...)`, accumulates `TokenUsage` cost, and records a `StatusTransition` audit trail. The original `Task` is preserved unchanged; `to_task_snapshot()` produces a `Task` copy with the current execution status for persistence. ### 6.2 Task Definition @@ -748,6 +754,7 @@ task: task_structure: "parallel" # sequential, parallel, mixed (M4 — see §6.9) budget_limit: 2.00 # max USD for this task deadline: null + max_retries: 1 # max reassignment attempts after failure (0 = no retry) status: "assigned" ``` @@ -952,11 +959,28 @@ When an agent execution fails unexpectedly (unhandled exception, OOM, process ki > **MVP: Fail-and-Reassign only (Strategy 1).** Checkpoint Recovery is M4/M5. +**`RecoveryStrategy` protocol:** + +| Method | Signature | Description | +|--------|-----------|-------------| +| `recover` | `async def recover(*, task_execution: TaskExecution, error_message: str, context: AgentContext) -> RecoveryResult` | Apply recovery to a failed task execution | +| `get_strategy_type` | `def get_strategy_type() -> str` | Return strategy type identifier (must not be empty) | + +**`RecoveryResult` model (frozen):** + +| Field | Type | Description | +|-------|------|-------------| +| `task_execution` | `TaskExecution` | Updated execution after recovery (typically `FAILED`) | +| `strategy_type` | `NotBlankStr` | Strategy identifier | +| `context_snapshot` | `AgentContextSnapshot` | Redacted snapshot (turn count, accumulated cost, message count, max turns — no message contents) | +| `error_message` | `NotBlankStr` | Error that triggered recovery | +| `can_reassign` | `bool` (computed) | `retry_count < task.max_retries` | + #### Strategy 1: Fail-and-Reassign (Default / MVP) The engine catches the failure at its outermost boundary, logs a redacted `AgentContext` snapshot (turn count, accumulated cost — excluding message contents to avoid leaking sensitive prompts/tool outputs), transitions the task to `FAILED`, and makes it available for reassignment (manual or automatic via the task router). -> **New non-terminal state:** `FAILED` is a new `TaskStatus` variant to be added alongside `CANCELLED`. The §6.1 lifecycle diagram and `TaskStatus` enum will be updated when crash recovery is implemented in M3. `FAILED` differs from `CANCELLED` (which is terminal) in that failed tasks are eligible for automatic reassignment. +> **Non-terminal state (implemented in M3):** `FAILED` is a `TaskStatus` variant alongside `CANCELLED`. `FAILED` differs from `CANCELLED` (which is terminal) in that failed tasks are eligible for automatic reassignment. Valid transitions: `IN_PROGRESS → FAILED`, `ASSIGNED → FAILED` (early setup failures), `FAILED → ASSIGNED` (reassignment). See the updated §6.1 lifecycle diagram. ```yaml crash_recovery: @@ -967,10 +991,12 @@ crash_recovery: - All progress is lost on crash — acceptable for short single-agent tasks in the MVP On crash: -1. Catch exception at the engine boundary (outermost `try/except` in the execution loop) -2. Log at ERROR with redacted `AgentContext` snapshot (turn count, accumulated cost, tool call history — message contents excluded) +1. Catch exception at the `AgentEngine` boundary (outermost `try/except` in `AgentEngine.run()`) +2. Log at ERROR with redacted `AgentContextSnapshot` (turn count, accumulated cost, message count, max turns — message contents excluded) 3. Transition `TaskExecution` → `FAILED` with the exception as the failure reason -4. Task becomes available for reassignment via the task router +4. `RecoveryResult.can_reassign` reports whether `retry_count < max_retries` + +> **M3 limitation:** The `can_reassign` flag is computed and returned in `RecoveryResult`, but automated reassignment is not yet implemented — the task router (§6.4) will consume this in a later milestone. The caller (task router) is responsible for incrementing `retry_count` when creating the next `TaskExecution`. #### Strategy 2: Checkpoint Recovery (Planned — M4/M5) @@ -2272,6 +2298,8 @@ ai-company/ │ │ ├── loop_protocol.py # ExecutionLoop protocol + result models │ │ ├── metrics.py # TaskCompletionMetrics proxy overhead model │ │ ├── react_loop.py # ReAct loop implementation +│ │ ├── recovery.py # Crash recovery strategies (RecoveryStrategy protocol) +│ │ ├── cost_recording.py # Per-turn cost recording helpers │ │ ├── run_result.py # AgentRunResult outcome model │ │ ├── agent_engine.py # Agent execution engine │ │ ├── task_engine.py # Task routing & scheduling (M3-M4) diff --git a/src/ai_company/core/enums.py b/src/ai_company/core/enums.py index edc71275ee..1e2f653333 100644 --- a/src/ai_company/core/enums.py +++ b/src/ai_company/core/enums.py @@ -127,11 +127,13 @@ class TaskStatus(StrEnum): Summary for quick reference: CREATED -> ASSIGNED - ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED - IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED + ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED + IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED IN_REVIEW -> COMPLETED | IN_PROGRESS (rework) | BLOCKED | CANCELLED BLOCKED -> ASSIGNED (unblocked) + FAILED -> ASSIGNED (reassignment for retry) COMPLETED and CANCELLED are terminal states. + FAILED is non-terminal (can be reassigned). """ CREATED = "created" @@ -140,6 +142,7 @@ class TaskStatus(StrEnum): IN_REVIEW = "in_review" COMPLETED = "completed" BLOCKED = "blocked" + FAILED = "failed" CANCELLED = "cancelled" diff --git a/src/ai_company/core/task.py b/src/ai_company/core/task.py index d34662e5ef..bcbd1ca776 100644 --- a/src/ai_company/core/task.py +++ b/src/ai_company/core/task.py @@ -58,6 +58,7 @@ class Task(BaseModel): estimated_complexity: Task complexity estimate. budget_limit: Maximum USD spend for this task. deadline: Optional deadline (ISO 8601 string or ``None``). + max_retries: Max reassignment attempts after failure (default 1). status: Current lifecycle status. """ @@ -112,6 +113,11 @@ class Task(BaseModel): default=None, description="Optional deadline (ISO 8601 string)", ) + max_retries: int = Field( + default=1, + ge=0, + description="Max reassignment attempts after failure", + ) status: TaskStatus = Field( default=TaskStatus.CREATED, description="Current lifecycle status", @@ -153,8 +159,8 @@ def _validate_assignment_consistency(self) -> Self: ``CREATED`` status must have ``assigned_to=None``. Statuses beyond ``CREATED`` (``ASSIGNED``, ``IN_PROGRESS``, ``IN_REVIEW``, - ``COMPLETED``) require ``assigned_to`` to be set. ``BLOCKED`` - and ``CANCELLED`` may or may not have an assignee. + ``COMPLETED``) require ``assigned_to`` to be set. ``BLOCKED``, + ``FAILED``, and ``CANCELLED`` may or may not have an assignee. """ requires_assignee = { TaskStatus.ASSIGNED, diff --git a/src/ai_company/core/task_transitions.py b/src/ai_company/core/task_transitions.py index a34dc5c636..b4f5921f98 100644 --- a/src/ai_company/core/task_transitions.py +++ b/src/ai_company/core/task_transitions.py @@ -1,17 +1,18 @@ """Task lifecycle state machine transitions. Defines the valid state transitions for the task lifecycle, based on -DESIGN_SPEC Section 6.1 and extended with BLOCKED and CANCELLED -transitions from IN_PROGRESS and IN_REVIEW for completeness:: +DESIGN_SPEC Sections 6.1 and 6.6, extended with BLOCKED, CANCELLED, and +FAILED transitions for completeness:: CREATED -> ASSIGNED - ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED - IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED + ASSIGNED -> IN_PROGRESS | BLOCKED | CANCELLED | FAILED + IN_PROGRESS -> IN_REVIEW | BLOCKED | CANCELLED | FAILED IN_REVIEW -> COMPLETED | IN_PROGRESS (rework) | BLOCKED | CANCELLED BLOCKED -> ASSIGNED (unblocked) + FAILED -> ASSIGNED (reassignment for retry) COMPLETED and CANCELLED are terminal states with no outgoing -transitions. +transitions. FAILED is non-terminal (can be reassigned). """ from ai_company.core.enums import TaskStatus @@ -30,6 +31,7 @@ TaskStatus.IN_PROGRESS, TaskStatus.BLOCKED, TaskStatus.CANCELLED, + TaskStatus.FAILED, } ), TaskStatus.IN_PROGRESS: frozenset( @@ -37,6 +39,7 @@ TaskStatus.IN_REVIEW, TaskStatus.BLOCKED, TaskStatus.CANCELLED, + TaskStatus.FAILED, } ), TaskStatus.IN_REVIEW: frozenset( @@ -48,6 +51,7 @@ } ), TaskStatus.BLOCKED: frozenset({TaskStatus.ASSIGNED}), + TaskStatus.FAILED: frozenset({TaskStatus.ASSIGNED}), # reassignment TaskStatus.COMPLETED: frozenset(), # terminal TaskStatus.CANCELLED: frozenset(), # terminal } diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index 08e39c58a8..d730dda30b 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -34,6 +34,11 @@ build_system_prompt, ) from ai_company.engine.react_loop import ReactLoop +from ai_company.engine.recovery import ( + FailAndReassignStrategy, + RecoveryResult, + RecoveryStrategy, +) from ai_company.engine.run_result import AgentRunResult from ai_company.engine.task_execution import StatusTransition, TaskExecution from ai_company.providers.models import ZERO_TOKEN_USAGE, add_token_usage @@ -52,11 +57,14 @@ "ExecutionLoop", "ExecutionResult", "ExecutionStateError", + "FailAndReassignStrategy", "LoopExecutionError", "MaxTurnsExceededError", "PromptBuildError", "PromptTokenEstimator", "ReactLoop", + "RecoveryResult", + "RecoveryStrategy", "StatusTransition", "SystemPrompt", "TaskCompletionMetrics", diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 80edf73d89..8e0b119e1e 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -7,17 +7,15 @@ import asyncio import contextlib import time -from datetime import UTC, datetime from typing import TYPE_CHECKING -from ai_company.budget.cost_record import CostRecord from ai_company.core.enums import AgentStatus, TaskStatus from ai_company.engine.context import DEFAULT_MAX_TURNS, AgentContext +from ai_company.engine.cost_recording import record_execution_costs from ai_company.engine.errors import ExecutionStateError from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, - TurnRecord, ) from ai_company.engine.metrics import TaskCompletionMetrics from ai_company.engine.prompt import ( @@ -26,13 +24,11 @@ format_task_instruction, ) from ai_company.engine.react_loop import ReactLoop +from ai_company.engine.recovery import FailAndReassignStrategy, RecoveryStrategy from ai_company.engine.run_result import AgentRunResult from ai_company.observability import get_logger from ai_company.observability.events.execution import ( EXECUTION_ENGINE_COMPLETE, - EXECUTION_ENGINE_COST_FAILED, - EXECUTION_ENGINE_COST_RECORDED, - EXECUTION_ENGINE_COST_SKIPPED, EXECUTION_ENGINE_CREATED, EXECUTION_ENGINE_ERROR, EXECUTION_ENGINE_INVALID_INPUT, @@ -41,6 +37,7 @@ EXECUTION_ENGINE_TASK_METRICS, EXECUTION_ENGINE_TASK_TRANSITION, EXECUTION_ENGINE_TIMEOUT, + EXECUTION_RECOVERY_FAILED, ) from ai_company.providers.enums import MessageRole from ai_company.providers.models import ChatMessage @@ -58,11 +55,15 @@ logger = get_logger(__name__) +_DEFAULT_RECOVERY_STRATEGY = FailAndReassignStrategy() +"""Module-level default instance for the recovery strategy.""" + _EXECUTABLE_STATUSES = frozenset({TaskStatus.ASSIGNED, TaskStatus.IN_PROGRESS}) """Task statuses the engine will accept for execution. -CREATED tasks lack an assignee; terminal statuses (COMPLETED, CANCELLED) -and BLOCKED/IN_REVIEW are not executable. +CREATED tasks lack an assignee; terminal statuses (COMPLETED, CANCELLED), +BLOCKED, IN_REVIEW, and FAILED are not executable. FAILED tasks must be +reassigned (FAILED -> ASSIGNED) before re-execution. """ @@ -79,6 +80,9 @@ class AgentEngine: tool_registry: Optional tools available to the agent. cost_tracker: Optional cost recording service. When ``None``, cost recording is skipped silently. + recovery_strategy: Crash recovery strategy. Defaults to a + shared ``FailAndReassignStrategy`` instance. Pass ``None`` + to disable. """ def __init__( @@ -88,11 +92,13 @@ def __init__( execution_loop: ExecutionLoop | None = None, tool_registry: ToolRegistry | None = None, cost_tracker: CostTracker | None = None, + recovery_strategy: RecoveryStrategy | None = _DEFAULT_RECOVERY_STRATEGY, ) -> None: self._provider = provider self._loop: ExecutionLoop = execution_loop or ReactLoop() self._tool_registry = tool_registry self._cost_tracker = cost_tracker + self._recovery_strategy = recovery_strategy logger.debug( EXECUTION_ENGINE_CREATED, loop_type=self._loop.get_loop_type(), @@ -194,7 +200,7 @@ async def run( # noqa: PLR0913 ) raise except Exception as exc: - return self._handle_fatal_error( + return await self._handle_fatal_error( exc=exc, identity=identity, task=task, @@ -245,13 +251,26 @@ async def _execute( # noqa: PLR0913 timeout_seconds=timeout_seconds, ) - await self._record_costs(execution_result, identity, agent_id, task_id) + await record_execution_costs( + execution_result, + identity, + agent_id, + task_id, + tracker=self._cost_tracker, + ) execution_result = self._apply_post_execution_transitions( execution_result, agent_id, task_id, ) + if execution_result.termination_reason == TerminationReason.ERROR: + execution_result = await self._apply_recovery( + execution_result, + agent_id, + task_id, + ) + duration = time.monotonic() - start result = AgentRunResult( execution_result=execution_result, @@ -531,7 +550,7 @@ def _apply_post_execution_transitions( from_status=prev_status.value, to_status=TaskStatus.IN_REVIEW.value, ) - # TODO(M4): Replace auto-complete with review gate + # TODO(M4): Replace auto-complete with review gate (§6.5) prev_status = ctx.task_execution.status # type: ignore[union-attr] ctx = ctx.with_task_transition( TaskStatus.COMPLETED, @@ -555,6 +574,49 @@ def _apply_post_execution_transitions( return execution_result.model_copy(update={"context": ctx}) + async def _apply_recovery( + self, + execution_result: ExecutionResult, + agent_id: str, + task_id: str, + ) -> ExecutionResult: + """Invoke the configured recovery strategy on error outcomes. + + The default strategy transitions the task to FAILED; other + strategies may behave differently. If no strategy is set or + no task execution exists, returns the result unchanged. + Recovery failures are logged but never block the error result. + """ + if self._recovery_strategy is None: + return execution_result + ctx = execution_result.context + if ctx.task_execution is None: + return execution_result + + error_msg = execution_result.error_message or "Unknown error" + try: + recovery_result = await self._recovery_strategy.recover( + task_execution=ctx.task_execution, + error_message=error_msg, + context=ctx, + ) + updated_ctx = ctx.model_copy( + update={"task_execution": recovery_result.task_execution}, + ) + return execution_result.model_copy( + update={"context": updated_ctx}, + ) + except MemoryError, RecursionError: + raise + except Exception as exc: + logger.exception( + EXECUTION_RECOVERY_FAILED, + agent_id=agent_id, + task_id=task_id, + error=f"{type(exc).__name__}: {exc}", + ) + return execution_result + def _make_tool_invoker( self, identity: AgentIdentity, @@ -597,111 +659,7 @@ def _log_completion( duration_seconds=metrics.duration_seconds, ) - async def _record_costs( - self, - result: ExecutionResult, - identity: AgentIdentity, - agent_id: str, - task_id: str, - ) -> None: - """Record per-turn costs to the CostTracker if available. - - Each turn produces its own ``CostRecord``, preserving per-call - granularity. Turns with zero cost and zero tokens are skipped. - - Recording failures for regular exceptions are logged but do not - affect the execution result. ``MemoryError`` and - ``RecursionError`` propagate unconditionally as non-recoverable - system errors. - """ - if self._cost_tracker is None: - logger.debug( - EXECUTION_ENGINE_COST_SKIPPED, - agent_id=agent_id, - task_id=task_id, - reason="no cost tracker configured", - ) - return - - tracker = self._cost_tracker - - for turn in result.turns: - # Skip only when provably nothing happened (zero cost and - # zero tokens); a turn with tokens but zero cost (e.g., a - # free-tier provider) is still recorded. - if ( - turn.cost_usd <= 0.0 - and turn.input_tokens == 0 - and turn.output_tokens == 0 - ): - logger.debug( - EXECUTION_ENGINE_COST_SKIPPED, - agent_id=agent_id, - task_id=task_id, - turn_number=turn.turn_number, - reason="zero cost and zero tokens", - ) - continue - - record = CostRecord( - agent_id=agent_id, - task_id=task_id, - provider=identity.model.provider, - model=identity.model.model_id, - input_tokens=turn.input_tokens, - output_tokens=turn.output_tokens, - cost_usd=turn.cost_usd, - timestamp=datetime.now(UTC), - ) - await self._submit_cost( - record, - turn, - agent_id, - task_id, - tracker=tracker, - ) - - async def _submit_cost( - self, - record: CostRecord, - turn: TurnRecord, - agent_id: str, - task_id: str, - *, - tracker: CostTracker, - ) -> None: - """Submit a cost record to the tracker, logging failures.""" - try: - await tracker.record(record) - except MemoryError, RecursionError: - logger.error( - EXECUTION_ENGINE_COST_FAILED, - agent_id=agent_id, - task_id=task_id, - error="non-recoverable error in cost recording", - exc_info=True, - ) - raise - except Exception as exc: - logger.exception( - EXECUTION_ENGINE_COST_FAILED, - agent_id=agent_id, - task_id=task_id, - error=f"{type(exc).__name__}: {exc}", - cost_usd=turn.cost_usd, - input_tokens=turn.input_tokens, - output_tokens=turn.output_tokens, - ) - return - - logger.info( - EXECUTION_ENGINE_COST_RECORDED, - agent_id=agent_id, - task_id=task_id, - cost_usd=turn.cost_usd, - ) - - def _handle_fatal_error( # noqa: PLR0913 + async def _handle_fatal_error( # noqa: PLR0913 self, *, exc: Exception, @@ -738,6 +696,11 @@ def _handle_fatal_error( # noqa: PLR0913 termination_reason=TerminationReason.ERROR, error_message=error_msg, ) + error_execution = await self._apply_recovery( + error_execution, + agent_id, + task_id, + ) error_prompt = system_prompt or SystemPrompt( content="", template_version="error", @@ -775,7 +738,7 @@ def _handle_fatal_error( # noqa: PLR0913 error=f"Failed to build error result: {build_exc}", original_error=error_msg, ) - raise exc from None + raise exc from build_exc def _make_budget_checker(task: Task) -> BudgetChecker | None: diff --git a/src/ai_company/engine/cost_recording.py b/src/ai_company/engine/cost_recording.py new file mode 100644 index 0000000000..90745feb4f --- /dev/null +++ b/src/ai_company/engine/cost_recording.py @@ -0,0 +1,124 @@ +"""Per-turn cost recording for agent execution. + +Extracts cost-recording logic from ``AgentEngine`` to keep the engine +module under the 800-line limit while preserving full per-turn +granularity and structured logging. +""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from ai_company.budget.cost_record import CostRecord +from ai_company.observability import get_logger +from ai_company.observability.events.execution import ( + EXECUTION_ENGINE_COST_FAILED, + EXECUTION_ENGINE_COST_RECORDED, + EXECUTION_ENGINE_COST_SKIPPED, +) + +if TYPE_CHECKING: + from ai_company.budget.tracker import CostTracker + from ai_company.core.agent import AgentIdentity + from ai_company.engine.loop_protocol import ExecutionResult, TurnRecord + +logger = get_logger(__name__) + + +async def record_execution_costs( + result: ExecutionResult, + identity: AgentIdentity, + agent_id: str, + task_id: str, + *, + tracker: CostTracker | None, +) -> None: + """Record per-turn costs to the CostTracker if available. + + Each turn produces its own ``CostRecord``, preserving per-call + granularity. Turns with zero cost and zero tokens are skipped. + + Recording failures for regular exceptions are logged but do not + affect the execution result. ``MemoryError`` and + ``RecursionError`` propagate unconditionally as non-recoverable + system errors. + """ + if tracker is None: + logger.debug( + EXECUTION_ENGINE_COST_SKIPPED, + agent_id=agent_id, + task_id=task_id, + reason="no cost tracker configured", + ) + return + + for turn in result.turns: + # Skip only when provably nothing happened (zero cost and + # zero tokens); a turn with tokens but zero cost (e.g., a + # free-tier provider) is still recorded. + if turn.cost_usd <= 0.0 and turn.input_tokens == 0 and turn.output_tokens == 0: + logger.debug( + EXECUTION_ENGINE_COST_SKIPPED, + agent_id=agent_id, + task_id=task_id, + turn_number=turn.turn_number, + reason="zero cost and zero tokens", + ) + continue + + record = CostRecord( + agent_id=agent_id, + task_id=task_id, + provider=identity.model.provider, + model=identity.model.model_id, + input_tokens=turn.input_tokens, + output_tokens=turn.output_tokens, + cost_usd=turn.cost_usd, + timestamp=datetime.now(UTC), + ) + await _submit_cost_record( + record, + turn, + agent_id, + task_id, + tracker=tracker, + ) + + +async def _submit_cost_record( + record: CostRecord, + turn: TurnRecord, + agent_id: str, + task_id: str, + *, + tracker: CostTracker, +) -> None: + """Submit a cost record to the tracker, logging failures.""" + try: + await tracker.record(record) + except MemoryError, RecursionError: + logger.error( + EXECUTION_ENGINE_COST_FAILED, + agent_id=agent_id, + task_id=task_id, + error="non-recoverable error in cost recording", + exc_info=True, + ) + raise + except Exception as exc: + logger.exception( + EXECUTION_ENGINE_COST_FAILED, + agent_id=agent_id, + task_id=task_id, + error=f"{type(exc).__name__}: {exc}", + cost_usd=turn.cost_usd, + input_tokens=turn.input_tokens, + output_tokens=turn.output_tokens, + ) + return + + logger.info( + EXECUTION_ENGINE_COST_RECORDED, + agent_id=agent_id, + task_id=task_id, + cost_usd=turn.cost_usd, + ) diff --git a/src/ai_company/engine/recovery.py b/src/ai_company/engine/recovery.py new file mode 100644 index 0000000000..1644f44f09 --- /dev/null +++ b/src/ai_company/engine/recovery.py @@ -0,0 +1,177 @@ +"""Crash recovery strategy protocol and fail-and-reassign implementation. + +Defines the ``RecoveryStrategy`` protocol and the default +``FailAndReassignStrategy`` that transitions a crashed task execution +from its current status (typically ``IN_PROGRESS``) to ``FAILED`` +status, captures a redacted context snapshot, and reports whether the +task can be reassigned (based on retry count vs max retries). + +See DESIGN_SPEC Section 6.6 for the full crash recovery design. +""" + +from typing import Final, Protocol, runtime_checkable + +from pydantic import BaseModel, ConfigDict, Field, computed_field + +from ai_company.core.enums import TaskStatus +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.engine.context import AgentContext, AgentContextSnapshot # noqa: TC001 +from ai_company.engine.task_execution import TaskExecution # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.execution import ( + EXECUTION_RECOVERY_COMPLETE, + EXECUTION_RECOVERY_SNAPSHOT, + EXECUTION_RECOVERY_START, +) + +logger = get_logger(__name__) + + +class RecoveryResult(BaseModel): + """Frozen result of a recovery strategy invocation. + + Attributes: + task_execution: Updated execution after recovery (typically + ``FAILED`` for the default strategy). + strategy_type: Identifier of the strategy used (e.g. ``"fail_reassign"``). + can_reassign: Computed — ``True`` when retry_count < task.max_retries. + The caller (task router) is responsible for incrementing + ``retry_count`` when creating the next ``TaskExecution``. + context_snapshot: Redacted snapshot (no message contents). + error_message: The error that triggered recovery. + """ + + model_config = ConfigDict(frozen=True) + + task_execution: TaskExecution = Field( + description="Updated execution with FAILED status", + ) + strategy_type: NotBlankStr = Field( + description="Identifier of the recovery strategy used", + ) + context_snapshot: AgentContextSnapshot = Field( + description="Redacted context snapshot (no message contents)", + ) + error_message: NotBlankStr = Field( + description="The error that triggered recovery", + ) + + @computed_field( # type: ignore[prop-decorator] + description="Whether the task can be reassigned for retry", + ) + @property + def can_reassign(self) -> bool: + """Whether the task can be reassigned for retry. + + Assumes the caller (task router) will increment ``retry_count`` + when creating the next ``TaskExecution`` for the reassigned task. + """ + return self.task_execution.retry_count < self.task_execution.task.max_retries + + +@runtime_checkable +class RecoveryStrategy(Protocol): + """Protocol for crash recovery strategies. + + Implementations decide how to handle a failed task execution: + transition the task, capture diagnostics, and report whether + reassignment is possible. + """ + + async def recover( + self, + *, + task_execution: TaskExecution, + error_message: str, + context: AgentContext, + ) -> RecoveryResult: + """Apply recovery to a failed task execution. + + Args: + task_execution: Current execution state (typically + ``IN_PROGRESS``, but may be ``ASSIGNED`` for early + setup failures). + error_message: Description of the failure. + context: Full agent context at the time of failure. + + Returns: + ``RecoveryResult`` with the updated execution and diagnostics. + """ + ... + + def get_strategy_type(self) -> str: + """Return the strategy type identifier.""" + ... + + +class FailAndReassignStrategy: + """Default recovery: transition to FAILED and report reassignment eligibility. + + 1. Capture a redacted ``AgentContextSnapshot`` (excludes message + contents to prevent leaking sensitive prompts/tool outputs). + 2. Log the snapshot at ERROR level. + 3. Transition ``TaskExecution`` to ``FAILED`` with the error as reason. + 4. Report ``can_reassign = retry_count < task.max_retries``. + """ + + STRATEGY_TYPE: Final[str] = "fail_reassign" + + async def recover( + self, + *, + task_execution: TaskExecution, + error_message: str, + context: AgentContext, + ) -> RecoveryResult: + """Apply fail-and-reassign recovery. + + Args: + task_execution: Current execution state. + error_message: Description of the failure. + context: Full agent context at the time of failure. + + Returns: + ``RecoveryResult`` with FAILED execution and reassignment info. + """ + logger.info( + EXECUTION_RECOVERY_START, + task_id=task_execution.task.id, + strategy=self.STRATEGY_TYPE, + retry_count=task_execution.retry_count, + ) + + snapshot = context.to_snapshot() + logger.error( + EXECUTION_RECOVERY_SNAPSHOT, + task_id=task_execution.task.id, + turn_count=snapshot.turn_count, + cost_usd=snapshot.accumulated_cost.cost_usd, + error_message=error_message, + ) + + failed_execution = task_execution.with_transition( + TaskStatus.FAILED, + reason=error_message, + ) + + result = RecoveryResult( + task_execution=failed_execution, + strategy_type=self.STRATEGY_TYPE, + context_snapshot=snapshot, + error_message=error_message, + ) + + logger.info( + EXECUTION_RECOVERY_COMPLETE, + task_id=task_execution.task.id, + strategy=self.STRATEGY_TYPE, + can_reassign=result.can_reassign, + retry_count=task_execution.retry_count, + max_retries=task_execution.task.max_retries, + ) + + return result + + def get_strategy_type(self) -> str: + """Return the strategy type identifier.""" + return self.STRATEGY_TYPE diff --git a/src/ai_company/engine/task_execution.py b/src/ai_company/engine/task_execution.py index df023ed963..4c2b8b840f 100644 --- a/src/ai_company/engine/task_execution.py +++ b/src/ai_company/engine/task_execution.py @@ -71,6 +71,7 @@ class TaskExecution(BaseModel): transition_log: Audit trail of status transitions. accumulated_cost: Running token usage and cost totals. turn_count: Number of LLM turns completed. + retry_count: Number of previous failure-reassignment cycles. started_at: Set by ``with_transition`` on first entry to ``IN_PROGRESS`` (``None`` until then). completed_at: When execution reached a terminal state. @@ -93,6 +94,11 @@ class TaskExecution(BaseModel): ge=0, description="Number of turns completed", ) + retry_count: int = Field( + default=0, + ge=0, + description="Number of previous failure-reassignment cycles", + ) started_at: AwareDatetime | None = Field( default=None, description="When execution entered IN_PROGRESS", @@ -103,16 +109,22 @@ class TaskExecution(BaseModel): ) @classmethod - def from_task(cls, task: Task) -> TaskExecution: + def from_task( + cls, + task: Task, + *, + retry_count: int = 0, + ) -> TaskExecution: """Create a fresh execution from a task definition. Args: task: The frozen task to wrap. + retry_count: Number of previous failure-reassignment cycles. Returns: New ``TaskExecution`` with status matching the task. """ - execution = cls(task=task, status=task.status) + execution = cls(task=task, status=task.status, retry_count=retry_count) logger.debug( EXECUTION_TASK_CREATED, task_id=task.id, diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index 30f50c220a..4542ffa3a6 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -34,3 +34,8 @@ EXECUTION_ENGINE_COST_FAILED: Final[str] = "execution.engine.cost_failed" EXECUTION_ENGINE_TASK_METRICS: Final[str] = "execution.engine.task_metrics" EXECUTION_ENGINE_TIMEOUT: Final[str] = "execution.engine.timeout" + +EXECUTION_RECOVERY_START: Final[str] = "execution.recovery.start" +EXECUTION_RECOVERY_COMPLETE: Final[str] = "execution.recovery.complete" +EXECUTION_RECOVERY_FAILED: Final[str] = "execution.recovery.failed" +EXECUTION_RECOVERY_SNAPSHOT: Final[str] = "execution.recovery.snapshot" diff --git a/tests/integration/engine/test_crash_recovery.py b/tests/integration/engine/test_crash_recovery.py new file mode 100644 index 0000000000..5362ef106b --- /dev/null +++ b/tests/integration/engine/test_crash_recovery.py @@ -0,0 +1,164 @@ +"""Integration test: crash recovery full flow. + +Engine.run() with failing provider -> task FAILED -> can_reassign checks. +""" + +from datetime import date +from typing import TYPE_CHECKING +from uuid import uuid4 + +import pytest + +from ai_company.core.agent import ( + AgentIdentity, + ModelConfig, + PersonalityConfig, +) +from ai_company.core.enums import Priority, SeniorityLevel, TaskStatus, TaskType +from ai_company.core.task import Task +from ai_company.engine.agent_engine import AgentEngine +from ai_company.engine.loop_protocol import TerminationReason +from ai_company.engine.task_execution import TaskExecution + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from ai_company.providers.capabilities import ModelCapabilities + from ai_company.providers.models import ( + ChatMessage, + CompletionConfig, + CompletionResponse, + StreamChunk, + ToolDefinition, + ) + +pytestmark = [pytest.mark.integration, pytest.mark.timeout(30)] + + +class _FailingProvider: + """Mock provider that always raises on complete().""" + + def __init__(self, error: Exception | None = None) -> None: + self._error = error or RuntimeError("Provider crashed") + + async def complete( + self, + messages: list[ChatMessage], + model: str, + *, + tools: list[ToolDefinition] | None = None, + config: CompletionConfig | None = None, + ) -> CompletionResponse: + raise self._error + + async def stream( + self, + messages: list[ChatMessage], + model: str, + *, + tools: list[ToolDefinition] | None = None, + config: CompletionConfig | None = None, + ) -> AsyncIterator[StreamChunk]: + msg = "stream not supported" + raise NotImplementedError(msg) + + async def get_model_capabilities(self, model: str) -> ModelCapabilities: + from ai_company.providers.capabilities import ModelCapabilities + + return ModelCapabilities( + model_id=model, + provider="test-provider", + supports_tools=True, + supports_streaming=False, + max_context_tokens=8192, + max_output_tokens=4096, + cost_per_1k_input=0.01, + cost_per_1k_output=0.03, + ) + + +def _make_identity() -> AgentIdentity: + return AgentIdentity( + id=uuid4(), + name="Recovery Agent", + role="Developer", + department="Engineering", + level=SeniorityLevel.MID, + hiring_date=date(2026, 1, 15), + personality=PersonalityConfig(traits=("analytical",)), + model=ModelConfig( + provider="test-provider", + model_id="test-small-001", + ), + ) + + +def _make_task( + identity: AgentIdentity, + *, + max_retries: int = 1, +) -> Task: + return Task( + id="task-recovery", + title="Crash recovery test", + description="Test the crash recovery flow.", + type=TaskType.DEVELOPMENT, + priority=Priority.MEDIUM, + project="proj-001", + created_by="manager", + assigned_to=str(identity.id), + status=TaskStatus.ASSIGNED, + max_retries=max_retries, + ) + + +class TestCrashRecoveryFlow: + """Full flow: engine.run() with failing provider -> FAILED.""" + + async def test_first_failure_can_reassign(self) -> None: + """First failure with max_retries=1 -> FAILED, can_reassign=True.""" + identity = _make_identity() + task = _make_task(identity, max_retries=1) + provider = _FailingProvider() + + engine = AgentEngine(provider=provider) + result = await engine.run(identity=identity, task=task) + + assert result.termination_reason == TerminationReason.ERROR + assert result.is_success is False + + te = result.execution_result.context.task_execution + assert te is not None + assert te.status is TaskStatus.FAILED + + # retry_count=0 < max_retries=1 means reassignment is possible + assert te.retry_count < task.max_retries + + async def test_second_failure_cannot_reassign(self) -> None: + """Second failure (retry_count=1, max_retries=1) -> cannot reassign.""" + identity = _make_identity() + task = _make_task(identity, max_retries=1) + provider = _FailingProvider() + + # First run + engine = AgentEngine(provider=provider) + first_result = await engine.run(identity=identity, task=task) + + first_te = first_result.execution_result.context.task_execution + assert first_te is not None + assert first_te.status is TaskStatus.FAILED + + # Simulate reassignment: create new task in ASSIGNED status with + # retry_count from the first execution + 1 + reassigned_task = task.model_copy( + update={"status": TaskStatus.ASSIGNED}, + ) + # Create new execution with incremented retry_count + new_exe = TaskExecution.from_task( + reassigned_task, + retry_count=first_te.retry_count + 1, + ) + assert new_exe.retry_count == 1 + + # The reassigned execution should indicate no more retries + assert new_exe.retry_count >= task.max_retries diff --git a/tests/unit/core/test_enums.py b/tests/unit/core/test_enums.py index c956ac28d8..b28363486d 100644 --- a/tests/unit/core/test_enums.py +++ b/tests/unit/core/test_enums.py @@ -58,8 +58,8 @@ def test_proficiency_level_has_4_members(self) -> None: def test_department_name_has_9_members(self) -> None: assert len(DepartmentName) == 9 - def test_task_status_has_7_members(self) -> None: - assert len(TaskStatus) == 7 + def test_task_status_has_8_members(self) -> None: + assert len(TaskStatus) == 8 def test_task_type_has_6_members(self) -> None: assert len(TaskType) == 6 @@ -109,6 +109,7 @@ def test_task_status_values(self) -> None: assert TaskStatus.IN_REVIEW.value == "in_review" assert TaskStatus.COMPLETED.value == "completed" assert TaskStatus.BLOCKED.value == "blocked" + assert TaskStatus.FAILED.value == "failed" assert TaskStatus.CANCELLED.value == "cancelled" def test_task_type_values(self) -> None: diff --git a/tests/unit/core/test_task.py b/tests/unit/core/test_task.py index d319892474..327357f2da 100644 --- a/tests/unit/core/test_task.py +++ b/tests/unit/core/test_task.py @@ -291,6 +291,37 @@ def test_cancelled_with_assigned_to_allowed(self) -> None: task = _make_task(assigned_to="agent-1", status=TaskStatus.CANCELLED) assert task.assigned_to == "agent-1" + def test_failed_without_assigned_to_allowed(self) -> None: + task = _make_task(status=TaskStatus.FAILED) + assert task.assigned_to is None + assert task.status is TaskStatus.FAILED + + def test_failed_with_assigned_to_allowed(self) -> None: + task = _make_task(assigned_to="agent-1", status=TaskStatus.FAILED) + assert task.assigned_to == "agent-1" + + +# ── Task: Max Retries ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestTaskMaxRetries: + def test_max_retries_default(self) -> None: + task = _make_task() + assert task.max_retries == 1 + + def test_max_retries_custom(self) -> None: + task = _make_task(max_retries=3) + assert task.max_retries == 3 + + def test_max_retries_zero_allowed(self) -> None: + task = _make_task(max_retries=0) + assert task.max_retries == 0 + + def test_max_retries_negative_rejected(self) -> None: + with pytest.raises(ValidationError): + _make_task(max_retries=-1) + # ── Task: Budget ───────────────────────────────────────────────── @@ -453,6 +484,16 @@ def test_validators_enforced_on_transition(self) -> None: ): task.with_transition(TaskStatus.ASSIGNED) + def test_valid_transition_failed_to_assigned(self) -> None: + """FAILED -> ASSIGNED reassignment with new assignee.""" + task = _make_task(assigned_to="agent-1", status=TaskStatus.FAILED) + new_task = task.with_transition( + TaskStatus.ASSIGNED, + assigned_to="agent-2", + ) + assert new_task.status is TaskStatus.ASSIGNED + assert new_task.assigned_to == "agent-2" + def test_original_unchanged(self) -> None: """Ensure the original task is not modified (immutability).""" task = _make_task() diff --git a/tests/unit/core/test_task_transitions.py b/tests/unit/core/test_task_transitions.py index d446dbdff9..489d6535d8 100644 --- a/tests/unit/core/test_task_transitions.py +++ b/tests/unit/core/test_task_transitions.py @@ -52,6 +52,15 @@ def test_in_review_to_cancelled(self) -> None: def test_blocked_to_assigned(self) -> None: validate_transition(TaskStatus.BLOCKED, TaskStatus.ASSIGNED) + def test_in_progress_to_failed(self) -> None: + validate_transition(TaskStatus.IN_PROGRESS, TaskStatus.FAILED) + + def test_assigned_to_failed(self) -> None: + validate_transition(TaskStatus.ASSIGNED, TaskStatus.FAILED) + + def test_failed_to_assigned(self) -> None: + validate_transition(TaskStatus.FAILED, TaskStatus.ASSIGNED) + # ── Invalid Transitions ────────────────────────────────────────── @@ -98,6 +107,14 @@ def test_in_progress_to_assigned_rejected(self) -> None: with pytest.raises(ValueError, match="Invalid task status transition"): validate_transition(TaskStatus.IN_PROGRESS, TaskStatus.ASSIGNED) + def test_failed_to_completed_rejected(self) -> None: + with pytest.raises(ValueError, match="Invalid task status transition"): + validate_transition(TaskStatus.FAILED, TaskStatus.COMPLETED) + + def test_failed_to_in_progress_rejected(self) -> None: + with pytest.raises(ValueError, match="Invalid task status transition"): + validate_transition(TaskStatus.FAILED, TaskStatus.IN_PROGRESS) + def test_error_message_includes_allowed(self) -> None: with pytest.raises(ValueError, match="Allowed from 'created'"): validate_transition(TaskStatus.CREATED, TaskStatus.COMPLETED) @@ -122,6 +139,10 @@ def test_terminal_states_have_empty_transitions(self) -> None: assert VALID_TRANSITIONS[TaskStatus.COMPLETED] == frozenset() assert VALID_TRANSITIONS[TaskStatus.CANCELLED] == frozenset() + def test_failed_is_non_terminal(self) -> None: + """FAILED has outgoing transitions (reassignment).""" + assert len(VALID_TRANSITIONS[TaskStatus.FAILED]) > 0 + def test_all_targets_are_valid_statuses(self) -> None: """Every target in the transition map must be a valid TaskStatus.""" for source, targets in VALID_TRANSITIONS.items(): diff --git a/tests/unit/engine/test_agent_engine_errors.py b/tests/unit/engine/test_agent_engine_errors.py index 0838dab740..e12592b060 100644 --- a/tests/unit/engine/test_agent_engine_errors.py +++ b/tests/unit/engine/test_agent_engine_errors.py @@ -6,6 +6,7 @@ import pytest from ai_company.core.agent import AgentIdentity # noqa: TC001 +from ai_company.core.enums import TaskStatus from ai_company.core.task import Task # noqa: TC001 from ai_company.engine.agent_engine import AgentEngine from ai_company.engine.context import AgentContext @@ -14,14 +15,22 @@ TerminationReason, TurnRecord, ) +from ai_company.engine.recovery import ( + FailAndReassignStrategy, + RecoveryResult, +) from ai_company.providers.enums import FinishReason, MessageRole from ai_company.providers.models import ChatMessage if TYPE_CHECKING: + from ai_company.engine.task_execution import TaskExecution + from .conftest import MockCompletionProvider from .conftest import make_completion_response as _make_completion_response +pytestmark = pytest.mark.timeout(30) + @pytest.mark.unit class TestAgentEngineErrorHandling: @@ -350,3 +359,228 @@ async def test_memory_messages_in_context( if m.role == MessageRole.USER and "# Task:" in m.content ) assert sys_idx < mem_idx < task_idx + + +@pytest.mark.unit +class TestAgentEngineRecovery: + """Recovery strategy is invoked on error outcomes.""" + + async def test_provider_error_transitions_task_to_failed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Provider exception -> task status is FAILED.""" + provider = MagicMock() + provider.complete = AsyncMock(side_effect=RuntimeError("LLM is down")) + engine = AgentEngine(provider=provider) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.termination_reason == TerminationReason.ERROR + te = result.execution_result.context.task_execution + assert te is not None + assert te.status is TaskStatus.FAILED + + async def test_recovery_strategy_invoked_on_failure( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Custom recovery strategy's recover() is called on failure.""" + mock_strategy = MagicMock(spec=FailAndReassignStrategy) + # Use actual strategy for the real call, but track it + real_strategy = FailAndReassignStrategy() + mock_strategy.recover = AsyncMock( + side_effect=real_strategy.recover, + ) + mock_strategy.get_strategy_type = MagicMock(return_value="fail_reassign") + + provider = MagicMock() + provider.complete = AsyncMock(side_effect=RuntimeError("crash")) + engine = AgentEngine(provider=provider, recovery_strategy=mock_strategy) + + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + mock_strategy.recover.assert_called_once() + + async def test_recovery_failure_is_swallowed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """If recovery itself fails, engine still returns error result.""" + mock_strategy = MagicMock() + mock_strategy.recover = AsyncMock( + side_effect=ValueError("recovery broken"), + ) + + provider = MagicMock() + provider.complete = AsyncMock(side_effect=RuntimeError("LLM down")) + engine = AgentEngine(provider=provider, recovery_strategy=mock_strategy) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + # Engine still returns an error result, doesn't crash + assert result.termination_reason == TerminationReason.ERROR + assert result.is_success is False + + async def test_no_recovery_when_strategy_is_none( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Opting out of recovery: task stays IN_PROGRESS (not FAILED).""" + provider = MagicMock() + provider.complete = AsyncMock(side_effect=RuntimeError("crash")) + engine = AgentEngine(provider=provider, recovery_strategy=None) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.termination_reason == TerminationReason.ERROR + te = result.execution_result.context.task_execution + assert te is not None + # Without recovery, task stays at IN_PROGRESS (engine transitions + # ASSIGNED->IN_PROGRESS before the loop runs) + assert te.status is TaskStatus.IN_PROGRESS + + async def test_loop_timeout_triggers_recovery( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Wall-clock timeout -> ERROR -> recovery -> FAILED.""" + import asyncio + + async def slow_execute(**_kwargs: object) -> ExecutionResult: + await asyncio.sleep(10) + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + return ExecutionResult( + context=ctx, + termination_reason=TerminationReason.COMPLETED, + ) + + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(side_effect=slow_execute) + mock_loop.get_loop_type = MagicMock(return_value="react") + + provider = mock_provider_factory([]) + engine = AgentEngine( + provider=provider, + execution_loop=mock_loop, + ) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + timeout_seconds=0.01, + ) + + assert result.termination_reason == TerminationReason.ERROR + te = result.execution_result.context.task_execution + assert te is not None + assert te.status is TaskStatus.FAILED + + async def test_custom_recovery_strategy_used( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Engine uses the custom strategy, not the default.""" + custom_results: list[str] = [] + + class CustomRecovery: + async def recover( + self, + *, + task_execution: TaskExecution, + error_message: str, + context: AgentContext, + ) -> RecoveryResult: + custom_results.append("custom_called") + real = FailAndReassignStrategy() + return await real.recover( + task_execution=task_execution, + error_message=error_message, + context=context, + ) + + def get_strategy_type(self) -> str: + return "custom" + + provider = MagicMock() + provider.complete = AsyncMock(side_effect=RuntimeError("crash")) + engine = AgentEngine( + provider=provider, + recovery_strategy=CustomRecovery(), + ) + + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert custom_results == ["custom_called"] + + async def test_memory_error_in_recovery_propagates( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """MemoryError from recovery strategy is not swallowed.""" + mock_strategy = MagicMock() + mock_strategy.recover = AsyncMock(side_effect=MemoryError("OOM")) + + provider = MagicMock() + provider.complete = AsyncMock(side_effect=RuntimeError("crash")) + engine = AgentEngine( + provider=provider, + recovery_strategy=mock_strategy, + ) + + with pytest.raises(MemoryError, match="OOM"): + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + async def test_recursion_error_in_recovery_propagates( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """RecursionError from recovery strategy is not swallowed.""" + mock_strategy = MagicMock() + mock_strategy.recover = AsyncMock( + side_effect=RecursionError("max depth"), + ) + + provider = MagicMock() + provider.complete = AsyncMock(side_effect=RuntimeError("crash")) + engine = AgentEngine( + provider=provider, + recovery_strategy=mock_strategy, + ) + + with pytest.raises(RecursionError, match="max depth"): + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) diff --git a/tests/unit/engine/test_agent_engine_lifecycle.py b/tests/unit/engine/test_agent_engine_lifecycle.py index 4651287c9e..5a5570e7e8 100644 --- a/tests/unit/engine/test_agent_engine_lifecycle.py +++ b/tests/unit/engine/test_agent_engine_lifecycle.py @@ -21,6 +21,8 @@ from .conftest import make_completion_response as _make_completion_response +pytestmark = pytest.mark.timeout(30) + @pytest.mark.unit class TestAgentEnginePostExecutionTransitions: @@ -160,7 +162,7 @@ async def test_budget_exhausted_stays_in_progress( assert te is not None assert te.status == TaskStatus.IN_PROGRESS - async def test_error_stays_in_progress( + async def test_error_transitions_to_failed( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, @@ -193,7 +195,7 @@ async def test_error_stays_in_progress( te = result.execution_result.context.task_execution assert te is not None - assert te.status == TaskStatus.IN_PROGRESS + assert te.status == TaskStatus.FAILED async def test_no_task_execution_passes_through( self, diff --git a/tests/unit/engine/test_recovery.py b/tests/unit/engine/test_recovery.py new file mode 100644 index 0000000000..66c7ed025f --- /dev/null +++ b/tests/unit/engine/test_recovery.py @@ -0,0 +1,252 @@ +"""Tests for crash recovery strategy protocol and FailAndReassignStrategy.""" + +from typing import TYPE_CHECKING + +import pytest +import structlog.testing + +from ai_company.core.enums import TaskStatus, TaskType +from ai_company.core.task import Task +from ai_company.engine.context import AgentContext + +if TYPE_CHECKING: + from ai_company.core.agent import AgentIdentity + +from pydantic import ValidationError + +from ai_company.engine.recovery import ( + FailAndReassignStrategy, + RecoveryResult, + RecoveryStrategy, +) +from ai_company.observability.events.execution import ( + EXECUTION_RECOVERY_COMPLETE, + EXECUTION_RECOVERY_SNAPSHOT, + EXECUTION_RECOVERY_START, +) + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestRecoveryStrategyProtocol: + """FailAndReassignStrategy satisfies the RecoveryStrategy protocol.""" + + def test_is_runtime_checkable(self) -> None: + strategy = FailAndReassignStrategy() + assert isinstance(strategy, RecoveryStrategy) + + def test_get_strategy_type(self) -> None: + strategy = FailAndReassignStrategy() + assert strategy.get_strategy_type() == "fail_reassign" + + +@pytest.mark.unit +class TestFailAndReassignStrategy: + """FailAndReassignStrategy behavior.""" + + async def test_happy_path_transitions_to_failed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Task transitions to FAILED, can_reassign=True when retries remain.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="starting", + ) + assert ctx.task_execution is not None + + strategy = FailAndReassignStrategy() + result = await strategy.recover( + task_execution=ctx.task_execution, + error_message="LLM crashed", + context=ctx, + ) + + assert isinstance(result, RecoveryResult) + assert result.task_execution.status is TaskStatus.FAILED + assert result.strategy_type == "fail_reassign" + assert result.can_reassign is True + assert result.error_message == "LLM crashed" + + async def test_max_retries_exceeded_cannot_reassign( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """can_reassign=False when retry_count >= max_retries.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="starting", + ) + assert ctx.task_execution is not None + + # Simulate retry_count already at max_retries (default=1) + exe_with_retries = ctx.task_execution.model_copy( + update={"retry_count": 1}, + ) + + strategy = FailAndReassignStrategy() + result = await strategy.recover( + task_execution=exe_with_retries, + error_message="Failed again", + context=ctx, + ) + + assert result.can_reassign is False + assert result.task_execution.status is TaskStatus.FAILED + + async def test_zero_max_retries_never_reassignable( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """Task with max_retries=0 is never reassignable.""" + task = Task( + id="task-no-retries", + title="No retries allowed", + description="This task cannot be retried.", + type=TaskType.DEVELOPMENT, + project="proj-001", + created_by="manager", + assigned_to=str(sample_agent_with_personality.id), + status=TaskStatus.ASSIGNED, + max_retries=0, + ) + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=task, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="starting", + ) + assert ctx.task_execution is not None + + strategy = FailAndReassignStrategy() + result = await strategy.recover( + task_execution=ctx.task_execution, + error_message="Crashed", + context=ctx, + ) + + assert result.can_reassign is False + assert result.task_execution.status is TaskStatus.FAILED + + async def test_snapshot_is_redacted( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Snapshot contains metadata but no message contents.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="starting", + ) + assert ctx.task_execution is not None + + strategy = FailAndReassignStrategy() + result = await strategy.recover( + task_execution=ctx.task_execution, + error_message="Provider error", + context=ctx, + ) + + snapshot = result.context_snapshot + assert snapshot.task_id == sample_task_with_criteria.id + assert snapshot.agent_id == str(sample_agent_with_personality.id) + assert snapshot.turn_count >= 0 + # Snapshot is an AgentContextSnapshot — has no message contents + assert not hasattr(snapshot, "conversation") + + async def test_error_message_captured( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Error message is preserved in the result.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="starting", + ) + assert ctx.task_execution is not None + + strategy = FailAndReassignStrategy() + result = await strategy.recover( + task_execution=ctx.task_execution, + error_message="Specific error: connection reset", + context=ctx, + ) + + assert result.error_message == "Specific error: connection reset" + + async def test_recovery_result_frozen( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """RecoveryResult fields cannot be reassigned (frozen model).""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="starting", + ) + assert ctx.task_execution is not None + + strategy = FailAndReassignStrategy() + result = await strategy.recover( + task_execution=ctx.task_execution, + error_message="Crashed", + context=ctx, + ) + + with pytest.raises(ValidationError, match="frozen"): + result.error_message = "changed" # type: ignore[misc] + + async def test_recovery_logs_events( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Recovery emits start, snapshot, and complete events.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="starting", + ) + assert ctx.task_execution is not None + + strategy = FailAndReassignStrategy() + with structlog.testing.capture_logs() as logs: + await strategy.recover( + task_execution=ctx.task_execution, + error_message="Error", + context=ctx, + ) + + events = [entry["event"] for entry in logs] + assert EXECUTION_RECOVERY_START in events + assert EXECUTION_RECOVERY_SNAPSHOT in events + assert EXECUTION_RECOVERY_COMPLETE in events diff --git a/tests/unit/engine/test_task_execution.py b/tests/unit/engine/test_task_execution.py index d77fa0daec..7646551327 100644 --- a/tests/unit/engine/test_task_execution.py +++ b/tests/unit/engine/test_task_execution.py @@ -23,6 +23,8 @@ add_token_usage, ) +pytestmark = pytest.mark.timeout(30) + @pytest.mark.unit class TestStatusTransition: @@ -82,6 +84,37 @@ def test_not_terminal_initially(self, sample_task_with_criteria: Task) -> None: assert exe.is_terminal is False +@pytest.mark.unit +class TestTaskExecutionRetryCount: + """TaskExecution.retry_count field.""" + + def test_retry_count_default_zero(self, sample_task_with_criteria: Task) -> None: + exe = TaskExecution.from_task(sample_task_with_criteria) + assert exe.retry_count == 0 + + def test_from_task_with_retry_count(self, sample_task_with_criteria: Task) -> None: + exe = TaskExecution.from_task(sample_task_with_criteria, retry_count=2) + assert exe.retry_count == 2 + + def test_retry_count_preserved_on_transition( + self, sample_task_with_criteria: Task + ) -> None: + exe = TaskExecution.from_task(sample_task_with_criteria, retry_count=1) + result = exe.with_transition(TaskStatus.IN_PROGRESS) + assert result.retry_count == 1 + + def test_failed_transition_not_terminal( + self, sample_task_with_criteria: Task + ) -> None: + """FAILED does not set completed_at and is_terminal is False.""" + exe = TaskExecution.from_task(sample_task_with_criteria) + in_progress = exe.with_transition(TaskStatus.IN_PROGRESS) + failed = in_progress.with_transition(TaskStatus.FAILED, reason="crash") + assert failed.status is TaskStatus.FAILED + assert failed.completed_at is None + assert failed.is_terminal is False + + @pytest.mark.unit class TestTaskExecutionTransitions: """TaskExecution.with_transition valid and invalid paths."""