diff --git a/CLAUDE.md b/CLAUDE.md index 374867ce42..882c698f70 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -92,7 +92,7 @@ src/ai_company/ communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol config/ # YAML company config loading and validation core/ # Shared domain models, base classes, and resilience config (RetryConfig, RateLimiterConfig) - engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation + engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, prompt policy validation, and AgentEngine-TaskEngine incremental status sync hr/ # HR engine: hiring, firing, onboarding, offboarding, agent registry, performance tracking (task metrics, collaboration scoring, trend detection), promotion/demotion (criteria evaluation, approval strategies, model mapping) memory/ # Persistent agent memory (Mem0 initial, custom stack future — see Decision Log), retrieval pipeline (ranking, injection, context formatting, non-inferable filtering), shared org memory (org/), consolidation/archival (consolidation/) persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (see Memory & Persistence design page) diff --git a/docs/design/engine.md b/docs/design/engine.md index f7b7b08ae5..5f5cb386b3 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -226,6 +226,42 @@ Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop - **stop()**: Sets `_running = False`, drains the queue within a configurable timeout, then cancels. Abandoned futures receive a failure result. +### AgentEngine ↔ TaskEngine Incremental Sync + +`AgentEngine` syncs task status transitions to `TaskEngine` incrementally at +each lifecycle point, rather than reporting only the final status. This gives +real-time visibility into execution progress and improves crash recovery +(a crash mid-execution leaves the task at the last-reached stage, not stuck +at `ASSIGNED`). + +**Transition sequences** (1–3 `submit()` calls per execution, bounded): + +| Path | Synced transitions | +|------|--------------------| +| Happy (COMPLETED) | `IN_PROGRESS` → `IN_REVIEW` → `COMPLETED` | +| Shutdown | `IN_PROGRESS` → `INTERRUPTED` | +| Error | `IN_PROGRESS` → `FAILED` (after recovery) | +| MAX_TURNS / BUDGET | `IN_PROGRESS` only | + +**Semantics:** + +- **Best-effort**: Sync failures are logged and swallowed — agent execution + is never blocked by a TaskEngine issue. Each sync failure is isolated and + does not prevent subsequent transitions. +- **Critical IN_PROGRESS**: The initial `ASSIGNED → IN_PROGRESS` sync is + logged at `ERROR` on failure (TaskEngine state coherence for all subsequent + transitions depends on it). Other sync failures log at `WARNING`. +- **Direct `submit()`**: Uses `TaskEngine.submit()` with + `TransitionTaskMutation` directly (not the convenience `transition_task()` + method) to inspect `TaskMutationResult` success/failure without exception + propagation, keeping sync best-effort. +- **No concurrency concern**: Each task has exactly one executing agent at + any time. Parallel agents operate on separate tasks. + +**Snapshot channel**: TaskEngine publishes `TaskStateChanged` events to the +`"tasks"` channel (matching `CHANNEL_TASKS` in `api.channels`) so events +reach the `MessageBusBridge` and WebSocket consumers. + --- ## Agent Execution Loop @@ -250,8 +286,8 @@ All loop implementations satisfy the `ExecutionLoop` runtime-checkable protocol: **Supporting models:** `TerminationReason` -: Enum: `COMPLETED`, `MAX_TURNS`, `BUDGET_EXHAUSTED`, `SHUTDOWN`, `ERROR`. - `max_turns` defaults to 20. +: Enum: `COMPLETED`, `MAX_TURNS`, `BUDGET_EXHAUSTED`, `SHUTDOWN`, `ERROR`, + `PARKED`. `max_turns` defaults to 20. `TurnRecord` : Frozen per-turn stats (tokens, cost, tool calls, finish reason). @@ -377,7 +413,7 @@ invocation, and cost tracking into a single `run()` call. ```python async run( identity, task, completion_config?, max_turns?, - memory_messages?, timeout_seconds? + memory_messages?, timeout_seconds?, effective_autonomy? ) -> AgentRunResult ``` @@ -421,8 +457,12 @@ async run( - `ERROR` termination: recovery strategy is applied (default `FailAndReassignStrategy` transitions to FAILED; see [Crash Recovery](#agent-crash-recovery)). - - All other termination reasons (`MAX_TURNS`, `BUDGET_EXHAUSTED`) leave the - task in its current state. + - All other termination reasons (`MAX_TURNS`, `BUDGET_EXHAUSTED`, `PARKED`) + leave the task in its current state. `PARKED` indicates the agent was + suspended by an approval-timeout policy; the task remains at its current + status until explicitly resumed. + - Each transition is synced to TaskEngine incrementally (see + [AgentEngine ↔ TaskEngine Incremental Sync](#agentengine--taskengine-incremental-sync)). - Transition failures are logged but do not discard the successful execution result. 11. **Return result** -- wraps `ExecutionResult` in `AgentRunResult` with diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index e470afd57a..17d7579b0e 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING from ai_company.budget.errors import BudgetExhaustedError -from ai_company.core.enums import TaskStatus from ai_company.engine._validation import ( validate_agent, validate_run_inputs, @@ -19,11 +18,7 @@ from ai_company.engine.classification.pipeline import classify_execution_errors from ai_company.engine.context import DEFAULT_MAX_TURNS, AgentContext from ai_company.engine.cost_recording import record_execution_costs -from ai_company.engine.errors import ( - ExecutionStateError, - TaskEngineError, - TaskMutationError, -) +from ai_company.engine.errors import ExecutionStateError from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, @@ -39,6 +34,11 @@ from ai_company.engine.react_loop import ReactLoop from ai_company.engine.recovery import FailAndReassignStrategy, RecoveryStrategy from ai_company.engine.run_result import AgentRunResult +from ai_company.engine.task_sync import ( + apply_post_execution_transitions, + sync_to_task_engine, + transition_task_if_needed, +) from ai_company.observability import get_logger from ai_company.observability.events.execution import ( EXECUTION_ENGINE_BUDGET_STOPPED, @@ -103,23 +103,6 @@ _DEFAULT_RECOVERY_STRATEGY = FailAndReassignStrategy() """Module-level default instance for the recovery strategy.""" -_REPORTABLE_STATUSES: frozenset[TaskStatus] = frozenset( - { - TaskStatus.COMPLETED, - TaskStatus.FAILED, - TaskStatus.INTERRUPTED, - TaskStatus.CANCELLED, - } -) -"""Statuses that trigger a report to the centralized TaskEngine. - -Evaluated after each AgentEngine run. - -Note: ``FAILED`` and ``INTERRUPTED`` are not strictly terminal in the task -lifecycle (they can be reassigned), but represent final outcomes of this -particular ``AgentEngine`` run that should be reported. -""" - class AgentEngine: """Top-level orchestrator for agent execution. @@ -142,8 +125,9 @@ class AgentEngine: enhanced in-flight budget checking. security_config: Optional security subsystem configuration. approval_store: Optional approval queue store. - task_engine: Optional centralized task engine for reporting - final execution status. + task_engine: Optional centralized task engine for real-time + status sync (incremental transitions at each lifecycle + point, best-effort). """ def __init__( # noqa: PLR0913 @@ -247,7 +231,7 @@ async def run( # noqa: PLR0913 task_id=task_id, effective_autonomy=effective_autonomy, ) - ctx, system_prompt = self._prepare_context( + ctx, system_prompt = await self._prepare_context( identity=identity, task=task, agent_id=agent_id, @@ -364,9 +348,10 @@ async def _post_execution_pipeline( agent_id: str, task_id: str, ) -> ExecutionResult: - """Post-execution: costs, transitions, TaskEngine, recovery, classify. + """Post-execution: costs, transitions, recovery, classify. - Best-effort: classification and reporting failures are logged, + Each transition is synced to TaskEngine incrementally + (best-effort). Classification and sync failures are logged, never fatal. """ await record_execution_costs( @@ -376,18 +361,46 @@ async def _post_execution_pipeline( task_id, tracker=self._cost_tracker, ) - execution_result = self._apply_post_execution_transitions( + execution_result = await apply_post_execution_transitions( execution_result, agent_id, task_id, + self._task_engine, ) - await self._report_to_task_engine(execution_result, agent_id, task_id) if execution_result.termination_reason == TerminationReason.ERROR: + pre_recovery_ctx = execution_result.context + pre_recovery_status = ( + pre_recovery_ctx.task_execution.status + if pre_recovery_ctx.task_execution is not None + else None + ) execution_result = await self._apply_recovery( execution_result, agent_id, task_id, ) + # Sync post-recovery status to TaskEngine (typically FAILED, + # depends on recovery strategy). + ctx = execution_result.context + if ( + ctx.task_execution is not None + and pre_recovery_status is not None + and ctx.task_execution.status != pre_recovery_status + ): + logger.info( + EXECUTION_ENGINE_TASK_TRANSITION, + agent_id=agent_id, + task_id=task_id, + from_status=pre_recovery_status.value, + to_status=ctx.task_execution.status.value, + ) + await sync_to_task_engine( + self._task_engine, + target_status=ctx.task_execution.status, + task_id=task_id, + agent_id=agent_id, + reason=f"Post-recovery status: {ctx.task_execution.status.value}", + ) # Classification is non-critical — never destroys a result. if self._error_taxonomy_config is not None: try: @@ -399,12 +412,12 @@ async def _post_execution_pipeline( ) except MemoryError, RecursionError: raise - except Exception: + except Exception as exc: logger.warning( EXECUTION_ENGINE_ERROR, agent_id=agent_id, task_id=task_id, - error="classification failed", + error=f"classification failed: {type(exc).__name__}: {exc}", exc_info=True, ) return execution_result @@ -498,7 +511,7 @@ async def _run_loop_with_timeout( # noqa: PLR0913 # ── Setup ──────────────────────────────────────────────────── - def _prepare_context( # noqa: PLR0913 + async def _prepare_context( # noqa: PLR0913 self, *, identity: AgentIdentity, @@ -536,204 +549,16 @@ def _prepare_context( # noqa: PLR0913 ), ) - ctx = self._transition_task_if_needed(ctx, agent_id, task_id) + ctx = await transition_task_if_needed( + ctx, + agent_id, + task_id, + self._task_engine, + ) return ctx, system_prompt # ── Helpers ────────────────────────────────────────────────── - def _transition_task_if_needed( - self, - ctx: AgentContext, - agent_id: str, - task_id: str, - ) -> AgentContext: - """Transition ASSIGNED -> IN_PROGRESS; pass through IN_PROGRESS.""" - if ( - ctx.task_execution is not None - and ctx.task_execution.status == TaskStatus.ASSIGNED - ): - ctx = ctx.with_task_transition( - TaskStatus.IN_PROGRESS, - reason="Engine starting execution", - ) - logger.info( - EXECUTION_ENGINE_TASK_TRANSITION, - agent_id=agent_id, - task_id=task_id, - from_status=TaskStatus.ASSIGNED.value, - to_status=TaskStatus.IN_PROGRESS.value, - ) - return ctx - - def _apply_post_execution_transitions( - self, - execution_result: ExecutionResult, - agent_id: str, - task_id: str, - ) -> ExecutionResult: - """Apply post-execution task transitions based on termination reason. - - COMPLETED triggers IN_PROGRESS -> IN_REVIEW -> COMPLETED. - SHUTDOWN triggers current status -> INTERRUPTED. - Transition failures are logged but never discard the result. - """ - ctx = execution_result.context - if ctx.task_execution is None: - return execution_result - - reason = execution_result.termination_reason - - if reason == TerminationReason.SHUTDOWN: - return self._transition_to_interrupted( - execution_result, ctx, agent_id, task_id - ) - - if reason != TerminationReason.COMPLETED: - return execution_result - - try: - ctx = self._transition_to_complete(ctx, agent_id, task_id) - except (ValueError, ExecutionStateError) as exc: - logger.exception( - EXECUTION_ENGINE_ERROR, - agent_id=agent_id, - task_id=task_id, - error=f"Post-execution transition failed: {exc}", - ) - return execution_result - - return execution_result.model_copy(update={"context": ctx}) - - def _transition_to_complete( - self, - ctx: AgentContext, - agent_id: str, - task_id: str, - ) -> AgentContext: - """Transition IN_PROGRESS -> IN_REVIEW -> COMPLETED with logging.""" - prev_status = ctx.task_execution.status # type: ignore[union-attr] - ctx = ctx.with_task_transition( - TaskStatus.IN_REVIEW, - reason="Agent completed execution", - ) - logger.info( - EXECUTION_ENGINE_TASK_TRANSITION, - agent_id=agent_id, - task_id=task_id, - from_status=prev_status.value, - to_status=TaskStatus.IN_REVIEW.value, - ) - # TODO: Replace auto-complete with review gate (§6.5) - prev_status = ctx.task_execution.status # type: ignore[union-attr] - ctx = ctx.with_task_transition( - TaskStatus.COMPLETED, - reason="Auto-completed (review gate not implemented)", - ) - logger.info( - EXECUTION_ENGINE_TASK_TRANSITION, - agent_id=agent_id, - task_id=task_id, - from_status=prev_status.value, - to_status=TaskStatus.COMPLETED.value, - ) - return ctx - - def _transition_to_interrupted( - self, - execution_result: ExecutionResult, - ctx: AgentContext, - agent_id: str, - task_id: str, - ) -> ExecutionResult: - """Transition task to INTERRUPTED on graceful shutdown.""" - try: - prev_status = ctx.task_execution.status # type: ignore[union-attr] - ctx = ctx.with_task_transition( - TaskStatus.INTERRUPTED, - reason="Graceful shutdown requested", - ) - logger.info( - EXECUTION_ENGINE_TASK_TRANSITION, - agent_id=agent_id, - task_id=task_id, - from_status=prev_status.value, - to_status=TaskStatus.INTERRUPTED.value, - ) - return execution_result.model_copy(update={"context": ctx}) - except (ValueError, ExecutionStateError) as exc: - logger.exception( - EXECUTION_ENGINE_ERROR, - agent_id=agent_id, - task_id=task_id, - error=f"Post-execution INTERRUPTED transition failed: {exc}", - ) - return execution_result - - async def _report_to_task_engine( - self, - execution_result: ExecutionResult, - agent_id: str, - task_id: str, - ) -> None: - """Report final execution status to the centralized TaskEngine. - - Only reports final execution outcomes (COMPLETED, FAILED, - INTERRUPTED, CANCELLED); other statuses are silently skipped. - - Best-effort: failures are logged and swallowed. If no - ``TaskEngine`` is configured, this is a no-op. - """ - if self._task_engine is None: - return - ctx = execution_result.context - if ctx.task_execution is None: - return - - final_status = ctx.task_execution.status - if final_status not in _REPORTABLE_STATUSES: - return - - try: - # Best-effort: discard return value intentionally — if the - # transition is rejected (e.g. parallel mutation moved the task), - # the exception handlers below log the failure. - _ = await self._task_engine.transition_task( - task_id, - final_status, - requested_by=agent_id, - reason=( - "AgentEngine execution ended: " - f"{execution_result.termination_reason.value}" - ), - ) - except MemoryError, RecursionError: - raise - except TaskMutationError: - logger.warning( - EXECUTION_ENGINE_ERROR, - agent_id=agent_id, - task_id=task_id, - error="Failed to report final status to TaskEngine (mutation rejected)", - exc_info=True, - ) - except TaskEngineError: - logger.error( - EXECUTION_ENGINE_ERROR, - agent_id=agent_id, - task_id=task_id, - error="TaskEngine unavailable for status report", - exc_info=True, - ) - except Exception: - logger.error( - EXECUTION_ENGINE_ERROR, - agent_id=agent_id, - task_id=task_id, - error="Unexpected error reporting to TaskEngine" - " -- state may be divergent", - exc_info=True, - ) - async def _apply_recovery( self, execution_result: ExecutionResult, @@ -979,6 +804,11 @@ async def _handle_fatal_error( # noqa: PLR0913 error=error_msg, ) + pre_fatal_status = ( + ctx.task_execution.status + if ctx is not None and ctx.task_execution is not None + else None + ) try: error_execution = await self._build_error_execution( identity, @@ -988,6 +818,27 @@ async def _handle_fatal_error( # noqa: PLR0913 error_msg, ctx, ) + # Sync fatal-error recovery status to TaskEngine (best-effort). + error_ctx = error_execution.context + if ( + error_ctx.task_execution is not None + and pre_fatal_status is not None + and error_ctx.task_execution.status != pre_fatal_status + ): + logger.info( + EXECUTION_ENGINE_TASK_TRANSITION, + agent_id=agent_id, + task_id=task_id, + from_status=pre_fatal_status.value, + to_status=error_ctx.task_execution.status.value, + ) + await sync_to_task_engine( + self._task_engine, + target_status=error_ctx.task_execution.status, + task_id=task_id, + agent_id=agent_id, + reason=f"Fatal error recovery: {type(exc).__name__}", + ) error_prompt = build_error_prompt( identity, agent_id, diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index ae4cc7d050..57638ea371 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -606,8 +606,12 @@ async def list_tasks( _SNAPSHOT_SENDER: str = "task-engine" """Sender identity used in snapshot ``Message`` envelopes.""" - _SNAPSHOT_CHANNEL: str = "task_engine" - """Message bus channel for snapshot publication.""" + _SNAPSHOT_CHANNEL: str = "tasks" + """Message bus channel for snapshot publication. + + Must match ``CHANNEL_TASKS`` in ``api.channels`` so that events + reach the MessageBusBridge and WebSocket consumers. + """ async def _processing_loop(self) -> None: """Background loop: dequeue and process mutations sequentially. diff --git a/src/ai_company/engine/task_sync.py b/src/ai_company/engine/task_sync.py new file mode 100644 index 0000000000..751d82db79 --- /dev/null +++ b/src/ai_company/engine/task_sync.py @@ -0,0 +1,291 @@ +"""Task status sync — AgentEngine → TaskEngine integration. + +Module-level functions extracted from ``AgentEngine`` to keep the +orchestrator file focused on execution flow. Every function is +best-effort: sync failures are logged and swallowed so agent +execution is never blocked by a ``TaskEngine`` issue. +""" + +import asyncio +from typing import TYPE_CHECKING +from uuid import uuid4 + +from ai_company.core.enums import TaskStatus +from ai_company.engine.errors import ExecutionStateError, TaskEngineError +from ai_company.engine.loop_protocol import TerminationReason +from ai_company.engine.task_engine_models import TransitionTaskMutation +from ai_company.observability import get_logger +from ai_company.observability.events.execution import ( + EXECUTION_ENGINE_ERROR, + EXECUTION_ENGINE_SYNC_FAILED, + EXECUTION_ENGINE_TASK_SYNCED, + EXECUTION_ENGINE_TASK_TRANSITION, +) + +if TYPE_CHECKING: + from ai_company.engine.context import AgentContext + from ai_company.engine.loop_protocol import ExecutionResult + from ai_company.engine.task_engine import TaskEngine + +logger = get_logger(__name__) + +# Stepwise completion transitions: each (target_status, reason) pair +# is applied in order. ``apply_post_execution_transitions`` updates +# ``ctx`` after each step so partial-failure always reflects the +# furthest-reached state. +_COMPLETION_STEPS: tuple[tuple[TaskStatus, str], ...] = ( + (TaskStatus.IN_REVIEW, "Agent completed execution"), + # TODO: Replace auto-complete with review gate (engine.md, step 10) + (TaskStatus.COMPLETED, "Auto-completed (review gate not implemented)"), +) + + +async def sync_to_task_engine( # noqa: PLR0913 + task_engine: TaskEngine | None, + *, + target_status: TaskStatus, + task_id: str, + agent_id: str, + reason: str, + critical: bool = False, +) -> None: + """Sync a status transition to the centralized TaskEngine. + + Best-effort: failures are logged and swallowed so that agent + execution is never blocked by a TaskEngine issue. + ``MemoryError`` and ``RecursionError`` propagate unconditionally. + + Args: + task_engine: The task engine to sync to, or ``None`` (no-op). + target_status: The status to transition to. + task_id: Task identifier. + agent_id: Agent performing the transition. + reason: Human-readable reason for the transition. + critical: If ``True``, sync failure is logged at ERROR level + instead of WARNING (severity only — sync remains best-effort + regardless). + + Raises: + MemoryError: Propagated unconditionally (non-recoverable). + RecursionError: Propagated unconditionally (non-recoverable). + asyncio.CancelledError: Propagated so shutdown can proceed. + """ + if task_engine is None: + return + + try: + mutation = TransitionTaskMutation( + request_id=uuid4().hex, + requested_by=agent_id, + task_id=task_id, + target_status=target_status, + reason=reason, + ) + result = await task_engine.submit(mutation) + except MemoryError, RecursionError, asyncio.CancelledError: + raise + except Exception as exc: + _log_sync_issue( + critical=critical, + agent_id=agent_id, + task_id=task_id, + target_status=target_status, + error=( + "TaskEngine unavailable" + if isinstance(exc, TaskEngineError) + else "Unexpected error syncing to TaskEngine" + ), + exc_info=True, + ) + return + + if result.success: + logger.debug( + EXECUTION_ENGINE_TASK_SYNCED, + agent_id=agent_id, + task_id=task_id, + target_status=target_status.value, + version=result.version, + ) + return + + # Mutation was rejected (e.g. version conflict, invalid + # transition, task not found). + _log_sync_issue( + critical=critical, + agent_id=agent_id, + task_id=task_id, + target_status=target_status, + error=result.error or "Mutation rejected (no error detail)", + error_code=result.error_code, + ) + + +def _log_sync_issue( + *, + critical: bool, + agent_id: str, + task_id: str, + target_status: TaskStatus, + **extra: object, +) -> None: + """Log a sync failure at ERROR (critical) or WARNING severity.""" + common = { + "agent_id": agent_id, + "task_id": task_id, + "target_status": target_status.value, + **extra, + } + if critical: + logger.error(EXECUTION_ENGINE_SYNC_FAILED, **common) + else: + logger.warning(EXECUTION_ENGINE_SYNC_FAILED, **common) + + +async def transition_task_if_needed( + ctx: AgentContext, + agent_id: str, + task_id: str, + task_engine: TaskEngine | None, +) -> AgentContext: + """Transition ASSIGNED -> IN_PROGRESS; pass through IN_PROGRESS. + + Also syncs the transition to TaskEngine (best-effort). + """ + if ( + ctx.task_execution is not None + and ctx.task_execution.status == TaskStatus.ASSIGNED + ): + ctx = await _transition_and_sync( + ctx, + target_status=TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + agent_id=agent_id, + task_id=task_id, + task_engine=task_engine, + critical=True, + ) + return ctx + + +async def apply_post_execution_transitions( + execution_result: ExecutionResult, + agent_id: str, + task_id: str, + task_engine: TaskEngine | None, +) -> ExecutionResult: + """Apply post-execution task transitions based on termination reason. + + COMPLETED triggers IN_PROGRESS -> IN_REVIEW -> COMPLETED. + SHUTDOWN triggers current status -> INTERRUPTED. + Each transition is synced to TaskEngine incrementally. + Transition failures are logged but never discard the result. + ``MemoryError`` and ``RecursionError`` propagate unconditionally. + + Returns the original ``execution_result`` unchanged if no + transitions apply, or a copy with updated context reflecting + the furthest-reached state on success or partial failure. + """ + ctx = execution_result.context + if ctx.task_execution is None: + return execution_result + + reason = execution_result.termination_reason + + if reason == TerminationReason.SHUTDOWN: + return await _transition_to_interrupted( + execution_result, ctx, agent_id, task_id, task_engine + ) + + if reason != TerminationReason.COMPLETED: + return execution_result + + # Apply IN_PROGRESS -> IN_REVIEW -> COMPLETED stepwise so that + # ``ctx`` always reflects the furthest-reached state, even when + # one step raises (partial-completion safety). + for target, step_reason in _COMPLETION_STEPS: + try: + ctx = await _transition_and_sync( + ctx, + target_status=target, + reason=step_reason, + agent_id=agent_id, + task_id=task_id, + task_engine=task_engine, + ) + except (ValueError, ExecutionStateError) as exc: + logger.exception( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error=f"Post-execution transition failed: {exc}", + ) + break + + if ctx is execution_result.context: + return execution_result + return execution_result.model_copy(update={"context": ctx}) + + +async def _transition_and_sync( # noqa: PLR0913 + ctx: AgentContext, + *, + target_status: TaskStatus, + reason: str, + agent_id: str, + task_id: str, + task_engine: TaskEngine | None, + critical: bool = False, +) -> AgentContext: + """Apply a local task transition, log it, and sync to TaskEngine. + + Returns the updated context. The local transition (via + ``with_task_transition``) is applied unconditionally; the remote + sync is best-effort. + """ + prev_status = ctx.task_execution.status # type: ignore[union-attr] + ctx = ctx.with_task_transition(target_status, reason=reason) + logger.info( + EXECUTION_ENGINE_TASK_TRANSITION, + agent_id=agent_id, + task_id=task_id, + from_status=prev_status.value, + to_status=target_status.value, + ) + await sync_to_task_engine( + task_engine, + target_status=target_status, + task_id=task_id, + agent_id=agent_id, + reason=reason, + critical=critical, + ) + return ctx + + +async def _transition_to_interrupted( + execution_result: ExecutionResult, + ctx: AgentContext, + agent_id: str, + task_id: str, + task_engine: TaskEngine | None, +) -> ExecutionResult: + """Transition task to INTERRUPTED on graceful shutdown.""" + try: + ctx = await _transition_and_sync( + ctx, + target_status=TaskStatus.INTERRUPTED, + reason="Graceful shutdown requested", + agent_id=agent_id, + task_id=task_id, + task_engine=task_engine, + ) + return execution_result.model_copy(update={"context": ctx}) + except (ValueError, ExecutionStateError) as exc: + logger.exception( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error=f"Post-execution INTERRUPTED transition failed: {exc}", + ) + return execution_result diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index 700224a46e..e519866d5e 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -35,6 +35,8 @@ EXECUTION_ENGINE_TASK_METRICS: Final[str] = "execution.engine.task_metrics" EXECUTION_ENGINE_TIMEOUT: Final[str] = "execution.engine.timeout" EXECUTION_ENGINE_BUDGET_STOPPED: Final[str] = "execution.engine.budget_stopped" +EXECUTION_ENGINE_TASK_SYNCED: Final[str] = "execution.engine.task_synced" +EXECUTION_ENGINE_SYNC_FAILED: Final[str] = "execution.engine.sync_failed" EXECUTION_SHUTDOWN_SIGNAL: Final[str] = "execution.shutdown.signal" EXECUTION_SHUTDOWN_MANAGER_CREATED: Final[str] = "execution.shutdown.manager_created" diff --git a/tests/unit/engine/test_agent_engine.py b/tests/unit/engine/test_agent_engine.py index c85b370a05..ee15b23adf 100644 --- a/tests/unit/engine/test_agent_engine.py +++ b/tests/unit/engine/test_agent_engine.py @@ -17,7 +17,6 @@ from ai_company.engine.errors import ( ExecutionStateError, TaskEngineError, - TaskMutationError, ) from ai_company.engine.loop_protocol import ( ExecutionResult, @@ -25,6 +24,7 @@ TurnRecord, ) from ai_company.engine.run_result import AgentRunResult +from ai_company.engine.task_engine_models import TaskMutationResult from ai_company.observability.events.prompt import PROMPT_TOKEN_RATIO_HIGH from ai_company.providers.enums import FinishReason @@ -937,9 +937,34 @@ async def test_prompt_token_ratio_warning( # noqa: PLR0913 assert len(warning_events) == 0 +def _make_sync_success( + request_id: str = "test", + version: int = 1, +) -> TaskMutationResult: + """Build a successful TaskMutationResult for sync tests.""" + return TaskMutationResult( + request_id=request_id, + success=True, + version=version, + ) + + +def _make_sync_failure( + request_id: str = "test", + error: str = "rejected", +) -> TaskMutationResult: + """Build a failed TaskMutationResult for sync tests.""" + return TaskMutationResult( + request_id=request_id, + success=False, + error=error, + error_code="validation", + ) + + @pytest.mark.unit -class TestReportToTaskEngine: - """Tests for _report_to_task_engine interaction.""" +class TestSyncToTaskEngine: + """Tests for incremental TaskEngine status sync.""" async def test_no_task_engine_is_noop( self, @@ -947,7 +972,7 @@ async def test_no_task_engine_is_noop( sample_task_with_criteria: Task, mock_provider_factory: type[MockCompletionProvider], ) -> None: - """Without task_engine, run() succeeds and no reporting occurs.""" + """Without task_engine, run() succeeds and no syncing occurs.""" response = _make_completion_response() provider = mock_provider_factory([response]) engine = AgentEngine(provider=provider, task_engine=None) @@ -959,14 +984,44 @@ async def test_no_task_engine_is_noop( assert result.is_success is True - async def test_nonterminal_status_skipped( + async def test_completed_path_produces_three_syncs( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """COMPLETED path syncs IN_PROGRESS, IN_REVIEW, COMPLETED.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.submit = AsyncMock(return_value=_make_sync_success()) + + engine = AgentEngine(provider=provider, task_engine=mock_te) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.is_success is True + assert mock_te.submit.await_count == 3 + synced_statuses = [ + call.args[0].target_status for call in mock_te.submit.call_args_list + ] + assert synced_statuses == [ + TaskStatus.IN_PROGRESS, + TaskStatus.IN_REVIEW, + TaskStatus.COMPLETED, + ] + + async def test_shutdown_path_produces_two_syncs( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, mock_provider_factory: type[MockCompletionProvider], ) -> None: - """Non-terminal task status does not trigger TaskEngine call.""" - # Build a mock loop that returns IN_PROGRESS (non-terminal) + """SHUTDOWN path syncs IN_PROGRESS then INTERRUPTED.""" ctx = AgentContext.from_identity( sample_agent_with_personality, task=sample_task_with_criteria, @@ -977,7 +1032,7 @@ async def test_nonterminal_status_skipped( ) mock_result = ExecutionResult( context=ctx, - termination_reason=TerminationReason.MAX_TURNS, + termination_reason=TerminationReason.SHUTDOWN, turns=( TurnRecord( turn_number=1, @@ -993,14 +1048,13 @@ async def test_nonterminal_status_skipped( mock_loop.get_loop_type = MagicMock(return_value="react") mock_te = MagicMock() - mock_te.transition_task = AsyncMock() + mock_te.submit = AsyncMock(return_value=_make_sync_success()) provider = mock_provider_factory([]) engine = AgentEngine( provider=provider, execution_loop=mock_loop, task_engine=mock_te, - recovery_strategy=None, ) await engine.run( @@ -1008,95 +1062,156 @@ async def test_nonterminal_status_skipped( task=sample_task_with_criteria, ) - mock_te.transition_task.assert_not_awaited() + assert mock_te.submit.await_count == 2 + synced_statuses = [ + call.args[0].target_status for call in mock_te.submit.call_args_list + ] + assert synced_statuses == [ + TaskStatus.IN_PROGRESS, + TaskStatus.INTERRUPTED, + ] - async def test_terminal_status_reported( + async def test_error_path_produces_two_syncs( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, mock_provider_factory: type[MockCompletionProvider], ) -> None: - """COMPLETED status is reported to TaskEngine.""" - response = _make_completion_response() - provider = mock_provider_factory([response]) + """ERROR path syncs IN_PROGRESS then FAILED (after recovery).""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + ) + mock_result = ExecutionResult( + context=ctx, + termination_reason=TerminationReason.ERROR, + error_message="something broke", + turns=( + TurnRecord( + turn_number=1, + input_tokens=10, + output_tokens=5, + cost_usd=0.001, + finish_reason=FinishReason.STOP, + ), + ), + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") mock_te = MagicMock() - mock_te.transition_task = AsyncMock() + mock_te.submit = AsyncMock(return_value=_make_sync_success()) + provider = mock_provider_factory([]) engine = AgentEngine( provider=provider, + execution_loop=mock_loop, task_engine=mock_te, ) - result = await engine.run( + await engine.run( identity=sample_agent_with_personality, task=sample_task_with_criteria, ) - assert result.is_success is True - mock_te.transition_task.assert_awaited_once() - call_args = mock_te.transition_task.call_args - assert call_args.args[0] == sample_task_with_criteria.id - assert call_args.args[1] == TaskStatus.COMPLETED - assert call_args.kwargs["requested_by"] == str( - sample_agent_with_personality.id, - ) + assert mock_te.submit.await_count == 2 + synced_statuses = [ + call.args[0].target_status for call in mock_te.submit.call_args_list + ] + assert synced_statuses == [ + TaskStatus.IN_PROGRESS, + TaskStatus.FAILED, + ] - async def test_mutation_error_swallowed( + async def test_max_turns_syncs_only_in_progress( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, mock_provider_factory: type[MockCompletionProvider], ) -> None: - """TaskMutationError from TaskEngine is logged and swallowed.""" - response = _make_completion_response() - provider = mock_provider_factory([response]) + """MAX_TURNS path: only IN_PROGRESS is synced (no final transition).""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + ) + mock_result = ExecutionResult( + context=ctx, + termination_reason=TerminationReason.MAX_TURNS, + turns=( + TurnRecord( + turn_number=1, + input_tokens=10, + output_tokens=5, + cost_usd=0.001, + finish_reason=FinishReason.STOP, + ), + ), + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") mock_te = MagicMock() - mock_te.transition_task = AsyncMock( - side_effect=TaskMutationError("rejected"), - ) + mock_te.submit = AsyncMock(return_value=_make_sync_success()) + provider = mock_provider_factory([]) engine = AgentEngine( provider=provider, + execution_loop=mock_loop, task_engine=mock_te, + recovery_strategy=None, ) - result = await engine.run( + await engine.run( identity=sample_agent_with_personality, task=sample_task_with_criteria, ) - # Run still succeeds despite task engine failure - assert result.is_success is True + assert mock_te.submit.await_count == 1 + assert ( + mock_te.submit.call_args_list[0].args[0].target_status + == TaskStatus.IN_PROGRESS + ) - async def test_unexpected_error_swallowed( + async def test_sync_failure_isolated_from_subsequent_transitions( self, sample_agent_with_personality: AgentIdentity, sample_task_with_criteria: Task, mock_provider_factory: type[MockCompletionProvider], ) -> None: - """Unexpected Exception from TaskEngine is logged and swallowed.""" + """A failed sync does not block subsequent transitions.""" response = _make_completion_response() provider = mock_provider_factory([response]) + # First call (IN_PROGRESS) fails, rest succeed mock_te = MagicMock() - mock_te.transition_task = AsyncMock( - side_effect=RuntimeError("connection lost"), + mock_te.submit = AsyncMock( + side_effect=[ + _make_sync_failure(), + _make_sync_success(), + _make_sync_success(), + ], ) - engine = AgentEngine( - provider=provider, - task_engine=mock_te, - ) + engine = AgentEngine(provider=provider, task_engine=mock_te) result = await engine.run( identity=sample_agent_with_personality, task=sample_task_with_criteria, ) - # Run still succeeds despite task engine failure + # Run still succeeds despite first sync failure assert result.is_success is True + assert mock_te.submit.await_count == 3 async def test_task_engine_error_swallowed( self, @@ -1104,26 +1219,46 @@ async def test_task_engine_error_swallowed( sample_task_with_criteria: Task, mock_provider_factory: type[MockCompletionProvider], ) -> None: - """TaskEngineError (non-mutation) from TaskEngine is logged and swallowed.""" + """TaskEngineError from submit() is logged and swallowed.""" response = _make_completion_response() provider = mock_provider_factory([response]) mock_te = MagicMock() - mock_te.transition_task = AsyncMock( + mock_te.submit = AsyncMock( side_effect=TaskEngineError("engine unavailable"), ) - engine = AgentEngine( - provider=provider, - task_engine=mock_te, + engine = AgentEngine(provider=provider, task_engine=mock_te) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.is_success is True + + async def test_unexpected_error_swallowed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Unexpected Exception from submit() is logged and swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.submit = AsyncMock( + side_effect=RuntimeError("connection lost"), ) + engine = AgentEngine(provider=provider, task_engine=mock_te) + result = await engine.run( identity=sample_agent_with_personality, task=sample_task_with_criteria, ) - # Run still succeeds despite task engine failure assert result.is_success is True async def test_memory_error_propagates( @@ -1132,22 +1267,51 @@ async def test_memory_error_propagates( sample_task_with_criteria: Task, mock_provider_factory: type[MockCompletionProvider], ) -> None: - """MemoryError from TaskEngine is re-raised, not swallowed.""" + """MemoryError from submit() is re-raised, not swallowed.""" response = _make_completion_response() provider = mock_provider_factory([response]) mock_te = MagicMock() - mock_te.transition_task = AsyncMock( + mock_te.submit = AsyncMock( side_effect=MemoryError("out of memory"), ) - engine = AgentEngine( - provider=provider, - task_engine=mock_te, - ) + engine = AgentEngine(provider=provider, task_engine=mock_te) with pytest.raises(MemoryError, match="out of memory"): await engine.run( identity=sample_agent_with_personality, task=sample_task_with_criteria, ) + + async def test_recursion_error_propagates( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """RecursionError from submit() is re-raised, not swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.submit = AsyncMock( + side_effect=RecursionError("maximum recursion depth exceeded"), + ) + + engine = AgentEngine(provider=provider, task_engine=mock_te) + + with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + +@pytest.mark.unit +def test_snapshot_channel_matches_api_channel() -> None: + """TaskEngine._SNAPSHOT_CHANNEL must match CHANNEL_TASKS in api.channels.""" + from ai_company.api.channels import CHANNEL_TASKS + from ai_company.engine.task_engine import TaskEngine + + assert TaskEngine._SNAPSHOT_CHANNEL == CHANNEL_TASKS diff --git a/tests/unit/engine/test_task_sync.py b/tests/unit/engine/test_task_sync.py new file mode 100644 index 0000000000..670b342ebf --- /dev/null +++ b/tests/unit/engine/test_task_sync.py @@ -0,0 +1,634 @@ +"""Unit tests for task_sync module — AgentEngine → TaskEngine sync functions.""" + +import asyncio +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.context import AgentContext +from ai_company.engine.errors import ExecutionStateError, TaskEngineError +from ai_company.engine.loop_protocol import ( + ExecutionResult, + TerminationReason, + TurnRecord, +) +from ai_company.engine.task_engine_models import ( + TaskErrorCode, + TaskMutationResult, +) +from ai_company.engine.task_sync import ( + apply_post_execution_transitions, + sync_to_task_engine, + transition_task_if_needed, +) +from ai_company.providers.enums import FinishReason + +if TYPE_CHECKING: + from ai_company.core.agent import AgentIdentity + from ai_company.core.task import Task + +pytestmark = pytest.mark.timeout(30) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_sync_success( + request_id: str = "test", + version: int = 1, +) -> TaskMutationResult: + """Build a successful TaskMutationResult for sync tests.""" + return TaskMutationResult( + request_id=request_id, + success=True, + version=version, + ) + + +def _make_sync_failure( + request_id: str = "test", + error: str = "rejected", + error_code: TaskErrorCode = "validation", +) -> TaskMutationResult: + """Build a failed TaskMutationResult for sync tests.""" + return TaskMutationResult( + request_id=request_id, + success=False, + error=error, + error_code=error_code, + ) + + +def _make_mock_task_engine( + side_effect: object | None = None, + return_value: TaskMutationResult | None = None, +) -> MagicMock: + """Build a mock TaskEngine with configurable submit behavior.""" + mock_te = MagicMock() + if side_effect is not None: + mock_te.submit = AsyncMock(side_effect=side_effect) + else: + mock_te.submit = AsyncMock( + return_value=return_value or _make_sync_success(), + ) + return mock_te + + +def _make_execution_result( + ctx: AgentContext, + reason: TerminationReason = TerminationReason.COMPLETED, + error_message: str | None = None, +) -> ExecutionResult: + """Build an ExecutionResult with a single dummy turn.""" + return ExecutionResult( + context=ctx, + termination_reason=reason, + error_message=error_message, + turns=( + TurnRecord( + turn_number=1, + input_tokens=10, + output_tokens=5, + cost_usd=0.001, + finish_reason=FinishReason.STOP, + ), + ), + ) + + +# =================================================================== +# sync_to_task_engine +# =================================================================== + + +@pytest.mark.unit +class TestSyncToTaskEngine: + """Direct tests for the sync_to_task_engine function.""" + + async def test_none_task_engine_is_noop(self) -> None: + """When task_engine is None, nothing happens (no error).""" + await sync_to_task_engine( + None, + target_status=TaskStatus.IN_PROGRESS, + task_id="task-1", + agent_id="agent-1", + reason="test", + ) + # No exception = success + + async def test_successful_sync(self) -> None: + """Successful submit logs debug and returns without error.""" + mock_te = _make_mock_task_engine() + + await sync_to_task_engine( + mock_te, + target_status=TaskStatus.IN_PROGRESS, + task_id="task-1", + agent_id="agent-1", + reason="starting", + ) + + mock_te.submit.assert_awaited_once() + mutation = mock_te.submit.call_args.args[0] + assert mutation.target_status == TaskStatus.IN_PROGRESS + assert mutation.task_id == "task-1" + assert mutation.requested_by == "agent-1" + assert mutation.reason == "starting" + + async def test_rejected_mutation_swallowed(self) -> None: + """A rejected mutation (success=False) is logged, not raised.""" + mock_te = _make_mock_task_engine( + return_value=_make_sync_failure( + error="version conflict", + error_code="version_conflict", + ), + ) + + # Should not raise + await sync_to_task_engine( + mock_te, + target_status=TaskStatus.COMPLETED, + task_id="task-1", + agent_id="agent-1", + reason="completing", + ) + + mock_te.submit.assert_awaited_once() + + async def test_rejected_mutation_empty_error_detail(self) -> None: + """Rejection with empty error uses fallback message.""" + mock_te = _make_mock_task_engine( + return_value=TaskMutationResult( + request_id="test", + success=False, + error="", + error_code="validation", + ), + ) + + await sync_to_task_engine( + mock_te, + target_status=TaskStatus.COMPLETED, + task_id="task-1", + agent_id="agent-1", + reason="completing", + ) + # No exception = fallback message was used for empty string + + async def test_task_engine_error_swallowed(self) -> None: + """TaskEngineError from submit() is logged and swallowed.""" + mock_te = _make_mock_task_engine( + side_effect=TaskEngineError("engine down"), + ) + + await sync_to_task_engine( + mock_te, + target_status=TaskStatus.IN_PROGRESS, + task_id="task-1", + agent_id="agent-1", + reason="test", + ) + + async def test_unexpected_exception_swallowed(self) -> None: + """Unexpected RuntimeError from submit() is swallowed.""" + mock_te = _make_mock_task_engine( + side_effect=RuntimeError("connection lost"), + ) + + await sync_to_task_engine( + mock_te, + target_status=TaskStatus.IN_PROGRESS, + task_id="task-1", + agent_id="agent-1", + reason="test", + ) + + @pytest.mark.parametrize( + ("exc_class", "exc_args"), + [ + (MemoryError, ("out of memory",)), + (RecursionError, ("maximum recursion depth exceeded",)), + (asyncio.CancelledError, ()), + ], + ids=["MemoryError", "RecursionError", "CancelledError"], + ) + async def test_non_swallowed_exception_propagates( + self, + exc_class: type[BaseException], + exc_args: tuple[str, ...], + ) -> None: + """Non-recoverable and cancellation exceptions propagate.""" + mock_te = _make_mock_task_engine( + side_effect=exc_class(*exc_args), + ) + + with pytest.raises(exc_class): + await sync_to_task_engine( + mock_te, + target_status=TaskStatus.IN_PROGRESS, + task_id="task-1", + agent_id="agent-1", + reason="test", + ) + + async def test_critical_flag_logs_at_error_level(self) -> None: + """critical=True escalates log severity to ERROR.""" + mock_te = _make_mock_task_engine( + side_effect=TaskEngineError("unavailable"), + ) + + with patch("ai_company.engine.task_sync.logger") as mock_logger: + await sync_to_task_engine( + mock_te, + target_status=TaskStatus.IN_PROGRESS, + task_id="task-1", + agent_id="agent-1", + reason="test", + critical=True, + ) + + mock_logger.error.assert_called_once() + mock_logger.warning.assert_not_called() + + +# =================================================================== +# transition_task_if_needed +# =================================================================== + + +@pytest.mark.unit +class TestTransitionTaskIfNeeded: + """Tests for ASSIGNED -> IN_PROGRESS pre-execution transition.""" + + async def test_assigned_transitions_to_in_progress( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """ASSIGNED task transitions to IN_PROGRESS and syncs.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + assert ctx.task_execution is not None + assert ctx.task_execution.status == TaskStatus.ASSIGNED + + mock_te = _make_mock_task_engine() + + result_ctx = await transition_task_if_needed( + ctx, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=mock_te, + ) + + assert result_ctx.task_execution is not None + assert result_ctx.task_execution.status == TaskStatus.IN_PROGRESS + mock_te.submit.assert_awaited_once() + assert mock_te.submit.call_args.args[0].target_status == TaskStatus.IN_PROGRESS + + async def test_in_progress_passes_through( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """IN_PROGRESS task is returned as-is (no sync).""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="already started") + + mock_te = _make_mock_task_engine() + + result_ctx = await transition_task_if_needed( + ctx, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=mock_te, + ) + + assert result_ctx.task_execution is not None + assert result_ctx.task_execution.status == TaskStatus.IN_PROGRESS + mock_te.submit.assert_not_awaited() + + async def test_no_task_execution_passes_through( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """Context without task_execution returns unchanged.""" + ctx = AgentContext.from_identity(sample_agent_with_personality) + assert ctx.task_execution is None + + mock_te = _make_mock_task_engine() + + result_ctx = await transition_task_if_needed( + ctx, + agent_id=str(sample_agent_with_personality.id), + task_id="irrelevant", + task_engine=mock_te, + ) + + assert result_ctx is ctx + mock_te.submit.assert_not_awaited() + + async def test_none_task_engine_still_transitions_locally( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Local transition works even when task_engine is None.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + result_ctx = await transition_task_if_needed( + ctx, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=None, + ) + + assert result_ctx.task_execution is not None + assert result_ctx.task_execution.status == TaskStatus.IN_PROGRESS + + +# =================================================================== +# apply_post_execution_transitions +# =================================================================== + + +@pytest.mark.unit +class TestApplyPostExecutionTransitions: + """Tests for post-execution transition logic.""" + + async def test_no_task_execution_returns_unchanged( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """Without task_execution, result is returned as-is.""" + ctx = AgentContext.from_identity(sample_agent_with_personality) + result = _make_execution_result(ctx) + + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id="irrelevant", + task_engine=None, + ) + + assert out is result + + async def test_completed_transitions_through_in_review_to_completed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """COMPLETED termination: IN_PROGRESS -> IN_REVIEW -> COMPLETED.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="started") + result = _make_execution_result(ctx, reason=TerminationReason.COMPLETED) + + mock_te = _make_mock_task_engine() + + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=mock_te, + ) + + assert out.context.task_execution is not None + assert out.context.task_execution.status == TaskStatus.COMPLETED + + # Two syncs: IN_REVIEW and COMPLETED + assert mock_te.submit.await_count == 2 + synced = [call.args[0].target_status for call in mock_te.submit.call_args_list] + assert synced == [TaskStatus.IN_REVIEW, TaskStatus.COMPLETED] + + async def test_shutdown_transitions_to_interrupted( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """SHUTDOWN termination: current status -> INTERRUPTED.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="started") + result = _make_execution_result(ctx, reason=TerminationReason.SHUTDOWN) + + mock_te = _make_mock_task_engine() + + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=mock_te, + ) + + assert out.context.task_execution is not None + assert out.context.task_execution.status == TaskStatus.INTERRUPTED + mock_te.submit.assert_awaited_once() + + @pytest.mark.parametrize( + "reason", + [TerminationReason.MAX_TURNS, TerminationReason.BUDGET_EXHAUSTED], + ids=["MAX_TURNS", "BUDGET_EXHAUSTED"], + ) + async def test_non_completion_reasons_return_unchanged( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + reason: TerminationReason, + ) -> None: + """Non-completion termination reasons leave task state unchanged.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="started") + result = _make_execution_result(ctx, reason=reason) + + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=None, + ) + + assert out is result + + async def test_completed_partial_failure_returns_furthest_state( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Partial failure: IN_REVIEW succeeds locally, COMPLETED raises. + + The stepwise loop in ``apply_post_execution_transitions`` captures + the intermediate ``ctx`` after each successful step. When the + second step raises, the returned context reflects IN_REVIEW + (the furthest-reached state), not the original IN_PROGRESS. + """ + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="started") + result = _make_execution_result(ctx, reason=TerminationReason.COMPLETED) + + # Patch with_task_transition to raise on second call + call_count = 0 + original_transition = AgentContext.with_task_transition + + def patched_transition( + self: AgentContext, + target: TaskStatus, + *, + reason: str = "", + ) -> AgentContext: + nonlocal call_count + call_count += 1 + if call_count == 2: + msg = "Simulated transition failure" + raise ExecutionStateError(msg) + return original_transition(self, target, reason=reason) + + with patch.object(AgentContext, "with_task_transition", patched_transition): + mock_te = _make_mock_task_engine() + + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=mock_te, + ) + + # Context reflects IN_REVIEW — the furthest-reached state + # before the second transition failed. + assert out.context.task_execution is not None + assert out.context.task_execution.status == TaskStatus.IN_REVIEW + # Result is a model_copy, not the bare original + assert out is not result + + async def test_shutdown_transition_failure_returns_original( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """SHUTDOWN: if INTERRUPTED transition fails, original result returned.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="started") + result = _make_execution_result(ctx, reason=TerminationReason.SHUTDOWN) + + def raise_on_transition( + self: AgentContext, + target: TaskStatus, + *, + reason: str = "", + ) -> AgentContext: + msg = "cannot interrupt" + raise ExecutionStateError(msg) + + with patch.object(AgentContext, "with_task_transition", raise_on_transition): + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=None, + ) + + # Original result returned when transition fails + assert out is result + assert out.context.task_execution is not None + assert out.context.task_execution.status == TaskStatus.IN_PROGRESS + + async def test_completed_with_none_task_engine( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """COMPLETED path works with task_engine=None (local only).""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="started") + result = _make_execution_result(ctx, reason=TerminationReason.COMPLETED) + + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=None, + ) + + assert out.context.task_execution is not None + assert out.context.task_execution.status == TaskStatus.COMPLETED + + async def test_sync_failure_does_not_block_transitions( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Sync failures (rejected mutations) don't block local transitions.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="started") + result = _make_execution_result(ctx, reason=TerminationReason.COMPLETED) + + # All syncs fail but local transitions should still complete + mock_te = _make_mock_task_engine( + return_value=_make_sync_failure(), + ) + + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=mock_te, + ) + + assert out.context.task_execution is not None + assert out.context.task_execution.status == TaskStatus.COMPLETED + + async def test_task_engine_exception_does_not_block_transitions( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """TaskEngineError from submit() doesn't block local transitions.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="started") + result = _make_execution_result(ctx, reason=TerminationReason.COMPLETED) + + mock_te = _make_mock_task_engine( + side_effect=TaskEngineError("engine down"), + ) + + out = await apply_post_execution_transitions( + result, + agent_id=str(sample_agent_with_personality.id), + task_id=sample_task_with_criteria.id, + task_engine=mock_te, + ) + + assert out.context.task_execution is not None + assert out.context.task_execution.status == TaskStatus.COMPLETED