From 194ef102b0043394120c976af7843d868c408532 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:17:44 +0100 Subject: [PATCH 01/14] feat: add centralized single-writer TaskEngine (#204) Implement actor-like TaskEngine that owns all task state mutations via asyncio.Queue. A single background task processes mutations sequentially using model_copy(update=...), persists results, and publishes snapshots to the message bus. Reads bypass the queue for direct persistence access. - Add TaskEngine core with start/stop lifecycle, submit/convenience methods - Add 5 mutation types (create, update, transition, delete, cancel) - Add TaskMutationResult response and TaskStateChanged event models - Add TaskEngineConfig (queue size, drain timeout, snapshot toggle) - Add 4 error types (TaskEngineError, NotRunning, Mutation, VersionConflict) - Wire into API controllers, AppState, app lifecycle, and config - Add optional AgentEngine report-back for terminal task status - Add 57 unit tests covering all mutations, ordering, versioning, drain Closes #204 --- src/ai_company/api/app.py | 39 +- src/ai_company/api/controllers/tasks.py | 107 ++- src/ai_company/api/state.py | 31 + src/ai_company/config/defaults.py | 1 + src/ai_company/config/schema.py | 5 + src/ai_company/engine/__init__.py | 32 + src/ai_company/engine/agent_engine.py | 54 +- src/ai_company/engine/errors.py | 16 + src/ai_company/engine/task_engine.py | 744 ++++++++++++++++ src/ai_company/engine/task_engine_config.py | 37 + src/ai_company/engine/task_engine_models.py | 250 ++++++ .../observability/events/task_engine.py | 16 + tests/unit/api/conftest.py | 24 +- tests/unit/api/test_app.py | 4 +- tests/unit/engine/test_task_engine.py | 838 ++++++++++++++++++ tests/unit/engine/test_task_engine_models.py | 312 +++++++ tests/unit/observability/test_events.py | 1 + 17 files changed, 2455 insertions(+), 56 deletions(-) create mode 100644 src/ai_company/engine/task_engine.py create mode 100644 src/ai_company/engine/task_engine_config.py create mode 100644 src/ai_company/engine/task_engine_models.py create mode 100644 src/ai_company/observability/events/task_engine.py create mode 100644 tests/unit/engine/test_task_engine.py create mode 100644 tests/unit/engine/test_task_engine_models.py diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 2d0eb7435d..8edddcbfc9 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -35,6 +35,7 @@ from ai_company.communication.bus_protocol import MessageBus # noqa: TC001 from ai_company.config.schema import RootConfig from ai_company.core.approval import ApprovalItem # noqa: TC001 +from ai_company.engine.task_engine import TaskEngine # noqa: TC001 from ai_company.observability import get_logger from ai_company.observability.events.api import ( API_APP_SHUTDOWN, @@ -102,6 +103,7 @@ def _build_lifecycle( persistence: PersistenceBackend | None, message_bus: MessageBus | None, bridge: MessageBusBridge | None, + task_engine: TaskEngine | None, app_state: AppState, ) -> tuple[ Sequence[Callable[[], Awaitable[None]]], @@ -115,11 +117,17 @@ def _build_lifecycle( async def on_startup() -> None: logger.info(API_APP_STARTUP, version=__version__) - await _safe_startup(persistence, message_bus, bridge, app_state) + await _safe_startup( + persistence, + message_bus, + bridge, + task_engine, + app_state, + ) async def on_shutdown() -> None: logger.info(API_APP_SHUTDOWN, version=__version__) - await _safe_shutdown(bridge, message_bus, persistence) + await _safe_shutdown(bridge, task_engine, message_bus, persistence) return [on_startup], [on_shutdown] @@ -196,9 +204,10 @@ async def _safe_startup( persistence: PersistenceBackend | None, message_bus: MessageBus | None, bridge: MessageBusBridge | None, + task_engine: TaskEngine | None, app_state: AppState, ) -> None: - """Connect persistence, resolve JWT secret, start message bus and bridge. + """Connect persistence, resolve JWT secret, start bus, bridge, task engine. Executes in order; on failure, cleans up already-started components in reverse order before re-raising. @@ -239,6 +248,15 @@ async def _safe_startup( error="Failed to start message bus bridge", ) raise + if task_engine is not None: + try: + task_engine.start() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to start task engine", + ) + raise except Exception: await _cleanup_on_failure( persistence=persistence, @@ -251,10 +269,11 @@ async def _safe_startup( async def _safe_shutdown( bridge: MessageBusBridge | None, + task_engine: TaskEngine | None, message_bus: MessageBus | None, persistence: PersistenceBackend | None, ) -> None: - """Stop bridge, message bus and disconnect persistence.""" + """Stop bridge, task engine, message bus and disconnect persistence.""" if bridge is not None: try: await bridge.stop() @@ -263,6 +282,14 @@ async def _safe_shutdown( API_APP_SHUTDOWN, error="Failed to stop message bus bridge", ) + if task_engine is not None: + try: + await task_engine.stop() + except Exception: + logger.exception( + API_APP_SHUTDOWN, + error="Failed to stop task engine", + ) if message_bus is not None: try: await message_bus.stop() @@ -289,6 +316,7 @@ def create_app( # noqa: PLR0913 cost_tracker: CostTracker | None = None, approval_store: ApprovalStore | None = None, auth_service: AuthService | None = None, + task_engine: TaskEngine | None = None, ) -> Litestar: """Create and configure the Litestar application. @@ -302,6 +330,7 @@ def create_app( # noqa: PLR0913 cost_tracker: Cost tracking service. approval_store: Approval queue store. auth_service: Pre-built auth service (for testing). + task_engine: Centralized task state engine. Returns: Configured Litestar application. @@ -330,6 +359,7 @@ def create_app( # noqa: PLR0913 cost_tracker=cost_tracker, approval_store=effective_approval_store, auth_service=auth_service, + task_engine=task_engine, startup_time=time.monotonic(), ) @@ -347,6 +377,7 @@ def create_app( # noqa: PLR0913 persistence, message_bus, bridge, + task_engine, app_state, ) diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index c309969585..463aaa7970 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -1,6 +1,4 @@ -"""Task controller — full CRUD via TaskRepository.""" - -from uuid import uuid4 +"""Task controller — full CRUD via TaskEngine.""" from litestar import Controller, delete, get, patch, post from litestar.datastructures import State # noqa: TC002 @@ -17,7 +15,9 @@ from ai_company.api.pagination import PaginationLimit, PaginationOffset, paginate from ai_company.api.state import AppState # noqa: TC001 from ai_company.core.enums import TaskStatus # noqa: TC001 -from ai_company.core.task import Task +from ai_company.core.task import Task # noqa: TC001 +from ai_company.engine.errors import TaskMutationError +from ai_company.engine.task_engine_models import CreateTaskData from ai_company.observability import get_logger from ai_company.observability.events.api import ( API_RESOURCE_NOT_FOUND, @@ -33,7 +33,7 @@ class TaskController(Controller): - """Full CRUD for tasks via ``TaskRepository``.""" + """Full CRUD for tasks via ``TaskEngine``.""" path = "/tasks" tags = ("tasks",) @@ -63,7 +63,7 @@ async def list_tasks( # noqa: PLR0913 Paginated task list. """ app_state: AppState = state.app_state - tasks = await app_state.persistence.tasks.list_tasks( + tasks = await app_state.task_engine.list_tasks( status=status, assigned_to=assigned_to, project=project, @@ -90,7 +90,7 @@ async def get_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - task = await app_state.persistence.tasks.get(task_id) + task = await app_state.task_engine.get_task(task_id) if task is None: msg = f"Task {task_id!r} not found" logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) @@ -113,9 +113,7 @@ async def create_task( Created task envelope. """ app_state: AppState = state.app_state - task_id = f"task-{uuid4().hex}" - task = Task( - id=task_id, + task_data = CreateTaskData( title=data.title, description=data.description, type=data.type, @@ -126,7 +124,10 @@ async def create_task( estimated_complexity=data.estimated_complexity, budget_limit=data.budget_limit, ) - await app_state.persistence.tasks.save(task) + task = await app_state.task_engine.create_task( + task_data, + requested_by=data.created_by, + ) logger.info( TASK_CREATED, task_id=task.id, @@ -155,17 +156,23 @@ async def update_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - task = await app_state.persistence.tasks.get(task_id) - if task is None: - msg = f"Task {task_id!r} not found" - logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) - raise NotFoundError(msg) - updates = data.model_dump(exclude_none=True) - if updates: - task = task.model_copy(update=updates) - await app_state.persistence.tasks.save(task) - logger.info(API_TASK_UPDATED, task_id=task_id, fields=list(updates)) + try: + task = await app_state.task_engine.update_task( + task_id, + updates, + requested_by="api", + ) + except TaskMutationError as exc: + if "not found" in str(exc): + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(str(exc)) from exc + raise + logger.info(API_TASK_UPDATED, task_id=task_id, fields=list(updates)) return ApiResponse(data=task) @post( @@ -192,33 +199,35 @@ async def transition_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - task = await app_state.persistence.tasks.get(task_id) - if task is None: - msg = f"Task {task_id!r} not found" - logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) - raise NotFoundError(msg) - - overrides: dict[str, object] = {} - if data.assigned_to is not None: - overrides["assigned_to"] = data.assigned_to - try: - new_task = task.with_transition(data.target_status, **overrides) - except ValueError as exc: + task = await app_state.task_engine.transition_task( + task_id, + data.target_status, + requested_by="api", + reason=f"API transition to {data.target_status.value}", + assigned_to=data.assigned_to, + ) + except TaskMutationError as exc: + error_str = str(exc) + if "not found" in error_str: + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(error_str) from exc logger.warning( TASK_STATUS_CHANGED, task_id=task_id, - error=str(exc), + error=error_str, ) - raise ApiValidationError(str(exc)) from exc - await app_state.persistence.tasks.save(new_task) + raise ApiValidationError(error_str) from exc logger.info( TASK_STATUS_CHANGED, task_id=task_id, - from_status=task.status.value, - to_status=new_task.status.value, + to_status=task.status.value, ) - return ApiResponse(data=new_task) + return ApiResponse(data=task) @delete("/{task_id:str}", guards=[require_write_access], status_code=200) async def delete_task( @@ -239,10 +248,20 @@ async def delete_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - deleted = await app_state.persistence.tasks.delete(task_id) - if not deleted: - msg = f"Task {task_id!r} not found" - logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) - raise NotFoundError(msg) + try: + await app_state.task_engine.delete_task( + task_id, + requested_by="api", + ) + except TaskMutationError as exc: + if "not found" in str(exc): + msg = f"Task {task_id!r} not found" + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(msg) from exc + raise logger.info(API_TASK_DELETED, task_id=task_id) return ApiResponse(data=None) diff --git a/src/ai_company/api/state.py b/src/ai_company/api/state.py index 2ffd80dab2..cd113f4619 100644 --- a/src/ai_company/api/state.py +++ b/src/ai_company/api/state.py @@ -11,6 +11,7 @@ from ai_company.budget.tracker import CostTracker # noqa: TC001 from ai_company.communication.bus_protocol import MessageBus # noqa: TC001 from ai_company.config.schema import RootConfig # noqa: TC001 +from ai_company.engine.task_engine import TaskEngine # noqa: TC001 from ai_company.observability import get_logger from ai_company.observability.events.api import API_APP_STARTUP, API_SERVICE_UNAVAILABLE from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 @@ -39,6 +40,7 @@ class AppState: "_cost_tracker", "_message_bus", "_persistence", + "_task_engine", "approval_store", "config", "startup_time", @@ -53,6 +55,7 @@ def __init__( # noqa: PLR0913 message_bus: MessageBus | None = None, cost_tracker: CostTracker | None = None, auth_service: AuthService | None = None, + task_engine: TaskEngine | None = None, startup_time: float = 0.0, ) -> None: self.config = config @@ -61,6 +64,7 @@ def __init__( # noqa: PLR0913 self._message_bus = message_bus self._cost_tracker = cost_tracker self._auth_service = auth_service + self._task_engine = task_engine self.startup_time = startup_time def _require_service[T](self, service: T | None, name: str) -> T: @@ -99,6 +103,33 @@ def auth_service(self) -> AuthService: """Return auth service or raise 503.""" return self._require_service(self._auth_service, "auth_service") + @property + def task_engine(self) -> TaskEngine: + """Return task engine or raise 503.""" + return self._require_service(self._task_engine, "task_engine") + + @property + def has_task_engine(self) -> bool: + """Check whether the task engine is already configured.""" + return self._task_engine is not None + + def set_task_engine(self, engine: TaskEngine) -> None: + """Set the task engine (deferred initialisation). + + Called once during startup after persistence is connected. + + Args: + engine: Fully configured task engine. + + Raises: + RuntimeError: If the task engine was already configured. + """ + if self._task_engine is not None: + msg = "Task engine already configured" + logger.error(API_APP_STARTUP, error=msg) + raise RuntimeError(msg) + self._task_engine = engine + @property def has_auth_service(self) -> bool: """Check whether the auth service is already configured.""" diff --git a/src/ai_company/config/defaults.py b/src/ai_company/config/defaults.py index 02145cd16e..579d1fecbd 100644 --- a/src/ai_company/config/defaults.py +++ b/src/ai_company/config/defaults.py @@ -40,4 +40,5 @@ def default_config_dict() -> dict[str, Any]: "security": {}, "trust": {}, "promotion": {}, + "task_engine": {}, } diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index 39beb30bc2..d73869a67f 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -21,6 +21,7 @@ from ai_company.core.resilience_config import RateLimiterConfig, RetryConfig from ai_company.core.role import CustomRole # noqa: TC001 from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.engine.task_engine_config import TaskEngineConfig from ai_company.hr.promotion.config import PromotionConfig from ai_company.memory.config import CompanyMemoryConfig from ai_company.memory.org.config import OrgMemoryConfig @@ -521,6 +522,10 @@ class RootConfig(BaseModel): default_factory=PromotionConfig, description="Promotion/demotion configuration", ) + task_engine: TaskEngineConfig = Field( + default_factory=TaskEngineConfig, + description="Task engine configuration", + ) @model_validator(mode="after") def _validate_unique_agent_names(self) -> Self: diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index d9d43940c8..0d6565aed9 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -67,7 +67,11 @@ PromptBuildError, ResourceConflictError, TaskAssignmentError, + TaskEngineError, + TaskEngineNotRunningError, + TaskMutationError, TaskRoutingError, + TaskVersionConflictError, WorkspaceCleanupError, WorkspaceError, WorkspaceLimitError, @@ -128,6 +132,19 @@ ShutdownResult, ShutdownStrategy, ) +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskData, + CreateTaskMutation, + DeleteTaskMutation, + TaskMutation, + TaskMutationResult, + TaskStateChanged, + TransitionTaskMutation, + UpdateTaskMutation, +) from ai_company.engine.task_execution import StatusTransition, TaskExecution from ai_company.engine.workspace import ( MergeConflict, @@ -168,10 +185,13 @@ "AuctionAssignmentStrategy", "AutoTopologyConfig", "BudgetChecker", + "CancelTaskMutation", "ClassificationResult", "CleanupCallback", "CooperativeTimeoutStrategy", "CostOptimizedAssignmentStrategy", + "CreateTaskData", + "CreateTaskMutation", "DecompositionContext", "DecompositionCycleError", "DecompositionDepthError", @@ -181,6 +201,7 @@ "DecompositionService", "DecompositionStrategy", "DefaultTokenEstimator", + "DeleteTaskMutation", "DependencyGraph", "EngineError", "ErrorFinding", @@ -239,13 +260,24 @@ "TaskAssignmentService", "TaskAssignmentStrategy", "TaskCompletionMetrics", + "TaskEngine", + "TaskEngineConfig", + "TaskEngineError", + "TaskEngineNotRunningError", "TaskExecution", + "TaskMutation", + "TaskMutationError", + "TaskMutationResult", "TaskRoutingError", "TaskRoutingService", + "TaskStateChanged", "TaskStructureClassifier", + "TaskVersionConflictError", "TerminationReason", "TopologySelector", + "TransitionTaskMutation", "TurnRecord", + "UpdateTaskMutation", "Workspace", "WorkspaceCleanupError", "WorkspaceError", diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 3915fe78e5..a032b748d4 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -84,6 +84,7 @@ ExecutionLoop, ShutdownChecker, ) + from ai_company.engine.task_engine import TaskEngine from ai_company.providers.models import CompletionConfig from ai_company.providers.protocol import CompletionProvider from ai_company.security.config import SecurityConfig @@ -133,6 +134,7 @@ def __init__( # noqa: PLR0913 budget_enforcer: BudgetEnforcer | None = None, security_config: SecurityConfig | None = None, approval_store: ApprovalStore | None = None, + task_engine: TaskEngine | None = None, ) -> None: self._provider = provider self._loop: ExecutionLoop = execution_loop or ReactLoop() @@ -154,6 +156,7 @@ def __init__( # noqa: PLR0913 self._cost_tracker = cost_tracker self._security_config = security_config self._approval_store = approval_store + self._task_engine = task_engine self._recovery_strategy = recovery_strategy self._shutdown_checker = shutdown_checker self._error_taxonomy_config = error_taxonomy_config @@ -336,7 +339,7 @@ async def _post_execution_pipeline( agent_id: str, task_id: str, ) -> ExecutionResult: - """Record costs, apply transitions, run recovery and classify.""" + """Record costs, apply transitions, report to TaskEngine.""" await record_execution_costs( execution_result, identity, @@ -349,6 +352,7 @@ async def _post_execution_pipeline( agent_id, task_id, ) + await self._report_to_task_engine(execution_result, agent_id, task_id) if execution_result.termination_reason == TerminationReason.ERROR: execution_result = await self._apply_recovery( execution_result, @@ -636,6 +640,54 @@ def _transition_to_interrupted( ) return execution_result + async def _report_to_task_engine( + self, + execution_result: ExecutionResult, + agent_id: str, + task_id: str, + ) -> None: + """Report final execution status to the centralized TaskEngine. + + Best-effort: failures are logged and swallowed. If no + ``TaskEngine`` is configured, this is a no-op. + """ + if self._task_engine is None: + return + ctx = execution_result.context + if ctx.task_execution is None: + return + + final_status = ctx.task_execution.status + terminal = { + TaskStatus.COMPLETED, + TaskStatus.FAILED, + TaskStatus.INTERRUPTED, + TaskStatus.CANCELLED, + } + if final_status not in terminal: + return + + try: + await self._task_engine.transition_task( + task_id, + final_status, + requested_by=agent_id, + reason=( + "AgentEngine execution ended: " + f"{execution_result.termination_reason.value}" + ), + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error="Failed to report final status to TaskEngine", + exc_info=True, + ) + async def _apply_recovery( self, execution_result: ExecutionResult, diff --git a/src/ai_company/engine/errors.py b/src/ai_company/engine/errors.py index 7f5fdb7bce..c4101439c7 100644 --- a/src/ai_company/engine/errors.py +++ b/src/ai_company/engine/errors.py @@ -80,3 +80,19 @@ class WorkspaceCleanupError(WorkspaceError): class WorkspaceLimitError(WorkspaceError): """Raised when maximum concurrent workspaces reached.""" + + +class TaskEngineError(EngineError): + """Base exception for all task engine errors.""" + + +class TaskEngineNotRunningError(TaskEngineError): + """Raised when a mutation is submitted to a stopped task engine.""" + + +class TaskMutationError(TaskEngineError): + """Raised when a task mutation fails (not found, validation, etc.).""" + + +class TaskVersionConflictError(TaskMutationError): + """Raised when optimistic concurrency version does not match.""" diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py new file mode 100644 index 0000000000..318a2a7ffa --- /dev/null +++ b/src/ai_company/engine/task_engine.py @@ -0,0 +1,744 @@ +"""Centralized single-writer task engine. + +Owns all task state mutations via an ``asyncio.Queue``. A single +background task consumes mutation requests sequentially, applies +``model_copy(update=...)`` on frozen ``Task`` models, persists the +result, and publishes snapshots to the message bus. + +Reads bypass the queue and go directly to persistence — this is safe +because the TaskEngine is the only writer. +""" + +import asyncio +import contextlib +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +from ai_company.core.enums import TaskStatus +from ai_company.core.task import Task +from ai_company.engine.errors import ( + TaskEngineNotRunningError, + TaskMutationError, + TaskVersionConflictError, +) +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskData, + CreateTaskMutation, + DeleteTaskMutation, + TaskMutation, + TaskMutationResult, + TaskStateChanged, + TransitionTaskMutation, + UpdateTaskMutation, +) +from ai_company.observability import get_logger +from ai_company.observability.events.task_engine import ( + TASK_ENGINE_DRAIN_COMPLETE, + TASK_ENGINE_DRAIN_START, + TASK_ENGINE_DRAIN_TIMEOUT, + TASK_ENGINE_MUTATION_APPLIED, + TASK_ENGINE_MUTATION_FAILED, + TASK_ENGINE_MUTATION_RECEIVED, + TASK_ENGINE_NOT_RUNNING, + TASK_ENGINE_QUEUE_FULL, + TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED, + TASK_ENGINE_SNAPSHOT_PUBLISHED, + TASK_ENGINE_STARTED, + TASK_ENGINE_STOPPED, +) + +if TYPE_CHECKING: + from ai_company.communication.bus_protocol import MessageBus + from ai_company.persistence.protocol import PersistenceBackend + +logger = get_logger(__name__) + + +@dataclass +class _MutationEnvelope: + """Pairs a mutation request with its response future.""" + + mutation: TaskMutation + future: asyncio.Future[TaskMutationResult] = field( + default_factory=lambda: asyncio.get_running_loop().create_future(), + ) + + +class TaskEngine: + """Centralized single-writer for all task state mutations. + + Uses an actor-like pattern: a single background ``asyncio.Task`` + consumes ``TaskMutation`` requests from an ``asyncio.Queue``, + applies each mutation sequentially, persists the result, and + publishes state-change snapshots. + + Args: + persistence: Backend for task storage. + message_bus: Optional bus for snapshot publication. + config: Engine configuration. + """ + + def __init__( + self, + *, + persistence: PersistenceBackend, + message_bus: MessageBus | None = None, + config: TaskEngineConfig | None = None, + ) -> None: + self._persistence = persistence + self._message_bus = message_bus + self._config = config or TaskEngineConfig() + self._queue: asyncio.Queue[_MutationEnvelope] = asyncio.Queue( + maxsize=self._config.max_queue_size, + ) + self._versions: dict[str, int] = {} + self._processing_task: asyncio.Task[None] | None = None + self._running = False + + # ── Lifecycle ───────────────────────────────────────────── + + def start(self) -> None: + """Spawn the background processing loop. + + Raises: + RuntimeError: If already running. + """ + if self._running: + msg = "TaskEngine is already running" + raise RuntimeError(msg) + self._running = True + self._processing_task = asyncio.create_task( + self._processing_loop(), + name="task-engine-loop", + ) + logger.info( + TASK_ENGINE_STARTED, + max_queue_size=self._config.max_queue_size, + ) + + async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 + """Stop the engine and drain pending mutations. + + Args: + timeout: Seconds to wait for drain. Defaults to + ``config.drain_timeout_seconds``. + """ + if not self._running: + return + self._running = False + effective_timeout = ( + timeout if timeout is not None else self._config.drain_timeout_seconds + ) + + if self._processing_task is not None: + logger.info( + TASK_ENGINE_DRAIN_START, + pending=self._queue.qsize(), + timeout_seconds=effective_timeout, + ) + try: + await asyncio.wait_for( + self._processing_task, + timeout=effective_timeout, + ) + logger.info(TASK_ENGINE_DRAIN_COMPLETE) + except TimeoutError: + logger.warning( + TASK_ENGINE_DRAIN_TIMEOUT, + remaining=self._queue.qsize(), + ) + self._processing_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._processing_task + self._processing_task = None + + logger.info(TASK_ENGINE_STOPPED) + + @property + def is_running(self) -> bool: + """Whether the engine is accepting mutations.""" + return self._running + + # ── Submit & convenience methods ────────────────────────── + + async def submit(self, mutation: TaskMutation) -> TaskMutationResult: + """Submit a mutation and await its result. + + Args: + mutation: The mutation to apply. + + Returns: + Result of the mutation. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + """ + if not self._running: + logger.warning( + TASK_ENGINE_NOT_RUNNING, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + ) + msg = "TaskEngine is not running" + raise TaskEngineNotRunningError(msg) + + envelope = _MutationEnvelope(mutation=mutation) + try: + self._queue.put_nowait(envelope) + except asyncio.QueueFull: + logger.warning( + TASK_ENGINE_QUEUE_FULL, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + queue_size=self._queue.qsize(), + ) + msg = "TaskEngine queue is full" + raise TaskEngineNotRunningError(msg) from None + + return await envelope.future + + async def create_task( + self, + data: CreateTaskData, + *, + requested_by: str, + ) -> Task: + """Convenience: create a task and return the created Task. + + Args: + data: Task creation data. + requested_by: Identity of the requester. + + Returns: + The created task. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskMutationError: If the mutation fails. + """ + mutation = CreateTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_data=data, + ) + result = await self.submit(mutation) + if not result.success: + raise TaskMutationError(result.error or "Create failed") + assert result.task is not None # noqa: S101 + return result.task + + async def update_task( + self, + task_id: str, + updates: dict[str, object], + *, + requested_by: str, + expected_version: int | None = None, + ) -> Task: + """Convenience: update task fields and return the updated Task. + + Args: + task_id: Target task identifier. + updates: Field-value pairs to apply. + requested_by: Identity of the requester. + expected_version: Optional optimistic concurrency version. + + Returns: + The updated task. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskMutationError: If the mutation fails. + """ + mutation = UpdateTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + updates=updates, + expected_version=expected_version, + ) + result = await self.submit(mutation) + if not result.success: + raise TaskMutationError(result.error or "Update failed") + assert result.task is not None # noqa: S101 + return result.task + + async def transition_task( + self, + task_id: str, + target_status: TaskStatus, + *, + requested_by: str, + reason: str = "", + expected_version: int | None = None, + **overrides: object, + ) -> Task: + """Convenience: transition task status and return the updated Task. + + Args: + task_id: Target task identifier. + target_status: Desired target status. + requested_by: Identity of the requester. + reason: Reason for the transition. + expected_version: Optional optimistic concurrency version. + **overrides: Additional field overrides for the transition. + + Returns: + The transitioned task. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskMutationError: If the mutation fails. + """ + effective_reason = reason or f"Transition to {target_status.value}" + mutation = TransitionTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + target_status=target_status, + reason=effective_reason, + overrides=dict(overrides), + expected_version=expected_version, + ) + result = await self.submit(mutation) + if not result.success: + raise TaskMutationError(result.error or "Transition failed") + assert result.task is not None # noqa: S101 + return result.task + + async def delete_task( + self, + task_id: str, + *, + requested_by: str, + ) -> bool: + """Convenience: delete a task and return success. + + Args: + task_id: Target task identifier. + requested_by: Identity of the requester. + + Returns: + ``True`` if the task was deleted. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskMutationError: If the mutation fails. + """ + mutation = DeleteTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + ) + result = await self.submit(mutation) + if not result.success: + raise TaskMutationError(result.error or "Delete failed") + return True + + async def cancel_task( + self, + task_id: str, + *, + requested_by: str, + reason: str, + ) -> Task: + """Convenience: cancel a task and return the cancelled Task. + + Args: + task_id: Target task identifier. + requested_by: Identity of the requester. + reason: Reason for cancellation. + + Returns: + The cancelled task. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskMutationError: If the mutation fails. + """ + mutation = CancelTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + reason=reason, + ) + result = await self.submit(mutation) + if not result.success: + raise TaskMutationError(result.error or "Cancel failed") + assert result.task is not None # noqa: S101 + return result.task + + # ── Read-through (bypass queue) ─────────────────────────── + + async def get_task(self, task_id: str) -> Task | None: + """Read a task directly from persistence (bypass queue). + + Args: + task_id: Task identifier. + + Returns: + The task, or ``None`` if not found. + """ + return await self._persistence.tasks.get(task_id) + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + """List tasks directly from persistence (bypass queue). + + Args: + status: Filter by status. + assigned_to: Filter by assignee. + project: Filter by project. + + Returns: + Matching tasks as a tuple. + """ + return await self._persistence.tasks.list_tasks( + status=status, + assigned_to=assigned_to, + project=project, + ) + + # ── Background processing ───────────────────────────────── + + async def _processing_loop(self) -> None: + """Background loop: dequeue and process mutations sequentially.""" + while self._running or not self._queue.empty(): + try: + envelope = await asyncio.wait_for( + self._queue.get(), + timeout=0.5, + ) + except TimeoutError: + continue + await self._process_one(envelope) + + async def _process_one(self, envelope: _MutationEnvelope) -> None: + """Process a single mutation envelope.""" + mutation = envelope.mutation + logger.debug( + TASK_ENGINE_MUTATION_RECEIVED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + ) + try: + result = await self._apply_mutation(mutation) + envelope.future.set_result(result) + if result.success and self._config.publish_snapshots: + await self._publish_snapshot(mutation, result) + except Exception as exc: + error_msg = f"{type(exc).__name__}: {exc}" + logger.exception( + TASK_ENGINE_MUTATION_FAILED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + error=error_msg, + ) + if not envelope.future.done(): + envelope.future.set_result( + TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=error_msg, + ), + ) + + async def _apply_mutation(self, mutation: TaskMutation) -> TaskMutationResult: + """Dispatch and apply a mutation by type.""" + match mutation: + case CreateTaskMutation(): + return await self._apply_create(mutation) + case UpdateTaskMutation(): + return await self._apply_update(mutation) + case TransitionTaskMutation(): + return await self._apply_transition(mutation) + case DeleteTaskMutation(): + return await self._apply_delete(mutation) + case CancelTaskMutation(): + return await self._apply_cancel(mutation) + + async def _apply_create( + self, + mutation: CreateTaskMutation, + ) -> TaskMutationResult: + """Create a new task.""" + data = mutation.task_data + task_id = f"task-{uuid4().hex}" + + task = Task( + id=task_id, + title=data.title, + description=data.description, + type=data.type, + priority=data.priority, + project=data.project, + created_by=data.created_by, + assigned_to=data.assigned_to, + estimated_complexity=data.estimated_complexity, + budget_limit=data.budget_limit, + ) + await self._persistence.tasks.save(task) + self._versions[task_id] = 1 + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="create", + request_id=mutation.request_id, + task_id=task_id, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=task, + version=1, + ) + + async def _apply_update( + self, + mutation: UpdateTaskMutation, + ) -> TaskMutationResult: + """Update task fields.""" + task = await self._persistence.tasks.get(mutation.task_id) + if task is None: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=f"Task {mutation.task_id!r} not found", + ) + + self._check_version(mutation.task_id, mutation.expected_version) + + if not mutation.updates: + version = self._versions.get(mutation.task_id, 0) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=task, + version=version, + ) + + updated = task.model_copy(update=mutation.updates) + await self._persistence.tasks.save(updated) + version = self._bump_version(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="update", + request_id=mutation.request_id, + task_id=mutation.task_id, + fields=list(mutation.updates), + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + ) + + async def _apply_transition( + self, + mutation: TransitionTaskMutation, + ) -> TaskMutationResult: + """Perform a task status transition.""" + task = await self._persistence.tasks.get(mutation.task_id) + if task is None: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=f"Task {mutation.task_id!r} not found", + ) + + self._check_version(mutation.task_id, mutation.expected_version) + previous_status = task.status + + try: + updated = task.with_transition( + mutation.target_status, + **mutation.overrides, + ) + except ValueError as exc: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + ) + + await self._persistence.tasks.save(updated) + version = self._bump_version(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="transition", + request_id=mutation.request_id, + task_id=mutation.task_id, + from_status=previous_status.value, + to_status=mutation.target_status.value, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + ) + + async def _apply_delete( + self, + mutation: DeleteTaskMutation, + ) -> TaskMutationResult: + """Delete a task.""" + deleted = await self._persistence.tasks.delete(mutation.task_id) + if not deleted: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=f"Task {mutation.task_id!r} not found", + ) + + self._versions.pop(mutation.task_id, None) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="delete", + request_id=mutation.request_id, + task_id=mutation.task_id, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + version=0, + ) + + async def _apply_cancel( + self, + mutation: CancelTaskMutation, + ) -> TaskMutationResult: + """Cancel a task (shortcut for transition to CANCELLED).""" + task = await self._persistence.tasks.get(mutation.task_id) + if task is None: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=f"Task {mutation.task_id!r} not found", + ) + + previous_status = task.status + try: + updated = task.with_transition(TaskStatus.CANCELLED) + except ValueError as exc: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + ) + + await self._persistence.tasks.save(updated) + version = self._bump_version(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="cancel", + request_id=mutation.request_id, + task_id=mutation.task_id, + from_status=previous_status.value, + to_status=TaskStatus.CANCELLED.value, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + ) + + # ── Snapshot publishing ─────────────────────────────────── + + async def _publish_snapshot( + self, + mutation: TaskMutation, + result: TaskMutationResult, + ) -> None: + """Publish a TaskStateChanged event to the message bus. + + Best-effort: failures are logged and swallowed. + """ + if self._message_bus is None: + return + + new_status = ( + None + if isinstance(mutation, DeleteTaskMutation) + else (result.task.status if result.task else None) + ) + + event = TaskStateChanged( + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + requested_by=mutation.requested_by, + task=result.task, + previous_status=None, + new_status=new_status, + version=result.version, + timestamp=datetime.now(UTC), + ) + + try: + from ai_company.communication.enums import MessageType # noqa: PLC0415 + from ai_company.communication.message import Message # noqa: PLC0415 + + msg = Message( + timestamp=datetime.now(UTC), + sender="task-engine", + to="task_engine", + type=MessageType.TASK_UPDATE, + channel="task_engine", + content=event.model_dump_json(), + ) + await self._message_bus.publish(msg) + logger.debug( + TASK_ENGINE_SNAPSHOT_PUBLISHED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + ) + except Exception: + logger.warning( + TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + exc_info=True, + ) + + # ── Version tracking ────────────────────────────────────── + + def _bump_version(self, task_id: str) -> int: + """Increment and return the version counter for a task.""" + version = self._versions.get(task_id, 0) + 1 + self._versions[task_id] = version + return version + + def _check_version( + self, + task_id: str, + expected_version: int | None, + ) -> None: + """Check optimistic concurrency version if provided. + + Raises: + TaskVersionConflictError: If versions don't match. + """ + if expected_version is None: + return + current = self._versions.get(task_id, 0) + if current != expected_version: + msg = ( + f"Version conflict for task {task_id!r}: " + f"expected {expected_version}, current {current}" + ) + raise TaskVersionConflictError(msg) diff --git a/src/ai_company/engine/task_engine_config.py b/src/ai_company/engine/task_engine_config.py new file mode 100644 index 0000000000..e0e4e5e4da --- /dev/null +++ b/src/ai_company/engine/task_engine_config.py @@ -0,0 +1,37 @@ +"""Task engine configuration model.""" + +from pydantic import BaseModel, ConfigDict, Field + + +class TaskEngineConfig(BaseModel): + """Configuration for the centralized task engine. + + Controls queue sizing, drain behaviour on shutdown, and whether + state-change snapshots are published to the message bus. + + Attributes: + max_queue_size: Maximum pending mutations before backpressure + is applied. ``0`` means unbounded. + drain_timeout_seconds: Seconds to wait for pending mutations + to drain during ``stop()``. + publish_snapshots: Whether to publish ``TaskStateChanged`` + events to the message bus after each mutation. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + max_queue_size: int = Field( + default=1000, + ge=0, + description="Maximum pending mutations (0 = unbounded)", + ) + drain_timeout_seconds: float = Field( + default=10.0, + gt=0, + le=300, + description="Seconds to wait for drain during stop()", + ) + publish_snapshots: bool = Field( + default=True, + description="Publish TaskStateChanged to message bus", + ) diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py new file mode 100644 index 0000000000..d151a53d39 --- /dev/null +++ b/src/ai_company/engine/task_engine_models.py @@ -0,0 +1,250 @@ +"""Task engine request, response, and event models. + +All mutation requests are frozen Pydantic models, discriminated by a +``mutation_type`` literal. Each request carries a ``request_id`` and +``requested_by`` field for tracing and auditing. +""" + +from datetime import UTC, datetime +from typing import Literal + +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field + +from ai_company.core.enums import Complexity, Priority, TaskStatus, TaskType +from ai_company.core.task import Task # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 + +# ── Mutation data ───────────────────────────────────────────── + + +class CreateTaskData(BaseModel): + """Data required to create a new task (server-generated fields excluded). + + Mirrors :class:`~ai_company.api.dto.CreateTaskRequest` but lives in + the engine layer so it has no dependency on the API. + + Attributes: + title: Short task title. + description: Detailed task description. + type: Task work type. + priority: Task priority level. + project: Project ID. + created_by: Agent name of the creator. + assigned_to: Optional assignee agent ID. + estimated_complexity: Complexity estimate. + budget_limit: Maximum USD spend. + """ + + model_config = ConfigDict(frozen=True) + + title: NotBlankStr = Field(description="Short task title") + description: NotBlankStr = Field(description="Detailed task description") + type: TaskType = Field(description="Task work type") + priority: Priority = Field(default=Priority.MEDIUM, description="Task priority") + project: NotBlankStr = Field(description="Project ID") + created_by: NotBlankStr = Field(description="Agent name of the creator") + assigned_to: NotBlankStr | None = Field( + default=None, + description="Assignee agent ID", + ) + estimated_complexity: Complexity = Field( + default=Complexity.MEDIUM, + description="Complexity estimate", + ) + budget_limit: float = Field( + default=0.0, + ge=0.0, + description="Maximum USD spend", + ) + + +# ── Mutation requests ───────────────────────────────────────── + + +class CreateTaskMutation(BaseModel): + """Request to create a new task. + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_data: Task creation payload. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["create"] = "create" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_data: CreateTaskData = Field(description="Task creation payload") + + +class UpdateTaskMutation(BaseModel): + """Request to update task fields. + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_id: Target task identifier. + updates: Field-value pairs to apply. + expected_version: Optional optimistic concurrency version. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["update"] = "update" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Target task identifier") + updates: dict[str, object] = Field(description="Field-value pairs to apply") + expected_version: int | None = Field( + default=None, + ge=1, + description="Optional optimistic concurrency version", + ) + + +class TransitionTaskMutation(BaseModel): + """Request to perform a task status transition. + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_id: Target task identifier. + target_status: Desired target status. + reason: Reason for the transition. + overrides: Additional field overrides for the transition. + expected_version: Optional optimistic concurrency version. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["transition"] = "transition" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Target task identifier") + target_status: TaskStatus = Field(description="Desired target status") + reason: NotBlankStr = Field(description="Reason for the transition") + overrides: dict[str, object] = Field( + default_factory=dict, + description="Additional field overrides", + ) + expected_version: int | None = Field( + default=None, + ge=1, + description="Optional optimistic concurrency version", + ) + + +class DeleteTaskMutation(BaseModel): + """Request to delete a task. + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_id: Target task identifier. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["delete"] = "delete" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Target task identifier") + + +class CancelTaskMutation(BaseModel): + """Request to cancel a task (shortcut for transition to CANCELLED). + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_id: Target task identifier. + reason: Reason for cancellation. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["cancel"] = "cancel" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Target task identifier") + reason: NotBlankStr = Field(description="Reason for cancellation") + + +TaskMutation = ( + CreateTaskMutation + | UpdateTaskMutation + | TransitionTaskMutation + | DeleteTaskMutation + | CancelTaskMutation +) +"""Union of all task mutation request types.""" + + +# ── Mutation result ─────────────────────────────────────────── + + +class TaskMutationResult(BaseModel): + """Result of a processed task mutation. + + Attributes: + request_id: Echoed request identifier. + success: Whether the mutation succeeded. + task: The task after mutation (``None`` on delete or failure). + version: Current version counter for the task. + error: Error description (``None`` on success). + """ + + model_config = ConfigDict(frozen=True) + + request_id: NotBlankStr = Field(description="Echoed request identifier") + success: bool = Field(description="Whether the mutation succeeded") + task: Task | None = Field(default=None, description="Task after mutation") + version: int = Field(default=0, ge=0, description="Version counter") + error: str | None = Field(default=None, description="Error description") + + +# ── State-change event ──────────────────────────────────────── + + +class TaskStateChanged(BaseModel): + """Event published to the message bus after each successful mutation. + + Attributes: + mutation_type: Type of mutation that triggered the event. + request_id: Originating request identifier. + requested_by: Identity of the requester. + task: Task snapshot after mutation (``None`` on delete). + previous_status: Status before the mutation (``None`` on create). + new_status: Status after the mutation (``None`` on delete). + version: Version counter after mutation. + timestamp: When the mutation was applied. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: str = Field(description="Mutation type that triggered event") + request_id: NotBlankStr = Field(description="Originating request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task: Task | None = Field( + default=None, + description="Task snapshot after mutation", + ) + previous_status: TaskStatus | None = Field( + default=None, + description="Status before mutation", + ) + new_status: TaskStatus | None = Field( + default=None, + description="Status after mutation", + ) + version: int = Field(ge=0, description="Version counter after mutation") + timestamp: AwareDatetime = Field( + default_factory=lambda: datetime.now(UTC), + description="When the mutation was applied", + ) diff --git a/src/ai_company/observability/events/task_engine.py b/src/ai_company/observability/events/task_engine.py new file mode 100644 index 0000000000..1519d281d2 --- /dev/null +++ b/src/ai_company/observability/events/task_engine.py @@ -0,0 +1,16 @@ +"""Task engine event constants.""" + +from typing import Final + +TASK_ENGINE_STARTED: Final[str] = "task_engine.started" +TASK_ENGINE_STOPPED: Final[str] = "task_engine.stopped" +TASK_ENGINE_MUTATION_RECEIVED: Final[str] = "task_engine.mutation.received" +TASK_ENGINE_MUTATION_APPLIED: Final[str] = "task_engine.mutation.applied" +TASK_ENGINE_MUTATION_FAILED: Final[str] = "task_engine.mutation.failed" +TASK_ENGINE_SNAPSHOT_PUBLISHED: Final[str] = "task_engine.snapshot.published" +TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED: Final[str] = "task_engine.snapshot.publish_failed" +TASK_ENGINE_QUEUE_FULL: Final[str] = "task_engine.queue.full" +TASK_ENGINE_DRAIN_START: Final[str] = "task_engine.drain.start" +TASK_ENGINE_DRAIN_COMPLETE: Final[str] = "task_engine.drain.complete" +TASK_ENGINE_DRAIN_TIMEOUT: Final[str] = "task_engine.drain.timeout" +TASK_ENGINE_NOT_RUNNING: Final[str] = "task_engine.not_running" diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index 32baeee0ce..aa7befec2b 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -2,6 +2,7 @@ import asyncio import uuid +from collections.abc import Generator # noqa: TC003 from datetime import UTC, datetime, timedelta from typing import Any @@ -26,6 +27,7 @@ TaskStatus, ) from ai_company.core.task import Task +from ai_company.engine.task_engine import TaskEngine from ai_company.persistence.errors import DuplicateRecordError, QueryError from ai_company.security.models import AuditEntry, AuditVerdictStr # noqa: TC001 from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 @@ -604,6 +606,16 @@ def root_config() -> RootConfig: return RootConfig(company_name="test-company") +@pytest.fixture +def fake_task_engine( + fake_persistence: FakePersistenceBackend, +) -> TaskEngine: + """TaskEngine backed by the shared fake persistence.""" + return TaskEngine( + persistence=fake_persistence, + ) + + @pytest.fixture def test_client( # noqa: PLR0913 fake_persistence: FakePersistenceBackend, @@ -612,7 +624,8 @@ def test_client( # noqa: PLR0913 approval_store: ApprovalStore, root_config: RootConfig, auth_service: AuthService, -) -> TestClient[Any]: + fake_task_engine: TaskEngine, +) -> Generator[TestClient[Any]]: # Pre-seed users for each role so JWT sub claims resolve _seed_test_users(fake_persistence, auth_service) @@ -623,11 +636,12 @@ def test_client( # noqa: PLR0913 cost_tracker=cost_tracker, approval_store=approval_store, auth_service=auth_service, + task_engine=fake_task_engine, ) - client = TestClient(app) - # Default: CEO token (most tests need write access) - client.headers.update(make_auth_headers("ceo")) - return client + with TestClient(app) as client: + # Default: CEO token (most tests need write access) + client.headers.update(make_auth_headers("ceo")) + yield client def _seed_test_users( diff --git a/tests/unit/api/test_app.py b/tests/unit/api/test_app.py index be172178fb..114dfb89db 100644 --- a/tests/unit/api/test_app.py +++ b/tests/unit/api/test_app.py @@ -63,7 +63,7 @@ async def failing_start() -> None: bus.start = failing_start # type: ignore[method-assign] with pytest.raises(RuntimeError, match="bus boom"): - await _safe_startup(persistence, bus, None, app_state) + await _safe_startup(persistence, bus, None, None, app_state) # Persistence should have been disconnected during cleanup assert not persistence.is_connected @@ -81,4 +81,4 @@ async def failing_disconnect() -> None: persistence.disconnect = failing_disconnect # type: ignore[method-assign] # Should not raise even when disconnect fails - await _safe_shutdown(None, None, persistence) + await _safe_shutdown(None, None, None, persistence) diff --git a/tests/unit/engine/test_task_engine.py b/tests/unit/engine/test_task_engine.py new file mode 100644 index 0000000000..6db8cbd8f7 --- /dev/null +++ b/tests/unit/engine/test_task_engine.py @@ -0,0 +1,838 @@ +"""Tests for the centralized single-writer TaskEngine.""" + +import asyncio +from collections.abc import AsyncGenerator # noqa: TC003 + +import pytest + +from ai_company.core.enums import ( + TaskStatus, + TaskType, +) +from ai_company.core.task import Task # noqa: TC001 +from ai_company.engine.errors import ( + TaskEngineNotRunningError, + TaskMutationError, +) +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CreateTaskData, + CreateTaskMutation, + DeleteTaskMutation, + UpdateTaskMutation, +) + +# ── Fakes ───────────────────────────────────────────────────── + + +class FakeTaskRepository: + """Minimal in-memory task repository for engine tests.""" + + def __init__(self) -> None: + self._tasks: dict[str, Task] = {} + + async def save(self, task: Task) -> None: + self._tasks[task.id] = task + + async def get(self, task_id: str) -> Task | None: + return self._tasks.get(task_id) + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + result = list(self._tasks.values()) + if status is not None: + result = [t for t in result if t.status == status] + if assigned_to is not None: + result = [t for t in result if t.assigned_to == assigned_to] + if project is not None: + result = [t for t in result if t.project == project] + return tuple(result) + + async def delete(self, task_id: str) -> bool: + return self._tasks.pop(task_id, None) is not None + + +class FakePersistence: + """Minimal fake persistence backend with only a task repository.""" + + def __init__(self) -> None: + self._tasks = FakeTaskRepository() + + @property + def tasks(self) -> FakeTaskRepository: + return self._tasks + + +class FakeMessageBus: + """Minimal fake message bus that records published messages.""" + + def __init__(self) -> None: + self.published: list[object] = [] + self._running = False + + async def start(self) -> None: + self._running = True + + async def stop(self) -> None: + self._running = False + + @property + def is_running(self) -> bool: + return self._running + + async def publish(self, message: object) -> None: + self.published.append(message) + + +class FailingMessageBus(FakeMessageBus): + """Message bus that always fails on publish.""" + + async def publish(self, message: object) -> None: + msg = "Publish failed" + raise RuntimeError(msg) + + +# ── Fixtures ────────────────────────────────────────────────── + + +def _make_create_data(**overrides: object) -> CreateTaskData: + """Build a CreateTaskData with sensible defaults.""" + defaults: dict[str, object] = { + "title": "Test task", + "description": "A test task", + "type": TaskType.DEVELOPMENT, + "project": "test-project", + "created_by": "test-agent", + } + defaults.update(overrides) + return CreateTaskData(**defaults) # type: ignore[arg-type] + + +@pytest.fixture +def persistence() -> FakePersistence: + return FakePersistence() + + +@pytest.fixture +def message_bus() -> FakeMessageBus: + return FakeMessageBus() + + +@pytest.fixture +def config() -> TaskEngineConfig: + return TaskEngineConfig(max_queue_size=100) + + +@pytest.fixture +async def engine( + persistence: FakePersistence, + config: TaskEngineConfig, +) -> AsyncGenerator[TaskEngine]: + """Create and start a TaskEngine, stop on teardown.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + yield eng + await eng.stop(timeout=2.0) + + +@pytest.fixture +async def engine_with_bus( + persistence: FakePersistence, + message_bus: FakeMessageBus, + config: TaskEngineConfig, +) -> AsyncGenerator[TaskEngine]: + """Create and start a TaskEngine with a message bus.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + yield eng + await eng.stop(timeout=2.0) + + +# ── Lifecycle tests ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestTaskEngineLifecycle: + """Tests for start/stop lifecycle.""" + + async def test_start_sets_running( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + assert eng.is_running is False + eng.start() + assert eng.is_running is True + await eng.stop(timeout=2.0) # type: ignore[unreachable] + assert eng.is_running is False + + async def test_double_start_raises( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + with pytest.raises(RuntimeError, match="already running"): + eng.start() + await eng.stop(timeout=2.0) + + async def test_stop_idempotent( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + await eng.stop(timeout=2.0) + await eng.stop(timeout=2.0) # no error + + async def test_restart( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + await eng.stop(timeout=2.0) + eng.start() + assert eng.is_running is True + await eng.stop(timeout=2.0) + + +# ── Submit to stopped engine ────────────────────────────────── + + +@pytest.mark.unit +class TestSubmitToStoppedEngine: + """Submitting to a stopped engine raises TaskEngineNotRunningError.""" + + async def test_submit_raises( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + with pytest.raises(TaskEngineNotRunningError): + await eng.submit(mutation) + + +# ── Create mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestCreateTask: + """Tests for task creation via TaskEngine.""" + + async def test_create_task( + self, + engine: TaskEngine, + persistence: FakePersistence, + ) -> None: + task = await engine.create_task( + _make_create_data(title="My Task"), + requested_by="alice", + ) + assert task.title == "My Task" + assert task.id.startswith("task-") + assert task.status == TaskStatus.CREATED + + stored = await persistence.tasks.get(task.id) + assert stored is not None + assert stored.title == "My Task" + + async def test_create_returns_version_1( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.version == 1 + + async def test_create_with_assignee( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(assigned_to=None), + requested_by="alice", + ) + assert task.assigned_to is None + + +# ── Update mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestUpdateTask: + """Tests for task update via TaskEngine.""" + + async def test_update_fields( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(title="Original"), + requested_by="alice", + ) + updated = await engine.update_task( + task.id, + {"title": "Updated"}, + requested_by="alice", + ) + assert updated.title == "Updated" + assert updated.id == task.id + + async def test_update_empty_no_op( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + result = await engine.update_task( + task.id, + {}, + requested_by="alice", + ) + assert result.title == task.title + + async def test_update_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.update_task( + "task-nonexistent", + {"title": "X"}, + requested_by="alice", + ) + + +# ── Transition mutation ─────────────────────────────────────── + + +@pytest.mark.unit +class TestTransitionTask: + """Tests for task status transitions via TaskEngine.""" + + async def test_valid_transition( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + assigned = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + assert assigned.status == TaskStatus.ASSIGNED + assert assigned.assigned_to == "bob" + + async def test_invalid_transition( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskMutationError): + await engine.transition_task( + task.id, + TaskStatus.COMPLETED, + requested_by="alice", + reason="Skip to done", + assigned_to="bob", + ) + + async def test_transition_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.transition_task( + "task-nonexistent", + TaskStatus.ASSIGNED, + requested_by="alice", + reason="test", + ) + + +# ── Delete mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestDeleteTask: + """Tests for task deletion via TaskEngine.""" + + async def test_delete_task( + self, + engine: TaskEngine, + persistence: FakePersistence, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + deleted = await engine.delete_task(task.id, requested_by="alice") + assert deleted is True + + stored = await persistence.tasks.get(task.id) + assert stored is None + + async def test_delete_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.delete_task( + "task-nonexistent", + requested_by="alice", + ) + + +# ── Cancel mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestCancelTask: + """Tests for task cancellation via TaskEngine.""" + + async def test_cancel_assigned_task( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + assigned = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + cancelled = await engine.cancel_task( + assigned.id, + requested_by="alice", + reason="No longer needed", + ) + assert cancelled.status == TaskStatus.CANCELLED + + async def test_cancel_from_created_fails( + self, + engine: TaskEngine, + ) -> None: + """CREATED -> CANCELLED is not a valid transition.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskMutationError): + await engine.cancel_task( + task.id, + requested_by="alice", + reason="Oops", + ) + + +# ── Read-through ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestReadThrough: + """Tests for read-through methods that bypass the queue.""" + + async def test_get_task( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(title="Findme"), + requested_by="alice", + ) + found = await engine.get_task(task.id) + assert found is not None + assert found.title == "Findme" + + async def test_get_task_not_found( + self, + engine: TaskEngine, + ) -> None: + result = await engine.get_task("task-nonexistent") + assert result is None + + async def test_list_tasks( + self, + engine: TaskEngine, + ) -> None: + await engine.create_task( + _make_create_data(project="proj-a"), + requested_by="alice", + ) + await engine.create_task( + _make_create_data(project="proj-b"), + requested_by="alice", + ) + all_tasks = await engine.list_tasks() + assert len(all_tasks) == 2 + + filtered = await engine.list_tasks(project="proj-a") + assert len(filtered) == 1 + + async def test_list_tasks_by_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + + created = await engine.list_tasks(status=TaskStatus.CREATED) + assigned = await engine.list_tasks(status=TaskStatus.ASSIGNED) + assert len(created) == 0 + assert len(assigned) == 1 + + +# ── Version tracking ────────────────────────────────────────── + + +@pytest.mark.unit +class TestVersionTracking: + """Tests for the in-memory version counter.""" + + async def test_version_increments( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + r1 = await engine.submit(mutation) + assert r1.version == 1 + + update = UpdateTaskMutation( + request_id="req-2", + requested_by="alice", + task_id=r1.task.id, # type: ignore[union-attr] + updates={"title": "Updated"}, + ) + r2 = await engine.submit(update) + assert r2.version == 2 + + async def test_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # version is 1 after create; expected_version=99 should fail + update = UpdateTaskMutation( + request_id="req-2", + requested_by="alice", + task_id=task.id, + updates={"title": "X"}, + expected_version=99, + ) + result = await engine.submit(update) + assert result.success is False + assert "conflict" in (result.error or "").lower() + + async def test_version_reset_on_delete( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + delete = DeleteTaskMutation( + request_id="req-3", + requested_by="alice", + task_id=task.id, + ) + result = await engine.submit(delete) + assert result.version == 0 + + +# ── Snapshot publishing ─────────────────────────────────────── + + +@pytest.mark.unit +class TestSnapshotPublishing: + """Tests for event publishing to the message bus.""" + + async def test_snapshot_published_on_create( + self, + engine_with_bus: TaskEngine, + message_bus: FakeMessageBus, + ) -> None: + await engine_with_bus.create_task( + _make_create_data(), + requested_by="alice", + ) + # Give the processing loop time to publish + await asyncio.sleep(0.1) + assert len(message_bus.published) == 1 + + async def test_snapshot_publish_failure_does_not_affect_mutation( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + failing_bus = FailingMessageBus() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=failing_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + task = await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + assert task.id.startswith("task-") + + stored = await persistence.tasks.get(task.id) + assert stored is not None + finally: + await eng.stop(timeout=2.0) + + async def test_no_snapshot_when_disabled( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + ) -> None: + no_snap_config = TaskEngineConfig(publish_snapshots=False) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=no_snap_config, + ) + eng.start() + try: + await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + await asyncio.sleep(0.1) + assert len(message_bus.published) == 0 + finally: + await eng.stop(timeout=2.0) + + +# ── Sequential ordering ────────────────────────────────────── + + +@pytest.mark.unit +class TestSequentialOrdering: + """Tests that mutations are processed sequentially.""" + + async def test_concurrent_submits( + self, + engine: TaskEngine, + ) -> None: + """Multiple concurrent creates all succeed without interleaving.""" + tasks = await asyncio.gather( + *( + engine.create_task( + _make_create_data(title=f"Task {i}"), + requested_by="alice", + ) + for i in range(10) + ), + ) + assert len(tasks) == 10 + ids = {t.id for t in tasks} + assert len(ids) == 10 # all unique + + +# ── Drain on stop ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestDrainOnStop: + """Tests that stop() drains pending mutations.""" + + async def test_pending_mutations_processed( + self, + persistence: FakePersistence, + ) -> None: + config = TaskEngineConfig(max_queue_size=100) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + + # Submit several mutations + results = await asyncio.gather( + *( + eng.create_task( + _make_create_data(title=f"Drain {i}"), + requested_by="alice", + ) + for i in range(5) + ), + ) + assert len(results) == 5 + + await eng.stop(timeout=5.0) + assert eng.is_running is False + + # All tasks should be persisted + all_tasks = await persistence.tasks.list_tasks() + assert len(all_tasks) == 5 + + +# ── Queue full ──────────────────────────────────────────────── + + +@pytest.mark.unit +class TestQueueFull: + """Tests for queue full backpressure.""" + + async def test_queue_full_raises( + self, + persistence: FakePersistence, + ) -> None: + tiny_config = TaskEngineConfig(max_queue_size=1) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=tiny_config, + ) + # Start the engine but pause the processing loop + eng._running = True + + # First submit fills the queue + mutation1 = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + eng._queue.put_nowait( + __import__( + "ai_company.engine.task_engine", + fromlist=["_MutationEnvelope"], + )._MutationEnvelope(mutation=mutation1), + ) + + # Second submit should fail because queue is full + mutation2 = CreateTaskMutation( + request_id="req-2", + requested_by="alice", + task_data=_make_create_data(), + ) + with pytest.raises(TaskEngineNotRunningError, match="queue is full"): + await eng.submit(mutation2) + + eng._running = False + + +# ── Error propagation ──────────────────────────────────────── + + +@pytest.mark.unit +class TestErrorPropagation: + """Tests for error propagation via futures.""" + + async def test_persistence_error_returns_failure( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + """Persistence errors during mutation are captured in the result.""" + + class FailingSaveRepo(FakeTaskRepository): + async def save(self, task: Task) -> None: + msg = "Disk full" + raise OSError(msg) + + persistence._tasks = FailingSaveRepo() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await eng.submit(mutation) + assert result.success is False + assert "Disk full" in (result.error or "") + finally: + await eng.stop(timeout=2.0) + + +# ── TaskEngineConfig ────────────────────────────────────────── + + +@pytest.mark.unit +class TestTaskEngineConfig: + """Tests for TaskEngineConfig model.""" + + def test_defaults(self) -> None: + config = TaskEngineConfig() + assert config.max_queue_size == 1000 + assert config.drain_timeout_seconds == 10.0 + assert config.publish_snapshots is True + + def test_custom_values(self) -> None: + config = TaskEngineConfig( + max_queue_size=500, + drain_timeout_seconds=5.0, + publish_snapshots=False, + ) + assert config.max_queue_size == 500 + assert config.drain_timeout_seconds == 5.0 + assert config.publish_snapshots is False + + def test_frozen(self) -> None: + from pydantic import ValidationError + + config = TaskEngineConfig() + with pytest.raises(ValidationError): + config.max_queue_size = 999 # type: ignore[misc] diff --git a/tests/unit/engine/test_task_engine_models.py b/tests/unit/engine/test_task_engine_models.py new file mode 100644 index 0000000000..ee50199fcc --- /dev/null +++ b/tests/unit/engine/test_task_engine_models.py @@ -0,0 +1,312 @@ +"""Tests for task engine request, response, and event models.""" + +import pytest +from pydantic import ValidationError + +from ai_company.core.enums import Complexity, Priority, TaskStatus, TaskType +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskData, + CreateTaskMutation, + DeleteTaskMutation, + TaskMutationResult, + TaskStateChanged, + TransitionTaskMutation, + UpdateTaskMutation, +) + + +@pytest.mark.unit +class TestCreateTaskData: + """Tests for CreateTaskData model.""" + + def test_minimal_construction(self) -> None: + data = CreateTaskData( + title="Fix bug", + description="Fix the login bug", + type=TaskType.DEVELOPMENT, + project="proj-1", + created_by="alice", + ) + assert data.title == "Fix bug" + assert data.priority == Priority.MEDIUM + assert data.estimated_complexity == Complexity.MEDIUM + assert data.budget_limit == 0.0 + assert data.assigned_to is None + + def test_full_construction(self) -> None: + data = CreateTaskData( + title="Implement feature", + description="Add new dashboard", + type=TaskType.DEVELOPMENT, + priority=Priority.HIGH, + project="proj-2", + created_by="bob", + assigned_to="charlie", + estimated_complexity=Complexity.COMPLEX, + budget_limit=5.0, + ) + assert data.assigned_to == "charlie" + assert data.budget_limit == 5.0 + + def test_blank_title_rejected(self) -> None: + with pytest.raises(ValueError, match="must not be whitespace"): + CreateTaskData( + title=" ", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + + def test_negative_budget_rejected(self) -> None: + with pytest.raises(ValueError, match="greater than or equal to 0"): + CreateTaskData( + title="Task", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + budget_limit=-1.0, + ) + + def test_frozen(self) -> None: + data = CreateTaskData( + title="Task", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + with pytest.raises(ValidationError): + data.title = "changed" # type: ignore[misc] + + +@pytest.mark.unit +class TestCreateTaskMutation: + """Tests for CreateTaskMutation model.""" + + def test_construction(self) -> None: + data = CreateTaskData( + title="Task", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=data, + ) + assert mutation.mutation_type == "create" + assert mutation.request_id == "req-1" + assert mutation.requested_by == "alice" + + def test_mutation_type_literal(self) -> None: + data = CreateTaskData( + title="Task", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=data, + ) + assert mutation.mutation_type == "create" + + +@pytest.mark.unit +class TestUpdateTaskMutation: + """Tests for UpdateTaskMutation model.""" + + def test_construction(self) -> None: + mutation = UpdateTaskMutation( + request_id="req-2", + requested_by="bob", + task_id="task-123", + updates={"title": "New title"}, + ) + assert mutation.mutation_type == "update" + assert mutation.task_id == "task-123" + assert mutation.updates == {"title": "New title"} + assert mutation.expected_version is None + + def test_with_expected_version(self) -> None: + mutation = UpdateTaskMutation( + request_id="req-2", + requested_by="bob", + task_id="task-123", + updates={}, + expected_version=3, + ) + assert mutation.expected_version == 3 + + def test_empty_updates(self) -> None: + mutation = UpdateTaskMutation( + request_id="req-2", + requested_by="bob", + task_id="task-123", + updates={}, + ) + assert mutation.updates == {} + + def test_expected_version_must_be_positive(self) -> None: + with pytest.raises(ValueError, match="greater than or equal to 1"): + UpdateTaskMutation( + request_id="req-2", + requested_by="bob", + task_id="task-123", + updates={}, + expected_version=0, + ) + + +@pytest.mark.unit +class TestTransitionTaskMutation: + """Tests for TransitionTaskMutation model.""" + + def test_construction(self) -> None: + mutation = TransitionTaskMutation( + request_id="req-3", + requested_by="charlie", + task_id="task-456", + target_status=TaskStatus.IN_PROGRESS, + reason="Starting work", + ) + assert mutation.mutation_type == "transition" + assert mutation.target_status == TaskStatus.IN_PROGRESS + assert mutation.reason == "Starting work" + assert mutation.overrides == {} + + def test_with_overrides(self) -> None: + mutation = TransitionTaskMutation( + request_id="req-3", + requested_by="charlie", + task_id="task-456", + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "dave"}, + ) + assert mutation.overrides == {"assigned_to": "dave"} + + +@pytest.mark.unit +class TestDeleteTaskMutation: + """Tests for DeleteTaskMutation model.""" + + def test_construction(self) -> None: + mutation = DeleteTaskMutation( + request_id="req-4", + requested_by="alice", + task_id="task-789", + ) + assert mutation.mutation_type == "delete" + assert mutation.task_id == "task-789" + + +@pytest.mark.unit +class TestCancelTaskMutation: + """Tests for CancelTaskMutation model.""" + + def test_construction(self) -> None: + mutation = CancelTaskMutation( + request_id="req-5", + requested_by="bob", + task_id="task-abc", + reason="No longer needed", + ) + assert mutation.mutation_type == "cancel" + assert mutation.reason == "No longer needed" + + +@pytest.mark.unit +class TestTaskMutationResult: + """Tests for TaskMutationResult model.""" + + def test_success_result(self) -> None: + result = TaskMutationResult( + request_id="req-1", + success=True, + version=1, + ) + assert result.success is True + assert result.task is None + assert result.error is None + + def test_failure_result(self) -> None: + result = TaskMutationResult( + request_id="req-1", + success=False, + error="Not found", + ) + assert result.success is False + assert result.error == "Not found" + assert result.version == 0 + + def test_frozen(self) -> None: + result = TaskMutationResult( + request_id="req-1", + success=True, + version=1, + ) + with pytest.raises(ValidationError): + result.success = False # type: ignore[misc] + + +@pytest.mark.unit +class TestTaskStateChanged: + """Tests for TaskStateChanged event model.""" + + def test_construction(self) -> None: + event = TaskStateChanged( + mutation_type="create", + request_id="req-1", + requested_by="alice", + new_status=TaskStatus.CREATED, + version=1, + ) + assert event.mutation_type == "create" + assert event.previous_status is None + assert event.new_status == TaskStatus.CREATED + assert event.timestamp is not None + + def test_transition_event(self) -> None: + event = TaskStateChanged( + mutation_type="transition", + request_id="req-2", + requested_by="bob", + previous_status=TaskStatus.CREATED, + new_status=TaskStatus.ASSIGNED, + version=2, + ) + assert event.previous_status == TaskStatus.CREATED + assert event.new_status == TaskStatus.ASSIGNED + + def test_delete_event(self) -> None: + event = TaskStateChanged( + mutation_type="delete", + request_id="req-3", + requested_by="charlie", + version=0, + ) + assert event.task is None + assert event.previous_status is None + assert event.new_status is None + + def test_serialization_roundtrip(self) -> None: + event = TaskStateChanged( + mutation_type="create", + request_id="req-1", + requested_by="alice", + new_status=TaskStatus.CREATED, + version=1, + ) + json_str = event.model_dump_json() + restored = TaskStateChanged.model_validate_json(json_str) + assert restored.mutation_type == event.mutation_type + assert restored.request_id == event.request_id + assert restored.version == event.version diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 9e51eb994d..9d4238b3dd 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -210,6 +210,7 @@ def test_all_domain_modules_discovered(self) -> None: "security", "task", "task_assignment", + "task_engine", "task_routing", "template", "timeout", From ef4e7dfd8362067c04bc21d45032232fb05c5402 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 13:56:25 +0100 Subject: [PATCH 02/14] fix: harden TaskEngine error handling, types, and test coverage Pre-reviewed by 10 agents, 37 findings addressed: - Add exhaustive match default + typed error hierarchy (TaskNotFoundError, TaskEngineQueueFullError, TaskVersionConflictError) - Sanitize internal exception details from API responses - Add immutable field rejection validators on UpdateTaskMutation and TransitionTaskMutation - Thread previous_status through TaskMutationResult and snapshots - Add consistency model_validator to TaskMutationResult - Guard _processing_loop against unhandled exceptions - Fix startup cleanup to handle task engine failures - Replace assert with proper error handling in convenience methods - Add _fail_remaining_futures for drain timeout cleanup - Add comprehensive logging coverage (creation, conflicts, loop errors) - Add _not_found_result helper to reduce duplication - Extract _TERMINAL_STATUSES module constant - Use Self return type in model validators - Split broad except in _report_to_task_engine (TaskMutationError vs Exception) - Update docs: tech-stack Adopted, CLAUDE.md engine description, engine.md TaskEngine architecture subsection - Add tests: AppState.task_engine, _report_to_task_engine, app lifecycle, version conflicts, cancel not-found, previous_status, immutable fields, typed errors --- CLAUDE.md | 4 +- docs/architecture/tech-stack.md | 2 +- docs/design/engine.md | 55 +++++ src/ai_company/api/app.py | 18 +- src/ai_company/api/controllers/tasks.py | 49 ++-- src/ai_company/config/schema.py | 1 + src/ai_company/engine/__init__.py | 10 +- src/ai_company/engine/agent_engine.py | 46 +++- src/ai_company/engine/errors.py | 8 + src/ai_company/engine/task_engine.py | 209 +++++++++++++---- src/ai_company/engine/task_engine_models.py | 74 +++++- .../observability/events/task_engine.py | 3 + tests/unit/api/test_app.py | 46 ++++ tests/unit/api/test_state.py | 44 ++++ tests/unit/engine/test_agent_engine.py | 160 ++++++++++++- tests/unit/engine/test_task_engine.py | 222 +++++++++++++++++- 16 files changed, 845 insertions(+), 106 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 8a4df2f256..3a4a6eed15 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -92,7 +92,7 @@ src/ai_company/ communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol config/ # YAML company config loading and validation core/ # Shared domain models, base classes, and resilience config (RetryConfig, RateLimiterConfig) - engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation + engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation hr/ # HR engine: hiring, firing, onboarding, offboarding, agent registry, performance tracking (task metrics, collaboration scoring, trend detection), promotion/demotion (criteria evaluation, approval strategies, model mapping) memory/ # Persistent agent memory (Mem0 initial, custom stack future — see Decision Log), retrieval pipeline (ranking, injection, context formatting, non-inferable filtering), shared org memory (org/), consolidation/archival (consolidation/) persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (see Memory & Persistence design page) @@ -127,7 +127,7 @@ src/ai_company/ - **Every module** with business logic MUST have: `from ai_company.observability import get_logger` then `logger = get_logger(__name__)` - **Never** use `import logging` / `logging.getLogger()` / `print()` in application code - **Variable name**: always `logger` (not `_logger`, not `log`) -- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` +- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`, `TASK_ENGINE_STARTED` from `events.task_engine`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` - **Structured kwargs**: always `logger.info(EVENT, key=value)` — never `logger.info("msg %s", val)` - **All error paths** must log at WARNING or ERROR with context before raising - **All state transitions** must log at INFO diff --git a/docs/architecture/tech-stack.md b/docs/architecture/tech-stack.md index 4580351e7c..4b867d4798 100644 --- a/docs/architecture/tech-stack.md +++ b/docs/architecture/tech-stack.md @@ -119,7 +119,7 @@ These conventions are used throughout the codebase. For full details on each, se | **Cost tiers and quota tracking** | Adopted | Configurable `CostTierDefinition` with merge/override semantics. `QuotaTracker` enforces per-provider request/token quotas with window-based rotation. | | **Shared org memory** | Adopted | `OrgMemoryBackend` protocol with `HybridPromptRetrievalBackend`. Seniority-based write access control. Core policies in system prompts; extended facts retrieved on demand. | | **Memory consolidation** | Adopted | `ConsolidationStrategy` protocol with deduplication + summarization. `RetentionEnforcer` for age-based cleanup. `ArchivalStore` for cold storage. | -| **State coordination** | Planned | Centralized single-writer `TaskEngine` with `asyncio.Queue`. Agents submit requests; engine applies `model_copy(update=...)` sequentially and publishes snapshots. | +| **State coordination** | Adopted | Centralized single-writer `TaskEngine` with `asyncio.Queue`. Agents submit requests; engine applies `model_copy(update=...)` sequentially and publishes snapshots. | | **Workspace isolation** | Adopted | Pluggable `WorkspaceIsolationStrategy` protocol. Default: git worktrees with sequential merge on completion. | | **Graceful shutdown** | Adopted | Pluggable `ShutdownStrategy` protocol with cooperative 30-second timeout. Force-cancel after timeout with `INTERRUPTED` status. | | **Template inheritance** | Adopted | `extends` field triggers parent resolution at render time with deep merge by field type. Circular chain detection included. | diff --git a/docs/design/engine.md b/docs/design/engine.md index 62eafca175..409abe9039 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -166,6 +166,61 @@ exceptions on failure; scoring-based strategies return --- +## TaskEngine — Centralized State Coordination + +All task state mutations flow through a single-writer `TaskEngine` that owns the +authoritative task state. This eliminates race conditions when multiple agents +attempt concurrent transitions on the same task. + +### Architecture + +``` +Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop ──▶ Persistence + │ + ├──▶ Version tracking (optimistic concurrency) + └──▶ Snapshot publishing (MessageBus) +``` + +- **Single writer**: A background `asyncio.Task` consumes `TaskMutation` + requests sequentially from an `asyncio.Queue`. +- **Immutable updates**: Each mutation calls `model_copy(update=...)` on + frozen `Task` models — the original is never mutated. +- **Optimistic concurrency**: In-memory version counters per task. + Callers can pass `expected_version` to detect stale writes; + `TaskVersionConflictError` is raised on mismatch. +- **Read-through**: `get_task()` and `list_tasks()` bypass the queue and + read directly from persistence — safe because TaskEngine is the sole writer. +- **Snapshot publishing**: On success, a `TaskStateChanged` event is published + to the message bus for downstream consumers (WebSocket bridge, audit, etc.). + +### Mutation Types + +| Mutation | Description | +|----------|-------------| +| `CreateTaskMutation` | Generates a unique ID, persists, and returns the new task. | +| `UpdateTaskMutation` | Applies field updates with immutable-field rejection (`id`, `created_by`, `created_at`). | +| `TransitionTaskMutation` | Validates status transition via `Task.with_transition()`, supports field overrides. | +| `DeleteTaskMutation` | Removes from persistence and clears version tracking. | +| `CancelTaskMutation` | Shortcut for transition to `CANCELLED`. | + +### Error Handling + +- **Typed errors**: `TaskNotFoundError` and `TaskVersionConflictError` provide + precise failure classification — API controllers catch these directly instead + of parsing error strings. +- **Error sanitization**: Internal exception details (SQL paths, stack traces) + are replaced with a generic message before reaching callers. +- **Queue full**: `TaskEngineQueueFullError` signals backpressure when the + queue is at capacity. + +### Lifecycle + +- **start()**: Spawns the background processing task. +- **stop()**: Sets `_running = False`, drains the queue within a configurable + timeout, then cancels. Abandoned futures receive a failure result. + +--- + ## Agent Execution Loop The agent execution loop defines how an agent processes a task from start to diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 8edddcbfc9..c2d289c1c3 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -132,14 +132,24 @@ async def on_shutdown() -> None: return [on_startup], [on_shutdown] -async def _cleanup_on_failure( +async def _cleanup_on_failure( # noqa: PLR0913 *, persistence: PersistenceBackend | None, started_persistence: bool, message_bus: MessageBus | None, started_bus: bool, + task_engine: TaskEngine | None = None, + started_task_engine: bool = False, ) -> None: - """Reverse cleanup of persistence and message bus on startup failure.""" + """Reverse cleanup on startup failure (task engine, bus, persistence).""" + if started_task_engine and task_engine is not None: + try: + await task_engine.stop() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Cleanup: failed to stop task engine", + ) if started_bus and message_bus is not None: try: await message_bus.stop() @@ -214,6 +224,7 @@ async def _safe_startup( """ started_bus = False started_persistence = False + started_task_engine = False try: if persistence is not None: try: @@ -257,12 +268,15 @@ async def _safe_startup( error="Failed to start task engine", ) raise + started_task_engine = True except Exception: await _cleanup_on_failure( persistence=persistence, started_persistence=started_persistence, message_bus=message_bus, started_bus=started_bus, + task_engine=task_engine, + started_task_engine=started_task_engine, ) raise diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 463aaa7970..733f58089b 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -16,7 +16,7 @@ from ai_company.api.state import AppState # noqa: TC001 from ai_company.core.enums import TaskStatus # noqa: TC001 from ai_company.core.task import Task # noqa: TC001 -from ai_company.engine.errors import TaskMutationError +from ai_company.engine.errors import TaskMutationError, TaskNotFoundError from ai_company.engine.task_engine_models import CreateTaskData from ai_company.observability import get_logger from ai_company.observability.events.api import ( @@ -163,15 +163,13 @@ async def update_task( updates, requested_by="api", ) - except TaskMutationError as exc: - if "not found" in str(exc): - logger.warning( - API_RESOURCE_NOT_FOUND, - resource="task", - id=task_id, - ) - raise NotFoundError(str(exc)) from exc - raise + except TaskNotFoundError as exc: + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(str(exc)) from exc logger.info(API_TASK_UPDATED, task_id=task_id, fields=list(updates)) return ApiResponse(data=task) @@ -207,15 +205,15 @@ async def transition_task( reason=f"API transition to {data.target_status.value}", assigned_to=data.assigned_to, ) + except TaskNotFoundError as exc: + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(str(exc)) from exc except TaskMutationError as exc: error_str = str(exc) - if "not found" in error_str: - logger.warning( - API_RESOURCE_NOT_FOUND, - resource="task", - id=task_id, - ) - raise NotFoundError(error_str) from exc logger.warning( TASK_STATUS_CHANGED, task_id=task_id, @@ -253,15 +251,12 @@ async def delete_task( task_id, requested_by="api", ) - except TaskMutationError as exc: - if "not found" in str(exc): - msg = f"Task {task_id!r} not found" - logger.warning( - API_RESOURCE_NOT_FOUND, - resource="task", - id=task_id, - ) - raise NotFoundError(msg) from exc - raise + except TaskNotFoundError as exc: + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(str(exc)) from exc logger.info(API_TASK_DELETED, task_id=task_id) return ApiResponse(data=None) diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index d73869a67f..aa6a3e37aa 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -415,6 +415,7 @@ class RootConfig(BaseModel): security: Security subsystem configuration. trust: Progressive trust configuration. promotion: Promotion/demotion configuration. + task_engine: Task engine configuration. """ model_config = ConfigDict(frozen=True) diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index 0d6565aed9..aee8f905c6 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -1,8 +1,8 @@ """Agent execution engine. -Re-exports the public API for the agent orchestrator, run results, -system prompt construction, runtime execution state, execution loops, -and engine errors. +Re-exports the public API for the agent orchestrator, task engine, +run results, system prompt construction, runtime execution state, +execution loops, and engine errors. """ from ai_company.engine.agent_engine import AgentEngine @@ -69,7 +69,9 @@ TaskAssignmentError, TaskEngineError, TaskEngineNotRunningError, + TaskEngineQueueFullError, TaskMutationError, + TaskNotFoundError, TaskRoutingError, TaskVersionConflictError, WorkspaceCleanupError, @@ -264,10 +266,12 @@ "TaskEngineConfig", "TaskEngineError", "TaskEngineNotRunningError", + "TaskEngineQueueFullError", "TaskExecution", "TaskMutation", "TaskMutationError", "TaskMutationResult", + "TaskNotFoundError", "TaskRoutingError", "TaskRoutingService", "TaskStateChanged", diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index a032b748d4..db92679eab 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -19,7 +19,7 @@ from ai_company.engine.classification.pipeline import classify_execution_errors from ai_company.engine.context import DEFAULT_MAX_TURNS, AgentContext from ai_company.engine.cost_recording import record_execution_costs -from ai_company.engine.errors import ExecutionStateError +from ai_company.engine.errors import ExecutionStateError, TaskMutationError from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, @@ -99,6 +99,16 @@ _DEFAULT_RECOVERY_STRATEGY = FailAndReassignStrategy() """Module-level default instance for the recovery strategy.""" +_TERMINAL_STATUSES: frozenset[TaskStatus] = frozenset( + { + TaskStatus.COMPLETED, + TaskStatus.FAILED, + TaskStatus.INTERRUPTED, + TaskStatus.CANCELLED, + } +) +"""Task statuses that trigger a report to the centralized TaskEngine.""" + class AgentEngine: """Top-level orchestrator for agent execution. @@ -119,6 +129,10 @@ class AgentEngine: error_taxonomy_config: Post-execution error classification. budget_enforcer: Pre-flight checks, auto-downgrade, and enhanced in-flight budget checking. + security_config: Optional security subsystem configuration. + approval_store: Optional approval queue store. + task_engine: Optional centralized task engine for reporting + final execution status. """ def __init__( # noqa: PLR0913 @@ -339,7 +353,11 @@ async def _post_execution_pipeline( agent_id: str, task_id: str, ) -> ExecutionResult: - """Record costs, apply transitions, report to TaskEngine.""" + """Post-execution: costs, transitions, TaskEngine, recovery, classify. + + Best-effort: classification and reporting failures are logged, + never fatal. + """ await record_execution_costs( execution_result, identity, @@ -648,6 +666,9 @@ async def _report_to_task_engine( ) -> None: """Report final execution status to the centralized TaskEngine. + Only reports terminal statuses (COMPLETED, FAILED, INTERRUPTED, + CANCELLED); non-terminal statuses are silently skipped. + Best-effort: failures are logged and swallowed. If no ``TaskEngine`` is configured, this is a no-op. """ @@ -658,13 +679,7 @@ async def _report_to_task_engine( return final_status = ctx.task_execution.status - terminal = { - TaskStatus.COMPLETED, - TaskStatus.FAILED, - TaskStatus.INTERRUPTED, - TaskStatus.CANCELLED, - } - if final_status not in terminal: + if final_status not in _TERMINAL_STATUSES: return try: @@ -679,12 +694,21 @@ async def _report_to_task_engine( ) except MemoryError, RecursionError: raise - except Exception: + except TaskMutationError: logger.warning( EXECUTION_ENGINE_ERROR, agent_id=agent_id, task_id=task_id, - error="Failed to report final status to TaskEngine", + error="Failed to report final status to TaskEngine (mutation rejected)", + exc_info=True, + ) + except Exception: + logger.error( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error="Unexpected error reporting to TaskEngine" + " -- state may be divergent", exc_info=True, ) diff --git a/src/ai_company/engine/errors.py b/src/ai_company/engine/errors.py index c4101439c7..a8774210fe 100644 --- a/src/ai_company/engine/errors.py +++ b/src/ai_company/engine/errors.py @@ -90,9 +90,17 @@ class TaskEngineNotRunningError(TaskEngineError): """Raised when a mutation is submitted to a stopped task engine.""" +class TaskEngineQueueFullError(TaskEngineError): + """Raised when the task engine queue is at capacity.""" + + class TaskMutationError(TaskEngineError): """Raised when a task mutation fails (not found, validation, etc.).""" +class TaskNotFoundError(TaskMutationError): + """Raised when a task is not found during mutation.""" + + class TaskVersionConflictError(TaskMutationError): """Raised when optimistic concurrency version does not match.""" diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index 318a2a7ffa..1316d91fd7 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -5,7 +5,7 @@ ``model_copy(update=...)`` on frozen ``Task`` models, persists the result, and publishes snapshots to the message bus. -Reads bypass the queue and go directly to persistence — this is safe +Reads bypass the queue and go directly to persistence -- this is safe because the TaskEngine is the only writer. """ @@ -20,7 +20,9 @@ from ai_company.core.task import Task from ai_company.engine.errors import ( TaskEngineNotRunningError, + TaskEngineQueueFullError, TaskMutationError, + TaskNotFoundError, TaskVersionConflictError, ) from ai_company.engine.task_engine_config import TaskEngineConfig @@ -37,9 +39,11 @@ ) from ai_company.observability import get_logger from ai_company.observability.events.task_engine import ( + TASK_ENGINE_CREATED, TASK_ENGINE_DRAIN_COMPLETE, TASK_ENGINE_DRAIN_START, TASK_ENGINE_DRAIN_TIMEOUT, + TASK_ENGINE_LOOP_ERROR, TASK_ENGINE_MUTATION_APPLIED, TASK_ENGINE_MUTATION_FAILED, TASK_ENGINE_MUTATION_RECEIVED, @@ -49,6 +53,7 @@ TASK_ENGINE_SNAPSHOT_PUBLISHED, TASK_ENGINE_STARTED, TASK_ENGINE_STOPPED, + TASK_ENGINE_VERSION_CONFLICT, ) if TYPE_CHECKING: @@ -60,7 +65,11 @@ @dataclass class _MutationEnvelope: - """Pairs a mutation request with its response future.""" + """Pairs a mutation request with its response future. + + Note: must be instantiated within a running event loop (the + ``future`` default factory calls ``asyncio.get_running_loop()``). + """ mutation: TaskMutation future: asyncio.Future[TaskMutationResult] = field( @@ -98,8 +107,13 @@ def __init__( self._versions: dict[str, int] = {} self._processing_task: asyncio.Task[None] | None = None self._running = False + logger.debug( + TASK_ENGINE_CREATED, + max_queue_size=self._config.max_queue_size, + publish_snapshots=self._config.publish_snapshots, + ) - # ── Lifecycle ───────────────────────────────────────────── + # -- Lifecycle --------------------------------------------------------- def start(self) -> None: """Spawn the background processing loop. @@ -109,6 +123,7 @@ def start(self) -> None: """ if self._running: msg = "TaskEngine is already running" + logger.warning(TASK_ENGINE_STARTED, error=msg) raise RuntimeError(msg) self._running = True self._processing_task = asyncio.create_task( @@ -154,16 +169,31 @@ async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 self._processing_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._processing_task + self._fail_remaining_futures() self._processing_task = None logger.info(TASK_ENGINE_STOPPED) + def _fail_remaining_futures(self) -> None: + """Fail all remaining enqueued futures after drain timeout.""" + while not self._queue.empty(): + with contextlib.suppress(asyncio.QueueEmpty): + envelope = self._queue.get_nowait() + if not envelope.future.done(): + envelope.future.set_result( + TaskMutationResult( + request_id=envelope.mutation.request_id, + success=False, + error="TaskEngine shut down before processing", + ), + ) + @property def is_running(self) -> bool: """Whether the engine is accepting mutations.""" return self._running - # ── Submit & convenience methods ────────────────────────── + # -- Submit & convenience methods -------------------------------------- async def submit(self, mutation: TaskMutation) -> TaskMutationResult: """Submit a mutation and await its result. @@ -176,6 +206,7 @@ async def submit(self, mutation: TaskMutation) -> TaskMutationResult: Raises: TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. """ if not self._running: logger.warning( @@ -197,7 +228,7 @@ async def submit(self, mutation: TaskMutation) -> TaskMutationResult: queue_size=self._queue.qsize(), ) msg = "TaskEngine queue is full" - raise TaskEngineNotRunningError(msg) from None + raise TaskEngineQueueFullError(msg) from None return await envelope.future @@ -228,7 +259,9 @@ async def create_task( result = await self.submit(mutation) if not result.success: raise TaskMutationError(result.error or "Create failed") - assert result.task is not None # noqa: S101 + if result.task is None: + msg = "Internal error: create succeeded but task is None" + raise TaskMutationError(msg) return result.task async def update_task( @@ -252,6 +285,7 @@ async def update_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskNotFoundError: If the task is not found. TaskMutationError: If the mutation fails. """ mutation = UpdateTaskMutation( @@ -263,8 +297,10 @@ async def update_task( ) result = await self.submit(mutation) if not result.success: - raise TaskMutationError(result.error or "Update failed") - assert result.task is not None # noqa: S101 + self._raise_typed_error(result) + if result.task is None: + msg = "Internal error: update succeeded but task is None" + raise TaskMutationError(msg) return result.task async def transition_task( @@ -292,6 +328,7 @@ async def transition_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskNotFoundError: If the task is not found. TaskMutationError: If the mutation fails. """ effective_reason = reason or f"Transition to {target_status.value}" @@ -306,8 +343,10 @@ async def transition_task( ) result = await self.submit(mutation) if not result.success: - raise TaskMutationError(result.error or "Transition failed") - assert result.task is not None # noqa: S101 + self._raise_typed_error(result) + if result.task is None: + msg = "Internal error: transition succeeded but task is None" + raise TaskMutationError(msg) return result.task async def delete_task( @@ -327,6 +366,7 @@ async def delete_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskNotFoundError: If the task is not found. TaskMutationError: If the mutation fails. """ mutation = DeleteTaskMutation( @@ -336,7 +376,7 @@ async def delete_task( ) result = await self.submit(mutation) if not result.success: - raise TaskMutationError(result.error or "Delete failed") + self._raise_typed_error(result) return True async def cancel_task( @@ -358,6 +398,7 @@ async def cancel_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskNotFoundError: If the task is not found. TaskMutationError: If the mutation fails. """ mutation = CancelTaskMutation( @@ -368,11 +409,21 @@ async def cancel_task( ) result = await self.submit(mutation) if not result.success: - raise TaskMutationError(result.error or "Cancel failed") - assert result.task is not None # noqa: S101 + self._raise_typed_error(result) + if result.task is None: + msg = "Internal error: cancel succeeded but task is None" + raise TaskMutationError(msg) return result.task - # ── Read-through (bypass queue) ─────────────────────────── + @staticmethod + def _raise_typed_error(result: TaskMutationResult) -> None: + """Raise a typed error from a failed mutation result.""" + error = result.error or "Mutation failed" + if "not found" in error: + raise TaskNotFoundError(error) + raise TaskMutationError(error) + + # -- Read-through (bypass queue) --------------------------------------- async def get_task(self, task_id: str) -> Task | None: """Read a task directly from persistence (bypass queue). @@ -408,7 +459,7 @@ async def list_tasks( project=project, ) - # ── Background processing ───────────────────────────────── + # -- Background processing --------------------------------------------- async def _processing_loop(self) -> None: """Background loop: dequeue and process mutations sequentially.""" @@ -420,7 +471,13 @@ async def _processing_loop(self) -> None: ) except TimeoutError: continue - await self._process_one(envelope) + try: + await self._process_one(envelope) + except Exception: + logger.exception( + TASK_ENGINE_LOOP_ERROR, + error="Unhandled exception in processing loop", + ) async def _process_one(self, envelope: _MutationEnvelope) -> None: """Process a single mutation envelope.""" @@ -436,24 +493,28 @@ async def _process_one(self, envelope: _MutationEnvelope) -> None: if result.success and self._config.publish_snapshots: await self._publish_snapshot(mutation, result) except Exception as exc: - error_msg = f"{type(exc).__name__}: {exc}" + internal_msg = f"{type(exc).__name__}: {exc}" logger.exception( TASK_ENGINE_MUTATION_FAILED, mutation_type=mutation.mutation_type, request_id=mutation.request_id, - error=error_msg, + error=internal_msg, ) if not envelope.future.done(): envelope.future.set_result( TaskMutationResult( request_id=mutation.request_id, success=False, - error=error_msg, + error="Internal error processing mutation", ), ) async def _apply_mutation(self, mutation: TaskMutation) -> TaskMutationResult: - """Dispatch and apply a mutation by type.""" + """Dispatch and apply a mutation by type. + + Raises: + TypeError: If the mutation type is unrecognised. + """ match mutation: case CreateTaskMutation(): return await self._apply_create(mutation) @@ -465,6 +526,30 @@ async def _apply_mutation(self, mutation: TaskMutation) -> TaskMutationResult: return await self._apply_delete(mutation) case CancelTaskMutation(): return await self._apply_cancel(mutation) + case _: + msg = f"Unknown mutation type: {type(mutation).__name__}" # type: ignore[unreachable] + raise TypeError(msg) + + def _not_found_result( + self, + mutation_type: str, + request_id: str, + task_id: str, + ) -> TaskMutationResult: + """Build a failure result for a missing task and log it.""" + error = f"Task {task_id!r} not found" + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type=mutation_type, + request_id=request_id, + task_id=task_id, + error=error, + ) + return TaskMutationResult( + request_id=request_id, + success=False, + error=error, + ) async def _apply_create( self, @@ -509,14 +594,21 @@ async def _apply_update( """Update task fields.""" task = await self._persistence.tasks.get(mutation.task_id) if task is None: + return self._not_found_result( + "update", + mutation.request_id, + mutation.task_id, + ) + + try: + self._check_version(mutation.task_id, mutation.expected_version) + except TaskVersionConflictError as exc: return TaskMutationResult( request_id=mutation.request_id, success=False, - error=f"Task {mutation.task_id!r} not found", + error=str(exc), ) - self._check_version(mutation.task_id, mutation.expected_version) - if not mutation.updates: version = self._versions.get(mutation.task_id, 0) return TaskMutationResult( @@ -524,6 +616,7 @@ async def _apply_update( success=True, task=task, version=version, + previous_status=task.status, ) updated = task.model_copy(update=mutation.updates) @@ -542,6 +635,7 @@ async def _apply_update( success=True, task=updated, version=version, + previous_status=task.status, ) async def _apply_transition( @@ -551,13 +645,21 @@ async def _apply_transition( """Perform a task status transition.""" task = await self._persistence.tasks.get(mutation.task_id) if task is None: + return self._not_found_result( + "transition", + mutation.request_id, + mutation.task_id, + ) + + try: + self._check_version(mutation.task_id, mutation.expected_version) + except TaskVersionConflictError as exc: return TaskMutationResult( request_id=mutation.request_id, success=False, - error=f"Task {mutation.task_id!r} not found", + error=str(exc), ) - self._check_version(mutation.task_id, mutation.expected_version) previous_status = task.status try: @@ -566,6 +668,13 @@ async def _apply_transition( **mutation.overrides, ) except ValueError as exc: + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type="transition", + request_id=mutation.request_id, + task_id=mutation.task_id, + error=str(exc), + ) return TaskMutationResult( request_id=mutation.request_id, success=False, @@ -588,6 +697,7 @@ async def _apply_transition( success=True, task=updated, version=version, + previous_status=previous_status, ) async def _apply_delete( @@ -597,10 +707,10 @@ async def _apply_delete( """Delete a task.""" deleted = await self._persistence.tasks.delete(mutation.task_id) if not deleted: - return TaskMutationResult( - request_id=mutation.request_id, - success=False, - error=f"Task {mutation.task_id!r} not found", + return self._not_found_result( + "delete", + mutation.request_id, + mutation.task_id, ) self._versions.pop(mutation.task_id, None) @@ -624,16 +734,23 @@ async def _apply_cancel( """Cancel a task (shortcut for transition to CANCELLED).""" task = await self._persistence.tasks.get(mutation.task_id) if task is None: - return TaskMutationResult( - request_id=mutation.request_id, - success=False, - error=f"Task {mutation.task_id!r} not found", + return self._not_found_result( + "cancel", + mutation.request_id, + mutation.task_id, ) previous_status = task.status try: updated = task.with_transition(TaskStatus.CANCELLED) except ValueError as exc: + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type="cancel", + request_id=mutation.request_id, + task_id=mutation.task_id, + error=str(exc), + ) return TaskMutationResult( request_id=mutation.request_id, success=False, @@ -656,9 +773,10 @@ async def _apply_cancel( success=True, task=updated, version=version, + previous_status=previous_status, ) - # ── Snapshot publishing ─────────────────────────────────── + # -- Snapshot publishing ----------------------------------------------- async def _publish_snapshot( self, @@ -672,24 +790,27 @@ async def _publish_snapshot( if self._message_bus is None: return - new_status = ( - None - if isinstance(mutation, DeleteTaskMutation) - else (result.task.status if result.task else None) - ) + if isinstance(mutation, DeleteTaskMutation): + new_status = None + elif result.task is not None: + new_status = result.task.status + else: + new_status = None event = TaskStateChanged( mutation_type=mutation.mutation_type, request_id=mutation.request_id, requested_by=mutation.requested_by, task=result.task, - previous_status=None, + previous_status=result.previous_status, new_status=new_status, version=result.version, timestamp=datetime.now(UTC), ) try: + # Deferred to break circular import: + # communication -> engine -> communication from ai_company.communication.enums import MessageType # noqa: PLC0415 from ai_company.communication.message import Message # noqa: PLC0415 @@ -715,7 +836,7 @@ async def _publish_snapshot( exc_info=True, ) - # ── Version tracking ────────────────────────────────────── + # -- Version tracking -------------------------------------------------- def _bump_version(self, task_id: str) -> int: """Increment and return the version counter for a task.""" @@ -741,4 +862,10 @@ def _check_version( f"Version conflict for task {task_id!r}: " f"expected {expected_version}, current {current}" ) + logger.warning( + TASK_ENGINE_VERSION_CONFLICT, + task_id=task_id, + expected_version=expected_version, + current_version=current, + ) raise TaskVersionConflictError(msg) diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py index d151a53d39..86fccde0bd 100644 --- a/src/ai_company/engine/task_engine_models.py +++ b/src/ai_company/engine/task_engine_models.py @@ -6,9 +6,9 @@ """ from datetime import UTC, datetime -from typing import Literal +from typing import Literal, Self -from pydantic import AwareDatetime, BaseModel, ConfigDict, Field +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator from ai_company.core.enums import Complexity, Priority, TaskStatus, TaskType from ai_company.core.task import Task # noqa: TC001 @@ -21,7 +21,8 @@ class CreateTaskData(BaseModel): """Data required to create a new task (server-generated fields excluded). Mirrors :class:`~ai_company.api.dto.CreateTaskRequest` but lives in - the engine layer so it has no dependency on the API. + the engine layer so it has no dependency on the API (field parity is + maintained by convention, not enforced). Attributes: title: Short task title. @@ -35,7 +36,7 @@ class CreateTaskData(BaseModel): budget_limit: Maximum USD spend. """ - model_config = ConfigDict(frozen=True) + model_config = ConfigDict(frozen=True, allow_inf_nan=False) title: NotBlankStr = Field(description="Short task title") description: NotBlankStr = Field(description="Detailed task description") @@ -79,6 +80,20 @@ class CreateTaskMutation(BaseModel): task_data: CreateTaskData = Field(description="Task creation payload") +_IMMUTABLE_TASK_FIELDS: frozenset[str] = frozenset( + { + "id", + "status", + "created_by", + "created_at", + "updated_at", + "started_at", + "completed_at", + } +) +"""Fields that must not be modified via :class:`UpdateTaskMutation`.""" + + class UpdateTaskMutation(BaseModel): """Request to update task fields. @@ -87,7 +102,7 @@ class UpdateTaskMutation(BaseModel): request_id: Unique request identifier for tracing. requested_by: Identity of the requester. task_id: Target task identifier. - updates: Field-value pairs to apply. + updates: Field-value pairs to apply (immutable fields rejected). expected_version: Optional optimistic concurrency version. """ @@ -104,6 +119,25 @@ class UpdateTaskMutation(BaseModel): description="Optional optimistic concurrency version", ) + @model_validator(mode="after") + def _reject_immutable_fields(self) -> Self: + forbidden = set(self.updates) & _IMMUTABLE_TASK_FIELDS + if forbidden: + msg = f"Cannot update immutable fields: {sorted(forbidden)}" + raise ValueError(msg) + return self + + +_IMMUTABLE_OVERRIDE_FIELDS: frozenset[str] = frozenset( + { + "id", + "created_by", + "created_at", + "status", + } +) +"""Fields that must not be overridden during a transition.""" + class TransitionTaskMutation(BaseModel): """Request to perform a task status transition. @@ -115,7 +149,7 @@ class TransitionTaskMutation(BaseModel): task_id: Target task identifier. target_status: Desired target status. reason: Reason for the transition. - overrides: Additional field overrides for the transition. + overrides: Additional field overrides (immutable fields rejected). expected_version: Optional optimistic concurrency version. """ @@ -137,6 +171,14 @@ class TransitionTaskMutation(BaseModel): description="Optional optimistic concurrency version", ) + @model_validator(mode="after") + def _reject_immutable_overrides(self) -> Self: + forbidden = set(self.overrides) & _IMMUTABLE_OVERRIDE_FIELDS + if forbidden: + msg = f"Cannot override immutable fields: {sorted(forbidden)}" + raise ValueError(msg) + return self + class DeleteTaskMutation(BaseModel): """Request to delete a task. @@ -197,6 +239,8 @@ class TaskMutationResult(BaseModel): success: Whether the mutation succeeded. task: The task after mutation (``None`` on delete or failure). version: Current version counter for the task. + previous_status: Status before the mutation (``None`` on create + or failure). error: Error description (``None`` on success). """ @@ -206,8 +250,22 @@ class TaskMutationResult(BaseModel): success: bool = Field(description="Whether the mutation succeeded") task: Task | None = Field(default=None, description="Task after mutation") version: int = Field(default=0, ge=0, description="Version counter") + previous_status: TaskStatus | None = Field( + default=None, + description="Status before mutation", + ) error: str | None = Field(default=None, description="Error description") + @model_validator(mode="after") + def _check_consistency(self) -> Self: + if self.success and self.error is not None: + msg = "Successful result must not carry an error" + raise ValueError(msg) + if not self.success and self.error is None: + msg = "Failed result must carry an error description" + raise ValueError(msg) + return self + # ── State-change event ──────────────────────────────────────── @@ -228,7 +286,9 @@ class TaskStateChanged(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: str = Field(description="Mutation type that triggered event") + mutation_type: Literal["create", "update", "transition", "delete", "cancel"] = ( + Field(description="Mutation type that triggered event") + ) request_id: NotBlankStr = Field(description="Originating request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task: Task | None = Field( diff --git a/src/ai_company/observability/events/task_engine.py b/src/ai_company/observability/events/task_engine.py index 1519d281d2..a97427a64e 100644 --- a/src/ai_company/observability/events/task_engine.py +++ b/src/ai_company/observability/events/task_engine.py @@ -2,6 +2,7 @@ from typing import Final +TASK_ENGINE_CREATED: Final[str] = "task_engine.created" TASK_ENGINE_STARTED: Final[str] = "task_engine.started" TASK_ENGINE_STOPPED: Final[str] = "task_engine.stopped" TASK_ENGINE_MUTATION_RECEIVED: Final[str] = "task_engine.mutation.received" @@ -14,3 +15,5 @@ TASK_ENGINE_DRAIN_COMPLETE: Final[str] = "task_engine.drain.complete" TASK_ENGINE_DRAIN_TIMEOUT: Final[str] = "task_engine.drain.timeout" TASK_ENGINE_NOT_RUNNING: Final[str] = "task_engine.not_running" +TASK_ENGINE_VERSION_CONFLICT: Final[str] = "task_engine.version.conflict" +TASK_ENGINE_LOOP_ERROR: Final[str] = "task_engine.loop.error" diff --git a/tests/unit/api/test_app.py b/tests/unit/api/test_app.py index 114dfb89db..4a10fadbec 100644 --- a/tests/unit/api/test_app.py +++ b/tests/unit/api/test_app.py @@ -82,3 +82,49 @@ async def failing_disconnect() -> None: # Should not raise even when disconnect fails await _safe_shutdown(None, None, None, persistence) + + async def test_task_engine_failure_cleans_up( + self, + root_config: Any, + ) -> None: + """Task engine start fails → persistence + bus cleaned up.""" + from unittest.mock import MagicMock + + from ai_company.api.app import _safe_startup + from ai_company.api.approval_store import ApprovalStore + from ai_company.api.state import AppState + from tests.unit.api.conftest import ( + FakeMessageBus, + FakePersistenceBackend, + ) + + persistence = FakePersistenceBackend() + bus = FakeMessageBus() + mock_te = MagicMock() + mock_te.start = MagicMock(side_effect=RuntimeError("engine boom")) + mock_te.stop = MagicMock() + + app_state = AppState( + config=root_config, + approval_store=ApprovalStore(), + persistence=persistence, + ) + + with pytest.raises(RuntimeError, match="engine boom"): + await _safe_startup(persistence, bus, None, mock_te, app_state) + + # Persistence and bus should be cleaned up + assert not persistence.is_connected + assert not bus.is_running + + async def test_shutdown_task_engine_failure_does_not_propagate(self) -> None: + """Task engine stop failure during shutdown is logged, not raised.""" + from unittest.mock import AsyncMock, MagicMock + + from ai_company.api.app import _safe_shutdown + + mock_te = MagicMock() + mock_te.stop = AsyncMock(side_effect=RuntimeError("stop boom")) + + # Should not raise even when task engine stop fails + await _safe_shutdown(None, mock_te, None, None) diff --git a/tests/unit/api/test_state.py b/tests/unit/api/test_state.py index 792907cead..404a13d51a 100644 --- a/tests/unit/api/test_state.py +++ b/tests/unit/api/test_state.py @@ -90,3 +90,47 @@ def test_set_auth_service_twice_raises(self) -> None: state = _make_state(auth_service=svc) with pytest.raises(RuntimeError, match="already configured"): state.set_auth_service(svc) + + +@pytest.mark.unit +class TestAppStateTaskEngine: + """Tests for task_engine property, has_task_engine, set_task_engine.""" + + def test_task_engine_raises_when_none(self) -> None: + state = _make_state(task_engine=None) + with pytest.raises(ServiceUnavailableError): + _ = state.task_engine + + def test_task_engine_returns_when_set(self) -> None: + from unittest.mock import MagicMock + + engine = MagicMock() + state = _make_state(task_engine=engine) + assert state.task_engine is engine + + def test_has_task_engine_false_when_none(self) -> None: + state = _make_state(task_engine=None) + assert state.has_task_engine is False + + def test_has_task_engine_true_when_set(self) -> None: + from unittest.mock import MagicMock + + engine = MagicMock() + state = _make_state(task_engine=engine) + assert state.has_task_engine is True + + def test_set_task_engine_succeeds_once(self) -> None: + from unittest.mock import MagicMock + + engine = MagicMock() + state = _make_state() + state.set_task_engine(engine) + assert state.task_engine is engine + + def test_set_task_engine_twice_raises(self) -> None: + from unittest.mock import MagicMock + + engine = MagicMock() + state = _make_state(task_engine=engine) + with pytest.raises(RuntimeError, match="already configured"): + state.set_task_engine(engine) diff --git a/tests/unit/engine/test_agent_engine.py b/tests/unit/engine/test_agent_engine.py index 6ff75b0839..47912d75d0 100644 --- a/tests/unit/engine/test_agent_engine.py +++ b/tests/unit/engine/test_agent_engine.py @@ -14,7 +14,7 @@ from ai_company.core.task import Task from ai_company.engine.agent_engine import AgentEngine from ai_company.engine.context import AgentContext -from ai_company.engine.errors import ExecutionStateError +from ai_company.engine.errors import ExecutionStateError, TaskMutationError from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, @@ -931,3 +931,161 @@ async def test_prompt_token_ratio_warning( # noqa: PLR0913 assert "prompt_token_ratio" in warning_events[0] else: assert len(warning_events) == 0 + + +@pytest.mark.unit +class TestReportToTaskEngine: + """Tests for _report_to_task_engine interaction.""" + + async def test_no_task_engine_is_noop( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Without task_engine, run() succeeds and no reporting occurs.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + engine = AgentEngine(provider=provider, task_engine=None) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.is_success is True + + async def test_nonterminal_status_skipped( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Non-terminal task status does not trigger TaskEngine call.""" + # Build a mock loop that returns IN_PROGRESS (non-terminal) + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + ) + mock_result = ExecutionResult( + context=ctx, + termination_reason=TerminationReason.MAX_TURNS, + turns=( + TurnRecord( + turn_number=1, + input_tokens=10, + output_tokens=5, + cost_usd=0.001, + finish_reason=FinishReason.STOP, + ), + ), + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock() + + provider = mock_provider_factory([]) + engine = AgentEngine( + provider=provider, + execution_loop=mock_loop, + task_engine=mock_te, + recovery_strategy=None, + ) + + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + mock_te.transition_task.assert_not_awaited() + + async def test_terminal_status_reported( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """COMPLETED status is reported to TaskEngine.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock() + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.is_success is True + mock_te.transition_task.assert_awaited_once() + call_args = mock_te.transition_task.call_args + assert call_args.args[1] == TaskStatus.COMPLETED + + async def test_mutation_error_swallowed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """TaskMutationError from TaskEngine is logged and swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock( + side_effect=TaskMutationError("rejected"), + ) + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + # Run still succeeds despite task engine failure + assert result.is_success is True + + async def test_unexpected_error_swallowed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Unexpected Exception from TaskEngine is logged and swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock( + side_effect=RuntimeError("connection lost"), + ) + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + # Run still succeeds despite task engine failure + assert result.is_success is True diff --git a/tests/unit/engine/test_task_engine.py b/tests/unit/engine/test_task_engine.py index 6db8cbd8f7..03eda54607 100644 --- a/tests/unit/engine/test_task_engine.py +++ b/tests/unit/engine/test_task_engine.py @@ -12,14 +12,18 @@ from ai_company.core.task import Task # noqa: TC001 from ai_company.engine.errors import ( TaskEngineNotRunningError, + TaskEngineQueueFullError, TaskMutationError, + TaskNotFoundError, ) from ai_company.engine.task_engine import TaskEngine from ai_company.engine.task_engine_config import TaskEngineConfig from ai_company.engine.task_engine_models import ( + CancelTaskMutation, CreateTaskData, CreateTaskMutation, DeleteTaskMutation, + TransitionTaskMutation, UpdateTaskMutation, ) @@ -610,8 +614,8 @@ async def test_snapshot_published_on_create( _make_create_data(), requested_by="alice", ) - # Give the processing loop time to publish - await asyncio.sleep(0.1) + # Yield to event loop so the processing loop completes snapshot publication + await asyncio.sleep(0) assert len(message_bus.published) == 1 async def test_snapshot_publish_failure_does_not_affect_mutation( @@ -655,7 +659,7 @@ async def test_no_snapshot_when_disabled( _make_create_data(), requested_by="alice", ) - await asyncio.sleep(0.1) + await asyncio.sleep(0) assert len(message_bus.published) == 0 finally: await eng.stop(timeout=2.0) @@ -736,6 +740,8 @@ async def test_queue_full_raises( self, persistence: FakePersistence, ) -> None: + from ai_company.engine.task_engine import _MutationEnvelope + tiny_config = TaskEngineConfig(max_queue_size=1) eng = TaskEngine( persistence=persistence, # type: ignore[arg-type] @@ -750,12 +756,7 @@ async def test_queue_full_raises( requested_by="alice", task_data=_make_create_data(), ) - eng._queue.put_nowait( - __import__( - "ai_company.engine.task_engine", - fromlist=["_MutationEnvelope"], - )._MutationEnvelope(mutation=mutation1), - ) + eng._queue.put_nowait(_MutationEnvelope(mutation=mutation1)) # Second submit should fail because queue is full mutation2 = CreateTaskMutation( @@ -763,7 +764,7 @@ async def test_queue_full_raises( requested_by="alice", task_data=_make_create_data(), ) - with pytest.raises(TaskEngineNotRunningError, match="queue is full"): + with pytest.raises(TaskEngineQueueFullError, match="queue is full"): await eng.submit(mutation2) eng._running = False @@ -802,7 +803,7 @@ async def save(self, task: Task) -> None: ) result = await eng.submit(mutation) assert result.success is False - assert "Disk full" in (result.error or "") + assert result.error == "Internal error processing mutation" finally: await eng.stop(timeout=2.0) @@ -836,3 +837,202 @@ def test_frozen(self) -> None: config = TaskEngineConfig() with pytest.raises(ValidationError): config.max_queue_size = 999 # type: ignore[misc] + + +# -- Version conflict on transition ──────────────────────────── + + +@pytest.mark.unit +class TestVersionConflictOnTransition: + """Version conflict detection on transition mutations.""" + + async def test_transition_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + expected_version=99, + ) + result = await engine.submit(mutation) + assert result.success is False + assert "conflict" in (result.error or "").lower() + + +# -- Cancel not found ───────────────────────────────────────── + + +@pytest.mark.unit +class TestCancelNotFound: + """Cancel mutation on a non-existent task.""" + + async def test_cancel_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError, match="not found"): + await engine.cancel_task( + "task-nonexistent", + requested_by="alice", + reason="test", + ) + + +# -- Previous status in results ──────────────────────────────── + + +@pytest.mark.unit +class TestPreviousStatus: + """Verify previous_status is populated in mutation results.""" + + async def test_create_has_no_previous_status( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status is None + + async def test_transition_has_previous_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status == TaskStatus.CREATED + + async def test_cancel_has_previous_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # First move to ASSIGNED so cancel is valid + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + mutation = CancelTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + reason="No longer needed", + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status == TaskStatus.ASSIGNED + + +# -- Immutable field rejection ───────────────────────────────── + + +@pytest.mark.unit +class TestImmutableFieldRejection: + """UpdateTaskMutation and TransitionTaskMutation reject immutable fields.""" + + def test_update_rejects_status(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"status": "completed"}, + ) + + def test_update_rejects_id(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"id": "new-id"}, + ) + + def test_transition_rejects_id_override(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + target_status=TaskStatus.ASSIGNED, + reason="test", + overrides={"id": "new-id"}, + ) + + +# -- Typed error propagation ────────────────────────────────── + + +@pytest.mark.unit +class TestTypedErrors: + """Convenience methods raise typed errors.""" + + async def test_update_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.update_task( + "task-nonexistent", + {"title": "X"}, + requested_by="alice", + ) + + async def test_delete_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.delete_task( + "task-nonexistent", + requested_by="alice", + ) + + async def test_transition_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.transition_task( + "task-nonexistent", + TaskStatus.ASSIGNED, + requested_by="alice", + reason="test", + ) From 03f14d483cfdb2813910a1dd838fe29487ca3197 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 15:10:47 +0100 Subject: [PATCH 03/14] fix: address 24 PR review items from local agents and external reviewers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add error_code discriminator to TaskMutationResult (not_found/version_conflict/validation/internal) - Fix _raise_typed_error to use error_code match instead of fragile string matching - Fix _processing_loop outer catch to resolve envelope future (prevents caller deadlock) - Guard happy-path set_result in _process_one with done() check - Fix _apply_update to use Task.model_validate() instead of model_copy() (runs validators) - Clean _IMMUTABLE_TASK_FIELDS: remove 4 non-existent timestamp fields (only id/status/created_by) - Rename _TERMINAL_STATUSES → _REPORTABLE_STATUSES (FAILED/INTERRUPTED are not strictly terminal) - Fix assigned_to=None override bug in transition_task controller - Add from_status to task transition audit log - Fix failure log to use API_TASK_TRANSITION_FAILED event instead of TASK_STATUS_CHANGED - Map TaskEngineNotRunningError/TaskEngineQueueFullError → ServiceUnavailableError (503) - Fix create_task missing error handling - Fix _on_expire broad exception handler to re-raise MemoryError/RecursionError - Add MemoryError/RecursionError re-raise to _publish_snapshot - Add API_TASK_TRANSITION_FAILED event constant - Fix AppState docstring and set_task_engine docstring - Add version_conflict test to TestTypedErrors - Add TestDrainTimeout: verify abandoned futures resolved on _fail_remaining_futures - Add TestMutationResultConsistency: validate success/error invariants enforced by Pydantic - Add test_memory_error_propagates to TestReportToTaskEngine - Fix engine.md: text lang specifier, version conflict description, _IMMUTABLE_TASK_FIELDS, asyncio.wait --- docs/design/engine.md | 12 ++- src/ai_company/api/app.py | 4 +- src/ai_company/api/controllers/tasks.py | 58 +++++++++-- src/ai_company/api/state.py | 7 +- src/ai_company/engine/agent_engine.py | 15 ++- src/ai_company/engine/task_engine.py | 39 +++++++- src/ai_company/engine/task_engine_models.py | 19 +++- src/ai_company/observability/events/api.py | 1 + tests/unit/engine/test_agent_engine.py | 26 +++++ tests/unit/engine/test_task_engine.py | 102 ++++++++++++++++++++ 10 files changed, 249 insertions(+), 34 deletions(-) diff --git a/docs/design/engine.md b/docs/design/engine.md index 409abe9039..6e40d84d5d 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -174,7 +174,7 @@ attempt concurrent transitions on the same task. ### Architecture -``` +```text Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop ──▶ Persistence │ ├──▶ Version tracking (optimistic concurrency) @@ -186,8 +186,10 @@ Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop - **Immutable updates**: Each mutation calls `model_copy(update=...)` on frozen `Task` models — the original is never mutated. - **Optimistic concurrency**: In-memory version counters per task. - Callers can pass `expected_version` to detect stale writes; - `TaskVersionConflictError` is raised on mismatch. + Callers can pass `expected_version` to detect stale writes; on mismatch + the engine returns a failed `TaskMutationResult` with + `error_code="version_conflict"`. Convenience methods raise + `TaskVersionConflictError`. - **Read-through**: `get_task()` and `list_tasks()` bypass the queue and read directly from persistence — safe because TaskEngine is the sole writer. - **Snapshot publishing**: On success, a `TaskStateChanged` event is published @@ -198,7 +200,7 @@ Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop | Mutation | Description | |----------|-------------| | `CreateTaskMutation` | Generates a unique ID, persists, and returns the new task. | -| `UpdateTaskMutation` | Applies field updates with immutable-field rejection (`id`, `created_by`, `created_at`). | +| `UpdateTaskMutation` | Applies field updates with immutable-field rejection (`id`, `status`, `created_by`) and re-validates via `model_validate`. | | `TransitionTaskMutation` | Validates status transition via `Task.with_transition()`, supports field overrides. | | `DeleteTaskMutation` | Removes from persistence and clears version tracking. | | `CancelTaskMutation` | Shortcut for transition to `CANCELLED`. | @@ -401,7 +403,7 @@ async run( alone when no enforcer is configured. 8. **Delegate to loop** -- calls `ExecutionLoop.execute()` with context, provider, tool invoker, budget checker, and completion config. If - `timeout_seconds` is set, wraps the call in `asyncio.wait_for`; on expiry + `timeout_seconds` is set, wraps the call in `asyncio.wait`; on expiry the run returns with `TerminationReason.ERROR` but cost recording and post-execution processing still occur. 9. **Record costs** -- records accumulated `TokenUsage` to `CostTracker` (if diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index c2d289c1c3..33794d9f72 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -88,7 +88,9 @@ def _on_expire(item: ApprovalItem) -> None: event.model_dump_json(), channels=[CHANNEL_APPROVALS], ) - except RuntimeError, OSError: + except MemoryError, RecursionError: + raise + except Exception: logger.warning( API_APPROVAL_PUBLISH_FAILED, approval_id=item.id, diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 733f58089b..b4847373b9 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -10,18 +10,28 @@ TransitionTaskRequest, UpdateTaskRequest, ) -from ai_company.api.errors import ApiValidationError, NotFoundError +from ai_company.api.errors import ( + ApiValidationError, + NotFoundError, + ServiceUnavailableError, +) from ai_company.api.guards import require_read_access, require_write_access from ai_company.api.pagination import PaginationLimit, PaginationOffset, paginate from ai_company.api.state import AppState # noqa: TC001 from ai_company.core.enums import TaskStatus # noqa: TC001 from ai_company.core.task import Task # noqa: TC001 -from ai_company.engine.errors import TaskMutationError, TaskNotFoundError +from ai_company.engine.errors import ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskMutationError, + TaskNotFoundError, +) from ai_company.engine.task_engine_models import CreateTaskData from ai_company.observability import get_logger from ai_company.observability.events.api import ( API_RESOURCE_NOT_FOUND, API_TASK_DELETED, + API_TASK_TRANSITION_FAILED, API_TASK_UPDATED, ) from ai_company.observability.events.task import ( @@ -124,10 +134,17 @@ async def create_task( estimated_complexity=data.estimated_complexity, budget_limit=data.budget_limit, ) - task = await app_state.task_engine.create_task( - task_data, - requested_by=data.created_by, - ) + try: + task = await app_state.task_engine.create_task( + task_data, + requested_by=data.created_by, + ) + except TaskEngineNotRunningError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskEngineQueueFullError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskMutationError as exc: + raise ApiValidationError(str(exc)) from exc logger.info( TASK_CREATED, task_id=task.id, @@ -163,6 +180,10 @@ async def update_task( updates, requested_by="api", ) + except TaskEngineNotRunningError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskEngineQueueFullError as exc: + raise ServiceUnavailableError(str(exc)) from exc except TaskNotFoundError as exc: logger.warning( API_RESOURCE_NOT_FOUND, @@ -170,6 +191,8 @@ async def update_task( id=task_id, ) raise NotFoundError(str(exc)) from exc + except TaskMutationError as exc: + raise ApiValidationError(str(exc)) from exc logger.info(API_TASK_UPDATED, task_id=task_id, fields=list(updates)) return ApiResponse(data=task) @@ -197,14 +220,24 @@ async def transition_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state + current_task = await app_state.task_engine.get_task(task_id) + from_status = current_task.status if current_task else None + transition_kwargs: dict[str, object] = { + "requested_by": "api", + "reason": f"API transition to {data.target_status.value}", + } + if data.assigned_to is not None: + transition_kwargs["assigned_to"] = data.assigned_to try: task = await app_state.task_engine.transition_task( task_id, data.target_status, - requested_by="api", - reason=f"API transition to {data.target_status.value}", - assigned_to=data.assigned_to, + **transition_kwargs, # type: ignore[arg-type] ) + except TaskEngineNotRunningError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskEngineQueueFullError as exc: + raise ServiceUnavailableError(str(exc)) from exc except TaskNotFoundError as exc: logger.warning( API_RESOURCE_NOT_FOUND, @@ -215,7 +248,7 @@ async def transition_task( except TaskMutationError as exc: error_str = str(exc) logger.warning( - TASK_STATUS_CHANGED, + API_TASK_TRANSITION_FAILED, task_id=task_id, error=error_str, ) @@ -223,6 +256,7 @@ async def transition_task( logger.info( TASK_STATUS_CHANGED, task_id=task_id, + from_status=from_status.value if from_status else None, to_status=task.status.value, ) return ApiResponse(data=task) @@ -251,6 +285,10 @@ async def delete_task( task_id, requested_by="api", ) + except TaskEngineNotRunningError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskEngineQueueFullError as exc: + raise ServiceUnavailableError(str(exc)) from exc except TaskNotFoundError as exc: logger.warning( API_RESOURCE_NOT_FOUND, diff --git a/src/ai_company/api/state.py b/src/ai_company/api/state.py index cd113f4619..9a8ebe847c 100644 --- a/src/ai_company/api/state.py +++ b/src/ai_company/api/state.py @@ -23,8 +23,8 @@ class AppState: """Typed application state container. Service fields (``persistence``, ``message_bus``, ``cost_tracker``, - ``auth_service``) accept ``None`` at construction time for dev/test - mode. Property + ``auth_service``, ``task_engine``) accept ``None`` at construction + time for dev/test mode. Property accessors raise ``ServiceUnavailableError`` (HTTP 503) when the service is not configured, producing a clear error instead of an opaque ``AttributeError``. @@ -116,7 +116,8 @@ def has_task_engine(self) -> bool: def set_task_engine(self, engine: TaskEngine) -> None: """Set the task engine (deferred initialisation). - Called once during startup after persistence is connected. + Supports late binding when the task engine is created after + ``AppState`` construction. Args: engine: Fully configured task engine. diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index db92679eab..8c97b3edaf 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -99,7 +99,7 @@ _DEFAULT_RECOVERY_STRATEGY = FailAndReassignStrategy() """Module-level default instance for the recovery strategy.""" -_TERMINAL_STATUSES: frozenset[TaskStatus] = frozenset( +_REPORTABLE_STATUSES: frozenset[TaskStatus] = frozenset( { TaskStatus.COMPLETED, TaskStatus.FAILED, @@ -107,7 +107,12 @@ TaskStatus.CANCELLED, } ) -"""Task statuses that trigger a report to the centralized TaskEngine.""" +"""Final execution outcomes that trigger a report to the centralized TaskEngine. + +Note: ``FAILED`` and ``INTERRUPTED`` are not strictly terminal in the task +lifecycle (they can be reassigned), but represent final outcomes of this +``AgentEngine`` run that should be reported. +""" class AgentEngine: @@ -666,8 +671,8 @@ async def _report_to_task_engine( ) -> None: """Report final execution status to the centralized TaskEngine. - Only reports terminal statuses (COMPLETED, FAILED, INTERRUPTED, - CANCELLED); non-terminal statuses are silently skipped. + Only reports final execution outcomes (COMPLETED, FAILED, + INTERRUPTED, CANCELLED); other statuses are silently skipped. Best-effort: failures are logged and swallowed. If no ``TaskEngine`` is configured, this is a no-op. @@ -679,7 +684,7 @@ async def _report_to_task_engine( return final_status = ctx.task_execution.status - if final_status not in _TERMINAL_STATUSES: + if final_status not in _REPORTABLE_STATUSES: return try: diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index 1316d91fd7..f1997e0918 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -249,6 +249,7 @@ async def create_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. TaskMutationError: If the mutation fails. """ mutation = CreateTaskMutation( @@ -285,7 +286,9 @@ async def update_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. TaskNotFoundError: If the task is not found. + TaskVersionConflictError: If ``expected_version`` doesn't match. TaskMutationError: If the mutation fails. """ mutation = UpdateTaskMutation( @@ -328,7 +331,9 @@ async def transition_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. TaskNotFoundError: If the task is not found. + TaskVersionConflictError: If ``expected_version`` doesn't match. TaskMutationError: If the mutation fails. """ effective_reason = reason or f"Transition to {target_status.value}" @@ -366,6 +371,7 @@ async def delete_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. TaskNotFoundError: If the task is not found. TaskMutationError: If the mutation fails. """ @@ -398,6 +404,7 @@ async def cancel_task( Raises: TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. TaskNotFoundError: If the task is not found. TaskMutationError: If the mutation fails. """ @@ -419,9 +426,13 @@ async def cancel_task( def _raise_typed_error(result: TaskMutationResult) -> None: """Raise a typed error from a failed mutation result.""" error = result.error or "Mutation failed" - if "not found" in error: - raise TaskNotFoundError(error) - raise TaskMutationError(error) + match result.error_code: + case "not_found": + raise TaskNotFoundError(error) + case "version_conflict": + raise TaskVersionConflictError(error) + case _: + raise TaskMutationError(error) # -- Read-through (bypass queue) --------------------------------------- @@ -478,6 +489,15 @@ async def _processing_loop(self) -> None: TASK_ENGINE_LOOP_ERROR, error="Unhandled exception in processing loop", ) + if not envelope.future.done(): + envelope.future.set_result( + TaskMutationResult( + request_id=envelope.mutation.request_id, + success=False, + error="Internal error in processing loop", + error_code="internal", + ), + ) async def _process_one(self, envelope: _MutationEnvelope) -> None: """Process a single mutation envelope.""" @@ -489,7 +509,8 @@ async def _process_one(self, envelope: _MutationEnvelope) -> None: ) try: result = await self._apply_mutation(mutation) - envelope.future.set_result(result) + if not envelope.future.done(): + envelope.future.set_result(result) if result.success and self._config.publish_snapshots: await self._publish_snapshot(mutation, result) except Exception as exc: @@ -506,6 +527,7 @@ async def _process_one(self, envelope: _MutationEnvelope) -> None: request_id=mutation.request_id, success=False, error="Internal error processing mutation", + error_code="internal", ), ) @@ -549,6 +571,7 @@ def _not_found_result( request_id=request_id, success=False, error=error, + error_code="not_found", ) async def _apply_create( @@ -607,6 +630,7 @@ async def _apply_update( request_id=mutation.request_id, success=False, error=str(exc), + error_code="version_conflict", ) if not mutation.updates: @@ -619,7 +643,9 @@ async def _apply_update( previous_status=task.status, ) - updated = task.model_copy(update=mutation.updates) + merged = task.model_dump() + merged.update(mutation.updates) + updated = Task.model_validate(merged) await self._persistence.tasks.save(updated) version = self._bump_version(mutation.task_id) @@ -658,6 +684,7 @@ async def _apply_transition( request_id=mutation.request_id, success=False, error=str(exc), + error_code="version_conflict", ) previous_status = task.status @@ -828,6 +855,8 @@ async def _publish_snapshot( mutation_type=mutation.mutation_type, request_id=mutation.request_id, ) + except MemoryError, RecursionError: + raise except Exception: logger.warning( TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED, diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py index 86fccde0bd..c4429e155c 100644 --- a/src/ai_company/engine/task_engine_models.py +++ b/src/ai_company/engine/task_engine_models.py @@ -85,13 +85,14 @@ class CreateTaskMutation(BaseModel): "id", "status", "created_by", - "created_at", - "updated_at", - "started_at", - "completed_at", } ) -"""Fields that must not be modified via :class:`UpdateTaskMutation`.""" +"""Fields that must not be modified via :class:`UpdateTaskMutation`. + +``status`` must go through :class:`TransitionTaskMutation` (which +validates the state machine); ``id`` and ``created_by`` are identity +fields set at creation time. +""" class UpdateTaskMutation(BaseModel): @@ -242,6 +243,8 @@ class TaskMutationResult(BaseModel): previous_status: Status before the mutation (``None`` on create or failure). error: Error description (``None`` on success). + error_code: Machine-readable error classification for reliable + dispatch (``None`` on success). """ model_config = ConfigDict(frozen=True) @@ -255,6 +258,12 @@ class TaskMutationResult(BaseModel): description="Status before mutation", ) error: str | None = Field(default=None, description="Error description") + error_code: ( + Literal["not_found", "version_conflict", "validation", "internal"] | None + ) = Field( + default=None, + description="Machine-readable error classification", + ) @model_validator(mode="after") def _check_consistency(self) -> Self: diff --git a/src/ai_company/observability/events/api.py b/src/ai_company/observability/events/api.py index 8890cd7538..67aecd48cb 100644 --- a/src/ai_company/observability/events/api.py +++ b/src/ai_company/observability/events/api.py @@ -35,3 +35,4 @@ API_AUTH_TOKEN_ISSUED: Final[str] = "api.auth.token_issued" # noqa: S105 API_AUTH_SETUP_COMPLETE: Final[str] = "api.auth.setup_complete" API_AUTH_PASSWORD_CHANGED: Final[str] = "api.auth.password_changed" # noqa: S105 +API_TASK_TRANSITION_FAILED: Final[str] = "api.task.transition_failed" diff --git a/tests/unit/engine/test_agent_engine.py b/tests/unit/engine/test_agent_engine.py index 47912d75d0..75da927abc 100644 --- a/tests/unit/engine/test_agent_engine.py +++ b/tests/unit/engine/test_agent_engine.py @@ -1089,3 +1089,29 @@ async def test_unexpected_error_swallowed( # Run still succeeds despite task engine failure assert result.is_success is True + + async def test_memory_error_propagates( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """MemoryError from TaskEngine is re-raised, not swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock( + side_effect=MemoryError("out of memory"), + ) + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + with pytest.raises(MemoryError, match="out of memory"): + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) diff --git a/tests/unit/engine/test_task_engine.py b/tests/unit/engine/test_task_engine.py index 03eda54607..8c86f0baa3 100644 --- a/tests/unit/engine/test_task_engine.py +++ b/tests/unit/engine/test_task_engine.py @@ -15,6 +15,7 @@ TaskEngineQueueFullError, TaskMutationError, TaskNotFoundError, + TaskVersionConflictError, ) from ai_company.engine.task_engine import TaskEngine from ai_company.engine.task_engine_config import TaskEngineConfig @@ -1036,3 +1037,104 @@ async def test_transition_not_found_raises_typed( requested_by="alice", reason="test", ) + + async def test_update_version_conflict_raises_typed( + self, + engine: TaskEngine, + ) -> None: + """Version conflict via convenience method raises TaskVersionConflictError.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskVersionConflictError, match="conflict"): + await engine.update_task( + task.id, + {"title": "changed"}, + requested_by="alice", + expected_version=99, + ) + + +# -- Drain timeout / _fail_remaining_futures ────────────────── + + +@pytest.mark.unit +class TestDrainTimeout: + """Verify _fail_remaining_futures resolves abandoned futures.""" + + async def test_abandoned_futures_resolved_on_drain_timeout( + self, + persistence: FakePersistence, + ) -> None: + """Futures left after drain timeout get failure results.""" + from ai_company.engine.task_engine import _MutationEnvelope + + config = TaskEngineConfig(drain_timeout_seconds=0.01) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + + # Pause processing by filling with a slow mutation + mutation = CreateTaskMutation( + request_id="req-slow", + requested_by="alice", + task_data=_make_create_data(), + ) + # Submit one that will process normally + await eng.submit(mutation) + + # Now stop the engine — force a very short drain + # Submit directly to queue to avoid await + mutation2 = CreateTaskMutation( + request_id="req-abandoned", + requested_by="alice", + task_data=_make_create_data(), + ) + envelope = _MutationEnvelope(mutation=mutation2) + eng._queue.put_nowait(envelope) + eng._running = False + + # Call _fail_remaining_futures directly + eng._fail_remaining_futures() + assert envelope.future.done() + result = envelope.future.result() + assert result.success is False + assert "shut down" in (result.error or "") + + await eng.stop(timeout=0.1) + + +# -- TaskMutationResult consistency ──────────────────────────── + + +@pytest.mark.unit +class TestMutationResultConsistency: + """Verify _check_consistency validator on TaskMutationResult.""" + + def test_success_with_error_rejected(self) -> None: + """Successful result must not carry an error.""" + from pydantic import ValidationError + + from ai_company.engine.task_engine_models import TaskMutationResult + + with pytest.raises(ValidationError, match="error"): + TaskMutationResult( + request_id="r", + success=True, + error="oops", + ) + + def test_failure_without_error_rejected(self) -> None: + """Failed result must carry an error description.""" + from pydantic import ValidationError + + from ai_company.engine.task_engine_models import TaskMutationResult + + with pytest.raises(ValidationError, match="error"): + TaskMutationResult( + request_id="r", + success=False, + ) From c92fed88db96b4b1c3afd608b05be492261c601f Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:33:39 +0100 Subject: [PATCH 04/14] fix: harden TaskEngine error handling, test structure, and API correctness - Add TaskInternalError for internal engine faults (maps to 5xx vs 4xx) - Thread error_code through all TaskMutationResult failure paths - Extend _raise_typed_error to cover 'internal' code with -> Never return - Add TaskInternalError handling in all TaskController endpoints - Fix bridge leak in _cleanup_on_failure (started_bridge flag was missing) - Remove phantom 'created_at' from _IMMUTABLE_OVERRIDE_FIELDS (field absent on Task) - Change transition_task to return tuple[Task, TaskStatus | None] (eliminates extra get_task round-trip in controller) - Split 1140-line test_task_engine.py into three focused files + helpers - Move TaskEngine fixtures from helpers to conftest.py (auto-discovery, no F401 hacks) - Fix all ruff/mypy issues in new test files (TC001, I001, F811, E501, unused-ignore) --- src/ai_company/api/app.py | 16 +- src/ai_company/api/controllers/tasks.py | 15 +- src/ai_company/engine/__init__.py | 2 + src/ai_company/engine/agent_engine.py | 2 +- src/ai_company/engine/errors.py | 10 + src/ai_company/engine/task_engine.py | 17 +- src/ai_company/engine/task_engine_models.py | 1 - tests/unit/engine/conftest.py | 58 +- tests/unit/engine/task_engine_helpers.py | 104 ++ tests/unit/engine/test_task_engine.py | 1140 ----------------- .../engine/test_task_engine_integration.py | 326 +++++ .../unit/engine/test_task_engine_lifecycle.py | 109 ++ .../unit/engine/test_task_engine_mutations.py | 583 +++++++++ 13 files changed, 1231 insertions(+), 1152 deletions(-) create mode 100644 tests/unit/engine/task_engine_helpers.py delete mode 100644 tests/unit/engine/test_task_engine.py create mode 100644 tests/unit/engine/test_task_engine_integration.py create mode 100644 tests/unit/engine/test_task_engine_lifecycle.py create mode 100644 tests/unit/engine/test_task_engine_mutations.py diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 33794d9f72..13428569f8 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -140,10 +140,12 @@ async def _cleanup_on_failure( # noqa: PLR0913 started_persistence: bool, message_bus: MessageBus | None, started_bus: bool, + bridge: MessageBusBridge | None = None, + started_bridge: bool = False, task_engine: TaskEngine | None = None, started_task_engine: bool = False, ) -> None: - """Reverse cleanup on startup failure (task engine, bus, persistence).""" + """Reverse cleanup on startup failure (task engine, bridge, bus, persistence).""" if started_task_engine and task_engine is not None: try: await task_engine.stop() @@ -152,6 +154,14 @@ async def _cleanup_on_failure( # noqa: PLR0913 API_APP_STARTUP, error="Cleanup: failed to stop task engine", ) + if started_bridge and bridge is not None: + try: + await bridge.stop() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Cleanup: failed to stop message bus bridge", + ) if started_bus and message_bus is not None: try: await message_bus.stop() @@ -225,6 +235,7 @@ async def _safe_startup( components in reverse order before re-raising. """ started_bus = False + started_bridge = False started_persistence = False started_task_engine = False try: @@ -261,6 +272,7 @@ async def _safe_startup( error="Failed to start message bus bridge", ) raise + started_bridge = True if task_engine is not None: try: task_engine.start() @@ -277,6 +289,8 @@ async def _safe_startup( started_persistence=started_persistence, message_bus=message_bus, started_bus=started_bus, + bridge=bridge, + started_bridge=started_bridge, task_engine=task_engine, started_task_engine=started_task_engine, ) diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index b4847373b9..88cc018876 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -23,6 +23,7 @@ from ai_company.engine.errors import ( TaskEngineNotRunningError, TaskEngineQueueFullError, + TaskInternalError, TaskMutationError, TaskNotFoundError, ) @@ -143,6 +144,8 @@ async def create_task( raise ServiceUnavailableError(str(exc)) from exc except TaskEngineQueueFullError as exc: raise ServiceUnavailableError(str(exc)) from exc + except TaskInternalError as exc: + raise ServiceUnavailableError(str(exc)) from exc except TaskMutationError as exc: raise ApiValidationError(str(exc)) from exc logger.info( @@ -191,6 +194,8 @@ async def update_task( id=task_id, ) raise NotFoundError(str(exc)) from exc + except TaskInternalError as exc: + raise ServiceUnavailableError(str(exc)) from exc except TaskMutationError as exc: raise ApiValidationError(str(exc)) from exc logger.info(API_TASK_UPDATED, task_id=task_id, fields=list(updates)) @@ -220,8 +225,6 @@ async def transition_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - current_task = await app_state.task_engine.get_task(task_id) - from_status = current_task.status if current_task else None transition_kwargs: dict[str, object] = { "requested_by": "api", "reason": f"API transition to {data.target_status.value}", @@ -229,7 +232,7 @@ async def transition_task( if data.assigned_to is not None: transition_kwargs["assigned_to"] = data.assigned_to try: - task = await app_state.task_engine.transition_task( + task, from_status = await app_state.task_engine.transition_task( task_id, data.target_status, **transition_kwargs, # type: ignore[arg-type] @@ -245,6 +248,8 @@ async def transition_task( id=task_id, ) raise NotFoundError(str(exc)) from exc + except TaskInternalError as exc: + raise ServiceUnavailableError(str(exc)) from exc except TaskMutationError as exc: error_str = str(exc) logger.warning( @@ -296,5 +301,9 @@ async def delete_task( id=task_id, ) raise NotFoundError(str(exc)) from exc + except TaskInternalError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskMutationError as exc: + raise ApiValidationError(str(exc)) from exc logger.info(API_TASK_DELETED, task_id=task_id) return ApiResponse(data=None) diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index aee8f905c6..c9bd49e2b7 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -70,6 +70,7 @@ TaskEngineError, TaskEngineNotRunningError, TaskEngineQueueFullError, + TaskInternalError, TaskMutationError, TaskNotFoundError, TaskRoutingError, @@ -268,6 +269,7 @@ "TaskEngineNotRunningError", "TaskEngineQueueFullError", "TaskExecution", + "TaskInternalError", "TaskMutation", "TaskMutationError", "TaskMutationResult", diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 8c97b3edaf..c4e3701411 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -688,7 +688,7 @@ async def _report_to_task_engine( return try: - await self._task_engine.transition_task( + _, _ = await self._task_engine.transition_task( task_id, final_status, requested_by=agent_id, diff --git a/src/ai_company/engine/errors.py b/src/ai_company/engine/errors.py index a8774210fe..a597b971a1 100644 --- a/src/ai_company/engine/errors.py +++ b/src/ai_company/engine/errors.py @@ -104,3 +104,13 @@ class TaskNotFoundError(TaskMutationError): class TaskVersionConflictError(TaskMutationError): """Raised when optimistic concurrency version does not match.""" + + +class TaskInternalError(TaskMutationError): + """Raised when a task mutation fails due to an internal engine error. + + Unlike :class:`TaskMutationError` (which covers business-rule failures + such as validation or not-found), this signals an unexpected engine fault + that the caller cannot fix by changing the request. Maps to 5xx at the API + layer. + """ diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index f1997e0918..eea3bdeb46 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -13,7 +13,7 @@ import contextlib from dataclasses import dataclass, field from datetime import UTC, datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Never from uuid import uuid4 from ai_company.core.enums import TaskStatus @@ -21,6 +21,7 @@ from ai_company.engine.errors import ( TaskEngineNotRunningError, TaskEngineQueueFullError, + TaskInternalError, TaskMutationError, TaskNotFoundError, TaskVersionConflictError, @@ -185,6 +186,7 @@ def _fail_remaining_futures(self) -> None: request_id=envelope.mutation.request_id, success=False, error="TaskEngine shut down before processing", + error_code="internal", ), ) @@ -315,7 +317,7 @@ async def transition_task( reason: str = "", expected_version: int | None = None, **overrides: object, - ) -> Task: + ) -> tuple[Task, TaskStatus | None]: """Convenience: transition task status and return the updated Task. Args: @@ -327,7 +329,8 @@ async def transition_task( **overrides: Additional field overrides for the transition. Returns: - The transitioned task. + Tuple of (transitioned task, status before the transition). + The second element is ``None`` when the previous status is unknown. Raises: TaskEngineNotRunningError: If the engine is not running. @@ -352,7 +355,7 @@ async def transition_task( if result.task is None: msg = "Internal error: transition succeeded but task is None" raise TaskMutationError(msg) - return result.task + return result.task, result.previous_status async def delete_task( self, @@ -423,7 +426,7 @@ async def cancel_task( return result.task @staticmethod - def _raise_typed_error(result: TaskMutationResult) -> None: + def _raise_typed_error(result: TaskMutationResult) -> Never: """Raise a typed error from a failed mutation result.""" error = result.error or "Mutation failed" match result.error_code: @@ -431,6 +434,8 @@ def _raise_typed_error(result: TaskMutationResult) -> None: raise TaskNotFoundError(error) case "version_conflict": raise TaskVersionConflictError(error) + case "internal": + raise TaskInternalError(error) case _: raise TaskMutationError(error) @@ -706,6 +711,7 @@ async def _apply_transition( request_id=mutation.request_id, success=False, error=str(exc), + error_code="validation", ) await self._persistence.tasks.save(updated) @@ -782,6 +788,7 @@ async def _apply_cancel( request_id=mutation.request_id, success=False, error=str(exc), + error_code="validation", ) await self._persistence.tasks.save(updated) diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py index c4429e155c..f977ebc8df 100644 --- a/src/ai_company/engine/task_engine_models.py +++ b/src/ai_company/engine/task_engine_models.py @@ -133,7 +133,6 @@ def _reject_immutable_fields(self) -> Self: { "id", "created_by", - "created_at", "status", } ) diff --git a/tests/unit/engine/conftest.py b/tests/unit/engine/conftest.py index 3901cd71f4..4b6d6d1e99 100644 --- a/tests/unit/engine/conftest.py +++ b/tests/unit/engine/conftest.py @@ -27,6 +27,8 @@ from ai_company.core.role import Authority, Role from ai_company.core.task import AcceptanceCriterion, Task from ai_company.engine.context import AgentContext +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_config import TaskEngineConfig from ai_company.engine.task_execution import TaskExecution from ai_company.providers.capabilities import ModelCapabilities from ai_company.providers.enums import FinishReason @@ -38,9 +40,10 @@ TokenUsage, ToolDefinition, ) +from tests.unit.engine.task_engine_helpers import FakeMessageBus, FakePersistence if TYPE_CHECKING: - from collections.abc import AsyncIterator + from collections.abc import AsyncGenerator, AsyncIterator from ai_company.core.enums import ConflictEscalation from ai_company.engine.workspace.models import ( @@ -402,3 +405,56 @@ def make_assignment_task(**overrides: object) -> Task: } defaults.update(overrides) return Task(**defaults) # type: ignore[arg-type] + + +# ── TaskEngine fixtures ─────────────────────────────────────── + + +@pytest.fixture +def persistence() -> FakePersistence: + """Provide a fresh FakePersistence instance.""" + return FakePersistence() + + +@pytest.fixture +def message_bus() -> FakeMessageBus: + """Provide a fresh FakeMessageBus instance.""" + return FakeMessageBus() + + +@pytest.fixture +def config() -> TaskEngineConfig: + """Provide a TaskEngineConfig with a sensible queue size.""" + return TaskEngineConfig(max_queue_size=100) + + +@pytest.fixture +async def engine( + persistence: FakePersistence, + config: TaskEngineConfig, +) -> AsyncGenerator[TaskEngine]: + """Create and start a TaskEngine, stop on teardown.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + yield eng + await eng.stop(timeout=2.0) + + +@pytest.fixture +async def engine_with_bus( + persistence: FakePersistence, + message_bus: FakeMessageBus, + config: TaskEngineConfig, +) -> AsyncGenerator[TaskEngine]: + """Create and start a TaskEngine with a message bus.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + yield eng + await eng.stop(timeout=2.0) diff --git a/tests/unit/engine/task_engine_helpers.py b/tests/unit/engine/task_engine_helpers.py new file mode 100644 index 0000000000..97884af09a --- /dev/null +++ b/tests/unit/engine/task_engine_helpers.py @@ -0,0 +1,104 @@ +"""Shared fakes and helpers for TaskEngine tests.""" + +from typing import TYPE_CHECKING + +from ai_company.core.task import Task # noqa: TC001 +from ai_company.engine.task_engine_models import CreateTaskData + +if TYPE_CHECKING: + from ai_company.core.enums import TaskStatus + + +# ── Fakes ───────────────────────────────────────────────────── + + +class FakeTaskRepository: + """Minimal in-memory task repository for engine tests.""" + + def __init__(self) -> None: + self._tasks: dict[str, Task] = {} + + async def save(self, task: Task) -> None: + self._tasks[task.id] = task + + async def get(self, task_id: str) -> Task | None: + return self._tasks.get(task_id) + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + result = list(self._tasks.values()) + if status is not None: + result = [t for t in result if t.status == status] + if assigned_to is not None: + result = [t for t in result if t.assigned_to == assigned_to] + if project is not None: + result = [t for t in result if t.project == project] + return tuple(result) + + async def delete(self, task_id: str) -> bool: + return self._tasks.pop(task_id, None) is not None + + +class FakePersistence: + """Minimal fake persistence backend with only a task repository.""" + + def __init__(self) -> None: + self._tasks = FakeTaskRepository() + + @property + def tasks(self) -> FakeTaskRepository: + return self._tasks + + +class FakeMessageBus: + """Minimal fake message bus that records published messages.""" + + def __init__(self) -> None: + self.published: list[object] = [] + self._running = False + + async def start(self) -> None: + self._running = True + + async def stop(self) -> None: + self._running = False + + @property + def is_running(self) -> bool: + return self._running + + async def publish(self, message: object) -> None: + self.published.append(message) + + +class FailingMessageBus(FakeMessageBus): + """Message bus that always fails on publish.""" + + async def publish(self, message: object) -> None: + msg = "Publish failed" + raise RuntimeError(msg) + + +# ── Helpers ──────────────────────────────────────────────────── + + +def _make_create_data(**overrides: object) -> CreateTaskData: + """Build a CreateTaskData with sensible defaults.""" + from ai_company.core.enums import TaskType + + defaults: dict[str, object] = { + "title": "Test task", + "description": "A test task", + "type": TaskType.DEVELOPMENT, + "project": "test-project", + "created_by": "test-agent", + } + defaults.update(overrides) + return CreateTaskData(**defaults) # type: ignore[arg-type] + + diff --git a/tests/unit/engine/test_task_engine.py b/tests/unit/engine/test_task_engine.py deleted file mode 100644 index 8c86f0baa3..0000000000 --- a/tests/unit/engine/test_task_engine.py +++ /dev/null @@ -1,1140 +0,0 @@ -"""Tests for the centralized single-writer TaskEngine.""" - -import asyncio -from collections.abc import AsyncGenerator # noqa: TC003 - -import pytest - -from ai_company.core.enums import ( - TaskStatus, - TaskType, -) -from ai_company.core.task import Task # noqa: TC001 -from ai_company.engine.errors import ( - TaskEngineNotRunningError, - TaskEngineQueueFullError, - TaskMutationError, - TaskNotFoundError, - TaskVersionConflictError, -) -from ai_company.engine.task_engine import TaskEngine -from ai_company.engine.task_engine_config import TaskEngineConfig -from ai_company.engine.task_engine_models import ( - CancelTaskMutation, - CreateTaskData, - CreateTaskMutation, - DeleteTaskMutation, - TransitionTaskMutation, - UpdateTaskMutation, -) - -# ── Fakes ───────────────────────────────────────────────────── - - -class FakeTaskRepository: - """Minimal in-memory task repository for engine tests.""" - - def __init__(self) -> None: - self._tasks: dict[str, Task] = {} - - async def save(self, task: Task) -> None: - self._tasks[task.id] = task - - async def get(self, task_id: str) -> Task | None: - return self._tasks.get(task_id) - - async def list_tasks( - self, - *, - status: TaskStatus | None = None, - assigned_to: str | None = None, - project: str | None = None, - ) -> tuple[Task, ...]: - result = list(self._tasks.values()) - if status is not None: - result = [t for t in result if t.status == status] - if assigned_to is not None: - result = [t for t in result if t.assigned_to == assigned_to] - if project is not None: - result = [t for t in result if t.project == project] - return tuple(result) - - async def delete(self, task_id: str) -> bool: - return self._tasks.pop(task_id, None) is not None - - -class FakePersistence: - """Minimal fake persistence backend with only a task repository.""" - - def __init__(self) -> None: - self._tasks = FakeTaskRepository() - - @property - def tasks(self) -> FakeTaskRepository: - return self._tasks - - -class FakeMessageBus: - """Minimal fake message bus that records published messages.""" - - def __init__(self) -> None: - self.published: list[object] = [] - self._running = False - - async def start(self) -> None: - self._running = True - - async def stop(self) -> None: - self._running = False - - @property - def is_running(self) -> bool: - return self._running - - async def publish(self, message: object) -> None: - self.published.append(message) - - -class FailingMessageBus(FakeMessageBus): - """Message bus that always fails on publish.""" - - async def publish(self, message: object) -> None: - msg = "Publish failed" - raise RuntimeError(msg) - - -# ── Fixtures ────────────────────────────────────────────────── - - -def _make_create_data(**overrides: object) -> CreateTaskData: - """Build a CreateTaskData with sensible defaults.""" - defaults: dict[str, object] = { - "title": "Test task", - "description": "A test task", - "type": TaskType.DEVELOPMENT, - "project": "test-project", - "created_by": "test-agent", - } - defaults.update(overrides) - return CreateTaskData(**defaults) # type: ignore[arg-type] - - -@pytest.fixture -def persistence() -> FakePersistence: - return FakePersistence() - - -@pytest.fixture -def message_bus() -> FakeMessageBus: - return FakeMessageBus() - - -@pytest.fixture -def config() -> TaskEngineConfig: - return TaskEngineConfig(max_queue_size=100) - - -@pytest.fixture -async def engine( - persistence: FakePersistence, - config: TaskEngineConfig, -) -> AsyncGenerator[TaskEngine]: - """Create and start a TaskEngine, stop on teardown.""" - eng = TaskEngine( - persistence=persistence, # type: ignore[arg-type] - config=config, - ) - eng.start() - yield eng - await eng.stop(timeout=2.0) - - -@pytest.fixture -async def engine_with_bus( - persistence: FakePersistence, - message_bus: FakeMessageBus, - config: TaskEngineConfig, -) -> AsyncGenerator[TaskEngine]: - """Create and start a TaskEngine with a message bus.""" - eng = TaskEngine( - persistence=persistence, # type: ignore[arg-type] - message_bus=message_bus, # type: ignore[arg-type] - config=config, - ) - eng.start() - yield eng - await eng.stop(timeout=2.0) - - -# ── Lifecycle tests ─────────────────────────────────────────── - - -@pytest.mark.unit -class TestTaskEngineLifecycle: - """Tests for start/stop lifecycle.""" - - async def test_start_sets_running( - self, - persistence: FakePersistence, - ) -> None: - eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] - assert eng.is_running is False - eng.start() - assert eng.is_running is True - await eng.stop(timeout=2.0) # type: ignore[unreachable] - assert eng.is_running is False - - async def test_double_start_raises( - self, - persistence: FakePersistence, - ) -> None: - eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] - eng.start() - with pytest.raises(RuntimeError, match="already running"): - eng.start() - await eng.stop(timeout=2.0) - - async def test_stop_idempotent( - self, - persistence: FakePersistence, - ) -> None: - eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] - eng.start() - await eng.stop(timeout=2.0) - await eng.stop(timeout=2.0) # no error - - async def test_restart( - self, - persistence: FakePersistence, - ) -> None: - eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] - eng.start() - await eng.stop(timeout=2.0) - eng.start() - assert eng.is_running is True - await eng.stop(timeout=2.0) - - -# ── Submit to stopped engine ────────────────────────────────── - - -@pytest.mark.unit -class TestSubmitToStoppedEngine: - """Submitting to a stopped engine raises TaskEngineNotRunningError.""" - - async def test_submit_raises( - self, - persistence: FakePersistence, - ) -> None: - eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] - mutation = CreateTaskMutation( - request_id="req-1", - requested_by="alice", - task_data=_make_create_data(), - ) - with pytest.raises(TaskEngineNotRunningError): - await eng.submit(mutation) - - -# ── Create mutation ─────────────────────────────────────────── - - -@pytest.mark.unit -class TestCreateTask: - """Tests for task creation via TaskEngine.""" - - async def test_create_task( - self, - engine: TaskEngine, - persistence: FakePersistence, - ) -> None: - task = await engine.create_task( - _make_create_data(title="My Task"), - requested_by="alice", - ) - assert task.title == "My Task" - assert task.id.startswith("task-") - assert task.status == TaskStatus.CREATED - - stored = await persistence.tasks.get(task.id) - assert stored is not None - assert stored.title == "My Task" - - async def test_create_returns_version_1( - self, - engine: TaskEngine, - ) -> None: - mutation = CreateTaskMutation( - request_id="req-1", - requested_by="alice", - task_data=_make_create_data(), - ) - result = await engine.submit(mutation) - assert result.success is True - assert result.version == 1 - - async def test_create_with_assignee( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(assigned_to=None), - requested_by="alice", - ) - assert task.assigned_to is None - - -# ── Update mutation ─────────────────────────────────────────── - - -@pytest.mark.unit -class TestUpdateTask: - """Tests for task update via TaskEngine.""" - - async def test_update_fields( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(title="Original"), - requested_by="alice", - ) - updated = await engine.update_task( - task.id, - {"title": "Updated"}, - requested_by="alice", - ) - assert updated.title == "Updated" - assert updated.id == task.id - - async def test_update_empty_no_op( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - result = await engine.update_task( - task.id, - {}, - requested_by="alice", - ) - assert result.title == task.title - - async def test_update_not_found( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskMutationError, match="not found"): - await engine.update_task( - "task-nonexistent", - {"title": "X"}, - requested_by="alice", - ) - - -# ── Transition mutation ─────────────────────────────────────── - - -@pytest.mark.unit -class TestTransitionTask: - """Tests for task status transitions via TaskEngine.""" - - async def test_valid_transition( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - assigned = await engine.transition_task( - task.id, - TaskStatus.ASSIGNED, - requested_by="alice", - reason="Assigning", - assigned_to="bob", - ) - assert assigned.status == TaskStatus.ASSIGNED - assert assigned.assigned_to == "bob" - - async def test_invalid_transition( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - with pytest.raises(TaskMutationError): - await engine.transition_task( - task.id, - TaskStatus.COMPLETED, - requested_by="alice", - reason="Skip to done", - assigned_to="bob", - ) - - async def test_transition_not_found( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskMutationError, match="not found"): - await engine.transition_task( - "task-nonexistent", - TaskStatus.ASSIGNED, - requested_by="alice", - reason="test", - ) - - -# ── Delete mutation ─────────────────────────────────────────── - - -@pytest.mark.unit -class TestDeleteTask: - """Tests for task deletion via TaskEngine.""" - - async def test_delete_task( - self, - engine: TaskEngine, - persistence: FakePersistence, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - deleted = await engine.delete_task(task.id, requested_by="alice") - assert deleted is True - - stored = await persistence.tasks.get(task.id) - assert stored is None - - async def test_delete_not_found( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskMutationError, match="not found"): - await engine.delete_task( - "task-nonexistent", - requested_by="alice", - ) - - -# ── Cancel mutation ─────────────────────────────────────────── - - -@pytest.mark.unit -class TestCancelTask: - """Tests for task cancellation via TaskEngine.""" - - async def test_cancel_assigned_task( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - assigned = await engine.transition_task( - task.id, - TaskStatus.ASSIGNED, - requested_by="alice", - reason="Assigning", - assigned_to="bob", - ) - cancelled = await engine.cancel_task( - assigned.id, - requested_by="alice", - reason="No longer needed", - ) - assert cancelled.status == TaskStatus.CANCELLED - - async def test_cancel_from_created_fails( - self, - engine: TaskEngine, - ) -> None: - """CREATED -> CANCELLED is not a valid transition.""" - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - with pytest.raises(TaskMutationError): - await engine.cancel_task( - task.id, - requested_by="alice", - reason="Oops", - ) - - -# ── Read-through ────────────────────────────────────────────── - - -@pytest.mark.unit -class TestReadThrough: - """Tests for read-through methods that bypass the queue.""" - - async def test_get_task( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(title="Findme"), - requested_by="alice", - ) - found = await engine.get_task(task.id) - assert found is not None - assert found.title == "Findme" - - async def test_get_task_not_found( - self, - engine: TaskEngine, - ) -> None: - result = await engine.get_task("task-nonexistent") - assert result is None - - async def test_list_tasks( - self, - engine: TaskEngine, - ) -> None: - await engine.create_task( - _make_create_data(project="proj-a"), - requested_by="alice", - ) - await engine.create_task( - _make_create_data(project="proj-b"), - requested_by="alice", - ) - all_tasks = await engine.list_tasks() - assert len(all_tasks) == 2 - - filtered = await engine.list_tasks(project="proj-a") - assert len(filtered) == 1 - - async def test_list_tasks_by_status( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - await engine.transition_task( - task.id, - TaskStatus.ASSIGNED, - requested_by="alice", - reason="Assigning", - assigned_to="bob", - ) - - created = await engine.list_tasks(status=TaskStatus.CREATED) - assigned = await engine.list_tasks(status=TaskStatus.ASSIGNED) - assert len(created) == 0 - assert len(assigned) == 1 - - -# ── Version tracking ────────────────────────────────────────── - - -@pytest.mark.unit -class TestVersionTracking: - """Tests for the in-memory version counter.""" - - async def test_version_increments( - self, - engine: TaskEngine, - ) -> None: - mutation = CreateTaskMutation( - request_id="req-1", - requested_by="alice", - task_data=_make_create_data(), - ) - r1 = await engine.submit(mutation) - assert r1.version == 1 - - update = UpdateTaskMutation( - request_id="req-2", - requested_by="alice", - task_id=r1.task.id, # type: ignore[union-attr] - updates={"title": "Updated"}, - ) - r2 = await engine.submit(update) - assert r2.version == 2 - - async def test_version_conflict( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - # version is 1 after create; expected_version=99 should fail - update = UpdateTaskMutation( - request_id="req-2", - requested_by="alice", - task_id=task.id, - updates={"title": "X"}, - expected_version=99, - ) - result = await engine.submit(update) - assert result.success is False - assert "conflict" in (result.error or "").lower() - - async def test_version_reset_on_delete( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - delete = DeleteTaskMutation( - request_id="req-3", - requested_by="alice", - task_id=task.id, - ) - result = await engine.submit(delete) - assert result.version == 0 - - -# ── Snapshot publishing ─────────────────────────────────────── - - -@pytest.mark.unit -class TestSnapshotPublishing: - """Tests for event publishing to the message bus.""" - - async def test_snapshot_published_on_create( - self, - engine_with_bus: TaskEngine, - message_bus: FakeMessageBus, - ) -> None: - await engine_with_bus.create_task( - _make_create_data(), - requested_by="alice", - ) - # Yield to event loop so the processing loop completes snapshot publication - await asyncio.sleep(0) - assert len(message_bus.published) == 1 - - async def test_snapshot_publish_failure_does_not_affect_mutation( - self, - persistence: FakePersistence, - config: TaskEngineConfig, - ) -> None: - failing_bus = FailingMessageBus() - eng = TaskEngine( - persistence=persistence, # type: ignore[arg-type] - message_bus=failing_bus, # type: ignore[arg-type] - config=config, - ) - eng.start() - try: - task = await eng.create_task( - _make_create_data(), - requested_by="alice", - ) - assert task.id.startswith("task-") - - stored = await persistence.tasks.get(task.id) - assert stored is not None - finally: - await eng.stop(timeout=2.0) - - async def test_no_snapshot_when_disabled( - self, - persistence: FakePersistence, - message_bus: FakeMessageBus, - ) -> None: - no_snap_config = TaskEngineConfig(publish_snapshots=False) - eng = TaskEngine( - persistence=persistence, # type: ignore[arg-type] - message_bus=message_bus, # type: ignore[arg-type] - config=no_snap_config, - ) - eng.start() - try: - await eng.create_task( - _make_create_data(), - requested_by="alice", - ) - await asyncio.sleep(0) - assert len(message_bus.published) == 0 - finally: - await eng.stop(timeout=2.0) - - -# ── Sequential ordering ────────────────────────────────────── - - -@pytest.mark.unit -class TestSequentialOrdering: - """Tests that mutations are processed sequentially.""" - - async def test_concurrent_submits( - self, - engine: TaskEngine, - ) -> None: - """Multiple concurrent creates all succeed without interleaving.""" - tasks = await asyncio.gather( - *( - engine.create_task( - _make_create_data(title=f"Task {i}"), - requested_by="alice", - ) - for i in range(10) - ), - ) - assert len(tasks) == 10 - ids = {t.id for t in tasks} - assert len(ids) == 10 # all unique - - -# ── Drain on stop ───────────────────────────────────────────── - - -@pytest.mark.unit -class TestDrainOnStop: - """Tests that stop() drains pending mutations.""" - - async def test_pending_mutations_processed( - self, - persistence: FakePersistence, - ) -> None: - config = TaskEngineConfig(max_queue_size=100) - eng = TaskEngine( - persistence=persistence, # type: ignore[arg-type] - config=config, - ) - eng.start() - - # Submit several mutations - results = await asyncio.gather( - *( - eng.create_task( - _make_create_data(title=f"Drain {i}"), - requested_by="alice", - ) - for i in range(5) - ), - ) - assert len(results) == 5 - - await eng.stop(timeout=5.0) - assert eng.is_running is False - - # All tasks should be persisted - all_tasks = await persistence.tasks.list_tasks() - assert len(all_tasks) == 5 - - -# ── Queue full ──────────────────────────────────────────────── - - -@pytest.mark.unit -class TestQueueFull: - """Tests for queue full backpressure.""" - - async def test_queue_full_raises( - self, - persistence: FakePersistence, - ) -> None: - from ai_company.engine.task_engine import _MutationEnvelope - - tiny_config = TaskEngineConfig(max_queue_size=1) - eng = TaskEngine( - persistence=persistence, # type: ignore[arg-type] - config=tiny_config, - ) - # Start the engine but pause the processing loop - eng._running = True - - # First submit fills the queue - mutation1 = CreateTaskMutation( - request_id="req-1", - requested_by="alice", - task_data=_make_create_data(), - ) - eng._queue.put_nowait(_MutationEnvelope(mutation=mutation1)) - - # Second submit should fail because queue is full - mutation2 = CreateTaskMutation( - request_id="req-2", - requested_by="alice", - task_data=_make_create_data(), - ) - with pytest.raises(TaskEngineQueueFullError, match="queue is full"): - await eng.submit(mutation2) - - eng._running = False - - -# ── Error propagation ──────────────────────────────────────── - - -@pytest.mark.unit -class TestErrorPropagation: - """Tests for error propagation via futures.""" - - async def test_persistence_error_returns_failure( - self, - persistence: FakePersistence, - config: TaskEngineConfig, - ) -> None: - """Persistence errors during mutation are captured in the result.""" - - class FailingSaveRepo(FakeTaskRepository): - async def save(self, task: Task) -> None: - msg = "Disk full" - raise OSError(msg) - - persistence._tasks = FailingSaveRepo() - eng = TaskEngine( - persistence=persistence, # type: ignore[arg-type] - config=config, - ) - eng.start() - try: - mutation = CreateTaskMutation( - request_id="req-1", - requested_by="alice", - task_data=_make_create_data(), - ) - result = await eng.submit(mutation) - assert result.success is False - assert result.error == "Internal error processing mutation" - finally: - await eng.stop(timeout=2.0) - - -# ── TaskEngineConfig ────────────────────────────────────────── - - -@pytest.mark.unit -class TestTaskEngineConfig: - """Tests for TaskEngineConfig model.""" - - def test_defaults(self) -> None: - config = TaskEngineConfig() - assert config.max_queue_size == 1000 - assert config.drain_timeout_seconds == 10.0 - assert config.publish_snapshots is True - - def test_custom_values(self) -> None: - config = TaskEngineConfig( - max_queue_size=500, - drain_timeout_seconds=5.0, - publish_snapshots=False, - ) - assert config.max_queue_size == 500 - assert config.drain_timeout_seconds == 5.0 - assert config.publish_snapshots is False - - def test_frozen(self) -> None: - from pydantic import ValidationError - - config = TaskEngineConfig() - with pytest.raises(ValidationError): - config.max_queue_size = 999 # type: ignore[misc] - - -# -- Version conflict on transition ──────────────────────────── - - -@pytest.mark.unit -class TestVersionConflictOnTransition: - """Version conflict detection on transition mutations.""" - - async def test_transition_version_conflict( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - mutation = TransitionTaskMutation( - request_id="req-1", - requested_by="alice", - task_id=task.id, - target_status=TaskStatus.ASSIGNED, - reason="Assigning", - overrides={"assigned_to": "bob"}, - expected_version=99, - ) - result = await engine.submit(mutation) - assert result.success is False - assert "conflict" in (result.error or "").lower() - - -# -- Cancel not found ───────────────────────────────────────── - - -@pytest.mark.unit -class TestCancelNotFound: - """Cancel mutation on a non-existent task.""" - - async def test_cancel_not_found( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskNotFoundError, match="not found"): - await engine.cancel_task( - "task-nonexistent", - requested_by="alice", - reason="test", - ) - - -# -- Previous status in results ──────────────────────────────── - - -@pytest.mark.unit -class TestPreviousStatus: - """Verify previous_status is populated in mutation results.""" - - async def test_create_has_no_previous_status( - self, - engine: TaskEngine, - ) -> None: - mutation = CreateTaskMutation( - request_id="req-1", - requested_by="alice", - task_data=_make_create_data(), - ) - result = await engine.submit(mutation) - assert result.success is True - assert result.previous_status is None - - async def test_transition_has_previous_status( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - mutation = TransitionTaskMutation( - request_id="req-1", - requested_by="alice", - task_id=task.id, - target_status=TaskStatus.ASSIGNED, - reason="Assigning", - overrides={"assigned_to": "bob"}, - ) - result = await engine.submit(mutation) - assert result.success is True - assert result.previous_status == TaskStatus.CREATED - - async def test_cancel_has_previous_status( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - # First move to ASSIGNED so cancel is valid - await engine.transition_task( - task.id, - TaskStatus.ASSIGNED, - requested_by="alice", - reason="Assigning", - assigned_to="bob", - ) - mutation = CancelTaskMutation( - request_id="req-1", - requested_by="alice", - task_id=task.id, - reason="No longer needed", - ) - result = await engine.submit(mutation) - assert result.success is True - assert result.previous_status == TaskStatus.ASSIGNED - - -# -- Immutable field rejection ───────────────────────────────── - - -@pytest.mark.unit -class TestImmutableFieldRejection: - """UpdateTaskMutation and TransitionTaskMutation reject immutable fields.""" - - def test_update_rejects_status(self) -> None: - from pydantic import ValidationError - - with pytest.raises(ValidationError, match="immutable"): - UpdateTaskMutation( - request_id="req-1", - requested_by="alice", - task_id="task-1", - updates={"status": "completed"}, - ) - - def test_update_rejects_id(self) -> None: - from pydantic import ValidationError - - with pytest.raises(ValidationError, match="immutable"): - UpdateTaskMutation( - request_id="req-1", - requested_by="alice", - task_id="task-1", - updates={"id": "new-id"}, - ) - - def test_transition_rejects_id_override(self) -> None: - from pydantic import ValidationError - - with pytest.raises(ValidationError, match="immutable"): - TransitionTaskMutation( - request_id="req-1", - requested_by="alice", - task_id="task-1", - target_status=TaskStatus.ASSIGNED, - reason="test", - overrides={"id": "new-id"}, - ) - - -# -- Typed error propagation ────────────────────────────────── - - -@pytest.mark.unit -class TestTypedErrors: - """Convenience methods raise typed errors.""" - - async def test_update_not_found_raises_typed( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskNotFoundError): - await engine.update_task( - "task-nonexistent", - {"title": "X"}, - requested_by="alice", - ) - - async def test_delete_not_found_raises_typed( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskNotFoundError): - await engine.delete_task( - "task-nonexistent", - requested_by="alice", - ) - - async def test_transition_not_found_raises_typed( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskNotFoundError): - await engine.transition_task( - "task-nonexistent", - TaskStatus.ASSIGNED, - requested_by="alice", - reason="test", - ) - - async def test_update_version_conflict_raises_typed( - self, - engine: TaskEngine, - ) -> None: - """Version conflict via convenience method raises TaskVersionConflictError.""" - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - with pytest.raises(TaskVersionConflictError, match="conflict"): - await engine.update_task( - task.id, - {"title": "changed"}, - requested_by="alice", - expected_version=99, - ) - - -# -- Drain timeout / _fail_remaining_futures ────────────────── - - -@pytest.mark.unit -class TestDrainTimeout: - """Verify _fail_remaining_futures resolves abandoned futures.""" - - async def test_abandoned_futures_resolved_on_drain_timeout( - self, - persistence: FakePersistence, - ) -> None: - """Futures left after drain timeout get failure results.""" - from ai_company.engine.task_engine import _MutationEnvelope - - config = TaskEngineConfig(drain_timeout_seconds=0.01) - eng = TaskEngine( - persistence=persistence, # type: ignore[arg-type] - config=config, - ) - eng.start() - - # Pause processing by filling with a slow mutation - mutation = CreateTaskMutation( - request_id="req-slow", - requested_by="alice", - task_data=_make_create_data(), - ) - # Submit one that will process normally - await eng.submit(mutation) - - # Now stop the engine — force a very short drain - # Submit directly to queue to avoid await - mutation2 = CreateTaskMutation( - request_id="req-abandoned", - requested_by="alice", - task_data=_make_create_data(), - ) - envelope = _MutationEnvelope(mutation=mutation2) - eng._queue.put_nowait(envelope) - eng._running = False - - # Call _fail_remaining_futures directly - eng._fail_remaining_futures() - assert envelope.future.done() - result = envelope.future.result() - assert result.success is False - assert "shut down" in (result.error or "") - - await eng.stop(timeout=0.1) - - -# -- TaskMutationResult consistency ──────────────────────────── - - -@pytest.mark.unit -class TestMutationResultConsistency: - """Verify _check_consistency validator on TaskMutationResult.""" - - def test_success_with_error_rejected(self) -> None: - """Successful result must not carry an error.""" - from pydantic import ValidationError - - from ai_company.engine.task_engine_models import TaskMutationResult - - with pytest.raises(ValidationError, match="error"): - TaskMutationResult( - request_id="r", - success=True, - error="oops", - ) - - def test_failure_without_error_rejected(self) -> None: - """Failed result must carry an error description.""" - from pydantic import ValidationError - - from ai_company.engine.task_engine_models import TaskMutationResult - - with pytest.raises(ValidationError, match="error"): - TaskMutationResult( - request_id="r", - success=False, - ) diff --git a/tests/unit/engine/test_task_engine_integration.py b/tests/unit/engine/test_task_engine_integration.py new file mode 100644 index 0000000000..2298a9de1b --- /dev/null +++ b/tests/unit/engine/test_task_engine_integration.py @@ -0,0 +1,326 @@ +"""Integration tests for TaskEngine: publishing, ordering, queue, versioning, drain.""" + +import asyncio +import contextlib + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.errors import TaskEngineQueueFullError +from ai_company.engine.task_engine import TaskEngine, _MutationEnvelope +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CreateTaskMutation, + DeleteTaskMutation, + TransitionTaskMutation, + UpdateTaskMutation, +) +from tests.unit.engine.task_engine_helpers import ( + FailingMessageBus, + FakeMessageBus, + FakePersistence, + _make_create_data, +) + +# ── Snapshot publishing ─────────────────────────────────────── + + +@pytest.mark.unit +class TestSnapshotPublishing: + """Tests for event publishing to the message bus.""" + + async def test_snapshot_published_on_create( + self, + engine_with_bus: TaskEngine, + message_bus: FakeMessageBus, + ) -> None: + await engine_with_bus.create_task( + _make_create_data(), + requested_by="alice", + ) + # Yield to event loop so the processing loop completes snapshot publication + await asyncio.sleep(0) + assert len(message_bus.published) == 1 + + async def test_snapshot_publish_failure_does_not_affect_mutation( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + failing_bus = FailingMessageBus() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=failing_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + task = await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + assert task.id.startswith("task-") + + stored = await persistence.tasks.get(task.id) + assert stored is not None + finally: + await eng.stop(timeout=2.0) + + async def test_no_snapshot_when_disabled( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + ) -> None: + no_snap_config = TaskEngineConfig(publish_snapshots=False) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=no_snap_config, + ) + eng.start() + try: + await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + await asyncio.sleep(0) + assert len(message_bus.published) == 0 + finally: + await eng.stop(timeout=2.0) + + async def test_pending_mutations_drained_on_stop( + self, + persistence: FakePersistence, + ) -> None: + """Tasks submitted before stop() are processed during drain.""" + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + + # Submit concurrently without awaiting — items enter the queue + create_tasks = [ + asyncio.create_task( + eng.create_task(_make_create_data(), requested_by="alice") + ) + for _ in range(5) + ] + + # Yield to the event loop so the tasks can enqueue their mutations before stop() + await asyncio.sleep(0) + + # Stop while tasks may still be in flight — drain timeout is generous + await eng.stop(timeout=5.0) + + # All futures resolved (drained during stop or completed before stop) + results = await asyncio.gather(*create_tasks) + assert len(results) == 5 + stored = await persistence.tasks.list_tasks() + assert len(stored) == 5 + + +# ── Sequential ordering ────────────────────────────────────── + + +@pytest.mark.unit +class TestSequentialOrdering: + """Tests that mutations are processed sequentially.""" + + async def test_concurrent_submits( + self, + engine: TaskEngine, + ) -> None: + """Multiple concurrent creates all succeed without interleaving.""" + tasks = await asyncio.gather( + *( + engine.create_task( + _make_create_data(title=f"Task {i}"), + requested_by="alice", + ) + for i in range(10) + ), + ) + assert len(tasks) == 10 + ids = {t.id for t in tasks} + assert len(ids) == 10 # all unique + + +# ── Queue backpressure ──────────────────────────────────────── + + +@pytest.mark.unit +class TestQueueFull: + """Tests for queue full backpressure.""" + + async def test_queue_full_raises( + self, + persistence: FakePersistence, + ) -> None: + tiny_config = TaskEngineConfig(max_queue_size=1) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=tiny_config, + ) + # Start the engine but pause the processing loop + eng._running = True + + # First submit fills the queue + mutation1 = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + eng._queue.put_nowait(_MutationEnvelope(mutation=mutation1)) + + # Second submit should fail because queue is full + mutation2 = CreateTaskMutation( + request_id="req-2", + requested_by="alice", + task_data=_make_create_data(), + ) + with pytest.raises(TaskEngineQueueFullError, match="queue is full"): + await eng.submit(mutation2) + + eng._running = False + + +# ── Version tracking ────────────────────────────────────────── + + +@pytest.mark.unit +class TestVersionTracking: + """Tests for the in-memory version counter.""" + + async def test_version_increments( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + r1 = await engine.submit(mutation) + assert r1.version == 1 + + update = UpdateTaskMutation( + request_id="req-2", + requested_by="alice", + task_id=r1.task.id, # type: ignore[union-attr] + updates={"title": "Updated"}, + ) + r2 = await engine.submit(update) + assert r2.version == 2 + + async def test_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # version is 1 after create; expected_version=99 should fail + update = UpdateTaskMutation( + request_id="req-2", + requested_by="alice", + task_id=task.id, + updates={"title": "X"}, + expected_version=99, + ) + result = await engine.submit(update) + assert result.success is False + assert "conflict" in (result.error or "").lower() + + async def test_version_reset_on_delete( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + delete = DeleteTaskMutation( + request_id="req-3", + requested_by="alice", + task_id=task.id, + ) + result = await engine.submit(delete) + assert result.version == 0 + + async def test_transition_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + expected_version=99, + ) + result = await engine.submit(mutation) + assert result.success is False + assert "conflict" in (result.error or "").lower() + + +# ── Drain timeout ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestDrainTimeout: + """Verify drain-timeout cleanup resolves outstanding futures.""" + + async def test_drain_timeout_resolves_pending_futures( + self, + persistence: FakePersistence, + ) -> None: + """Futures still in queue are failed when stop() times out.""" + # Block the processing loop with a slow save + block = asyncio.Event() + original_save = persistence.tasks.save + + async def slow_save(task: object) -> None: + await block.wait() + await original_save(task) # type: ignore[arg-type] + + persistence.tasks.save = slow_save # type: ignore[method-assign] + + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + + # Submit a task — it'll block in slow_save, holding the processing loop + blocked_task = asyncio.create_task( + eng.create_task(_make_create_data(), requested_by="alice") + ) + + # Queue a second task directly so it's definitely waiting + mutation2 = CreateTaskMutation( + request_id="req-queued", + requested_by="alice", + task_data=_make_create_data(), + ) + envelope = _MutationEnvelope(mutation=mutation2) + # Give the engine a tick to start processing the first task + await asyncio.sleep(0.05) + eng._queue.put_nowait(envelope) + + # Stop with a very short timeout — loop is blocked, so timeout fires + await eng.stop(timeout=0.05) + + # The queued envelope (not yet processed) must be failed + assert envelope.future.done() + result = envelope.future.result() + assert result.success is False + assert result.error_code == "internal" + + # Release the block so slow_save can finish, then cancel the blocked task + # (its future was never set because the processing loop was cancelled) + block.set() + blocked_task.cancel() + with contextlib.suppress(Exception, asyncio.CancelledError): + await blocked_task diff --git a/tests/unit/engine/test_task_engine_lifecycle.py b/tests/unit/engine/test_task_engine_lifecycle.py new file mode 100644 index 0000000000..818b2fa2d7 --- /dev/null +++ b/tests/unit/engine/test_task_engine_lifecycle.py @@ -0,0 +1,109 @@ +"""Lifecycle and config tests for TaskEngine.""" + +import pytest + +from ai_company.engine.errors import TaskEngineNotRunningError +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import CreateTaskMutation +from tests.unit.engine.task_engine_helpers import FakePersistence, _make_create_data + +# ── Lifecycle tests ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestTaskEngineLifecycle: + """Tests for start/stop lifecycle.""" + + async def test_start_sets_running( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + assert eng.is_running is False + eng.start() + assert eng.is_running is True + await eng.stop(timeout=2.0) # type: ignore[unreachable] + assert eng.is_running is False + + async def test_double_start_raises( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + with pytest.raises(RuntimeError, match="already running"): + eng.start() + await eng.stop(timeout=2.0) + + async def test_stop_idempotent( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + await eng.stop(timeout=2.0) + await eng.stop(timeout=2.0) # no error + + async def test_restart( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + await eng.stop(timeout=2.0) + eng.start() + assert eng.is_running is True + await eng.stop(timeout=2.0) + + +# ── Submit to stopped engine ────────────────────────────────── + + +@pytest.mark.unit +class TestSubmitToStoppedEngine: + """Submitting to a stopped engine raises TaskEngineNotRunningError.""" + + async def test_submit_raises( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + with pytest.raises(TaskEngineNotRunningError): + await eng.submit(mutation) + + +# ── TaskEngineConfig ────────────────────────────────────────── + + +@pytest.mark.unit +class TestTaskEngineConfig: + """Tests for TaskEngineConfig model.""" + + def test_defaults(self) -> None: + cfg = TaskEngineConfig() + assert cfg.max_queue_size == 1000 + assert cfg.drain_timeout_seconds == 10.0 + assert cfg.publish_snapshots is True + + def test_custom_values(self) -> None: + cfg = TaskEngineConfig( + max_queue_size=500, + drain_timeout_seconds=5.0, + publish_snapshots=False, + ) + assert cfg.max_queue_size == 500 + assert cfg.drain_timeout_seconds == 5.0 + assert cfg.publish_snapshots is False + + def test_frozen(self) -> None: + from pydantic import ValidationError + + cfg = TaskEngineConfig() + with pytest.raises(ValidationError): + cfg.max_queue_size = 999 # type: ignore[misc] diff --git a/tests/unit/engine/test_task_engine_mutations.py b/tests/unit/engine/test_task_engine_mutations.py new file mode 100644 index 0000000000..4cbb3474fe --- /dev/null +++ b/tests/unit/engine/test_task_engine_mutations.py @@ -0,0 +1,583 @@ +"""CRUD mutation, typed error, and consistency tests for TaskEngine.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.core.task import Task # noqa: TC001 +from ai_company.engine.errors import ( + TaskMutationError, + TaskNotFoundError, + TaskVersionConflictError, +) +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskMutation, + TransitionTaskMutation, + UpdateTaskMutation, +) +from tests.unit.engine.task_engine_helpers import ( + FakePersistence, + FakeTaskRepository, + _make_create_data, +) + +if TYPE_CHECKING: + from ai_company.engine.task_engine_config import TaskEngineConfig + +# ── Create mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestCreateTask: + """Tests for task creation via TaskEngine.""" + + async def test_create_task( + self, + engine: TaskEngine, + persistence: FakePersistence, + ) -> None: + task = await engine.create_task( + _make_create_data(title="My Task"), + requested_by="alice", + ) + assert task.title == "My Task" + assert task.id.startswith("task-") + assert task.status == TaskStatus.CREATED + + stored = await persistence.tasks.get(task.id) + assert stored is not None + assert stored.title == "My Task" + + async def test_create_returns_version_1( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.version == 1 + + async def test_create_with_assignee( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(assigned_to=None), + requested_by="alice", + ) + assert task.assigned_to is None + + +# ── Update mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestUpdateTask: + """Tests for task update via TaskEngine.""" + + async def test_update_fields( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(title="Original"), + requested_by="alice", + ) + updated = await engine.update_task( + task.id, + {"title": "Updated"}, + requested_by="alice", + ) + assert updated.title == "Updated" + assert updated.id == task.id + + async def test_update_empty_no_op( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + result = await engine.update_task( + task.id, + {}, + requested_by="alice", + ) + assert result.title == task.title + + async def test_update_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.update_task( + "task-nonexistent", + {"title": "X"}, + requested_by="alice", + ) + + +# ── Transition mutation ─────────────────────────────────────── + + +@pytest.mark.unit +class TestTransitionTask: + """Tests for task status transitions via TaskEngine.""" + + async def test_valid_transition( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + assigned, _ = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + assert assigned.status == TaskStatus.ASSIGNED + assert assigned.assigned_to == "bob" + + async def test_invalid_transition( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskMutationError): + await engine.transition_task( + task.id, + TaskStatus.COMPLETED, + requested_by="alice", + reason="Skip to done", + assigned_to="bob", + ) + + async def test_transition_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.transition_task( + "task-nonexistent", + TaskStatus.ASSIGNED, + requested_by="alice", + reason="test", + ) + + +# ── Delete mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestDeleteTask: + """Tests for task deletion via TaskEngine.""" + + async def test_delete_task( + self, + engine: TaskEngine, + persistence: FakePersistence, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + deleted = await engine.delete_task(task.id, requested_by="alice") + assert deleted is True + + stored = await persistence.tasks.get(task.id) + assert stored is None + + async def test_delete_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.delete_task( + "task-nonexistent", + requested_by="alice", + ) + + +# ── Cancel mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestCancelTask: + """Tests for task cancellation via TaskEngine.""" + + async def test_cancel_assigned_task( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + assigned, _ = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + cancelled = await engine.cancel_task( + assigned.id, + requested_by="alice", + reason="No longer needed", + ) + assert cancelled.status == TaskStatus.CANCELLED + + async def test_cancel_from_created_fails( + self, + engine: TaskEngine, + ) -> None: + """CREATED -> CANCELLED is not a valid transition.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskMutationError): + await engine.cancel_task( + task.id, + requested_by="alice", + reason="Oops", + ) + + +# ── Read-through ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestReadThrough: + """Tests for read-through methods that bypass the queue.""" + + async def test_get_task( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(title="Findme"), + requested_by="alice", + ) + found = await engine.get_task(task.id) + assert found is not None + assert found.title == "Findme" + + async def test_get_task_not_found( + self, + engine: TaskEngine, + ) -> None: + result = await engine.get_task("task-nonexistent") + assert result is None + + async def test_list_tasks( + self, + engine: TaskEngine, + ) -> None: + await engine.create_task( + _make_create_data(project="proj-a"), + requested_by="alice", + ) + await engine.create_task( + _make_create_data(project="proj-b"), + requested_by="alice", + ) + all_tasks = await engine.list_tasks() + assert len(all_tasks) == 2 + + filtered = await engine.list_tasks(project="proj-a") + assert len(filtered) == 1 + + async def test_list_tasks_by_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + + created = await engine.list_tasks(status=TaskStatus.CREATED) + assigned = await engine.list_tasks(status=TaskStatus.ASSIGNED) + assert len(created) == 0 + assert len(assigned) == 1 + + +# ── Cancel not found ───────────────────────────────────────── + + +@pytest.mark.unit +class TestCancelNotFound: + """Cancel mutation on a non-existent task.""" + + async def test_cancel_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError, match="not found"): + await engine.cancel_task( + "task-nonexistent", + requested_by="alice", + reason="test", + ) + + +# ── Previous status in results ──────────────────────────────── + + +@pytest.mark.unit +class TestPreviousStatus: + """Verify previous_status is populated in mutation results.""" + + async def test_create_has_no_previous_status( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status is None + + async def test_transition_has_previous_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status == TaskStatus.CREATED + + async def test_cancel_has_previous_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # First move to ASSIGNED so cancel is valid + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + mutation = CancelTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + reason="No longer needed", + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status == TaskStatus.ASSIGNED + + +# ── Immutable field rejection ───────────────────────────────── + + +@pytest.mark.unit +class TestImmutableFieldRejection: + """UpdateTaskMutation and TransitionTaskMutation reject immutable fields.""" + + def test_update_rejects_status(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"status": "completed"}, + ) + + def test_update_rejects_id(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"id": "new-id"}, + ) + + def test_transition_rejects_id_override(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + target_status=TaskStatus.ASSIGNED, + reason="test", + overrides={"id": "new-id"}, + ) + + +# ── Typed error propagation ────────────────────────────────── + + +@pytest.mark.unit +class TestTypedErrors: + """Convenience methods raise typed errors.""" + + async def test_update_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.update_task( + "task-nonexistent", + {"title": "X"}, + requested_by="alice", + ) + + async def test_delete_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.delete_task( + "task-nonexistent", + requested_by="alice", + ) + + async def test_transition_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.transition_task( + "task-nonexistent", + TaskStatus.ASSIGNED, + requested_by="alice", + reason="test", + ) + + async def test_update_version_conflict_raises_typed( + self, + engine: TaskEngine, + ) -> None: + """Version conflict via convenience method raises TaskVersionConflictError.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskVersionConflictError, match="conflict"): + await engine.update_task( + task.id, + {"title": "changed"}, + requested_by="alice", + expected_version=99, + ) + + +# ── Error propagation ──────────────────────────────────────── + + +@pytest.mark.unit +class TestErrorPropagation: + """Tests for error propagation via futures.""" + + async def test_persistence_error_returns_failure( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + """Persistence errors during mutation are captured in the result.""" + + class FailingSaveRepo(FakeTaskRepository): + async def save(self, task: Task) -> None: + msg = "Disk full" + raise OSError(msg) + + persistence._tasks = FailingSaveRepo() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await eng.submit(mutation) + assert result.success is False + assert result.error == "Internal error processing mutation" + finally: + await eng.stop(timeout=2.0) + + +# ── TaskMutationResult consistency ──────────────────────────── + + +@pytest.mark.unit +class TestMutationResultConsistency: + """Verify _check_consistency validator on TaskMutationResult.""" + + def test_success_with_error_rejected(self) -> None: + """Successful result must not carry an error.""" + from pydantic import ValidationError + + from ai_company.engine.task_engine_models import TaskMutationResult + + with pytest.raises(ValidationError, match="error"): + TaskMutationResult( + request_id="r", + success=True, + error="oops", + ) + + def test_failure_without_error_rejected(self) -> None: + """Failed result must carry an error description.""" + from pydantic import ValidationError + + from ai_company.engine.task_engine_models import TaskMutationResult + + with pytest.raises(ValidationError, match="error"): + TaskMutationResult( + request_id="r", + success=False, + ) From 3254142a793907e50765ecbb3b998d6aa702cc12 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:39:16 +0100 Subject: [PATCH 05/14] fix: remove trailing blank lines in task_engine_helpers (ruff format) --- tests/unit/engine/task_engine_helpers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/engine/task_engine_helpers.py b/tests/unit/engine/task_engine_helpers.py index 97884af09a..6837d0be44 100644 --- a/tests/unit/engine/task_engine_helpers.py +++ b/tests/unit/engine/task_engine_helpers.py @@ -100,5 +100,3 @@ def _make_create_data(**overrides: object) -> CreateTaskData: } defaults.update(overrides) return CreateTaskData(**defaults) # type: ignore[arg-type] - - From 38006d5e20848dd5a2e3efce743543765ef777a1 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 16:48:51 +0100 Subject: [PATCH 06/14] fix: pin escalation_paths to empty tuple in RootConfigFactory to prevent flaky test --- tests/unit/config/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/config/conftest.py b/tests/unit/config/conftest.py index 7300e5af2d..c84857ab56 100644 --- a/tests/unit/config/conftest.py +++ b/tests/unit/config/conftest.py @@ -77,6 +77,7 @@ class RootConfigFactory(ModelFactory[RootConfig]): departments = () agents = () custom_roles = () + escalation_paths = () providers: dict[str, ProviderConfig] = {} # noqa: RUF012 config = CompanyConfig() api = ApiConfig() From e894134c99869628b6af518f878a37a93f815735 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 18:41:14 +0100 Subject: [PATCH 07/14] =?UTF-8?q?fix:=20address=20PR=20#325=20review=20fee?= =?UTF-8?q?dback=20=E2=80=94=20extract=20modules,=20harden=20error=20handl?= =?UTF-8?q?ing,=20add=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract VersionTracker into task_engine_version.py (optimistic concurrency) - Extract mutation apply logic into task_engine_apply.py (dispatch + apply_*) - Reduce task_engine.py from 976 to ~620 lines (well under 800 limit) - Make TaskInternalError a sibling of TaskMutationError, not a subtype - Add MemoryError/RecursionError re-raise guards in all except-Exception blocks - Add _try_stop helper in app.py to reduce C901 complexity - Add _extract_requester and _map_task_engine_errors helpers in tasks controller - Add deep-copy at system boundaries for mutable dict fields in frozen models - Add unknown-key validation in UpdateTaskMutation and TransitionTaskMutation - Track in-flight envelope for drain-timeout resolution - Fix _safe_shutdown argument ordering (task_engine first, mirrors cleanup) - Fix existing test bug: mock_te passed as bridge instead of task_engine - Update docs/design/engine.md for immutable-style updates and version tracking - Add 62 new tests: VersionTracker, apply functions, coverage edge cases, controller helpers, _try_stop, in-flight resolution, processing loop resilience --- docs/design/engine.md | 18 +- src/ai_company/api/app.py | 139 ++--- src/ai_company/api/controllers/tasks.py | 131 +++-- src/ai_company/engine/errors.py | 6 +- src/ai_company/engine/task_engine.py | 365 ++---------- src/ai_company/engine/task_engine_apply.py | 352 ++++++++++++ src/ai_company/engine/task_engine_models.py | 24 +- src/ai_company/engine/task_engine_version.py | 76 +++ .../unit/api/controllers/test_task_helpers.py | 93 ++++ tests/unit/api/test_app.py | 51 +- tests/unit/engine/test_agent_engine.py | 4 + tests/unit/engine/test_task_engine_apply.py | 527 ++++++++++++++++++ .../unit/engine/test_task_engine_coverage.py | 235 ++++++++ .../engine/test_task_engine_integration.py | 5 +- tests/unit/engine/test_task_engine_version.py | 94 ++++ 15 files changed, 1663 insertions(+), 457 deletions(-) create mode 100644 src/ai_company/engine/task_engine_apply.py create mode 100644 src/ai_company/engine/task_engine_version.py create mode 100644 tests/unit/api/controllers/test_task_helpers.py create mode 100644 tests/unit/engine/test_task_engine_apply.py create mode 100644 tests/unit/engine/test_task_engine_coverage.py create mode 100644 tests/unit/engine/test_task_engine_version.py diff --git a/docs/design/engine.md b/docs/design/engine.md index 6e40d84d5d..91774e4c79 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -183,13 +183,17 @@ Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop - **Single writer**: A background `asyncio.Task` consumes `TaskMutation` requests sequentially from an `asyncio.Queue`. -- **Immutable updates**: Each mutation calls `model_copy(update=...)` on - frozen `Task` models — the original is never mutated. -- **Optimistic concurrency**: In-memory version counters per task. - Callers can pass `expected_version` to detect stale writes; on mismatch - the engine returns a failed `TaskMutationResult` with - `error_code="version_conflict"`. Convenience methods raise - `TaskVersionConflictError`. +- **Immutable-style updates**: Each mutation constructs a new `Task` instance + from the previous one (for example via + `Task.model_validate({**task.model_dump(), **updates})` or + `Task.with_transition(...)`); the existing instance is never mutated. +- **Optimistic concurrency**: Per-task version counters. The persisted + task version is the source of truth; any in-memory cache is an + optimization that is seeded from persistence on task load and may be + invalid after a restart. Callers can pass `expected_version` to detect + stale writes; on mismatch the engine returns a failed + `TaskMutationResult` with `error_code="version_conflict"`. Convenience + methods raise `TaskVersionConflictError`. - **Read-through**: `get_task()` and `list_tasks()` bypass the queue and read directly from persistence — safe because TaskEngine is the sole writer. - **Snapshot publishing**: On success, a `TaskStateChanged` event is published diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 13428569f8..588a51c507 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -129,11 +129,30 @@ async def on_startup() -> None: async def on_shutdown() -> None: logger.info(API_APP_SHUTDOWN, version=__version__) - await _safe_shutdown(bridge, task_engine, message_bus, persistence) + await _safe_shutdown(task_engine, bridge, message_bus, persistence) return [on_startup], [on_shutdown] +async def _try_stop( + coro: object, + event: str, + error_msg: str, +) -> None: + """Await *coro* inside a safe try/except, logging failures. + + ``MemoryError`` and ``RecursionError`` are re-raised immediately; + all other exceptions are logged and swallowed so that sibling + shutdown steps can still run. + """ + try: + await coro # type: ignore[misc] + except MemoryError, RecursionError: + raise + except Exception: + logger.exception(event, error=error_msg) + + async def _cleanup_on_failure( # noqa: PLR0913 *, persistence: PersistenceBackend | None, @@ -147,37 +166,29 @@ async def _cleanup_on_failure( # noqa: PLR0913 ) -> None: """Reverse cleanup on startup failure (task engine, bridge, bus, persistence).""" if started_task_engine and task_engine is not None: - try: - await task_engine.stop() - except Exception: - logger.exception( - API_APP_STARTUP, - error="Cleanup: failed to stop task engine", - ) + await _try_stop( + task_engine.stop(), + API_APP_STARTUP, + "Cleanup: failed to stop task engine", + ) if started_bridge and bridge is not None: - try: - await bridge.stop() - except Exception: - logger.exception( - API_APP_STARTUP, - error="Cleanup: failed to stop message bus bridge", - ) + await _try_stop( + bridge.stop(), + API_APP_STARTUP, + "Cleanup: failed to stop message bus bridge", + ) if started_bus and message_bus is not None: - try: - await message_bus.stop() - except Exception: - logger.exception( - API_APP_STARTUP, - error="Cleanup: failed to stop message bus", - ) + await _try_stop( + message_bus.stop(), + API_APP_STARTUP, + "Cleanup: failed to stop message bus", + ) if started_persistence and persistence is not None: - try: - await persistence.disconnect() - except Exception: - logger.exception( - API_APP_STARTUP, - error="Cleanup: failed to disconnect persistence", - ) + await _try_stop( + persistence.disconnect(), + API_APP_STARTUP, + "Cleanup: failed to disconnect persistence", + ) async def _init_persistence( @@ -298,44 +309,41 @@ async def _safe_startup( async def _safe_shutdown( - bridge: MessageBusBridge | None, task_engine: TaskEngine | None, + bridge: MessageBusBridge | None, message_bus: MessageBus | None, persistence: PersistenceBackend | None, ) -> None: - """Stop bridge, task engine, message bus and disconnect persistence.""" - if bridge is not None: - try: - await bridge.stop() - except Exception: - logger.exception( - API_APP_SHUTDOWN, - error="Failed to stop message bus bridge", - ) + """Stop task engine, bridge, message bus and disconnect persistence. + + Mirrors ``_cleanup_on_failure`` reverse order: task engine first so it + can drain queued mutations and publish final snapshots through the + still-running bridge. + """ if task_engine is not None: - try: - await task_engine.stop() - except Exception: - logger.exception( - API_APP_SHUTDOWN, - error="Failed to stop task engine", - ) + await _try_stop( + task_engine.stop(), + API_APP_SHUTDOWN, + "Failed to stop task engine", + ) + if bridge is not None: + await _try_stop( + bridge.stop(), + API_APP_SHUTDOWN, + "Failed to stop message bus bridge", + ) if message_bus is not None: - try: - await message_bus.stop() - except Exception: - logger.exception( - API_APP_SHUTDOWN, - error="Failed to stop message bus", - ) + await _try_stop( + message_bus.stop(), + API_APP_SHUTDOWN, + "Failed to stop message bus", + ) if persistence is not None: - try: - await persistence.disconnect() - except Exception: - logger.exception( - API_APP_SHUTDOWN, - error="Failed to disconnect persistence", - ) + await _try_stop( + persistence.disconnect(), + API_APP_SHUTDOWN, + "Failed to disconnect persistence", + ) def create_app( # noqa: PLR0913 @@ -368,11 +376,16 @@ def create_app( # noqa: PLR0913 effective_config = config or RootConfig(company_name="default") api_config = effective_config.api - if persistence is None or message_bus is None or cost_tracker is None: + if ( + persistence is None + or message_bus is None + or cost_tracker is None + or task_engine is None + ): msg = ( "create_app called without persistence, message_bus, " - "and/or cost_tracker — controllers accessing missing " - "services will return 503. Use test fakes for testing." + "cost_tracker, and/or task_engine — controllers accessing " + "missing services will return 503. Use test fakes for testing." ) logger.warning(API_APP_STARTUP, note=msg) diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 88cc018876..1a958d68e0 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -43,6 +43,44 @@ logger = get_logger(__name__) +def _extract_requester(state: State) -> str: + """Extract requester identity from the authenticated user. + + Falls back to ``"api"`` when the connection carries no user + (e.g. in tests without auth middleware). + """ + user = getattr(state, "_connection_user", None) + if user is not None and hasattr(user, "user_id"): + return str(user.user_id) + return "api" + + +def _map_task_engine_errors( + exc: Exception, + *, + task_id: str | None = None, +) -> Exception: + """Map a task-engine exception to the appropriate API error. + + Returns the API exception to raise (caller must ``raise`` it). + """ + if isinstance(exc, TaskNotFoundError): + if task_id is not None: + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + return NotFoundError(str(exc)) + if isinstance(exc, TaskEngineNotRunningError | TaskEngineQueueFullError): + return ServiceUnavailableError(str(exc)) + if isinstance(exc, TaskInternalError): + return ServiceUnavailableError(str(exc)) + if isinstance(exc, TaskMutationError): + return ApiValidationError(str(exc)) + return exc + + class TaskController(Controller): """Full CRUD for tasks via ``TaskEngine``.""" @@ -140,14 +178,13 @@ async def create_task( task_data, requested_by=data.created_by, ) - except TaskEngineNotRunningError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskEngineQueueFullError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskInternalError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskMutationError as exc: - raise ApiValidationError(str(exc)) from exc + except ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskInternalError, + TaskMutationError, + ) as exc: + raise _map_task_engine_errors(exc) from exc logger.info( TASK_CREATED, task_id=task.id, @@ -181,23 +218,16 @@ async def update_task( task = await app_state.task_engine.update_task( task_id, updates, - requested_by="api", + requested_by=_extract_requester(state), ) - except TaskEngineNotRunningError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskEngineQueueFullError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskNotFoundError as exc: - logger.warning( - API_RESOURCE_NOT_FOUND, - resource="task", - id=task_id, - ) - raise NotFoundError(str(exc)) from exc - except TaskInternalError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskMutationError as exc: - raise ApiValidationError(str(exc)) from exc + except ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskNotFoundError, + TaskInternalError, + TaskMutationError, + ) as exc: + raise _map_task_engine_errors(exc, task_id=task_id) from exc logger.info(API_TASK_UPDATED, task_id=task_id, fields=list(updates)) return ApiResponse(data=task) @@ -225,8 +255,9 @@ async def transition_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state + requester = _extract_requester(state) transition_kwargs: dict[str, object] = { - "requested_by": "api", + "requested_by": requester, "reason": f"API transition to {data.target_status.value}", } if data.assigned_to is not None: @@ -237,27 +268,20 @@ async def transition_task( data.target_status, **transition_kwargs, # type: ignore[arg-type] ) - except TaskEngineNotRunningError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskEngineQueueFullError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskNotFoundError as exc: - logger.warning( - API_RESOURCE_NOT_FOUND, - resource="task", - id=task_id, - ) - raise NotFoundError(str(exc)) from exc - except TaskInternalError as exc: - raise ServiceUnavailableError(str(exc)) from exc + except ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskNotFoundError, + TaskInternalError, + ) as exc: + raise _map_task_engine_errors(exc, task_id=task_id) from exc except TaskMutationError as exc: - error_str = str(exc) logger.warning( API_TASK_TRANSITION_FAILED, task_id=task_id, - error=error_str, + error=str(exc), ) - raise ApiValidationError(error_str) from exc + raise _map_task_engine_errors(exc, task_id=task_id) from exc logger.info( TASK_STATUS_CHANGED, task_id=task_id, @@ -288,22 +312,15 @@ async def delete_task( try: await app_state.task_engine.delete_task( task_id, - requested_by="api", + requested_by=_extract_requester(state), ) - except TaskEngineNotRunningError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskEngineQueueFullError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskNotFoundError as exc: - logger.warning( - API_RESOURCE_NOT_FOUND, - resource="task", - id=task_id, - ) - raise NotFoundError(str(exc)) from exc - except TaskInternalError as exc: - raise ServiceUnavailableError(str(exc)) from exc - except TaskMutationError as exc: - raise ApiValidationError(str(exc)) from exc + except ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskNotFoundError, + TaskInternalError, + TaskMutationError, + ) as exc: + raise _map_task_engine_errors(exc, task_id=task_id) from exc logger.info(API_TASK_DELETED, task_id=task_id) return ApiResponse(data=None) diff --git a/src/ai_company/engine/errors.py b/src/ai_company/engine/errors.py index a597b971a1..b804c98d38 100644 --- a/src/ai_company/engine/errors.py +++ b/src/ai_company/engine/errors.py @@ -106,11 +106,15 @@ class TaskVersionConflictError(TaskMutationError): """Raised when optimistic concurrency version does not match.""" -class TaskInternalError(TaskMutationError): +class TaskInternalError(TaskEngineError): """Raised when a task mutation fails due to an internal engine error. Unlike :class:`TaskMutationError` (which covers business-rule failures such as validation or not-found), this signals an unexpected engine fault that the caller cannot fix by changing the request. Maps to 5xx at the API layer. + + This is deliberately a sibling of ``TaskMutationError``, not a subtype, + so that broad ``except TaskMutationError`` handlers do not accidentally + catch internal engine faults. """ diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index eea3bdeb46..b774dc0060 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -1,9 +1,10 @@ """Centralized single-writer task engine. Owns all task state mutations via an ``asyncio.Queue``. A single -background task consumes mutation requests sequentially, applies -``model_copy(update=...)`` on frozen ``Task`` models, persists the -result, and publishes snapshots to the message bus. +background task consumes mutation requests sequentially, derives a new +``Task`` instance from the current state and the mutation (e.g. via +``Task.model_validate`` / ``Task.with_transition``), persists the result, +and publishes snapshots to the message bus. Reads bypass the queue and go directly to persistence -- this is safe because the TaskEngine is the only writer. @@ -16,8 +17,6 @@ from typing import TYPE_CHECKING, Never from uuid import uuid4 -from ai_company.core.enums import TaskStatus -from ai_company.core.task import Task from ai_company.engine.errors import ( TaskEngineNotRunningError, TaskEngineQueueFullError, @@ -26,6 +25,7 @@ TaskNotFoundError, TaskVersionConflictError, ) +from ai_company.engine.task_engine_apply import dispatch as _dispatch_mutation from ai_company.engine.task_engine_config import TaskEngineConfig from ai_company.engine.task_engine_models import ( CancelTaskMutation, @@ -38,6 +38,7 @@ TransitionTaskMutation, UpdateTaskMutation, ) +from ai_company.engine.task_engine_version import VersionTracker from ai_company.observability import get_logger from ai_company.observability.events.task_engine import ( TASK_ENGINE_CREATED, @@ -45,7 +46,6 @@ TASK_ENGINE_DRAIN_START, TASK_ENGINE_DRAIN_TIMEOUT, TASK_ENGINE_LOOP_ERROR, - TASK_ENGINE_MUTATION_APPLIED, TASK_ENGINE_MUTATION_FAILED, TASK_ENGINE_MUTATION_RECEIVED, TASK_ENGINE_NOT_RUNNING, @@ -54,11 +54,12 @@ TASK_ENGINE_SNAPSHOT_PUBLISHED, TASK_ENGINE_STARTED, TASK_ENGINE_STOPPED, - TASK_ENGINE_VERSION_CONFLICT, ) if TYPE_CHECKING: from ai_company.communication.bus_protocol import MessageBus + from ai_company.core.enums import TaskStatus + from ai_company.core.task import Task from ai_company.persistence.protocol import PersistenceBackend logger = get_logger(__name__) @@ -105,8 +106,9 @@ def __init__( self._queue: asyncio.Queue[_MutationEnvelope] = asyncio.Queue( maxsize=self._config.max_queue_size, ) - self._versions: dict[str, int] = {} + self._versions = VersionTracker() self._processing_task: asyncio.Task[None] | None = None + self._in_flight: _MutationEnvelope | None = None self._running = False logger.debug( TASK_ENGINE_CREATED, @@ -176,19 +178,27 @@ async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 logger.info(TASK_ENGINE_STOPPED) def _fail_remaining_futures(self) -> None: - """Fail all remaining enqueued futures after drain timeout.""" + """Fail in-flight and remaining enqueued futures after drain timeout.""" + shutdown_result_for = self._shutdown_result + in_flight = self._in_flight + if in_flight is not None and not in_flight.future.done(): + in_flight.future.set_result(shutdown_result_for(in_flight)) + self._in_flight = None while not self._queue.empty(): with contextlib.suppress(asyncio.QueueEmpty): envelope = self._queue.get_nowait() if not envelope.future.done(): - envelope.future.set_result( - TaskMutationResult( - request_id=envelope.mutation.request_id, - success=False, - error="TaskEngine shut down before processing", - error_code="internal", - ), - ) + envelope.future.set_result(shutdown_result_for(envelope)) + + @staticmethod + def _shutdown_result(envelope: _MutationEnvelope) -> TaskMutationResult: + """Build an internal-failure result for a shutdown-aborted envelope.""" + return TaskMutationResult( + request_id=envelope.mutation.request_id, + success=False, + error="TaskEngine shut down before processing", + error_code="internal", + ) @property def is_running(self) -> bool: @@ -489,6 +499,8 @@ async def _processing_loop(self) -> None: continue try: await self._process_one(envelope) + except MemoryError, RecursionError: + raise except Exception: logger.exception( TASK_ENGINE_LOOP_ERROR, @@ -507,17 +519,24 @@ async def _processing_loop(self) -> None: async def _process_one(self, envelope: _MutationEnvelope) -> None: """Process a single mutation envelope.""" mutation = envelope.mutation + self._in_flight = envelope logger.debug( TASK_ENGINE_MUTATION_RECEIVED, mutation_type=mutation.mutation_type, request_id=mutation.request_id, ) try: - result = await self._apply_mutation(mutation) + result = await _dispatch_mutation( + mutation, + self._persistence, + self._versions, + ) if not envelope.future.done(): envelope.future.set_result(result) if result.success and self._config.publish_snapshots: await self._publish_snapshot(mutation, result) + except MemoryError, RecursionError: + raise except Exception as exc: internal_msg = f"{type(exc).__name__}: {exc}" logger.exception( @@ -535,280 +554,8 @@ async def _process_one(self, envelope: _MutationEnvelope) -> None: error_code="internal", ), ) - - async def _apply_mutation(self, mutation: TaskMutation) -> TaskMutationResult: - """Dispatch and apply a mutation by type. - - Raises: - TypeError: If the mutation type is unrecognised. - """ - match mutation: - case CreateTaskMutation(): - return await self._apply_create(mutation) - case UpdateTaskMutation(): - return await self._apply_update(mutation) - case TransitionTaskMutation(): - return await self._apply_transition(mutation) - case DeleteTaskMutation(): - return await self._apply_delete(mutation) - case CancelTaskMutation(): - return await self._apply_cancel(mutation) - case _: - msg = f"Unknown mutation type: {type(mutation).__name__}" # type: ignore[unreachable] - raise TypeError(msg) - - def _not_found_result( - self, - mutation_type: str, - request_id: str, - task_id: str, - ) -> TaskMutationResult: - """Build a failure result for a missing task and log it.""" - error = f"Task {task_id!r} not found" - logger.warning( - TASK_ENGINE_MUTATION_FAILED, - mutation_type=mutation_type, - request_id=request_id, - task_id=task_id, - error=error, - ) - return TaskMutationResult( - request_id=request_id, - success=False, - error=error, - error_code="not_found", - ) - - async def _apply_create( - self, - mutation: CreateTaskMutation, - ) -> TaskMutationResult: - """Create a new task.""" - data = mutation.task_data - task_id = f"task-{uuid4().hex}" - - task = Task( - id=task_id, - title=data.title, - description=data.description, - type=data.type, - priority=data.priority, - project=data.project, - created_by=data.created_by, - assigned_to=data.assigned_to, - estimated_complexity=data.estimated_complexity, - budget_limit=data.budget_limit, - ) - await self._persistence.tasks.save(task) - self._versions[task_id] = 1 - - logger.info( - TASK_ENGINE_MUTATION_APPLIED, - mutation_type="create", - request_id=mutation.request_id, - task_id=task_id, - ) - return TaskMutationResult( - request_id=mutation.request_id, - success=True, - task=task, - version=1, - ) - - async def _apply_update( - self, - mutation: UpdateTaskMutation, - ) -> TaskMutationResult: - """Update task fields.""" - task = await self._persistence.tasks.get(mutation.task_id) - if task is None: - return self._not_found_result( - "update", - mutation.request_id, - mutation.task_id, - ) - - try: - self._check_version(mutation.task_id, mutation.expected_version) - except TaskVersionConflictError as exc: - return TaskMutationResult( - request_id=mutation.request_id, - success=False, - error=str(exc), - error_code="version_conflict", - ) - - if not mutation.updates: - version = self._versions.get(mutation.task_id, 0) - return TaskMutationResult( - request_id=mutation.request_id, - success=True, - task=task, - version=version, - previous_status=task.status, - ) - - merged = task.model_dump() - merged.update(mutation.updates) - updated = Task.model_validate(merged) - await self._persistence.tasks.save(updated) - version = self._bump_version(mutation.task_id) - - logger.info( - TASK_ENGINE_MUTATION_APPLIED, - mutation_type="update", - request_id=mutation.request_id, - task_id=mutation.task_id, - fields=list(mutation.updates), - ) - return TaskMutationResult( - request_id=mutation.request_id, - success=True, - task=updated, - version=version, - previous_status=task.status, - ) - - async def _apply_transition( - self, - mutation: TransitionTaskMutation, - ) -> TaskMutationResult: - """Perform a task status transition.""" - task = await self._persistence.tasks.get(mutation.task_id) - if task is None: - return self._not_found_result( - "transition", - mutation.request_id, - mutation.task_id, - ) - - try: - self._check_version(mutation.task_id, mutation.expected_version) - except TaskVersionConflictError as exc: - return TaskMutationResult( - request_id=mutation.request_id, - success=False, - error=str(exc), - error_code="version_conflict", - ) - - previous_status = task.status - - try: - updated = task.with_transition( - mutation.target_status, - **mutation.overrides, - ) - except ValueError as exc: - logger.warning( - TASK_ENGINE_MUTATION_FAILED, - mutation_type="transition", - request_id=mutation.request_id, - task_id=mutation.task_id, - error=str(exc), - ) - return TaskMutationResult( - request_id=mutation.request_id, - success=False, - error=str(exc), - error_code="validation", - ) - - await self._persistence.tasks.save(updated) - version = self._bump_version(mutation.task_id) - - logger.info( - TASK_ENGINE_MUTATION_APPLIED, - mutation_type="transition", - request_id=mutation.request_id, - task_id=mutation.task_id, - from_status=previous_status.value, - to_status=mutation.target_status.value, - ) - return TaskMutationResult( - request_id=mutation.request_id, - success=True, - task=updated, - version=version, - previous_status=previous_status, - ) - - async def _apply_delete( - self, - mutation: DeleteTaskMutation, - ) -> TaskMutationResult: - """Delete a task.""" - deleted = await self._persistence.tasks.delete(mutation.task_id) - if not deleted: - return self._not_found_result( - "delete", - mutation.request_id, - mutation.task_id, - ) - - self._versions.pop(mutation.task_id, None) - - logger.info( - TASK_ENGINE_MUTATION_APPLIED, - mutation_type="delete", - request_id=mutation.request_id, - task_id=mutation.task_id, - ) - return TaskMutationResult( - request_id=mutation.request_id, - success=True, - version=0, - ) - - async def _apply_cancel( - self, - mutation: CancelTaskMutation, - ) -> TaskMutationResult: - """Cancel a task (shortcut for transition to CANCELLED).""" - task = await self._persistence.tasks.get(mutation.task_id) - if task is None: - return self._not_found_result( - "cancel", - mutation.request_id, - mutation.task_id, - ) - - previous_status = task.status - try: - updated = task.with_transition(TaskStatus.CANCELLED) - except ValueError as exc: - logger.warning( - TASK_ENGINE_MUTATION_FAILED, - mutation_type="cancel", - request_id=mutation.request_id, - task_id=mutation.task_id, - error=str(exc), - ) - return TaskMutationResult( - request_id=mutation.request_id, - success=False, - error=str(exc), - error_code="validation", - ) - - await self._persistence.tasks.save(updated) - version = self._bump_version(mutation.task_id) - - logger.info( - TASK_ENGINE_MUTATION_APPLIED, - mutation_type="cancel", - request_id=mutation.request_id, - task_id=mutation.task_id, - from_status=previous_status.value, - to_status=TaskStatus.CANCELLED.value, - ) - return TaskMutationResult( - request_id=mutation.request_id, - success=True, - task=updated, - version=version, - previous_status=previous_status, - ) + finally: + self._in_flight = None # -- Snapshot publishing ----------------------------------------------- @@ -871,37 +618,3 @@ async def _publish_snapshot( request_id=mutation.request_id, exc_info=True, ) - - # -- Version tracking -------------------------------------------------- - - def _bump_version(self, task_id: str) -> int: - """Increment and return the version counter for a task.""" - version = self._versions.get(task_id, 0) + 1 - self._versions[task_id] = version - return version - - def _check_version( - self, - task_id: str, - expected_version: int | None, - ) -> None: - """Check optimistic concurrency version if provided. - - Raises: - TaskVersionConflictError: If versions don't match. - """ - if expected_version is None: - return - current = self._versions.get(task_id, 0) - if current != expected_version: - msg = ( - f"Version conflict for task {task_id!r}: " - f"expected {expected_version}, current {current}" - ) - logger.warning( - TASK_ENGINE_VERSION_CONFLICT, - task_id=task_id, - expected_version=expected_version, - current_version=current, - ) - raise TaskVersionConflictError(msg) diff --git a/src/ai_company/engine/task_engine_apply.py b/src/ai_company/engine/task_engine_apply.py new file mode 100644 index 0000000000..2fbf2d517c --- /dev/null +++ b/src/ai_company/engine/task_engine_apply.py @@ -0,0 +1,352 @@ +"""Mutation application logic for TaskEngine. + +Each ``apply_*`` function takes the mutation, a persistence backend, +and a :class:`VersionTracker`, returning a :class:`TaskMutationResult`. +Extracted from ``task_engine.py`` to keep the main module focused on +lifecycle, queue management, and the public API. +""" + +import copy +from typing import TYPE_CHECKING +from uuid import uuid4 + +from pydantic import ValidationError as PydanticValidationError + +from ai_company.core.enums import TaskStatus +from ai_company.core.task import Task +from ai_company.engine.errors import TaskVersionConflictError +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskMutation, + DeleteTaskMutation, + TaskMutation, + TaskMutationResult, + TransitionTaskMutation, + UpdateTaskMutation, +) +from ai_company.observability import get_logger +from ai_company.observability.events.task_engine import ( + TASK_ENGINE_MUTATION_APPLIED, + TASK_ENGINE_MUTATION_FAILED, +) + +if TYPE_CHECKING: + from ai_company.engine.task_engine_version import VersionTracker + from ai_company.persistence.protocol import PersistenceBackend + +logger = get_logger(__name__) + + +# ── Helpers ────────────────────────────────────────────────────── + + +def not_found_result( + mutation_type: str, + request_id: str, + task_id: str, +) -> TaskMutationResult: + """Build a failure result for a missing task and log it.""" + error = f"Task {task_id!r} not found" + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type=mutation_type, + request_id=request_id, + task_id=task_id, + error=error, + ) + return TaskMutationResult( + request_id=request_id, + success=False, + error=error, + error_code="not_found", + ) + + +# ── Dispatch ───────────────────────────────────────────────────── + + +async def dispatch( + mutation: TaskMutation, + persistence: PersistenceBackend, + versions: VersionTracker, +) -> TaskMutationResult: + """Dispatch and apply a mutation by type. + + Raises: + TypeError: If the mutation type is unrecognised. + """ + match mutation: + case CreateTaskMutation(): + return await apply_create(mutation, persistence, versions) + case UpdateTaskMutation(): + return await apply_update(mutation, persistence, versions) + case TransitionTaskMutation(): + return await apply_transition(mutation, persistence, versions) + case DeleteTaskMutation(): + return await apply_delete(mutation, persistence, versions) + case CancelTaskMutation(): + return await apply_cancel(mutation, persistence, versions) + case _: + msg = f"Unknown mutation type: {type(mutation).__name__}" # type: ignore[unreachable] + raise TypeError(msg) + + +# ── Apply methods ──────────────────────────────────────────────── + + +async def apply_create( + mutation: CreateTaskMutation, + persistence: PersistenceBackend, + versions: VersionTracker, +) -> TaskMutationResult: + """Create a new task.""" + data = mutation.task_data + task_id = f"task-{uuid4().hex}" + + try: + task = Task( + id=task_id, + title=data.title, + description=data.description, + type=data.type, + priority=data.priority, + project=data.project, + created_by=data.created_by, + assigned_to=data.assigned_to, + estimated_complexity=data.estimated_complexity, + budget_limit=data.budget_limit, + ) + except PydanticValidationError as exc: + error_msg = f"Invalid task data: {exc}" + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type="create", + request_id=mutation.request_id, + error=error_msg, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=error_msg, + error_code="validation", + ) + await persistence.tasks.save(task) + versions.set_initial(task_id, 1) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="create", + request_id=mutation.request_id, + task_id=task_id, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=task, + version=1, + ) + + +async def apply_update( + mutation: UpdateTaskMutation, + persistence: PersistenceBackend, + versions: VersionTracker, +) -> TaskMutationResult: + """Update task fields.""" + task = await persistence.tasks.get(mutation.task_id) + if task is None: + return not_found_result("update", mutation.request_id, mutation.task_id) + + try: + versions.check(mutation.task_id, mutation.expected_version) + except TaskVersionConflictError as exc: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + error_code="version_conflict", + ) + + if not mutation.updates: + version = versions.get(mutation.task_id) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=task, + version=version, + previous_status=task.status, + ) + + merged = task.model_dump() + merged.update(copy.deepcopy(mutation.updates)) + try: + updated = Task.model_validate(merged) + except PydanticValidationError as exc: + error_msg = f"Invalid update data: {exc}" + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type="update", + request_id=mutation.request_id, + task_id=mutation.task_id, + error=error_msg, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=error_msg, + error_code="validation", + ) + await persistence.tasks.save(updated) + version = versions.bump(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="update", + request_id=mutation.request_id, + task_id=mutation.task_id, + fields=list(mutation.updates), + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + previous_status=task.status, + ) + + +async def apply_transition( + mutation: TransitionTaskMutation, + persistence: PersistenceBackend, + versions: VersionTracker, +) -> TaskMutationResult: + """Perform a task status transition.""" + task = await persistence.tasks.get(mutation.task_id) + if task is None: + return not_found_result("transition", mutation.request_id, mutation.task_id) + + try: + versions.check(mutation.task_id, mutation.expected_version) + except TaskVersionConflictError as exc: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + error_code="version_conflict", + ) + + previous_status = task.status + + try: + updated = task.with_transition( + mutation.target_status, + **mutation.overrides, + ) + except ValueError as exc: + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type="transition", + request_id=mutation.request_id, + task_id=mutation.task_id, + error=str(exc), + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + error_code="validation", + ) + + await persistence.tasks.save(updated) + version = versions.bump(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="transition", + request_id=mutation.request_id, + task_id=mutation.task_id, + from_status=previous_status.value, + to_status=mutation.target_status.value, + reason=mutation.reason, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + previous_status=previous_status, + ) + + +async def apply_delete( + mutation: DeleteTaskMutation, + persistence: PersistenceBackend, + versions: VersionTracker, +) -> TaskMutationResult: + """Delete a task.""" + deleted = await persistence.tasks.delete(mutation.task_id) + if not deleted: + return not_found_result("delete", mutation.request_id, mutation.task_id) + + versions.remove(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="delete", + request_id=mutation.request_id, + task_id=mutation.task_id, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + version=0, + ) + + +async def apply_cancel( + mutation: CancelTaskMutation, + persistence: PersistenceBackend, + versions: VersionTracker, +) -> TaskMutationResult: + """Cancel a task (shortcut for transition to CANCELLED).""" + task = await persistence.tasks.get(mutation.task_id) + if task is None: + return not_found_result("cancel", mutation.request_id, mutation.task_id) + + previous_status = task.status + try: + updated = task.with_transition(TaskStatus.CANCELLED) + except ValueError as exc: + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type="cancel", + request_id=mutation.request_id, + task_id=mutation.task_id, + error=str(exc), + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + error_code="validation", + ) + + await persistence.tasks.save(updated) + version = versions.bump(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="cancel", + request_id=mutation.request_id, + task_id=mutation.task_id, + from_status=previous_status.value, + to_status=TaskStatus.CANCELLED.value, + reason=mutation.reason, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + previous_status=previous_status, + ) diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py index f977ebc8df..5ef0722198 100644 --- a/src/ai_company/engine/task_engine_models.py +++ b/src/ai_company/engine/task_engine_models.py @@ -5,15 +5,19 @@ ``requested_by`` field for tracing and auditing. """ +import copy from datetime import UTC, datetime from typing import Literal, Self from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator from ai_company.core.enums import Complexity, Priority, TaskStatus, TaskType -from ai_company.core.task import Task # noqa: TC001 +from ai_company.core.task import Task from ai_company.core.types import NotBlankStr # noqa: TC001 +_VALID_TASK_FIELDS: frozenset[str] = frozenset(Task.model_fields) +"""All declared field names on :class:`Task`, used to reject unknown keys.""" + # ── Mutation data ───────────────────────────────────────────── @@ -122,12 +126,21 @@ class UpdateTaskMutation(BaseModel): @model_validator(mode="after") def _reject_immutable_fields(self) -> Self: + unknown = set(self.updates) - _VALID_TASK_FIELDS + if unknown: + msg = f"Unknown task fields: {sorted(unknown)}" + raise ValueError(msg) forbidden = set(self.updates) & _IMMUTABLE_TASK_FIELDS if forbidden: msg = f"Cannot update immutable fields: {sorted(forbidden)}" raise ValueError(msg) return self + def __init__(self, **data: object) -> None: + super().__init__(**data) + # Deep-copy mutable dict at system boundary per coding guidelines. + object.__setattr__(self, "updates", copy.deepcopy(self.updates)) + _IMMUTABLE_OVERRIDE_FIELDS: frozenset[str] = frozenset( { @@ -173,12 +186,21 @@ class TransitionTaskMutation(BaseModel): @model_validator(mode="after") def _reject_immutable_overrides(self) -> Self: + unknown = set(self.overrides) - _VALID_TASK_FIELDS + if unknown: + msg = f"Unknown task fields in overrides: {sorted(unknown)}" + raise ValueError(msg) forbidden = set(self.overrides) & _IMMUTABLE_OVERRIDE_FIELDS if forbidden: msg = f"Cannot override immutable fields: {sorted(forbidden)}" raise ValueError(msg) return self + def __init__(self, **data: object) -> None: + super().__init__(**data) + # Deep-copy mutable dict at system boundary per coding guidelines. + object.__setattr__(self, "overrides", copy.deepcopy(self.overrides)) + class DeleteTaskMutation(BaseModel): """Request to delete a task. diff --git a/src/ai_company/engine/task_engine_version.py b/src/ai_company/engine/task_engine_version.py new file mode 100644 index 0000000000..fe6446c4c9 --- /dev/null +++ b/src/ai_company/engine/task_engine_version.py @@ -0,0 +1,76 @@ +"""Version tracking for TaskEngine optimistic concurrency. + +Wraps a plain ``dict[str, int]`` with seed, bump, check, and remove +operations. Extracted from ``task_engine.py`` to keep the main module +focused on lifecycle and queue management. +""" + +from ai_company.engine.errors import TaskVersionConflictError +from ai_company.observability import get_logger +from ai_company.observability.events.task_engine import TASK_ENGINE_VERSION_CONFLICT + +logger = get_logger(__name__) + + +class VersionTracker: + """In-memory per-task version counter for optimistic concurrency. + + After a restart the tracker is empty. The first time a persisted + task is encountered it is seeded at version 1 (it was created at + least once). This makes subsequent optimistic-concurrency checks + work within the current engine lifetime. + """ + + def __init__(self) -> None: + self._versions: dict[str, int] = {} + + def seed(self, task_id: str) -> None: + """Ensure *task_id* has a baseline version (idempotent).""" + if task_id not in self._versions: + self._versions[task_id] = 1 + + def set_initial(self, task_id: str, version: int) -> None: + """Set *task_id* to *version* unconditionally (used on create).""" + self._versions[task_id] = version + + def bump(self, task_id: str) -> int: + """Increment and return the version counter for *task_id*.""" + self.seed(task_id) + version = self._versions[task_id] + 1 + self._versions[task_id] = version + return version + + def get(self, task_id: str) -> int: + """Return the current version (0 if not tracked).""" + return self._versions.get(task_id, 0) + + def remove(self, task_id: str) -> None: + """Remove version tracking for a deleted task.""" + self._versions.pop(task_id, None) + + def check( + self, + task_id: str, + expected_version: int | None, + ) -> None: + """Raise ``TaskVersionConflictError`` if versions disagree. + + Seeds the version from persistence if not yet tracked so that + optimistic concurrency survives engine restarts. + """ + if expected_version is None: + return + self.seed(task_id) + current = self._versions[task_id] + if current != expected_version: + msg = ( + f"Version conflict for task {task_id!r}: " + f"expected {expected_version}, current {current}" + ) + logger.warning( + TASK_ENGINE_VERSION_CONFLICT, + task_id=task_id, + expected_version=expected_version, + current_version=current, + ) + raise TaskVersionConflictError(msg) diff --git a/tests/unit/api/controllers/test_task_helpers.py b/tests/unit/api/controllers/test_task_helpers.py new file mode 100644 index 0000000000..22d7db56a5 --- /dev/null +++ b/tests/unit/api/controllers/test_task_helpers.py @@ -0,0 +1,93 @@ +"""Unit tests for task controller helper functions.""" + +import pytest + +from ai_company.api.controllers.tasks import _extract_requester, _map_task_engine_errors +from ai_company.api.errors import ( + ApiValidationError, + NotFoundError, + ServiceUnavailableError, +) +from ai_company.engine.errors import ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskInternalError, + TaskMutationError, + TaskNotFoundError, +) + +# ── _extract_requester ─────────────────────────────────────── + + +@pytest.mark.unit +class TestExtractRequester: + """Tests for extracting requester identity from state.""" + + def test_returns_user_id_when_present(self) -> None: + """Auth middleware sets _connection_user with user_id.""" + + class FakeUser: + user_id = "user-123" + + class FakeState: + _connection_user = FakeUser() + + assert _extract_requester(FakeState()) == "user-123" # type: ignore[arg-type] + + def test_returns_api_fallback_when_no_user(self) -> None: + class FakeState: + pass + + assert _extract_requester(FakeState()) == "api" # type: ignore[arg-type] + + def test_returns_api_when_user_has_no_user_id(self) -> None: + class FakeUser: + pass + + class FakeState: + _connection_user = FakeUser() + + assert _extract_requester(FakeState()) == "api" # type: ignore[arg-type] + + +# ── _map_task_engine_errors ────────────────────────────────── + + +@pytest.mark.unit +class TestMapTaskEngineErrors: + """Tests for mapping engine errors to API errors.""" + + def test_not_found_maps_to_not_found_error(self) -> None: + exc = TaskNotFoundError("Task 'x' not found") + result = _map_task_engine_errors(exc, task_id="x") + assert isinstance(result, NotFoundError) + + def test_not_found_without_task_id(self) -> None: + exc = TaskNotFoundError("not found") + result = _map_task_engine_errors(exc) + assert isinstance(result, NotFoundError) + + def test_not_running_maps_to_service_unavailable(self) -> None: + exc = TaskEngineNotRunningError("not running") + result = _map_task_engine_errors(exc) + assert isinstance(result, ServiceUnavailableError) + + def test_queue_full_maps_to_service_unavailable(self) -> None: + exc = TaskEngineQueueFullError("queue full") + result = _map_task_engine_errors(exc) + assert isinstance(result, ServiceUnavailableError) + + def test_internal_error_maps_to_service_unavailable(self) -> None: + exc = TaskInternalError("internal fault") + result = _map_task_engine_errors(exc) + assert isinstance(result, ServiceUnavailableError) + + def test_mutation_error_maps_to_validation_error(self) -> None: + exc = TaskMutationError("bad input") + result = _map_task_engine_errors(exc) + assert isinstance(result, ApiValidationError) + + def test_unknown_error_passes_through(self) -> None: + exc = RuntimeError("unexpected") + result = _map_task_engine_errors(exc) + assert result is exc diff --git a/tests/unit/api/test_app.py b/tests/unit/api/test_app.py index 4a10fadbec..764000614c 100644 --- a/tests/unit/api/test_app.py +++ b/tests/unit/api/test_app.py @@ -127,4 +127,53 @@ async def test_shutdown_task_engine_failure_does_not_propagate(self) -> None: mock_te.stop = AsyncMock(side_effect=RuntimeError("stop boom")) # Should not raise even when task engine stop fails - await _safe_shutdown(None, mock_te, None, None) + await _safe_shutdown(mock_te, None, None, None) + + +@pytest.mark.unit +class TestTryStop: + """Tests for the _try_stop helper.""" + + async def test_try_stop_success(self) -> None: + """Successful coroutine runs without error.""" + from ai_company.api.app import _try_stop + + called = False + + async def noop() -> None: + nonlocal called + called = True + + await _try_stop(noop(), "event", "error msg") + assert called is True + + async def test_try_stop_exception_swallowed(self) -> None: + """Non-fatal exceptions are swallowed (logged).""" + from ai_company.api.app import _try_stop + + async def fail() -> None: + msg = "boom" + raise RuntimeError(msg) + + # Should not raise + await _try_stop(fail(), "event", "error msg") + + async def test_try_stop_memory_error_reraises(self) -> None: + """MemoryError is re-raised immediately.""" + from ai_company.api.app import _try_stop + + async def oom() -> None: + raise MemoryError + + with pytest.raises(MemoryError): + await _try_stop(oom(), "event", "error msg") + + async def test_try_stop_recursion_error_reraises(self) -> None: + """RecursionError is re-raised immediately.""" + from ai_company.api.app import _try_stop + + async def recurse() -> None: + raise RecursionError + + with pytest.raises(RecursionError): + await _try_stop(recurse(), "event", "error msg") diff --git a/tests/unit/engine/test_agent_engine.py b/tests/unit/engine/test_agent_engine.py index 75da927abc..a790323586 100644 --- a/tests/unit/engine/test_agent_engine.py +++ b/tests/unit/engine/test_agent_engine.py @@ -1032,7 +1032,11 @@ async def test_terminal_status_reported( assert result.is_success is True mock_te.transition_task.assert_awaited_once() call_args = mock_te.transition_task.call_args + assert call_args.args[0] == sample_task_with_criteria.id assert call_args.args[1] == TaskStatus.COMPLETED + assert call_args.kwargs["requested_by"] == str( + sample_agent_with_personality.id, + ) async def test_mutation_error_swallowed( self, diff --git a/tests/unit/engine/test_task_engine_apply.py b/tests/unit/engine/test_task_engine_apply.py new file mode 100644 index 0000000000..3d4b90c28e --- /dev/null +++ b/tests/unit/engine/test_task_engine_apply.py @@ -0,0 +1,527 @@ +"""Unit tests for task_engine_apply dispatch and apply functions.""" + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.task_engine_apply import ( + apply_cancel, + apply_create, + apply_delete, + apply_transition, + apply_update, + dispatch, +) +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskMutation, + DeleteTaskMutation, + TaskMutationResult, + TransitionTaskMutation, + UpdateTaskMutation, +) +from ai_company.engine.task_engine_version import VersionTracker +from tests.unit.engine.task_engine_helpers import FakePersistence, _make_create_data + + +@pytest.fixture +def persistence() -> FakePersistence: + return FakePersistence() + + +@pytest.fixture +def versions() -> VersionTracker: + return VersionTracker() + + +# ── Dispatch routing ───────────────────────────────────────── + + +@pytest.mark.unit +class TestDispatch: + """Tests for mutation dispatch routing.""" + + async def test_dispatch_create( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await dispatch(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is True + assert result.task is not None + + async def test_dispatch_unknown_type_raises( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + """Unknown mutation type raises TypeError.""" + + class FakeMutation: + mutation_type = "fake" + request_id = "req-1" + requested_by = "alice" + + with pytest.raises(TypeError, match="Unknown mutation type"): + await dispatch(FakeMutation(), persistence, versions) # type: ignore[arg-type] + + +# ── apply_create ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestApplyCreate: + """Tests for task creation apply logic.""" + + async def test_creates_task( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(title="New Task"), + ) + result = await apply_create(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is True + assert result.task is not None + assert result.task.title == "New Task" + assert result.task.id.startswith("task-") + assert result.version == 1 + + async def test_create_validation_error( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + """Invalid task data returns failure with validation error code. + + assigned_to is valid for CreateTaskData but Task rejects it + when status is CREATED (assignment consistency invariant). + """ + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(assigned_to="bob"), + ) + result = await apply_create(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "validation" + assert "Invalid task data" in (result.error or "") + + async def test_create_persists_task( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await apply_create(mutation, persistence, versions) # type: ignore[arg-type] + assert result.task is not None + stored = await persistence.tasks.get(result.task.id) + assert stored is not None + + +# ── apply_update ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestApplyUpdate: + """Tests for task update apply logic.""" + + async def _create_task( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> TaskMutationResult: + mutation = CreateTaskMutation( + request_id="req-c", + requested_by="alice", + task_data=_make_create_data(), + ) + return await apply_create(mutation, persistence, versions) # type: ignore[arg-type] + + async def test_update_fields( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + created = await self._create_task(persistence, versions) + assert created.task is not None + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=created.task.id, + updates={"title": "Updated"}, + ) + result = await apply_update(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is True + assert result.task is not None + assert result.task.title == "Updated" + assert result.version == 2 + + async def test_update_not_found( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-nonexistent", + updates={"title": "X"}, + ) + result = await apply_update(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "not_found" + + async def test_update_version_conflict( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + created = await self._create_task(persistence, versions) + assert created.task is not None + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=created.task.id, + updates={"title": "X"}, + expected_version=99, + ) + result = await apply_update(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "version_conflict" + + async def test_update_empty_no_op( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + created = await self._create_task(persistence, versions) + assert created.task is not None + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=created.task.id, + updates={}, + ) + result = await apply_update(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is True + assert result.task is not None + assert result.task.title == created.task.title + + async def test_update_validation_error( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + """Invalid update data returns failure with validation error code.""" + created = await self._create_task(persistence, versions) + assert created.task is not None + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=created.task.id, + updates={"priority": "bogus_priority"}, + ) + result = await apply_update(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "validation" + + async def test_update_records_previous_status( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + created = await self._create_task(persistence, versions) + assert created.task is not None + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=created.task.id, + updates={"title": "New"}, + ) + result = await apply_update(mutation, persistence, versions) # type: ignore[arg-type] + assert result.previous_status == TaskStatus.CREATED + + +# ── apply_transition ───────────────────────────────────────── + + +@pytest.mark.unit +class TestApplyTransition: + """Tests for task status transition apply logic.""" + + async def _create_task( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> TaskMutationResult: + mutation = CreateTaskMutation( + request_id="req-c", + requested_by="alice", + task_data=_make_create_data(), + ) + return await apply_create(mutation, persistence, versions) # type: ignore[arg-type] + + async def test_valid_transition( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + created = await self._create_task(persistence, versions) + assert created.task is not None + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=created.task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + ) + result = await apply_transition(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is True + assert result.task is not None + assert result.task.status == TaskStatus.ASSIGNED + assert result.previous_status == TaskStatus.CREATED + + async def test_transition_not_found( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-nonexistent", + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + ) + result = await apply_transition(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "not_found" + + async def test_transition_version_conflict( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + created = await self._create_task(persistence, versions) + assert created.task is not None + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=created.task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + expected_version=99, + ) + result = await apply_transition(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "version_conflict" + + async def test_invalid_transition( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + """CREATED -> COMPLETED is not valid.""" + created = await self._create_task(persistence, versions) + assert created.task is not None + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=created.task.id, + target_status=TaskStatus.COMPLETED, + reason="skip", + overrides={"assigned_to": "bob"}, + ) + result = await apply_transition(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "validation" + + +# ── apply_delete ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestApplyDelete: + """Tests for task deletion apply logic.""" + + async def test_delete_task( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + create_result = await apply_create( + CreateTaskMutation( + request_id="req-c", + requested_by="alice", + task_data=_make_create_data(), + ), + persistence, # type: ignore[arg-type] + versions, + ) + assert create_result.task is not None + mutation = DeleteTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=create_result.task.id, + ) + result = await apply_delete(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is True + assert result.version == 0 + + stored = await persistence.tasks.get(create_result.task.id) + assert stored is None + + async def test_delete_not_found( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + mutation = DeleteTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-nonexistent", + ) + result = await apply_delete(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "not_found" + + async def test_delete_removes_version_tracking( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + create_result = await apply_create( + CreateTaskMutation( + request_id="req-c", + requested_by="alice", + task_data=_make_create_data(), + ), + persistence, # type: ignore[arg-type] + versions, + ) + assert create_result.task is not None + task_id = create_result.task.id + assert versions.get(task_id) == 1 + + await apply_delete( + DeleteTaskMutation( + request_id="req-d", + requested_by="alice", + task_id=task_id, + ), + persistence, # type: ignore[arg-type] + versions, + ) + assert versions.get(task_id) == 0 + + +# ── apply_cancel ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestApplyCancel: + """Tests for task cancellation apply logic.""" + + async def _create_and_assign( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> str: + """Create a task and transition to ASSIGNED, return task_id.""" + create_result = await apply_create( + CreateTaskMutation( + request_id="req-c", + requested_by="alice", + task_data=_make_create_data(), + ), + persistence, # type: ignore[arg-type] + versions, + ) + assert create_result.task is not None + task_id = create_result.task.id + await apply_transition( + TransitionTaskMutation( + request_id="req-t", + requested_by="alice", + task_id=task_id, + target_status=TaskStatus.ASSIGNED, + reason="Assign", + overrides={"assigned_to": "bob"}, + ), + persistence, # type: ignore[arg-type] + versions, + ) + return task_id + + async def test_cancel_assigned_task( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + task_id = await self._create_and_assign(persistence, versions) + mutation = CancelTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task_id, + reason="No longer needed", + ) + result = await apply_cancel(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is True + assert result.task is not None + assert result.task.status == TaskStatus.CANCELLED + assert result.previous_status == TaskStatus.ASSIGNED + + async def test_cancel_not_found( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + mutation = CancelTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-nonexistent", + reason="test", + ) + result = await apply_cancel(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "not_found" + + async def test_cancel_invalid_status( + self, + persistence: FakePersistence, + versions: VersionTracker, + ) -> None: + """CREATED -> CANCELLED is not a valid transition.""" + create_result = await apply_create( + CreateTaskMutation( + request_id="req-c", + requested_by="alice", + task_data=_make_create_data(), + ), + persistence, # type: ignore[arg-type] + versions, + ) + assert create_result.task is not None + mutation = CancelTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=create_result.task.id, + reason="Oops", + ) + result = await apply_cancel(mutation, persistence, versions) # type: ignore[arg-type] + assert result.success is False + assert result.error_code == "validation" diff --git a/tests/unit/engine/test_task_engine_coverage.py b/tests/unit/engine/test_task_engine_coverage.py new file mode 100644 index 0000000000..933cf94690 --- /dev/null +++ b/tests/unit/engine/test_task_engine_coverage.py @@ -0,0 +1,235 @@ +"""Additional coverage tests for TaskEngine edge cases. + +Covers: in-flight envelope resolution during drain, MemoryError re-raise, +_process_one exception paths, and snapshot publishing failures. +""" + +import asyncio +import contextlib + +import pytest + +from ai_company.engine.errors import TaskInternalError +from ai_company.engine.task_engine import TaskEngine, _MutationEnvelope +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CreateTaskMutation, + TaskMutationResult, +) +from tests.unit.engine.task_engine_helpers import ( + FailingMessageBus, + FakePersistence, + _make_create_data, +) + +# ── In-flight envelope resolution ──────────────────────────── + + +@pytest.mark.unit +class TestInFlightResolution: + """Drain timeout resolves both in-flight and queued envelopes.""" + + async def test_in_flight_envelope_resolved_on_drain_timeout( + self, + persistence: FakePersistence, + ) -> None: + """The in-flight envelope gets a failure result on drain timeout.""" + block = asyncio.Event() + original_save = persistence.tasks.save + + async def slow_save(task: object) -> None: + await block.wait() + await original_save(task) # type: ignore[arg-type] + + persistence.tasks.save = slow_save # type: ignore[method-assign] + + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + + # Submit a task that will block in slow_save + blocked = asyncio.create_task( + eng.create_task(_make_create_data(), requested_by="alice"), + ) + await asyncio.sleep(0.05) + + # The processing loop should be in _process_one with _in_flight set + assert eng._in_flight is not None + + # Stop with very short timeout — triggers _fail_remaining_futures + await eng.stop(timeout=0.05) + + # In-flight should be cleared + assert eng._in_flight is None + + # Release the block and clean up + block.set() # type: ignore[unreachable] + blocked.cancel() + with contextlib.suppress(Exception, asyncio.CancelledError): + await blocked + + +# ── _process_one exception handling ────────────────────────── + + +@pytest.mark.unit +class TestProcessOneExceptionHandling: + """Test that _process_one handles unexpected exceptions gracefully.""" + + async def test_dispatch_exception_returns_internal_error( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + """An exception during dispatch produces an internal error result.""" + + async def exploding_save(task: object) -> None: + msg = "Unexpected persistence failure" + raise RuntimeError(msg) + + persistence.tasks.save = exploding_save # type: ignore[method-assign] + + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await eng.submit(mutation) + assert result.success is False + assert result.error_code == "internal" + assert "Internal error" in (result.error or "") + finally: + await eng.stop(timeout=2.0) + + +# ── Snapshot publish failure ───────────────────────────────── + + +@pytest.mark.unit +class TestSnapshotPublishFailure: + """Snapshot publishing failure does not affect the mutation result.""" + + async def test_publish_failure_logged_not_raised( + self, + persistence: FakePersistence, + ) -> None: + """Even when publish fails, create_task returns the task.""" + failing_bus = FailingMessageBus() + config = TaskEngineConfig(publish_snapshots=True) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=failing_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + task = await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + assert task.id.startswith("task-") + finally: + await eng.stop(timeout=2.0) + + +# ── _raise_typed_error coverage ────────────────────────────── + + +@pytest.mark.unit +class TestRaiseTypedError: + """Test _raise_typed_error for internal error code mapping.""" + + async def test_internal_error_code_raises_task_internal_error( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + """error_code='internal' should raise TaskInternalError.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + result = TaskMutationResult( + request_id="req-1", + success=False, + error="Something went wrong", + error_code="internal", + ) + with pytest.raises(TaskInternalError, match="Something went wrong"): + eng._raise_typed_error(result) + + +# ── _shutdown_result coverage ──────────────────────────────── + + +@pytest.mark.unit +class TestShutdownResult: + """Test _shutdown_result static method.""" + + async def test_shutdown_result_envelope(self) -> None: + mutation = CreateTaskMutation( + request_id="req-shutdown", + requested_by="alice", + task_data=_make_create_data(), + ) + envelope = _MutationEnvelope(mutation=mutation) + result = TaskEngine._shutdown_result(envelope) + assert result.success is False + assert result.error_code == "internal" + assert "shut down" in (result.error or "").lower() + assert result.request_id == "req-shutdown" + + +# ── Processing loop continues after error ──────────────────── + + +@pytest.mark.unit +class TestProcessingLoopResilience: + """Verify the processing loop continues after a single mutation fails.""" + + async def test_loop_continues_after_failure( + self, + persistence: FakePersistence, + ) -> None: + """A failing mutation does not stop subsequent mutations.""" + call_count = 0 + original_save = persistence.tasks.save + + async def fail_first_save(task: object) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + msg = "First save fails" + raise RuntimeError(msg) + await original_save(task) # type: ignore[arg-type] + + persistence.tasks.save = fail_first_save # type: ignore[method-assign] + + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + try: + # First mutation fails + m1 = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + r1 = await eng.submit(m1) + assert r1.success is False + + # Second mutation succeeds — loop recovered + m2 = CreateTaskMutation( + request_id="req-2", + requested_by="alice", + task_data=_make_create_data(), + ) + r2 = await eng.submit(m2) + assert r2.success is True + assert r2.task is not None + finally: + await eng.stop(timeout=2.0) diff --git a/tests/unit/engine/test_task_engine_integration.py b/tests/unit/engine/test_task_engine_integration.py index 2298a9de1b..0201c3bb85 100644 --- a/tests/unit/engine/test_task_engine_integration.py +++ b/tests/unit/engine/test_task_engine_integration.py @@ -159,7 +159,10 @@ async def test_queue_full_raises( persistence=persistence, # type: ignore[arg-type] config=tiny_config, ) - # Start the engine but pause the processing loop + # Directly manipulate internal state because triggering a full-queue + # condition through the public API is difficult: we need to fill the + # queue without the background loop draining it, so we set _running + # without calling start() (no processing task) and enqueue manually. eng._running = True # First submit fills the queue diff --git a/tests/unit/engine/test_task_engine_version.py b/tests/unit/engine/test_task_engine_version.py new file mode 100644 index 0000000000..d252a1c7af --- /dev/null +++ b/tests/unit/engine/test_task_engine_version.py @@ -0,0 +1,94 @@ +"""Unit tests for VersionTracker.""" + +import pytest + +from ai_company.engine.errors import TaskVersionConflictError +from ai_company.engine.task_engine_version import VersionTracker + + +@pytest.mark.unit +class TestVersionTracker: + """Tests for in-memory per-task version counter.""" + + def test_seed_sets_version_to_one(self) -> None: + vt = VersionTracker() + vt.seed("task-1") + assert vt.get("task-1") == 1 + + def test_seed_is_idempotent(self) -> None: + vt = VersionTracker() + vt.seed("task-1") + vt.seed("task-1") + assert vt.get("task-1") == 1 + + def test_seed_does_not_reset_after_bump(self) -> None: + vt = VersionTracker() + vt.set_initial("task-1", 1) + vt.bump("task-1") + vt.seed("task-1") + assert vt.get("task-1") == 2 + + def test_set_initial(self) -> None: + vt = VersionTracker() + vt.set_initial("task-1", 5) + assert vt.get("task-1") == 5 + + def test_set_initial_overwrites(self) -> None: + vt = VersionTracker() + vt.set_initial("task-1", 5) + vt.set_initial("task-1", 10) + assert vt.get("task-1") == 10 + + def test_bump_increments(self) -> None: + vt = VersionTracker() + vt.set_initial("task-1", 1) + assert vt.bump("task-1") == 2 + assert vt.bump("task-1") == 3 + + def test_bump_auto_seeds(self) -> None: + """Bumping an unknown task seeds at 1, then increments to 2.""" + vt = VersionTracker() + assert vt.bump("task-1") == 2 + + def test_get_returns_zero_for_untracked(self) -> None: + vt = VersionTracker() + assert vt.get("task-unknown") == 0 + + def test_remove_clears_tracking(self) -> None: + vt = VersionTracker() + vt.set_initial("task-1", 3) + vt.remove("task-1") + assert vt.get("task-1") == 0 + + def test_remove_nonexistent_is_noop(self) -> None: + vt = VersionTracker() + vt.remove("task-unknown") # no error + + def test_check_passes_when_none(self) -> None: + vt = VersionTracker() + vt.check("task-1", None) # no error + + def test_check_passes_when_version_matches(self) -> None: + vt = VersionTracker() + vt.set_initial("task-1", 3) + vt.check("task-1", 3) # no error + + def test_check_raises_on_conflict(self) -> None: + vt = VersionTracker() + vt.set_initial("task-1", 3) + with pytest.raises( + TaskVersionConflictError, + match="expected 99, current 3", + ): + vt.check("task-1", 99) + + def test_check_seeds_unknown_task(self) -> None: + """First check on unknown task seeds at 1 then validates.""" + vt = VersionTracker() + vt.check("task-1", 1) # seeds at 1, matches + assert vt.get("task-1") == 1 + + def test_check_seeds_then_rejects_mismatch(self) -> None: + vt = VersionTracker() + with pytest.raises(TaskVersionConflictError, match="expected 5"): + vt.check("task-1", 5) From 7ec0ab7ad45aac39bd4f3ed0c7590875d3b49303 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 20:02:40 +0100 Subject: [PATCH 08/14] fix: address 22 PR review items and add 33 new tests for TaskEngine coverage Review fixes: - create_task now uses _raise_typed_error for proper error dispatch - Add PydanticValidationError catch in all controller write methods - Use _extract_requester for create_task audit trail consistency - Add logging for TaskInternalError and TaskEngineQueueFullError in controller - Add TaskEngineError catch in agent_engine _report_to_task_engine - Extract hardcoded values to class constants (_POLL_INTERVAL_SECONDS, etc.) - Add failed_count tracking and logging in _fail_remaining_futures - Add reason field to TaskStateChanged event for audit trail - Add bounds check to VersionTracker.set_initial - Fix _try_stop parameter type from object to Awaitable[None] - Fix State Coordination docs (model_copy -> model_validate/with_transition) - Improve docstrings for _processing_loop, _publish_snapshot, apply_cancel, etc. New tests (33 total): - FIFO ordering guarantee - Default reason generation in transition_task - Delete snapshot new_status=None verification - Cancel version bump correctness - _raise_typed_error dispatch for all error codes - Snapshot reason propagation (transition, cancel, create, update) - MemoryError re-raise in processing loop - _fail_remaining_futures count tracking - Deep-copy isolation for UpdateTaskMutation and TransitionTaskMutation - Unknown field rejection in updates and overrides - VersionTracker.set_initial bounds check (zero and negative) - TaskStateChanged reason field (populated, default none, cancel) - create_task typed error dispatch (internal, validation) - Transition overrides and previous_status via engine --- docs/architecture/tech-stack.md | 2 +- docs/design/engine.md | 2 +- src/ai_company/api/app.py | 4 +- src/ai_company/api/controllers/tasks.py | 31 +- src/ai_company/engine/agent_engine.py | 20 +- src/ai_company/engine/task_engine.py | 46 +- src/ai_company/engine/task_engine_apply.py | 13 +- src/ai_company/engine/task_engine_models.py | 13 +- src/ai_company/engine/task_engine_version.py | 18 +- .../unit/engine/test_task_engine_extended.py | 616 ++++++++++++++++++ tests/unit/engine/test_task_engine_models.py | 114 ++++ tests/unit/engine/test_task_engine_version.py | 12 + 12 files changed, 870 insertions(+), 21 deletions(-) create mode 100644 tests/unit/engine/test_task_engine_extended.py diff --git a/docs/architecture/tech-stack.md b/docs/architecture/tech-stack.md index 4b867d4798..85d83432b8 100644 --- a/docs/architecture/tech-stack.md +++ b/docs/architecture/tech-stack.md @@ -119,7 +119,7 @@ These conventions are used throughout the codebase. For full details on each, se | **Cost tiers and quota tracking** | Adopted | Configurable `CostTierDefinition` with merge/override semantics. `QuotaTracker` enforces per-provider request/token quotas with window-based rotation. | | **Shared org memory** | Adopted | `OrgMemoryBackend` protocol with `HybridPromptRetrievalBackend`. Seniority-based write access control. Core policies in system prompts; extended facts retrieved on demand. | | **Memory consolidation** | Adopted | `ConsolidationStrategy` protocol with deduplication + summarization. `RetentionEnforcer` for age-based cleanup. `ArchivalStore` for cold storage. | -| **State coordination** | Adopted | Centralized single-writer `TaskEngine` with `asyncio.Queue`. Agents submit requests; engine applies `model_copy(update=...)` sequentially and publishes snapshots. | +| **State coordination** | Adopted | Centralized single-writer `TaskEngine` with `asyncio.Queue`. Agents submit requests; engine applies `model_validate` / `with_transition` sequentially and publishes snapshots. | | **Workspace isolation** | Adopted | Pluggable `WorkspaceIsolationStrategy` protocol. Default: git worktrees with sequential merge on completion. | | **Graceful shutdown** | Adopted | Pluggable `ShutdownStrategy` protocol with cooperative 30-second timeout. Force-cancel after timeout with `INTERRUPTED` status. | | **Template inheritance** | Adopted | `extends` field triggers parent resolution at render time with deep merge by field type. Circular chain detection included. | diff --git a/docs/design/engine.md b/docs/design/engine.md index 91774e4c79..6e91fa9231 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -680,7 +680,7 @@ These are complementary systems handling different types of shared state: | State Type | Coordination | Mechanism | |-----------|-------------|-----------| -| Framework state (tasks, assignments, budget) | Centralized single-writer (`TaskEngine`) | `model_copy(update=...)` via async queue | +| Framework state (tasks, assignments, budget) | Centralized single-writer (`TaskEngine`) | `model_validate` / `with_transition` via async queue | | Code and files (agent work output) | Workspace isolation (`WorkspaceIsolationStrategy`) | Git worktrees / branches | | Agent memory (personal) | Per-agent ownership | Each agent owns its memory exclusively | | Org memory (shared knowledge) | Single-writer (`OrgMemoryBackend`) | `OrgMemoryBackend` protocol with role-based write access control | diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 588a51c507..768fe9b174 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -135,7 +135,7 @@ async def on_shutdown() -> None: async def _try_stop( - coro: object, + coro: Awaitable[None], event: str, error_msg: str, ) -> None: @@ -146,7 +146,7 @@ async def _try_stop( shutdown steps can still run. """ try: - await coro # type: ignore[misc] + await coro except MemoryError, RecursionError: raise except Exception: diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 1a958d68e0..5da7a80919 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -2,6 +2,7 @@ from litestar import Controller, delete, get, patch, post from litestar.datastructures import State # noqa: TC002 +from pydantic import ValidationError as PydanticValidationError from ai_company.api.dto import ( ApiResponse, @@ -47,11 +48,17 @@ def _extract_requester(state: State) -> str: """Extract requester identity from the authenticated user. Falls back to ``"api"`` when the connection carries no user - (e.g. in tests without auth middleware). + (e.g. in tests without auth middleware). Logs a warning on + fallback so auth misconfiguration is visible in production. """ user = getattr(state, "_connection_user", None) if user is not None and hasattr(user, "user_id"): return str(user.user_id) + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="authenticated_user", + note="No authenticated user found, falling back to 'api'", + ) return "api" @@ -73,8 +80,22 @@ def _map_task_engine_errors( ) return NotFoundError(str(exc)) if isinstance(exc, TaskEngineNotRunningError | TaskEngineQueueFullError): + logger.error( + API_TASK_TRANSITION_FAILED, + resource="task", + task_id=task_id, + error=str(exc), + error_type=type(exc).__name__, + ) return ServiceUnavailableError(str(exc)) if isinstance(exc, TaskInternalError): + logger.error( + API_TASK_TRANSITION_FAILED, + resource="task", + task_id=task_id, + error=str(exc), + error_type="TaskInternalError", + ) return ServiceUnavailableError(str(exc)) if isinstance(exc, TaskMutationError): return ApiValidationError(str(exc)) @@ -176,8 +197,10 @@ async def create_task( try: task = await app_state.task_engine.create_task( task_data, - requested_by=data.created_by, + requested_by=_extract_requester(state), ) + except PydanticValidationError as exc: + raise ApiValidationError(str(exc)) from exc except ( TaskEngineNotRunningError, TaskEngineQueueFullError, @@ -220,6 +243,8 @@ async def update_task( updates, requested_by=_extract_requester(state), ) + except PydanticValidationError as exc: + raise ApiValidationError(str(exc)) from exc except ( TaskEngineNotRunningError, TaskEngineQueueFullError, @@ -268,6 +293,8 @@ async def transition_task( data.target_status, **transition_kwargs, # type: ignore[arg-type] ) + except PydanticValidationError as exc: + raise ApiValidationError(str(exc)) from exc except ( TaskEngineNotRunningError, TaskEngineQueueFullError, diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index c4e3701411..ec51c9f0c2 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -19,7 +19,11 @@ from ai_company.engine.classification.pipeline import classify_execution_errors from ai_company.engine.context import DEFAULT_MAX_TURNS, AgentContext from ai_company.engine.cost_recording import record_execution_costs -from ai_company.engine.errors import ExecutionStateError, TaskMutationError +from ai_company.engine.errors import ( + ExecutionStateError, + TaskEngineError, + TaskMutationError, +) from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, @@ -107,11 +111,13 @@ TaskStatus.CANCELLED, } ) -"""Final execution outcomes that trigger a report to the centralized TaskEngine. +"""Statuses that trigger a report to the centralized TaskEngine. + +Evaluated after each AgentEngine run. Note: ``FAILED`` and ``INTERRUPTED`` are not strictly terminal in the task lifecycle (they can be reassigned), but represent final outcomes of this -``AgentEngine`` run that should be reported. +particular ``AgentEngine`` run that should be reported. """ @@ -707,6 +713,14 @@ async def _report_to_task_engine( error="Failed to report final status to TaskEngine (mutation rejected)", exc_info=True, ) + except TaskEngineError: + logger.error( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error="TaskEngine unavailable for status report", + exc_info=True, + ) except Exception: logger.error( EXECUTION_ENGINE_ERROR, diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index b774dc0060..3e51dd9f1f 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -180,15 +180,24 @@ async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 def _fail_remaining_futures(self) -> None: """Fail in-flight and remaining enqueued futures after drain timeout.""" shutdown_result_for = self._shutdown_result + failed_count = 0 in_flight = self._in_flight if in_flight is not None and not in_flight.future.done(): in_flight.future.set_result(shutdown_result_for(in_flight)) + failed_count += 1 self._in_flight = None while not self._queue.empty(): with contextlib.suppress(asyncio.QueueEmpty): envelope = self._queue.get_nowait() if not envelope.future.done(): envelope.future.set_result(shutdown_result_for(envelope)) + failed_count += 1 + if failed_count: + logger.warning( + TASK_ENGINE_DRAIN_TIMEOUT, + failed_futures=failed_count, + note="Resolved remaining futures with shutdown failure", + ) @staticmethod def _shutdown_result(envelope: _MutationEnvelope) -> TaskMutationResult: @@ -271,10 +280,10 @@ async def create_task( ) result = await self.submit(mutation) if not result.success: - raise TaskMutationError(result.error or "Create failed") + self._raise_typed_error(result) if result.task is None: msg = "Internal error: create succeeded but task is None" - raise TaskMutationError(msg) + raise TaskInternalError(msg) return result.task async def update_task( @@ -487,13 +496,26 @@ async def list_tasks( # -- Background processing --------------------------------------------- + _POLL_INTERVAL_SECONDS: float = 0.5 + """How often the processing loop checks for ``_running = False``.""" + + _SNAPSHOT_SENDER: str = "task-engine" + """Sender identity used in snapshot ``Message`` envelopes.""" + + _SNAPSHOT_CHANNEL: str = "task_engine" + """Message bus channel for snapshot publication.""" + async def _processing_loop(self) -> None: - """Background loop: dequeue and process mutations sequentially.""" + """Background loop: dequeue and process mutations sequentially. + + Continues draining queued mutations after ``_running`` is set to + ``False``, enabling graceful shutdown. + """ while self._running or not self._queue.empty(): try: envelope = await asyncio.wait_for( self._queue.get(), - timeout=0.5, + timeout=self._POLL_INTERVAL_SECONDS, ) except TimeoutError: continue @@ -566,7 +588,8 @@ async def _publish_snapshot( ) -> None: """Publish a TaskStateChanged event to the message bus. - Best-effort: failures are logged and swallowed. + Best-effort: failures are logged and swallowed (except + ``MemoryError`` and ``RecursionError``, which propagate). """ if self._message_bus is None: return @@ -578,6 +601,8 @@ async def _publish_snapshot( else: new_status = None + reason: str | None = getattr(mutation, "reason", None) + event = TaskStateChanged( mutation_type=mutation.mutation_type, request_id=mutation.request_id, @@ -586,6 +611,7 @@ async def _publish_snapshot( previous_status=result.previous_status, new_status=new_status, version=result.version, + reason=reason, timestamp=datetime.now(UTC), ) @@ -595,12 +621,13 @@ async def _publish_snapshot( from ai_company.communication.enums import MessageType # noqa: PLC0415 from ai_company.communication.message import Message # noqa: PLC0415 + task_id = getattr(mutation, "task_id", None) msg = Message( timestamp=datetime.now(UTC), - sender="task-engine", - to="task_engine", + sender=self._SNAPSHOT_SENDER, + to=self._SNAPSHOT_CHANNEL, type=MessageType.TASK_UPDATE, - channel="task_engine", + channel=self._SNAPSHOT_CHANNEL, content=event.model_dump_json(), ) await self._message_bus.publish(msg) @@ -608,13 +635,16 @@ async def _publish_snapshot( TASK_ENGINE_SNAPSHOT_PUBLISHED, mutation_type=mutation.mutation_type, request_id=mutation.request_id, + task_id=task_id, ) except MemoryError, RecursionError: raise except Exception: + task_id = getattr(mutation, "task_id", None) logger.warning( TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED, mutation_type=mutation.mutation_type, request_id=mutation.request_id, + task_id=task_id, exc_info=True, ) diff --git a/src/ai_company/engine/task_engine_apply.py b/src/ai_company/engine/task_engine_apply.py index 2fbf2d517c..51239ae398 100644 --- a/src/ai_company/engine/task_engine_apply.py +++ b/src/ai_company/engine/task_engine_apply.py @@ -45,7 +45,10 @@ def not_found_result( request_id: str, task_id: str, ) -> TaskMutationResult: - """Build a failure result for a missing task and log it.""" + """Build a failure result for a missing task and log it. + + Sets ``error_code='not_found'`` on the result. + """ error = f"Task {task_id!r} not found" logger.warning( TASK_ENGINE_MUTATION_FAILED, @@ -308,7 +311,13 @@ async def apply_cancel( persistence: PersistenceBackend, versions: VersionTracker, ) -> TaskMutationResult: - """Cancel a task (shortcut for transition to CANCELLED).""" + """Cancel a task (shortcut for transition to CANCELLED). + + Unlike :func:`apply_update` and :func:`apply_transition`, cancel + intentionally omits an ``expected_version`` check — a cancellation + should always succeed regardless of version, similar to a forced + stop signal. + """ task = await persistence.tasks.get(mutation.task_id) if task is None: return not_found_result("cancel", mutation.request_id, mutation.task_id) diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py index 5ef0722198..951a811d59 100644 --- a/src/ai_company/engine/task_engine_models.py +++ b/src/ai_company/engine/task_engine_models.py @@ -16,7 +16,13 @@ from ai_company.core.types import NotBlankStr # noqa: TC001 _VALID_TASK_FIELDS: frozenset[str] = frozenset(Task.model_fields) -"""All declared field names on :class:`Task`, used to reject unknown keys.""" +"""Field names from ``model_fields`` on :class:`Task`. + +Excludes computed fields. + +Used to reject unknown keys in :class:`UpdateTaskMutation` and +:class:`TransitionTaskMutation` validators. +""" # ── Mutation data ───────────────────────────────────────────── @@ -311,6 +317,7 @@ class TaskStateChanged(BaseModel): previous_status: Status before the mutation (``None`` on create). new_status: Status after the mutation (``None`` on delete). version: Version counter after mutation. + reason: Reason for transition/cancel (``None`` for other mutations). timestamp: When the mutation was applied. """ @@ -334,6 +341,10 @@ class TaskStateChanged(BaseModel): description="Status after mutation", ) version: int = Field(ge=0, description="Version counter after mutation") + reason: str | None = Field( + default=None, + description="Reason for transition/cancel", + ) timestamp: AwareDatetime = Field( default_factory=lambda: datetime.now(UTC), description="When the mutation was applied", diff --git a/src/ai_company/engine/task_engine_version.py b/src/ai_company/engine/task_engine_version.py index fe6446c4c9..8b127edb72 100644 --- a/src/ai_company/engine/task_engine_version.py +++ b/src/ai_company/engine/task_engine_version.py @@ -19,6 +19,15 @@ class VersionTracker: task is encountered it is seeded at version 1 (it was created at least once). This makes subsequent optimistic-concurrency checks work within the current engine lifetime. + + **Limitation:** version tracking is volatile — it resets on process + restart. After a restart, the first optimistic-concurrency check + for any task will succeed regardless of the true version history + because the tracker seeds the version at 1. Durable version + tracking (persisted alongside the task) is a future enhancement. + + This class is designed for single-writer access from the + ``TaskEngine`` processing loop and is **not** thread-safe. """ def __init__(self) -> None: @@ -30,7 +39,14 @@ def seed(self, task_id: str) -> None: self._versions[task_id] = 1 def set_initial(self, task_id: str, version: int) -> None: - """Set *task_id* to *version* unconditionally (used on create).""" + """Set *task_id* to *version* unconditionally (used on create). + + Raises: + ValueError: If *version* is less than 1. + """ + if version < 1: + msg = f"Version must be >= 1, got {version}" + raise ValueError(msg) self._versions[task_id] = version def bump(self, task_id: str) -> int: diff --git a/tests/unit/engine/test_task_engine_extended.py b/tests/unit/engine/test_task_engine_extended.py new file mode 100644 index 0000000000..dc2a6c5cc3 --- /dev/null +++ b/tests/unit/engine/test_task_engine_extended.py @@ -0,0 +1,616 @@ +"""Extended coverage tests for TaskEngine. + +Covers test gaps identified during PR #325 review: +- FIFO ordering guarantee +- Default reason generation in transition_task +- Delete snapshot publishes new_status=None +- Cancel version bump correctness +- create_task _raise_typed_error dispatch for all error codes +- Snapshot reason propagation for transitions and cancels +- _processing_loop MemoryError re-raise +""" + +import asyncio +from typing import TYPE_CHECKING + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.errors import ( + TaskInternalError, + TaskMutationError, + TaskNotFoundError, +) +from ai_company.engine.task_engine import TaskEngine, _MutationEnvelope +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskMutation, + TaskMutationResult, + TaskStateChanged, + TransitionTaskMutation, +) +from tests.unit.engine.task_engine_helpers import ( + FakeMessageBus, + FakePersistence, + _make_create_data, +) + +if TYPE_CHECKING: + from ai_company.engine.task_engine_config import TaskEngineConfig + +# ── FIFO ordering guarantee ───────────────────────────────── + + +@pytest.mark.unit +class TestFIFOOrdering: + """Mutations are processed in FIFO order via the single-writer queue.""" + + async def test_mutations_processed_in_submission_order( + self, + engine: TaskEngine, + ) -> None: + """Create 5 tasks and verify they are processed in order.""" + results: list[TaskMutationResult] = [] + mutations = [ + CreateTaskMutation( + request_id=f"req-{i}", + requested_by="alice", + task_data=_make_create_data(title=f"Task {i}"), + ) + for i in range(5) + ] + for mutation in mutations: + result = await engine.submit(mutation) + results.append(result) + + assert all(r.success for r in results) + # Each result's request_id matches submission order + for i, result in enumerate(results): + assert result.request_id == f"req-{i}" + assert result.task is not None + assert result.task.title == f"Task {i}" + + async def test_interleaved_create_update_ordering( + self, + engine: TaskEngine, + ) -> None: + """Create then update: update sees the created task.""" + task = await engine.create_task( + _make_create_data(title="Original"), + requested_by="alice", + ) + # Immediately update — this should see the task because + # the queue processes sequentially + updated = await engine.update_task( + task.id, + {"title": "Updated"}, + requested_by="alice", + ) + assert updated.title == "Updated" + + +# ── Default reason generation ──────────────────────────────── + + +@pytest.mark.unit +class TestDefaultReasonGeneration: + """transition_task generates a default reason when none is provided.""" + + async def test_empty_reason_generates_default( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + transitioned, _ = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + assigned_to="bob", + # reason defaults to "" + ) + assert transitioned.status == TaskStatus.ASSIGNED + + async def test_explicit_reason_preserved( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + transitioned, _ = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Manager assigned task", + assigned_to="bob", + ) + assert transitioned.status == TaskStatus.ASSIGNED + + +# ── Delete snapshot new_status=None ────────────────────────── + + +@pytest.mark.unit +class TestDeleteSnapshotEvent: + """Delete mutations publish events with new_status=None.""" + + async def test_delete_snapshot_has_none_status( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + config: TaskEngineConfig, + ) -> None: + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + task = await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + await asyncio.sleep(0) # let snapshot publish + message_bus.published.clear() + + await eng.delete_task(task.id, requested_by="alice") + await asyncio.sleep(0) # let snapshot publish + + assert len(message_bus.published) == 1 + msg = message_bus.published[0] + # The message content is JSON-serialized TaskStateChanged + event = TaskStateChanged.model_validate_json(msg.content) + assert event.mutation_type == "delete" + assert event.new_status is None + assert event.task is None + finally: + await eng.stop(timeout=2.0) + + +# ── Cancel version bump ───────────────────────────────────── + + +@pytest.mark.unit +class TestCancelVersionBump: + """Cancel mutations correctly bump the version counter.""" + + async def test_cancel_increments_version( + self, + engine: TaskEngine, + ) -> None: + # Create (v1) -> Assign (v2) -> Cancel (v3) + create_mut = CreateTaskMutation( + request_id="req-c", + requested_by="alice", + task_data=_make_create_data(), + ) + r1 = await engine.submit(create_mut) + assert r1.version == 1 + assert r1.task is not None + + assign_mut = TransitionTaskMutation( + request_id="req-a", + requested_by="alice", + task_id=r1.task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + ) + r2 = await engine.submit(assign_mut) + assert r2.version == 2 + + cancel_mut = CancelTaskMutation( + request_id="req-x", + requested_by="alice", + task_id=r1.task.id, + reason="No longer needed", + ) + r3 = await engine.submit(cancel_mut) + assert r3.success is True + assert r3.version == 3 + assert r3.task is not None + assert r3.task.status == TaskStatus.CANCELLED + + +# ── create_task _raise_typed_error dispatch ────────────────── + + +@pytest.mark.unit +class TestCreateTaskTypedErrorDispatch: + """create_task uses _raise_typed_error for proper error dispatch.""" + + async def test_create_internal_error_raises_task_internal( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + """Internal persistence error raises TaskInternalError.""" + + async def exploding_save(task: object) -> None: + msg = "Disk full" + raise OSError(msg) + + persistence.tasks.save = exploding_save # type: ignore[method-assign] + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + with pytest.raises(TaskInternalError): + await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + finally: + await eng.stop(timeout=2.0) + + async def test_create_validation_error_raises_mutation_error( + self, + engine: TaskEngine, + ) -> None: + """Validation failure (assigned_to on CREATED) raises TaskMutationError.""" + with pytest.raises(TaskMutationError): + await engine.create_task( + _make_create_data(assigned_to="should-fail"), + requested_by="alice", + ) + + +# ── Snapshot reason propagation ────────────────────────────── + + +@pytest.mark.unit +class TestSnapshotReasonPropagation: + """Snapshot events carry the reason from transition/cancel mutations.""" + + async def test_transition_snapshot_carries_reason( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + config: TaskEngineConfig, + ) -> None: + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + task = await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + await asyncio.sleep(0) + message_bus.published.clear() + + await eng.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Manager assigned", + assigned_to="bob", + ) + await asyncio.sleep(0) + + assert len(message_bus.published) == 1 + event = TaskStateChanged.model_validate_json( + message_bus.published[0].content, + ) + assert event.reason == "Manager assigned" + finally: + await eng.stop(timeout=2.0) + + async def test_cancel_snapshot_carries_reason( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + config: TaskEngineConfig, + ) -> None: + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + task = await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + await eng.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + await asyncio.sleep(0) + message_bus.published.clear() + + await eng.cancel_task( + task.id, + requested_by="alice", + reason="Budget cut", + ) + await asyncio.sleep(0) + + assert len(message_bus.published) == 1 + event = TaskStateChanged.model_validate_json( + message_bus.published[0].content, + ) + assert event.reason == "Budget cut" + finally: + await eng.stop(timeout=2.0) + + async def test_create_snapshot_reason_is_none( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + config: TaskEngineConfig, + ) -> None: + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + await asyncio.sleep(0) + + assert len(message_bus.published) == 1 + event = TaskStateChanged.model_validate_json( + message_bus.published[0].content, + ) + assert event.reason is None + finally: + await eng.stop(timeout=2.0) + + async def test_update_snapshot_reason_is_none( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + config: TaskEngineConfig, + ) -> None: + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + task = await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + await asyncio.sleep(0) + message_bus.published.clear() + + await eng.update_task( + task.id, + {"title": "Updated"}, + requested_by="alice", + ) + await asyncio.sleep(0) + + assert len(message_bus.published) == 1 + event = TaskStateChanged.model_validate_json( + message_bus.published[0].content, + ) + assert event.reason is None + finally: + await eng.stop(timeout=2.0) + + +# ── MemoryError re-raise in processing loop ────────────────── + + +@pytest.mark.unit +class TestMemoryErrorReRaise: + """MemoryError and RecursionError must propagate, not be swallowed.""" + + async def test_memory_error_propagates_through_process_one( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + """MemoryError in dispatch propagates through _process_one.""" + + async def oom_save(task: object) -> None: + raise MemoryError + + persistence.tasks.save = oom_save # type: ignore[method-assign] + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + mutation = CreateTaskMutation( + request_id="req-oom", + requested_by="alice", + task_data=_make_create_data(), + ) + # The MemoryError propagates to the processing loop which + # re-raises it, causing the processing task to fail. + # The submit future may never resolve, so we check the + # processing task directly. + envelope = _MutationEnvelope(mutation=mutation) + eng._queue.put_nowait(envelope) + + # Wait for the processing task to complete/fail + assert eng._processing_task is not None + with pytest.raises(MemoryError): + await eng._processing_task + finally: + eng._running = False + eng._processing_task = None + + +# ── _fail_remaining_futures coverage ───────────────────────── + + +@pytest.mark.unit +class TestFailRemainingFuturesCount: + """Verify _fail_remaining_futures tracks and logs the count.""" + + async def test_multiple_queued_futures_all_failed( + self, + persistence: FakePersistence, + ) -> None: + """All queued futures get shutdown results.""" + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng._running = True + + envelopes = [] + for i in range(3): + mutation = CreateTaskMutation( + request_id=f"req-{i}", + requested_by="alice", + task_data=_make_create_data(), + ) + envelope = _MutationEnvelope(mutation=mutation) + eng._queue.put_nowait(envelope) + envelopes.append(envelope) + + eng._running = False + eng._fail_remaining_futures() + + for envelope in envelopes: + assert envelope.future.done() + result = envelope.future.result() + assert result.success is False + assert result.error_code == "internal" + assert "shut down" in (result.error or "").lower() + + +# ── Typed error dispatch for all error codes ───────────────── + + +@pytest.mark.unit +class TestRaiseTypedErrorAllCodes: + """_raise_typed_error maps all error_code values to typed exceptions.""" + + def test_not_found_code(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="not found", + error_code="not_found", + ) + with pytest.raises(TaskNotFoundError, match="not found"): + TaskEngine._raise_typed_error(result) + + def test_version_conflict_code(self) -> None: + from ai_company.engine.errors import TaskVersionConflictError + + result = TaskMutationResult( + request_id="r", + success=False, + error="conflict", + error_code="version_conflict", + ) + with pytest.raises(TaskVersionConflictError, match="conflict"): + TaskEngine._raise_typed_error(result) + + def test_internal_code(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="boom", + error_code="internal", + ) + with pytest.raises(TaskInternalError, match="boom"): + TaskEngine._raise_typed_error(result) + + def test_validation_code_falls_through(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="bad data", + error_code="validation", + ) + with pytest.raises(TaskMutationError, match="bad data"): + TaskEngine._raise_typed_error(result) + + def test_none_code_falls_through(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="generic", + ) + with pytest.raises(TaskMutationError, match="generic"): + TaskEngine._raise_typed_error(result) + + def test_missing_error_uses_default_message(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="Mutation failed", + ) + with pytest.raises(TaskMutationError, match="Mutation failed"): + TaskEngine._raise_typed_error(result) + + +# ── Transition with overrides via engine ───────────────────── + + +@pytest.mark.unit +class TestTransitionOverridesViaEngine: + """Transition overrides flow through the engine correctly.""" + + async def test_transition_with_assigned_to_override( + self, + engine: TaskEngine, + ) -> None: + """assigned_to passed as kwarg becomes an override on the mutation.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + transitioned, prev = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + assert transitioned.assigned_to == "bob" + assert prev == TaskStatus.CREATED + + async def test_transition_returns_previous_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # CREATED -> ASSIGNED + assigned, prev1 = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + assert prev1 == TaskStatus.CREATED + + # ASSIGNED -> IN_PROGRESS + in_progress, prev2 = await engine.transition_task( + assigned.id, + TaskStatus.IN_PROGRESS, + requested_by="bob", + reason="Starting work", + ) + assert prev2 == TaskStatus.ASSIGNED + assert in_progress.status == TaskStatus.IN_PROGRESS diff --git a/tests/unit/engine/test_task_engine_models.py b/tests/unit/engine/test_task_engine_models.py index ee50199fcc..1de77df13b 100644 --- a/tests/unit/engine/test_task_engine_models.py +++ b/tests/unit/engine/test_task_engine_models.py @@ -297,6 +297,40 @@ def test_delete_event(self) -> None: assert event.previous_status is None assert event.new_status is None + def test_reason_field_populated(self) -> None: + event = TaskStateChanged( + mutation_type="transition", + request_id="req-1", + requested_by="alice", + previous_status=TaskStatus.ASSIGNED, + new_status=TaskStatus.IN_PROGRESS, + version=2, + reason="Starting work on task", + ) + assert event.reason == "Starting work on task" + + def test_reason_field_default_none(self) -> None: + event = TaskStateChanged( + mutation_type="create", + request_id="req-1", + requested_by="alice", + new_status=TaskStatus.CREATED, + version=1, + ) + assert event.reason is None + + def test_cancel_event_has_reason(self) -> None: + event = TaskStateChanged( + mutation_type="cancel", + request_id="req-1", + requested_by="alice", + previous_status=TaskStatus.ASSIGNED, + new_status=TaskStatus.CANCELLED, + version=3, + reason="No longer needed", + ) + assert event.reason == "No longer needed" + def test_serialization_roundtrip(self) -> None: event = TaskStateChanged( mutation_type="create", @@ -310,3 +344,83 @@ def test_serialization_roundtrip(self) -> None: assert restored.mutation_type == event.mutation_type assert restored.request_id == event.request_id assert restored.version == event.version + + +@pytest.mark.unit +class TestDeepCopyIsolation: + """Verify deep-copy isolation of mutable dict fields.""" + + def test_update_mutation_isolates_updates(self) -> None: + """Mutating the original dict after construction has no effect.""" + original = {"title": "Original"} + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates=original, + ) + original["title"] = "Tampered" + assert mutation.updates["title"] == "Original" + + def test_transition_mutation_isolates_overrides(self) -> None: + """Mutating the original dict after construction has no effect.""" + original: dict[str, object] = {"assigned_to": "bob"} + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides=original, + ) + original["assigned_to"] = "tampered" + assert mutation.overrides["assigned_to"] == "bob" + + def test_update_mutation_nested_dict_isolation(self) -> None: + """Nested mutable values are also deep-copied.""" + nested = {"key": "value"} + original: dict[str, object] = {"description": nested} + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates=original, + ) + nested["key"] = "tampered" + inner = mutation.updates["description"] + assert isinstance(inner, dict) + assert inner["key"] == "value" + + +@pytest.mark.unit +class TestUnknownFieldRejection: + """Verify unknown field names are rejected in updates/overrides.""" + + def test_update_rejects_unknown_field(self) -> None: + with pytest.raises(ValidationError, match="Unknown task fields"): + UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"nonexistent_field": "value"}, + ) + + def test_update_rejects_multiple_unknown_fields(self) -> None: + with pytest.raises(ValidationError, match="Unknown task fields"): + UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"foo": 1, "bar": 2}, + ) + + def test_transition_rejects_unknown_override(self) -> None: + with pytest.raises(ValidationError, match="Unknown task fields"): + TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + target_status=TaskStatus.ASSIGNED, + reason="test", + overrides={"nonexistent_field": "value"}, + ) diff --git a/tests/unit/engine/test_task_engine_version.py b/tests/unit/engine/test_task_engine_version.py index d252a1c7af..0e1ecb3697 100644 --- a/tests/unit/engine/test_task_engine_version.py +++ b/tests/unit/engine/test_task_engine_version.py @@ -92,3 +92,15 @@ def test_check_seeds_then_rejects_mismatch(self) -> None: vt = VersionTracker() with pytest.raises(TaskVersionConflictError, match="expected 5"): vt.check("task-1", 5) + + def test_set_initial_rejects_zero(self) -> None: + """set_initial must reject version=0.""" + vt = VersionTracker() + with pytest.raises(ValueError, match="must be >= 1"): + vt.set_initial("task-1", 0) + + def test_set_initial_rejects_negative(self) -> None: + """set_initial must reject negative versions.""" + vt = VersionTracker() + with pytest.raises(ValueError, match="must be >= 1"): + vt.set_initial("task-1", -5) From d1adc05fc5b8487c1a7203d39625d7298cbd62a2 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 20:16:38 +0100 Subject: [PATCH 09/14] fix: map TaskVersionConflictError to 409, fix docstrings, add 15 tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add TaskVersionConflictError → ConflictError(409) mapping in _map_task_engine_errors (was silently falling through to 422) - Wrap PydanticValidationError in update_task/transition_task convenience methods so it raises TaskMutationError instead of escaping - Sanitize Pydantic error messages via _format_validation_error helper - Fix VersionTracker docstrings: clarify version is seeded at 1 heuristically, not loaded from persistence - Update engine design doc optimistic concurrency section to match reality - Replace timing-based asyncio.sleep with Event synchronization in drain test - Add _snapshot_content helper for type-safe snapshot access in tests - Add 15 new tests: version conflict via convenience methods, PydanticValidation wrapping, cancel/delete not-found, start/stop lifecycle, is_running property, 409 mapping (6913 total, up from 6898) --- docs/design/engine.md | 15 +- src/ai_company/api/controllers/tasks.py | 4 + src/ai_company/engine/task_engine.py | 40 +-- src/ai_company/engine/task_engine_apply.py | 19 +- src/ai_company/engine/task_engine_version.py | 22 +- .../unit/api/controllers/test_task_helpers.py | 14 + .../unit/engine/test_task_engine_extended.py | 241 +++++++++++++++++- .../engine/test_task_engine_integration.py | 6 +- 8 files changed, 317 insertions(+), 44 deletions(-) diff --git a/docs/design/engine.md b/docs/design/engine.md index 6e91fa9231..f7b7b08ae5 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -187,13 +187,14 @@ Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop from the previous one (for example via `Task.model_validate({**task.model_dump(), **updates})` or `Task.with_transition(...)`); the existing instance is never mutated. -- **Optimistic concurrency**: Per-task version counters. The persisted - task version is the source of truth; any in-memory cache is an - optimization that is seeded from persistence on task load and may be - invalid after a restart. Callers can pass `expected_version` to detect - stale writes; on mismatch the engine returns a failed - `TaskMutationResult` with `error_code="version_conflict"`. Convenience - methods raise `TaskVersionConflictError`. +- **Optimistic concurrency**: Per-task version counters held in-memory + (volatile). An unknown task is seeded at version 1 on first access — + this is a heuristic baseline, **not** loaded from persistence. Version + tracking resets on engine restart; durable persistence of versions is a + future enhancement. Callers can pass `expected_version` to detect stale + writes; on mismatch the engine returns a failed `TaskMutationResult` + with `error_code="version_conflict"`. Convenience methods raise + `TaskVersionConflictError`. - **Read-through**: `get_task()` and `list_tasks()` bypass the queue and read directly from persistence — safe because TaskEngine is the sole writer. - **Snapshot publishing**: On success, a `TaskStateChanged` event is published diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 5da7a80919..f13eaad6e0 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -13,6 +13,7 @@ ) from ai_company.api.errors import ( ApiValidationError, + ConflictError, NotFoundError, ServiceUnavailableError, ) @@ -27,6 +28,7 @@ TaskInternalError, TaskMutationError, TaskNotFoundError, + TaskVersionConflictError, ) from ai_company.engine.task_engine_models import CreateTaskData from ai_company.observability import get_logger @@ -97,6 +99,8 @@ def _map_task_engine_errors( error_type="TaskInternalError", ) return ServiceUnavailableError(str(exc)) + if isinstance(exc, TaskVersionConflictError): + return ConflictError(str(exc)) if isinstance(exc, TaskMutationError): return ApiValidationError(str(exc)) return exc diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index 3e51dd9f1f..ac9dcae707 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -17,6 +17,8 @@ from typing import TYPE_CHECKING, Never from uuid import uuid4 +from pydantic import ValidationError as PydanticValidationError + from ai_company.engine.errors import ( TaskEngineNotRunningError, TaskEngineQueueFullError, @@ -312,13 +314,16 @@ async def update_task( TaskVersionConflictError: If ``expected_version`` doesn't match. TaskMutationError: If the mutation fails. """ - mutation = UpdateTaskMutation( - request_id=uuid4().hex, - requested_by=requested_by, - task_id=task_id, - updates=updates, - expected_version=expected_version, - ) + try: + mutation = UpdateTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + updates=updates, + expected_version=expected_version, + ) + except PydanticValidationError as exc: + raise TaskMutationError(str(exc)) from exc result = await self.submit(mutation) if not result.success: self._raise_typed_error(result) @@ -359,15 +364,18 @@ async def transition_task( TaskMutationError: If the mutation fails. """ effective_reason = reason or f"Transition to {target_status.value}" - mutation = TransitionTaskMutation( - request_id=uuid4().hex, - requested_by=requested_by, - task_id=task_id, - target_status=target_status, - reason=effective_reason, - overrides=dict(overrides), - expected_version=expected_version, - ) + try: + mutation = TransitionTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + target_status=target_status, + reason=effective_reason, + overrides=dict(overrides), + expected_version=expected_version, + ) + except PydanticValidationError as exc: + raise TaskMutationError(str(exc)) from exc result = await self.submit(mutation) if not result.success: self._raise_typed_error(result) diff --git a/src/ai_company/engine/task_engine_apply.py b/src/ai_company/engine/task_engine_apply.py index 51239ae398..fdfb3871b2 100644 --- a/src/ai_company/engine/task_engine_apply.py +++ b/src/ai_company/engine/task_engine_apply.py @@ -40,6 +40,21 @@ # ── Helpers ────────────────────────────────────────────────────── +def _format_validation_error( + prefix: str, + exc: PydanticValidationError, +) -> str: + """Format a Pydantic validation error for external consumption. + + Extracts field paths and messages without exposing raw input + values or internal Pydantic URL hints. + """ + parts = [ + f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in exc.errors() + ] + return f"{prefix}: {'; '.join(parts)}" + + def not_found_result( mutation_type: str, request_id: str, @@ -120,7 +135,7 @@ async def apply_create( budget_limit=data.budget_limit, ) except PydanticValidationError as exc: - error_msg = f"Invalid task data: {exc}" + error_msg = _format_validation_error("Invalid task data", exc) logger.warning( TASK_ENGINE_MUTATION_FAILED, mutation_type="create", @@ -185,7 +200,7 @@ async def apply_update( try: updated = Task.model_validate(merged) except PydanticValidationError as exc: - error_msg = f"Invalid update data: {exc}" + error_msg = _format_validation_error("Invalid update data", exc) logger.warning( TASK_ENGINE_MUTATION_FAILED, mutation_type="update", diff --git a/src/ai_company/engine/task_engine_version.py b/src/ai_company/engine/task_engine_version.py index 8b127edb72..db5073a159 100644 --- a/src/ai_company/engine/task_engine_version.py +++ b/src/ai_company/engine/task_engine_version.py @@ -15,16 +15,18 @@ class VersionTracker: """In-memory per-task version counter for optimistic concurrency. - After a restart the tracker is empty. The first time a persisted - task is encountered it is seeded at version 1 (it was created at - least once). This makes subsequent optimistic-concurrency checks - work within the current engine lifetime. + After a restart the tracker is empty. The first time an unknown + task is encountered during a ``check()`` call, it is seeded at + version 1 — a heuristic baseline, **not** loaded from persistence. + This makes subsequent optimistic-concurrency checks work within the + current engine lifetime but cannot detect conflicts that span + restarts. **Limitation:** version tracking is volatile — it resets on process - restart. After a restart, the first optimistic-concurrency check - for any task will succeed regardless of the true version history - because the tracker seeds the version at 1. Durable version - tracking (persisted alongside the task) is a future enhancement. + restart. After a restart, the first ``expected_version=1`` check + for any task will pass even if the task was mutated many times in a + prior lifetime. Durable version tracking (persisted alongside the + task) is a future enhancement. This class is designed for single-writer access from the ``TaskEngine`` processing loop and is **not** thread-safe. @@ -71,8 +73,8 @@ def check( ) -> None: """Raise ``TaskVersionConflictError`` if versions disagree. - Seeds the version from persistence if not yet tracked so that - optimistic concurrency survives engine restarts. + Seeds the version at 1 if the task is not yet tracked so that + optimistic concurrency works within the current engine lifetime. """ if expected_version is None: return diff --git a/tests/unit/api/controllers/test_task_helpers.py b/tests/unit/api/controllers/test_task_helpers.py index 22d7db56a5..aa733a6b27 100644 --- a/tests/unit/api/controllers/test_task_helpers.py +++ b/tests/unit/api/controllers/test_task_helpers.py @@ -5,6 +5,7 @@ from ai_company.api.controllers.tasks import _extract_requester, _map_task_engine_errors from ai_company.api.errors import ( ApiValidationError, + ConflictError, NotFoundError, ServiceUnavailableError, ) @@ -14,6 +15,7 @@ TaskInternalError, TaskMutationError, TaskNotFoundError, + TaskVersionConflictError, ) # ── _extract_requester ─────────────────────────────────────── @@ -82,6 +84,18 @@ def test_internal_error_maps_to_service_unavailable(self) -> None: result = _map_task_engine_errors(exc) assert isinstance(result, ServiceUnavailableError) + def test_version_conflict_maps_to_conflict_error(self) -> None: + exc = TaskVersionConflictError("version mismatch") + result = _map_task_engine_errors(exc, task_id="task-1") + assert isinstance(result, ConflictError) + assert result.status_code == 409 + + def test_version_conflict_preserves_message(self) -> None: + exc = TaskVersionConflictError("expected 2, current 3") + result = _map_task_engine_errors(exc) + assert isinstance(result, ConflictError) + assert "expected 2, current 3" in str(result) + def test_mutation_error_maps_to_validation_error(self) -> None: exc = TaskMutationError("bad input") result = _map_task_engine_errors(exc) diff --git a/tests/unit/engine/test_task_engine_extended.py b/tests/unit/engine/test_task_engine_extended.py index dc2a6c5cc3..9048773ad0 100644 --- a/tests/unit/engine/test_task_engine_extended.py +++ b/tests/unit/engine/test_task_engine_extended.py @@ -20,6 +20,7 @@ TaskInternalError, TaskMutationError, TaskNotFoundError, + TaskVersionConflictError, ) from ai_company.engine.task_engine import TaskEngine, _MutationEnvelope from ai_company.engine.task_engine_models import ( @@ -38,6 +39,18 @@ if TYPE_CHECKING: from ai_company.engine.task_engine_config import TaskEngineConfig + +def _snapshot_content(bus: FakeMessageBus, index: int = 0) -> str: + """Extract the JSON content from a published snapshot message. + + The ``FakeMessageBus`` stores items as ``object`` because the bus + protocol is generic; this helper performs the attribute access that + mypy would otherwise reject. + """ + msg = bus.published[index] + return msg.content # type: ignore[attr-defined,no-any-return] + + # ── FIFO ordering guarantee ───────────────────────────────── @@ -162,9 +175,9 @@ async def test_delete_snapshot_has_none_status( await asyncio.sleep(0) # let snapshot publish assert len(message_bus.published) == 1 - msg = message_bus.published[0] - # The message content is JSON-serialized TaskStateChanged - event = TaskStateChanged.model_validate_json(msg.content) + event = TaskStateChanged.model_validate_json( + _snapshot_content(message_bus), + ) assert event.mutation_type == "delete" assert event.new_status is None assert event.task is None @@ -300,7 +313,7 @@ async def test_transition_snapshot_carries_reason( assert len(message_bus.published) == 1 event = TaskStateChanged.model_validate_json( - message_bus.published[0].content, + _snapshot_content(message_bus), ) assert event.reason == "Manager assigned" finally: @@ -342,7 +355,7 @@ async def test_cancel_snapshot_carries_reason( assert len(message_bus.published) == 1 event = TaskStateChanged.model_validate_json( - message_bus.published[0].content, + _snapshot_content(message_bus), ) assert event.reason == "Budget cut" finally: @@ -369,7 +382,7 @@ async def test_create_snapshot_reason_is_none( assert len(message_bus.published) == 1 event = TaskStateChanged.model_validate_json( - message_bus.published[0].content, + _snapshot_content(message_bus), ) assert event.reason is None finally: @@ -404,7 +417,7 @@ async def test_update_snapshot_reason_is_none( assert len(message_bus.published) == 1 event = TaskStateChanged.model_validate_json( - message_bus.published[0].content, + _snapshot_content(message_bus), ) assert event.reason is None finally: @@ -614,3 +627,217 @@ async def test_transition_returns_previous_status( ) assert prev2 == TaskStatus.ASSIGNED assert in_progress.status == TaskStatus.IN_PROGRESS + + +# ── PydanticValidationError wrapping in convenience methods ── + + +@pytest.mark.unit +class TestConvenienceMethodValidationWrapping: + """Convenience methods wrap PydanticValidationError as TaskMutationError.""" + + async def test_update_task_wraps_pydantic_validation( + self, + engine: TaskEngine, + ) -> None: + """UpdateTaskMutation with immutable field raises TaskMutationError.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # 'id' is an immutable field rejected by model_validator + with pytest.raises(TaskMutationError): + await engine.update_task( + task.id, + {"id": "hacked"}, + requested_by="alice", + ) + + async def test_transition_task_wraps_pydantic_validation( + self, + engine: TaskEngine, + ) -> None: + """TransitionTaskMutation with blank reason raises TaskMutationError.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # Blank requested_by should trigger NotBlankStr validation + with pytest.raises(TaskMutationError): + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by=" ", + reason="Assigning", + assigned_to="bob", + ) + + +# ── Version conflict via convenience methods ───────────────── + + +@pytest.mark.unit +class TestVersionConflictViaConvenienceMethods: + """Convenience methods raise TaskVersionConflictError on version mismatch.""" + + async def test_update_task_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskVersionConflictError): + await engine.update_task( + task.id, + {"title": "New title"}, + requested_by="alice", + expected_version=99, + ) + + async def test_transition_task_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskVersionConflictError): + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + expected_version=99, + assigned_to="bob", + ) + + +# ── Cancel task not found ──────────────────────────────────── + + +@pytest.mark.unit +class TestCancelTaskNotFound: + """cancel_task raises TaskNotFoundError for missing tasks.""" + + async def test_cancel_nonexistent_raises_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.cancel_task( + "task-nonexistent", + requested_by="alice", + reason="Cleanup", + ) + + +# ── Delete task not found ──────────────────────────────────── + + +@pytest.mark.unit +class TestDeleteTaskNotFound: + """delete_task raises TaskNotFoundError for missing tasks.""" + + async def test_delete_nonexistent_raises_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.delete_task( + "task-nonexistent", + requested_by="alice", + ) + + +# ── Start when already running ──────────────────────────────── + + +@pytest.mark.unit +class TestStartAlreadyRunning: + """Starting an already-running engine raises RuntimeError.""" + + async def test_double_start_raises( + self, + engine: TaskEngine, + ) -> None: + # engine fixture already called start() + with pytest.raises(RuntimeError, match="already running"): + engine.start() + + +# ── Stop idempotency ───────────────────────────────────────── + + +@pytest.mark.unit +class TestStopIdempotency: + """Stopping an already-stopped engine is a no-op.""" + + async def test_stop_when_not_running( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + # Never started — stop should be safe + await eng.stop(timeout=1.0) + + async def test_double_stop( + self, + engine: TaskEngine, + ) -> None: + await engine.stop(timeout=2.0) + # Second stop is a no-op + await engine.stop(timeout=1.0) + + +# ── Submit when not running ─────────────────────────────────── + + +@pytest.mark.unit +class TestSubmitWhenNotRunning: + """Submitting to a stopped engine raises TaskEngineNotRunningError.""" + + async def test_submit_after_stop( + self, + engine: TaskEngine, + ) -> None: + from ai_company.engine.errors import TaskEngineNotRunningError + + await engine.stop(timeout=2.0) + mutation = CreateTaskMutation( + request_id="req-late", + requested_by="alice", + task_data=_make_create_data(), + ) + with pytest.raises(TaskEngineNotRunningError): + await engine.submit(mutation) + + +# ── is_running property ────────────────────────────────────── + + +@pytest.mark.unit +class TestIsRunningProperty: + """is_running reflects engine lifecycle.""" + + async def test_running_after_start( + self, + engine: TaskEngine, + ) -> None: + assert engine.is_running is True + + async def test_not_running_after_stop( + self, + engine: TaskEngine, + ) -> None: + await engine.stop(timeout=2.0) + assert engine.is_running is False + + async def test_not_running_before_start( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + assert eng.is_running is False diff --git a/tests/unit/engine/test_task_engine_integration.py b/tests/unit/engine/test_task_engine_integration.py index 0201c3bb85..df1d887173 100644 --- a/tests/unit/engine/test_task_engine_integration.py +++ b/tests/unit/engine/test_task_engine_integration.py @@ -285,9 +285,11 @@ async def test_drain_timeout_resolves_pending_futures( """Futures still in queue are failed when stop() times out.""" # Block the processing loop with a slow save block = asyncio.Event() + entered_save = asyncio.Event() original_save = persistence.tasks.save async def slow_save(task: object) -> None: + entered_save.set() await block.wait() await original_save(task) # type: ignore[arg-type] @@ -308,8 +310,8 @@ async def slow_save(task: object) -> None: task_data=_make_create_data(), ) envelope = _MutationEnvelope(mutation=mutation2) - # Give the engine a tick to start processing the first task - await asyncio.sleep(0.05) + # Wait until slow_save is entered before queuing the second task + await entered_save.wait() eng._queue.put_nowait(envelope) # Stop with a very short timeout — loop is blocked, so timeout fires From ce3c594a950e34335c6f2190905823f120584ec4 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 20:53:15 +0100 Subject: [PATCH 10/14] =?UTF-8?q?fix:=20pre-PR=20review=20=E2=80=94=2034?= =?UTF-8?q?=20findings=20from=2010=20agents,=20fix=20code=20scanning=20ale?= =?UTF-8?q?rt=20#10?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pre-reviewed by 10 agents (code-reviewer, python-reviewer, pr-test-analyzer, silent-failure-hunter, comment-analyzer, type-design-analyzer, logging-audit, resilience-audit, security-reviewer, docs-consistency), 34 findings addressed. Key fixes: - CRITICAL: drain timeout in-flight future now resolved (saved before cancel) - Fix TaskInternalError for invariant violations (was TaskMutationError → wrong HTTP status) - Sanitize 5xx error messages (no internal details leaked to API clients) - Map TaskVersionConflictError to 409 Conflict - Add MutationType/TaskErrorCode type aliases, length constraints on CreateTaskData - Add logging in _raise_typed_error, read-through error wrapping, safety cap on list_tasks - Fix code scanning alert #10: remove explicit PR head SHA checkout in pages-preview - Replace Any with concrete types in API test fake repositories - Add TaskEngineConfig boundary validation, RecursionError re-raise, TaskEngineError branch tests - Split test_task_engine_extended.py (was 870 lines → 540 + 335) - Add apply_cancel docstring with Args/Returns --- .github/workflows/pages-preview.yml | 1 - src/ai_company/api/controllers/tasks.py | 53 ++- src/ai_company/config/defaults.py | 4 +- src/ai_company/engine/agent_engine.py | 2 +- src/ai_company/engine/task_engine.py | 94 +++- src/ai_company/engine/task_engine_apply.py | 69 ++- src/ai_company/engine/task_engine_models.py | 80 ++-- src/ai_company/engine/task_engine_version.py | 6 +- src/ai_company/observability/events/api.py | 1 + tests/unit/api/conftest.py | 34 +- .../unit/api/controllers/test_task_helpers.py | 5 +- tests/unit/engine/conftest.py | 6 +- tests/unit/engine/task_engine_helpers.py | 2 +- tests/unit/engine/test_agent_engine.py | 34 +- tests/unit/engine/test_task_engine_apply.py | 22 +- .../engine/test_task_engine_convenience.py | 357 ++++++++++++++++ .../unit/engine/test_task_engine_coverage.py | 14 +- .../unit/engine/test_task_engine_extended.py | 401 ++---------------- .../engine/test_task_engine_integration.py | 28 +- .../unit/engine/test_task_engine_lifecycle.py | 37 +- .../unit/engine/test_task_engine_mutations.py | 40 +- 21 files changed, 787 insertions(+), 503 deletions(-) create mode 100644 tests/unit/engine/test_task_engine_convenience.py diff --git a/.github/workflows/pages-preview.yml b/.github/workflows/pages-preview.yml index ff9226afd9..59f6588a82 100644 --- a/.github/workflows/pages-preview.yml +++ b/.github/workflows/pages-preview.yml @@ -32,7 +32,6 @@ jobs: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 with: persist-credentials: false - ref: ${{ github.event.pull_request.head.sha }} # --- MkDocs (documentation at /docs) --- - name: Set up Python diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index f13eaad6e0..8296d65dd4 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -33,6 +33,7 @@ from ai_company.engine.task_engine_models import CreateTaskData from ai_company.observability import get_logger from ai_company.observability.events.api import ( + API_AUTH_FALLBACK, API_RESOURCE_NOT_FOUND, API_TASK_DELETED, API_TASK_TRANSITION_FAILED, @@ -57,8 +58,7 @@ def _extract_requester(state: State) -> str: if user is not None and hasattr(user, "user_id"): return str(user.user_id) logger.warning( - API_RESOURCE_NOT_FOUND, - resource="authenticated_user", + API_AUTH_FALLBACK, note="No authenticated user found, falling back to 'api'", ) return "api" @@ -89,7 +89,7 @@ def _map_task_engine_errors( error=str(exc), error_type=type(exc).__name__, ) - return ServiceUnavailableError(str(exc)) + return ServiceUnavailableError("Service temporarily unavailable") if isinstance(exc, TaskInternalError): logger.error( API_TASK_TRANSITION_FAILED, @@ -98,12 +98,34 @@ def _map_task_engine_errors( error=str(exc), error_type="TaskInternalError", ) - return ServiceUnavailableError(str(exc)) + return ServiceUnavailableError("Internal server error") if isinstance(exc, TaskVersionConflictError): + logger.warning( + API_TASK_TRANSITION_FAILED, + resource="task", + task_id=task_id, + error=str(exc), + error_type="TaskVersionConflictError", + ) return ConflictError(str(exc)) if isinstance(exc, TaskMutationError): + logger.warning( + API_TASK_TRANSITION_FAILED, + resource="task", + task_id=task_id, + error=str(exc), + error_type="TaskMutationError", + ) return ApiValidationError(str(exc)) - return exc + # Unknown error type — log and wrap to prevent leaking internals + logger.error( + API_TASK_TRANSITION_FAILED, + resource="task", + task_id=task_id, + error=str(exc), + error_type=type(exc).__name__, + ) + return ServiceUnavailableError("Unexpected engine error") class TaskController(Controller): @@ -187,6 +209,7 @@ async def create_task( Created task envelope. """ app_state: AppState = state.app_state + requester = _extract_requester(state) task_data = CreateTaskData( title=data.title, description=data.description, @@ -198,10 +221,17 @@ async def create_task( estimated_complexity=data.estimated_complexity, budget_limit=data.budget_limit, ) + if data.created_by != requester: + logger.info( + API_TASK_UPDATED, + note="created_by differs from authenticated requester", + created_by=data.created_by, + requester=requester, + ) try: task = await app_state.task_engine.create_task( task_data, - requested_by=_extract_requester(state), + requested_by=requester, ) except PydanticValidationError as exc: raise ApiValidationError(str(exc)) from exc @@ -285,17 +315,16 @@ async def transition_task( """ app_state: AppState = state.app_state requester = _extract_requester(state) - transition_kwargs: dict[str, object] = { - "requested_by": requester, - "reason": f"API transition to {data.target_status.value}", - } + overrides: dict[str, object] = {} if data.assigned_to is not None: - transition_kwargs["assigned_to"] = data.assigned_to + overrides["assigned_to"] = data.assigned_to try: task, from_status = await app_state.task_engine.transition_task( task_id, data.target_status, - **transition_kwargs, # type: ignore[arg-type] + requested_by=requester, + reason=f"API transition to {data.target_status.value}", + **overrides, # type: ignore[arg-type] ) except PydanticValidationError as exc: raise ApiValidationError(str(exc)) from exc diff --git a/src/ai_company/config/defaults.py b/src/ai_company/config/defaults.py index 579d1fecbd..d2482fa955 100644 --- a/src/ai_company/config/defaults.py +++ b/src/ai_company/config/defaults.py @@ -1,9 +1,7 @@ """Built-in default values for company configuration.""" -from typing import Any - -def default_config_dict() -> dict[str, Any]: +def default_config_dict() -> dict[str, object]: """Return base-layer configuration defaults as a raw dict. These defaults serve as the base layer; user-provided YAML values diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index ec51c9f0c2..d76e3bbe67 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -694,7 +694,7 @@ async def _report_to_task_engine( return try: - _, _ = await self._task_engine.transition_task( + _ = await self._task_engine.transition_task( task_id, final_status, requested_by=agent_id, diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index ac9dcae707..400f4dc845 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -171,19 +171,32 @@ async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 TASK_ENGINE_DRAIN_TIMEOUT, remaining=self._queue.qsize(), ) + # Capture in-flight ref before cancel — the finally block + # in _process_one clears self._in_flight on CancelledError. + saved_in_flight = self._in_flight self._processing_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._processing_task - self._fail_remaining_futures() + self._fail_remaining_futures(saved_in_flight) self._processing_task = None logger.info(TASK_ENGINE_STOPPED) - def _fail_remaining_futures(self) -> None: - """Fail in-flight and remaining enqueued futures after drain timeout.""" + def _fail_remaining_futures( + self, + saved_in_flight: _MutationEnvelope | None = None, + ) -> None: + """Fail in-flight and remaining enqueued futures after drain timeout. + + Args: + saved_in_flight: In-flight envelope captured before task + cancellation — needed because ``_process_one``'s + ``finally`` block clears ``self._in_flight`` on + ``CancelledError``. + """ shutdown_result_for = self._shutdown_result failed_count = 0 - in_flight = self._in_flight + in_flight = saved_in_flight if saved_in_flight is not None else self._in_flight if in_flight is not None and not in_flight.future.done(): in_flight.future.set_result(shutdown_result_for(in_flight)) failed_count += 1 @@ -329,7 +342,7 @@ async def update_task( self._raise_typed_error(result) if result.task is None: msg = "Internal error: update succeeded but task is None" - raise TaskMutationError(msg) + raise TaskInternalError(msg) return result.task async def transition_task( @@ -354,7 +367,8 @@ async def transition_task( Returns: Tuple of (transitioned task, status before the transition). - The second element is ``None`` when the previous status is unknown. + The second element is ``None`` only when the underlying + mutation does not provide previous status. Raises: TaskEngineNotRunningError: If the engine is not running. @@ -381,7 +395,7 @@ async def transition_task( self._raise_typed_error(result) if result.task is None: msg = "Internal error: transition succeeded but task is None" - raise TaskMutationError(msg) + raise TaskInternalError(msg) return result.task, result.previous_status async def delete_task( @@ -449,13 +463,19 @@ async def cancel_task( self._raise_typed_error(result) if result.task is None: msg = "Internal error: cancel succeeded but task is None" - raise TaskMutationError(msg) + raise TaskInternalError(msg) return result.task @staticmethod def _raise_typed_error(result: TaskMutationResult) -> Never: """Raise a typed error from a failed mutation result.""" error = result.error or "Mutation failed" + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + request_id=result.request_id, + error=error, + error_code=result.error_code, + ) match result.error_code: case "not_found": raise TaskNotFoundError(error) @@ -476,8 +496,22 @@ async def get_task(self, task_id: str) -> Task | None: Returns: The task, or ``None`` if not found. + + Raises: + TaskInternalError: If the persistence backend fails. """ - return await self._persistence.tasks.get(task_id) + try: + return await self._persistence.tasks.get(task_id) + except MemoryError, RecursionError: + raise + except Exception as exc: + msg = f"Failed to read task: {exc}" + logger.exception( + TASK_ENGINE_MUTATION_FAILED, + error=msg, + task_id=task_id, + ) + raise TaskInternalError(msg) from exc async def list_tasks( self, @@ -495,15 +529,44 @@ async def list_tasks( Returns: Matching tasks as a tuple. + + Raises: + TaskInternalError: If the persistence backend fails. """ - return await self._persistence.tasks.list_tasks( - status=status, - assigned_to=assigned_to, - project=project, - ) + try: + tasks = await self._persistence.tasks.list_tasks( + status=status, + assigned_to=assigned_to, + project=project, + ) + except MemoryError, RecursionError: + raise + except Exception as exc: + msg = f"Failed to list tasks: {exc}" + logger.exception( + TASK_ENGINE_MUTATION_FAILED, + error=msg, + ) + raise TaskInternalError(msg) from exc + if len(tasks) > self._MAX_LIST_RESULTS: + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + error=( + f"list_tasks returned {len(tasks)} results, " + f"capping at {self._MAX_LIST_RESULTS}" + ), + ) + return tasks[: self._MAX_LIST_RESULTS] + return tasks # -- Background processing --------------------------------------------- + _MAX_LIST_RESULTS: int = 10_000 + """Safety cap on ``list_tasks`` results to bound memory usage. + + Real pagination should be pushed into the persistence layer. + """ + _POLL_INTERVAL_SECONDS: float = 0.5 """How often the processing loop checks for ``_running = False``.""" @@ -623,13 +686,13 @@ async def _publish_snapshot( timestamp=datetime.now(UTC), ) + task_id = getattr(mutation, "task_id", None) try: # Deferred to break circular import: # communication -> engine -> communication from ai_company.communication.enums import MessageType # noqa: PLC0415 from ai_company.communication.message import Message # noqa: PLC0415 - task_id = getattr(mutation, "task_id", None) msg = Message( timestamp=datetime.now(UTC), sender=self._SNAPSHOT_SENDER, @@ -648,7 +711,6 @@ async def _publish_snapshot( except MemoryError, RecursionError: raise except Exception: - task_id = getattr(mutation, "task_id", None) logger.warning( TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED, mutation_type=mutation.mutation_type, diff --git a/src/ai_company/engine/task_engine_apply.py b/src/ai_company/engine/task_engine_apply.py index fdfb3871b2..c67a0d3a64 100644 --- a/src/ai_company/engine/task_engine_apply.py +++ b/src/ai_company/engine/task_engine_apply.py @@ -55,7 +55,7 @@ def _format_validation_error( return f"{prefix}: {'; '.join(parts)}" -def not_found_result( +def _not_found_result( mutation_type: str, request_id: str, task_id: str, @@ -117,7 +117,17 @@ async def apply_create( persistence: PersistenceBackend, versions: VersionTracker, ) -> TaskMutationResult: - """Create a new task.""" + """Create a new task. + + Args: + mutation: Creation request with task data. + persistence: Backend for task storage. + versions: Version tracker for optimistic concurrency. + + Returns: + Result with the created task on success, or a validation + failure if the task data is invalid. + """ data = mutation.task_data task_id = f"task-{uuid4().hex}" @@ -170,10 +180,21 @@ async def apply_update( persistence: PersistenceBackend, versions: VersionTracker, ) -> TaskMutationResult: - """Update task fields.""" + """Update task fields. + + Args: + mutation: Update request with field-value pairs. + persistence: Backend for task storage. + versions: Version tracker for optimistic concurrency. + + Returns: + Result with the updated task on success, or a failure with + ``error_code`` of ``"not_found"``, ``"version_conflict"``, + or ``"validation"``. + """ task = await persistence.tasks.get(mutation.task_id) if task is None: - return not_found_result("update", mutation.request_id, mutation.task_id) + return _not_found_result("update", mutation.request_id, mutation.task_id) try: versions.check(mutation.task_id, mutation.expected_version) @@ -238,10 +259,21 @@ async def apply_transition( persistence: PersistenceBackend, versions: VersionTracker, ) -> TaskMutationResult: - """Perform a task status transition.""" + """Perform a task status transition. + + Args: + mutation: Transition request with target status and reason. + persistence: Backend for task storage. + versions: Version tracker for optimistic concurrency. + + Returns: + Result with the transitioned task on success, or a failure + with ``error_code`` of ``"not_found"``, + ``"version_conflict"``, or ``"validation"``. + """ task = await persistence.tasks.get(mutation.task_id) if task is None: - return not_found_result("transition", mutation.request_id, mutation.task_id) + return _not_found_result("transition", mutation.request_id, mutation.task_id) try: versions.check(mutation.task_id, mutation.expected_version) @@ -301,10 +333,20 @@ async def apply_delete( persistence: PersistenceBackend, versions: VersionTracker, ) -> TaskMutationResult: - """Delete a task.""" + """Delete a task. + + Args: + mutation: Deletion request with task identifier. + persistence: Backend for task storage. + versions: Version tracker for optimistic concurrency. + + Returns: + Result with ``success=True`` on deletion, or a failure + with ``error_code="not_found"`` if the task does not exist. + """ deleted = await persistence.tasks.delete(mutation.task_id) if not deleted: - return not_found_result("delete", mutation.request_id, mutation.task_id) + return _not_found_result("delete", mutation.request_id, mutation.task_id) versions.remove(mutation.task_id) @@ -332,10 +374,19 @@ async def apply_cancel( intentionally omits an ``expected_version`` check — a cancellation should always succeed regardless of version, similar to a forced stop signal. + + Args: + mutation: Cancellation request with task identifier and reason. + persistence: Backend for task storage. + versions: Version tracker for optimistic concurrency. + + Returns: + Result with the cancelled task on success, or a failure with + ``error_code`` of ``"not_found"`` or ``"validation"``. """ task = await persistence.tasks.get(mutation.task_id) if task is None: - return not_found_result("cancel", mutation.request_id, mutation.task_id) + return _not_found_result("cancel", mutation.request_id, mutation.task_id) previous_status = task.status try: diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py index 951a811d59..97de2c690f 100644 --- a/src/ai_company/engine/task_engine_models.py +++ b/src/ai_company/engine/task_engine_models.py @@ -7,7 +7,7 @@ import copy from datetime import UTC, datetime -from typing import Literal, Self +from typing import Final, Literal, Self from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator @@ -15,10 +15,23 @@ from ai_company.core.task import Task from ai_company.core.types import NotBlankStr # noqa: TC001 +MutationType = Literal["create", "update", "transition", "delete", "cancel"] +"""Discriminator literal for all mutation request types.""" + +TaskErrorCode = Literal["not_found", "version_conflict", "validation", "internal"] +"""Machine-readable error classification for mutation results.""" + +_MAX_TITLE_LENGTH: Final[int] = 256 +"""Maximum length for task titles (matches API-layer ``CreateTaskRequest``).""" + +_MAX_DESCRIPTION_LENGTH: Final[int] = 4096 +"""Maximum length for task descriptions (matches API-layer ``CreateTaskRequest``).""" + _VALID_TASK_FIELDS: frozenset[str] = frozenset(Task.model_fields) -"""Field names from ``model_fields`` on :class:`Task`. +"""Field names accepted by ``model_fields`` on :class:`Task`. -Excludes computed fields. +Pydantic's ``model_fields`` excludes any ``@computed_field`` +properties by design. Used to reject unknown keys in :class:`UpdateTaskMutation` and :class:`TransitionTaskMutation` validators. @@ -34,6 +47,10 @@ class CreateTaskData(BaseModel): the engine layer so it has no dependency on the API (field parity is maintained by convention, not enforced). + Note: ``CreateTaskRequest`` applies additional length constraints + (``max_length``) at the API boundary. This model enforces the same + limits for defense-in-depth so engine-layer callers also benefit. + Attributes: title: Short task title. description: Detailed task description. @@ -48,8 +65,14 @@ class CreateTaskData(BaseModel): model_config = ConfigDict(frozen=True, allow_inf_nan=False) - title: NotBlankStr = Field(description="Short task title") - description: NotBlankStr = Field(description="Detailed task description") + title: NotBlankStr = Field( + max_length=_MAX_TITLE_LENGTH, + description="Short task title", + ) + description: NotBlankStr = Field( + max_length=_MAX_DESCRIPTION_LENGTH, + description="Detailed task description", + ) type: TaskType = Field(description="Task work type") priority: Priority = Field(default=Priority.MEDIUM, description="Task priority") project: NotBlankStr = Field(description="Project ID") @@ -84,19 +107,21 @@ class CreateTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: Literal["create"] = "create" + mutation_type: MutationType = "create" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_data: CreateTaskData = Field(description="Task creation payload") -_IMMUTABLE_TASK_FIELDS: frozenset[str] = frozenset( - { - "id", - "status", - "created_by", - } -) +_ALWAYS_IMMUTABLE_FIELDS: frozenset[str] = frozenset({"id", "created_by"}) +"""Core identity fields that can never be changed via any mutation. + +``status`` is also immutable for updates (must use transitions) and +for transition overrides (set via ``target_status``). See +:data:`_IMMUTABLE_TASK_FIELDS` and :data:`_IMMUTABLE_OVERRIDE_FIELDS`. +""" + +_IMMUTABLE_TASK_FIELDS: frozenset[str] = _ALWAYS_IMMUTABLE_FIELDS | {"status"} """Fields that must not be modified via :class:`UpdateTaskMutation`. ``status`` must go through :class:`TransitionTaskMutation` (which @@ -119,7 +144,7 @@ class UpdateTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: Literal["update"] = "update" + mutation_type: MutationType = "update" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_id: NotBlankStr = Field(description="Target task identifier") @@ -148,14 +173,11 @@ def __init__(self, **data: object) -> None: object.__setattr__(self, "updates", copy.deepcopy(self.updates)) -_IMMUTABLE_OVERRIDE_FIELDS: frozenset[str] = frozenset( - { - "id", - "created_by", - "status", - } -) -"""Fields that must not be overridden during a transition.""" +_IMMUTABLE_OVERRIDE_FIELDS: frozenset[str] = _ALWAYS_IMMUTABLE_FIELDS | {"status"} +"""Fields that must not be overridden during a transition. + +See :data:`_ALWAYS_IMMUTABLE_FIELDS` for the shared base. +""" class TransitionTaskMutation(BaseModel): @@ -174,7 +196,7 @@ class TransitionTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: Literal["transition"] = "transition" + mutation_type: MutationType = "transition" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_id: NotBlankStr = Field(description="Target task identifier") @@ -220,7 +242,7 @@ class DeleteTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: Literal["delete"] = "delete" + mutation_type: MutationType = "delete" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_id: NotBlankStr = Field(description="Target task identifier") @@ -239,7 +261,7 @@ class CancelTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: Literal["cancel"] = "cancel" + mutation_type: MutationType = "cancel" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_id: NotBlankStr = Field(description="Target task identifier") @@ -285,9 +307,7 @@ class TaskMutationResult(BaseModel): description="Status before mutation", ) error: str | None = Field(default=None, description="Error description") - error_code: ( - Literal["not_found", "version_conflict", "validation", "internal"] | None - ) = Field( + error_code: TaskErrorCode | None = Field( default=None, description="Machine-readable error classification", ) @@ -323,8 +343,8 @@ class TaskStateChanged(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: Literal["create", "update", "transition", "delete", "cancel"] = ( - Field(description="Mutation type that triggered event") + mutation_type: MutationType = Field( + description="Mutation type that triggered event", ) request_id: NotBlankStr = Field(description="Originating request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") diff --git a/src/ai_company/engine/task_engine_version.py b/src/ai_company/engine/task_engine_version.py index db5073a159..b05909c97b 100644 --- a/src/ai_company/engine/task_engine_version.py +++ b/src/ai_company/engine/task_engine_version.py @@ -52,7 +52,11 @@ def set_initial(self, task_id: str, version: int) -> None: self._versions[task_id] = version def bump(self, task_id: str) -> int: - """Increment and return the version counter for *task_id*.""" + """Increment and return the version counter for *task_id*. + + If *task_id* is not yet tracked, it is seeded at version 1 + first, so the returned value will be 2 (not 1). + """ self.seed(task_id) version = self._versions[task_id] + 1 self._versions[task_id] = version diff --git a/src/ai_company/observability/events/api.py b/src/ai_company/observability/events/api.py index 67aecd48cb..688bc83a4d 100644 --- a/src/ai_company/observability/events/api.py +++ b/src/ai_company/observability/events/api.py @@ -36,3 +36,4 @@ API_AUTH_SETUP_COMPLETE: Final[str] = "api.auth.setup_complete" API_AUTH_PASSWORD_CHANGED: Final[str] = "api.auth.password_changed" # noqa: S105 API_TASK_TRANSITION_FAILED: Final[str] = "api.task.transition_failed" +API_AUTH_FALLBACK: Final[str] = "api.auth.fallback" diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index aa7befec2b..c0f0a014dc 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -28,6 +28,12 @@ ) from ai_company.core.task import Task from ai_company.engine.task_engine import TaskEngine +from ai_company.hr.enums import LifecycleEventType # noqa: TC001 +from ai_company.hr.models import AgentLifecycleEvent # noqa: TC001 +from ai_company.hr.performance.models import ( # noqa: TC001 + CollaborationMetricRecord, + TaskMetricRecord, +) from ai_company.persistence.errors import DuplicateRecordError, QueryError from ai_company.security.models import AuditEntry, AuditVerdictStr # noqa: TC001 from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 @@ -133,18 +139,18 @@ class FakeLifecycleEventRepository: """In-memory lifecycle event repository for tests.""" def __init__(self) -> None: - self._events: list[Any] = [] + self._events: list[AgentLifecycleEvent] = [] - async def save(self, event: Any) -> None: + async def save(self, event: AgentLifecycleEvent) -> None: self._events.append(event) async def list_events( self, *, agent_id: str | None = None, - event_type: Any = None, - since: Any = None, - ) -> tuple[Any, ...]: + event_type: LifecycleEventType | None = None, + since: datetime | None = None, + ) -> tuple[AgentLifecycleEvent, ...]: result = self._events if agent_id is not None: result = [e for e in result if e.agent_id == agent_id] @@ -159,18 +165,18 @@ class FakeTaskMetricRepository: """In-memory task metric repository for tests.""" def __init__(self) -> None: - self._records: list[Any] = [] + self._records: list[TaskMetricRecord] = [] - async def save(self, record: Any) -> None: + async def save(self, record: TaskMetricRecord) -> None: self._records.append(record) async def query( self, *, agent_id: str | None = None, - since: Any = None, - until: Any = None, - ) -> tuple[Any, ...]: + since: datetime | None = None, + until: datetime | None = None, + ) -> tuple[TaskMetricRecord, ...]: result = self._records if agent_id is not None: result = [r for r in result if r.agent_id == agent_id] @@ -185,17 +191,17 @@ class FakeCollaborationMetricRepository: """In-memory collaboration metric repository for tests.""" def __init__(self) -> None: - self._records: list[Any] = [] + self._records: list[CollaborationMetricRecord] = [] - async def save(self, record: Any) -> None: + async def save(self, record: CollaborationMetricRecord) -> None: self._records.append(record) async def query( self, *, agent_id: str | None = None, - since: Any = None, - ) -> tuple[Any, ...]: + since: datetime | None = None, + ) -> tuple[CollaborationMetricRecord, ...]: result = self._records if agent_id is not None: result = [r for r in result if r.agent_id == agent_id] diff --git a/tests/unit/api/controllers/test_task_helpers.py b/tests/unit/api/controllers/test_task_helpers.py index aa733a6b27..f82c3e7841 100644 --- a/tests/unit/api/controllers/test_task_helpers.py +++ b/tests/unit/api/controllers/test_task_helpers.py @@ -101,7 +101,8 @@ def test_mutation_error_maps_to_validation_error(self) -> None: result = _map_task_engine_errors(exc) assert isinstance(result, ApiValidationError) - def test_unknown_error_passes_through(self) -> None: + def test_unknown_error_wraps_as_service_unavailable(self) -> None: exc = RuntimeError("unexpected") result = _map_task_engine_errors(exc) - assert result is exc + assert isinstance(result, ServiceUnavailableError) + assert "Unexpected engine error" in str(result) diff --git a/tests/unit/engine/conftest.py b/tests/unit/engine/conftest.py index 4b6d6d1e99..b93566f8f2 100644 --- a/tests/unit/engine/conftest.py +++ b/tests/unit/engine/conftest.py @@ -43,7 +43,7 @@ from tests.unit.engine.task_engine_helpers import FakeMessageBus, FakePersistence if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterator + from collections.abc import AsyncIterator from ai_company.core.enums import ConflictEscalation from ai_company.engine.workspace.models import ( @@ -432,7 +432,7 @@ def config() -> TaskEngineConfig: async def engine( persistence: FakePersistence, config: TaskEngineConfig, -) -> AsyncGenerator[TaskEngine]: +) -> AsyncIterator[TaskEngine]: """Create and start a TaskEngine, stop on teardown.""" eng = TaskEngine( persistence=persistence, # type: ignore[arg-type] @@ -448,7 +448,7 @@ async def engine_with_bus( persistence: FakePersistence, message_bus: FakeMessageBus, config: TaskEngineConfig, -) -> AsyncGenerator[TaskEngine]: +) -> AsyncIterator[TaskEngine]: """Create and start a TaskEngine with a message bus.""" eng = TaskEngine( persistence=persistence, # type: ignore[arg-type] diff --git a/tests/unit/engine/task_engine_helpers.py b/tests/unit/engine/task_engine_helpers.py index 6837d0be44..330bc043f8 100644 --- a/tests/unit/engine/task_engine_helpers.py +++ b/tests/unit/engine/task_engine_helpers.py @@ -87,7 +87,7 @@ async def publish(self, message: object) -> None: # ── Helpers ──────────────────────────────────────────────────── -def _make_create_data(**overrides: object) -> CreateTaskData: +def make_create_data(**overrides: object) -> CreateTaskData: """Build a CreateTaskData with sensible defaults.""" from ai_company.core.enums import TaskType diff --git a/tests/unit/engine/test_agent_engine.py b/tests/unit/engine/test_agent_engine.py index a790323586..c85b370a05 100644 --- a/tests/unit/engine/test_agent_engine.py +++ b/tests/unit/engine/test_agent_engine.py @@ -14,7 +14,11 @@ from ai_company.core.task import Task from ai_company.engine.agent_engine import AgentEngine from ai_company.engine.context import AgentContext -from ai_company.engine.errors import ExecutionStateError, TaskMutationError +from ai_company.engine.errors import ( + ExecutionStateError, + TaskEngineError, + TaskMutationError, +) from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, @@ -1094,6 +1098,34 @@ async def test_unexpected_error_swallowed( # Run still succeeds despite task engine failure assert result.is_success is True + async def test_task_engine_error_swallowed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """TaskEngineError (non-mutation) from TaskEngine is logged and swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock( + side_effect=TaskEngineError("engine unavailable"), + ) + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + # Run still succeeds despite task engine failure + assert result.is_success is True + async def test_memory_error_propagates( self, sample_agent_with_personality: AgentIdentity, diff --git a/tests/unit/engine/test_task_engine_apply.py b/tests/unit/engine/test_task_engine_apply.py index 3d4b90c28e..c746771763 100644 --- a/tests/unit/engine/test_task_engine_apply.py +++ b/tests/unit/engine/test_task_engine_apply.py @@ -20,7 +20,7 @@ UpdateTaskMutation, ) from ai_company.engine.task_engine_version import VersionTracker -from tests.unit.engine.task_engine_helpers import FakePersistence, _make_create_data +from tests.unit.engine.task_engine_helpers import FakePersistence, make_create_data @pytest.fixture @@ -48,7 +48,7 @@ async def test_dispatch_create( mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) result = await dispatch(mutation, persistence, versions) # type: ignore[arg-type] assert result.success is True @@ -85,7 +85,7 @@ async def test_creates_task( mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(title="New Task"), + task_data=make_create_data(title="New Task"), ) result = await apply_create(mutation, persistence, versions) # type: ignore[arg-type] assert result.success is True @@ -107,7 +107,7 @@ async def test_create_validation_error( mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(assigned_to="bob"), + task_data=make_create_data(assigned_to="bob"), ) result = await apply_create(mutation, persistence, versions) # type: ignore[arg-type] assert result.success is False @@ -122,7 +122,7 @@ async def test_create_persists_task( mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) result = await apply_create(mutation, persistence, versions) # type: ignore[arg-type] assert result.task is not None @@ -145,7 +145,7 @@ async def _create_task( mutation = CreateTaskMutation( request_id="req-c", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) return await apply_create(mutation, persistence, versions) # type: ignore[arg-type] @@ -269,7 +269,7 @@ async def _create_task( mutation = CreateTaskMutation( request_id="req-c", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) return await apply_create(mutation, persistence, versions) # type: ignore[arg-type] @@ -367,7 +367,7 @@ async def test_delete_task( CreateTaskMutation( request_id="req-c", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ), persistence, # type: ignore[arg-type] versions, @@ -408,7 +408,7 @@ async def test_delete_removes_version_tracking( CreateTaskMutation( request_id="req-c", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ), persistence, # type: ignore[arg-type] versions, @@ -446,7 +446,7 @@ async def _create_and_assign( CreateTaskMutation( request_id="req-c", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ), persistence, # type: ignore[arg-type] versions, @@ -510,7 +510,7 @@ async def test_cancel_invalid_status( CreateTaskMutation( request_id="req-c", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ), persistence, # type: ignore[arg-type] versions, diff --git a/tests/unit/engine/test_task_engine_convenience.py b/tests/unit/engine/test_task_engine_convenience.py new file mode 100644 index 0000000000..76f71dc4bf --- /dev/null +++ b/tests/unit/engine/test_task_engine_convenience.py @@ -0,0 +1,357 @@ +"""Convenience method, typed error dispatch, and lifecycle edge-case tests. + +Split from ``test_task_engine_extended.py`` to keep files under 800 lines. +""" + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.errors import ( + TaskInternalError, + TaskMutationError, + TaskNotFoundError, + TaskVersionConflictError, +) +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_models import ( + CreateTaskMutation, + TaskMutationResult, +) +from tests.unit.engine.task_engine_helpers import ( + FakePersistence, + make_create_data, +) + +# ── Typed error dispatch for all error codes ───────────────── + + +@pytest.mark.unit +class TestRaiseTypedErrorAllCodes: + """_raise_typed_error maps all error_code values to typed exceptions.""" + + def test_not_found_code(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="not found", + error_code="not_found", + ) + with pytest.raises(TaskNotFoundError, match="not found"): + TaskEngine._raise_typed_error(result) + + def test_version_conflict_code(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="conflict", + error_code="version_conflict", + ) + with pytest.raises(TaskVersionConflictError, match="conflict"): + TaskEngine._raise_typed_error(result) + + def test_internal_code(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="boom", + error_code="internal", + ) + with pytest.raises(TaskInternalError, match="boom"): + TaskEngine._raise_typed_error(result) + + def test_validation_code_falls_through(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="bad data", + error_code="validation", + ) + with pytest.raises(TaskMutationError, match="bad data"): + TaskEngine._raise_typed_error(result) + + def test_none_code_falls_through(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="generic", + ) + with pytest.raises(TaskMutationError, match="generic"): + TaskEngine._raise_typed_error(result) + + def test_missing_error_uses_default_message(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="Mutation failed", + ) + with pytest.raises(TaskMutationError, match="Mutation failed"): + TaskEngine._raise_typed_error(result) + + +# ── Transition with overrides via engine ───────────────────── + + +@pytest.mark.unit +class TestTransitionOverridesViaEngine: + """Transition overrides flow through the engine correctly.""" + + async def test_transition_with_assigned_to_override( + self, + engine: TaskEngine, + ) -> None: + """assigned_to passed as kwarg becomes an override on the mutation.""" + task = await engine.create_task( + make_create_data(), + requested_by="alice", + ) + transitioned, prev = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + assert transitioned.assigned_to == "bob" + assert prev == TaskStatus.CREATED + + async def test_transition_returns_previous_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + make_create_data(), + requested_by="alice", + ) + # CREATED -> ASSIGNED + assigned, prev1 = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + assert prev1 == TaskStatus.CREATED + + # ASSIGNED -> IN_PROGRESS + in_progress, prev2 = await engine.transition_task( + assigned.id, + TaskStatus.IN_PROGRESS, + requested_by="bob", + reason="Starting work", + ) + assert prev2 == TaskStatus.ASSIGNED + assert in_progress.status == TaskStatus.IN_PROGRESS + + +# ── PydanticValidationError wrapping in convenience methods ── + + +@pytest.mark.unit +class TestConvenienceMethodValidationWrapping: + """Convenience methods wrap PydanticValidationError as TaskMutationError.""" + + async def test_update_task_wraps_pydantic_validation( + self, + engine: TaskEngine, + ) -> None: + """UpdateTaskMutation with immutable field raises TaskMutationError.""" + task = await engine.create_task( + make_create_data(), + requested_by="alice", + ) + # 'id' is an immutable field rejected by model_validator + with pytest.raises(TaskMutationError): + await engine.update_task( + task.id, + {"id": "hacked"}, + requested_by="alice", + ) + + async def test_transition_task_wraps_pydantic_validation( + self, + engine: TaskEngine, + ) -> None: + """TransitionTaskMutation with blank reason raises TaskMutationError.""" + task = await engine.create_task( + make_create_data(), + requested_by="alice", + ) + # Blank requested_by should trigger NotBlankStr validation + with pytest.raises(TaskMutationError): + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by=" ", + reason="Assigning", + assigned_to="bob", + ) + + +# ── Version conflict via convenience methods ───────────────── + + +@pytest.mark.unit +class TestVersionConflictViaConvenienceMethods: + """Convenience methods raise TaskVersionConflictError on version mismatch.""" + + async def test_update_task_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskVersionConflictError): + await engine.update_task( + task.id, + {"title": "New title"}, + requested_by="alice", + expected_version=99, + ) + + async def test_transition_task_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskVersionConflictError): + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + expected_version=99, + assigned_to="bob", + ) + + +# ── Cancel task not found ──────────────────────────────────── + + +@pytest.mark.unit +class TestCancelTaskNotFound: + """cancel_task raises TaskNotFoundError for missing tasks.""" + + async def test_cancel_nonexistent_raises_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.cancel_task( + "task-nonexistent", + requested_by="alice", + reason="Cleanup", + ) + + +# ── Delete task not found ──────────────────────────────────── + + +@pytest.mark.unit +class TestDeleteTaskNotFound: + """delete_task raises TaskNotFoundError for missing tasks.""" + + async def test_delete_nonexistent_raises_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.delete_task( + "task-nonexistent", + requested_by="alice", + ) + + +# ── Start when already running ──────────────────────────────── + + +@pytest.mark.unit +class TestStartAlreadyRunning: + """Starting an already-running engine raises RuntimeError.""" + + async def test_double_start_raises( + self, + engine: TaskEngine, + ) -> None: + # engine fixture already called start() + with pytest.raises(RuntimeError, match="already running"): + engine.start() + + +# ── Stop idempotency ───────────────────────────────────────── + + +@pytest.mark.unit +class TestStopIdempotency: + """Stopping an already-stopped engine is a no-op.""" + + async def test_stop_when_not_running( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + # Never started — stop should be safe + await eng.stop(timeout=1.0) + + async def test_double_stop( + self, + engine: TaskEngine, + ) -> None: + await engine.stop(timeout=2.0) + # Second stop is a no-op + await engine.stop(timeout=1.0) + + +# ── Submit when not running ─────────────────────────────────── + + +@pytest.mark.unit +class TestSubmitWhenNotRunning: + """Submitting to a stopped engine raises TaskEngineNotRunningError.""" + + async def test_submit_after_stop( + self, + engine: TaskEngine, + ) -> None: + from ai_company.engine.errors import TaskEngineNotRunningError + + await engine.stop(timeout=2.0) + mutation = CreateTaskMutation( + request_id="req-late", + requested_by="alice", + task_data=make_create_data(), + ) + with pytest.raises(TaskEngineNotRunningError): + await engine.submit(mutation) + + +# ── is_running property ────────────────────────────────────── + + +@pytest.mark.unit +class TestIsRunningProperty: + """is_running reflects engine lifecycle.""" + + async def test_running_after_start( + self, + engine: TaskEngine, + ) -> None: + assert engine.is_running is True + + async def test_not_running_after_stop( + self, + engine: TaskEngine, + ) -> None: + await engine.stop(timeout=2.0) + assert engine.is_running is False + + async def test_not_running_before_start( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + assert eng.is_running is False diff --git a/tests/unit/engine/test_task_engine_coverage.py b/tests/unit/engine/test_task_engine_coverage.py index 933cf94690..3d5a6e75d8 100644 --- a/tests/unit/engine/test_task_engine_coverage.py +++ b/tests/unit/engine/test_task_engine_coverage.py @@ -19,7 +19,7 @@ from tests.unit.engine.task_engine_helpers import ( FailingMessageBus, FakePersistence, - _make_create_data, + make_create_data, ) # ── In-flight envelope resolution ──────────────────────────── @@ -48,7 +48,7 @@ async def slow_save(task: object) -> None: # Submit a task that will block in slow_save blocked = asyncio.create_task( - eng.create_task(_make_create_data(), requested_by="alice"), + eng.create_task(make_create_data(), requested_by="alice"), ) await asyncio.sleep(0.05) @@ -97,7 +97,7 @@ async def exploding_save(task: object) -> None: mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) result = await eng.submit(mutation) assert result.success is False @@ -129,7 +129,7 @@ async def test_publish_failure_logged_not_raised( eng.start() try: task = await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) assert task.id.startswith("task-") @@ -175,7 +175,7 @@ async def test_shutdown_result_envelope(self) -> None: mutation = CreateTaskMutation( request_id="req-shutdown", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) envelope = _MutationEnvelope(mutation=mutation) result = TaskEngine._shutdown_result(envelope) @@ -217,7 +217,7 @@ async def fail_first_save(task: object) -> None: m1 = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) r1 = await eng.submit(m1) assert r1.success is False @@ -226,7 +226,7 @@ async def fail_first_save(task: object) -> None: m2 = CreateTaskMutation( request_id="req-2", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) r2 = await eng.submit(m2) assert r2.success is True diff --git a/tests/unit/engine/test_task_engine_extended.py b/tests/unit/engine/test_task_engine_extended.py index 9048773ad0..b28ff1e392 100644 --- a/tests/unit/engine/test_task_engine_extended.py +++ b/tests/unit/engine/test_task_engine_extended.py @@ -19,8 +19,6 @@ from ai_company.engine.errors import ( TaskInternalError, TaskMutationError, - TaskNotFoundError, - TaskVersionConflictError, ) from ai_company.engine.task_engine import TaskEngine, _MutationEnvelope from ai_company.engine.task_engine_models import ( @@ -33,7 +31,7 @@ from tests.unit.engine.task_engine_helpers import ( FakeMessageBus, FakePersistence, - _make_create_data, + make_create_data, ) if TYPE_CHECKING: @@ -68,7 +66,7 @@ async def test_mutations_processed_in_submission_order( CreateTaskMutation( request_id=f"req-{i}", requested_by="alice", - task_data=_make_create_data(title=f"Task {i}"), + task_data=make_create_data(title=f"Task {i}"), ) for i in range(5) ] @@ -89,7 +87,7 @@ async def test_interleaved_create_update_ordering( ) -> None: """Create then update: update sees the created task.""" task = await engine.create_task( - _make_create_data(title="Original"), + make_create_data(title="Original"), requested_by="alice", ) # Immediately update — this should see the task because @@ -114,7 +112,7 @@ async def test_empty_reason_generates_default( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) transitioned, _ = await engine.transition_task( @@ -131,7 +129,7 @@ async def test_explicit_reason_preserved( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) transitioned, _ = await engine.transition_task( @@ -165,7 +163,7 @@ async def test_delete_snapshot_has_none_status( eng.start() try: task = await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) await asyncio.sleep(0) # let snapshot publish @@ -200,7 +198,7 @@ async def test_cancel_increments_version( create_mut = CreateTaskMutation( request_id="req-c", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) r1 = await engine.submit(create_mut) assert r1.version == 1 @@ -257,7 +255,7 @@ async def exploding_save(task: object) -> None: try: with pytest.raises(TaskInternalError): await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) finally: @@ -270,7 +268,7 @@ async def test_create_validation_error_raises_mutation_error( """Validation failure (assigned_to on CREATED) raises TaskMutationError.""" with pytest.raises(TaskMutationError): await engine.create_task( - _make_create_data(assigned_to="should-fail"), + make_create_data(assigned_to="should-fail"), requested_by="alice", ) @@ -296,7 +294,7 @@ async def test_transition_snapshot_carries_reason( eng.start() try: task = await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) await asyncio.sleep(0) @@ -333,7 +331,7 @@ async def test_cancel_snapshot_carries_reason( eng.start() try: task = await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) await eng.transition_task( @@ -375,7 +373,7 @@ async def test_create_snapshot_reason_is_none( eng.start() try: await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) await asyncio.sleep(0) @@ -402,7 +400,7 @@ async def test_update_snapshot_reason_is_none( eng.start() try: task = await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) await asyncio.sleep(0) @@ -451,7 +449,7 @@ async def oom_save(task: object) -> None: mutation = CreateTaskMutation( request_id="req-oom", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) # The MemoryError propagates to the processing loop which # re-raises it, causing the processing task to fail. @@ -468,6 +466,38 @@ async def oom_save(task: object) -> None: eng._running = False eng._processing_task = None + async def test_recursion_error_propagates_through_process_one( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + """RecursionError in dispatch propagates through _process_one.""" + + async def recursive_save(task: object) -> None: + raise RecursionError + + persistence.tasks.save = recursive_save # type: ignore[method-assign] + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + mutation = CreateTaskMutation( + request_id="req-recurse", + requested_by="alice", + task_data=make_create_data(), + ) + envelope = _MutationEnvelope(mutation=mutation) + eng._queue.put_nowait(envelope) + + assert eng._processing_task is not None + with pytest.raises(RecursionError): + await eng._processing_task + finally: + eng._running = False + eng._processing_task = None + # ── _fail_remaining_futures coverage ───────────────────────── @@ -489,7 +519,7 @@ async def test_multiple_queued_futures_all_failed( mutation = CreateTaskMutation( request_id=f"req-{i}", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) envelope = _MutationEnvelope(mutation=mutation) eng._queue.put_nowait(envelope) @@ -504,340 +534,3 @@ async def test_multiple_queued_futures_all_failed( assert result.success is False assert result.error_code == "internal" assert "shut down" in (result.error or "").lower() - - -# ── Typed error dispatch for all error codes ───────────────── - - -@pytest.mark.unit -class TestRaiseTypedErrorAllCodes: - """_raise_typed_error maps all error_code values to typed exceptions.""" - - def test_not_found_code(self) -> None: - result = TaskMutationResult( - request_id="r", - success=False, - error="not found", - error_code="not_found", - ) - with pytest.raises(TaskNotFoundError, match="not found"): - TaskEngine._raise_typed_error(result) - - def test_version_conflict_code(self) -> None: - from ai_company.engine.errors import TaskVersionConflictError - - result = TaskMutationResult( - request_id="r", - success=False, - error="conflict", - error_code="version_conflict", - ) - with pytest.raises(TaskVersionConflictError, match="conflict"): - TaskEngine._raise_typed_error(result) - - def test_internal_code(self) -> None: - result = TaskMutationResult( - request_id="r", - success=False, - error="boom", - error_code="internal", - ) - with pytest.raises(TaskInternalError, match="boom"): - TaskEngine._raise_typed_error(result) - - def test_validation_code_falls_through(self) -> None: - result = TaskMutationResult( - request_id="r", - success=False, - error="bad data", - error_code="validation", - ) - with pytest.raises(TaskMutationError, match="bad data"): - TaskEngine._raise_typed_error(result) - - def test_none_code_falls_through(self) -> None: - result = TaskMutationResult( - request_id="r", - success=False, - error="generic", - ) - with pytest.raises(TaskMutationError, match="generic"): - TaskEngine._raise_typed_error(result) - - def test_missing_error_uses_default_message(self) -> None: - result = TaskMutationResult( - request_id="r", - success=False, - error="Mutation failed", - ) - with pytest.raises(TaskMutationError, match="Mutation failed"): - TaskEngine._raise_typed_error(result) - - -# ── Transition with overrides via engine ───────────────────── - - -@pytest.mark.unit -class TestTransitionOverridesViaEngine: - """Transition overrides flow through the engine correctly.""" - - async def test_transition_with_assigned_to_override( - self, - engine: TaskEngine, - ) -> None: - """assigned_to passed as kwarg becomes an override on the mutation.""" - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - transitioned, prev = await engine.transition_task( - task.id, - TaskStatus.ASSIGNED, - requested_by="alice", - reason="Assigning", - assigned_to="bob", - ) - assert transitioned.assigned_to == "bob" - assert prev == TaskStatus.CREATED - - async def test_transition_returns_previous_status( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - # CREATED -> ASSIGNED - assigned, prev1 = await engine.transition_task( - task.id, - TaskStatus.ASSIGNED, - requested_by="alice", - reason="Assigning", - assigned_to="bob", - ) - assert prev1 == TaskStatus.CREATED - - # ASSIGNED -> IN_PROGRESS - in_progress, prev2 = await engine.transition_task( - assigned.id, - TaskStatus.IN_PROGRESS, - requested_by="bob", - reason="Starting work", - ) - assert prev2 == TaskStatus.ASSIGNED - assert in_progress.status == TaskStatus.IN_PROGRESS - - -# ── PydanticValidationError wrapping in convenience methods ── - - -@pytest.mark.unit -class TestConvenienceMethodValidationWrapping: - """Convenience methods wrap PydanticValidationError as TaskMutationError.""" - - async def test_update_task_wraps_pydantic_validation( - self, - engine: TaskEngine, - ) -> None: - """UpdateTaskMutation with immutable field raises TaskMutationError.""" - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - # 'id' is an immutable field rejected by model_validator - with pytest.raises(TaskMutationError): - await engine.update_task( - task.id, - {"id": "hacked"}, - requested_by="alice", - ) - - async def test_transition_task_wraps_pydantic_validation( - self, - engine: TaskEngine, - ) -> None: - """TransitionTaskMutation with blank reason raises TaskMutationError.""" - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - # Blank requested_by should trigger NotBlankStr validation - with pytest.raises(TaskMutationError): - await engine.transition_task( - task.id, - TaskStatus.ASSIGNED, - requested_by=" ", - reason="Assigning", - assigned_to="bob", - ) - - -# ── Version conflict via convenience methods ───────────────── - - -@pytest.mark.unit -class TestVersionConflictViaConvenienceMethods: - """Convenience methods raise TaskVersionConflictError on version mismatch.""" - - async def test_update_task_version_conflict( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - with pytest.raises(TaskVersionConflictError): - await engine.update_task( - task.id, - {"title": "New title"}, - requested_by="alice", - expected_version=99, - ) - - async def test_transition_task_version_conflict( - self, - engine: TaskEngine, - ) -> None: - task = await engine.create_task( - _make_create_data(), - requested_by="alice", - ) - with pytest.raises(TaskVersionConflictError): - await engine.transition_task( - task.id, - TaskStatus.ASSIGNED, - requested_by="alice", - reason="Assigning", - expected_version=99, - assigned_to="bob", - ) - - -# ── Cancel task not found ──────────────────────────────────── - - -@pytest.mark.unit -class TestCancelTaskNotFound: - """cancel_task raises TaskNotFoundError for missing tasks.""" - - async def test_cancel_nonexistent_raises_not_found( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskNotFoundError): - await engine.cancel_task( - "task-nonexistent", - requested_by="alice", - reason="Cleanup", - ) - - -# ── Delete task not found ──────────────────────────────────── - - -@pytest.mark.unit -class TestDeleteTaskNotFound: - """delete_task raises TaskNotFoundError for missing tasks.""" - - async def test_delete_nonexistent_raises_not_found( - self, - engine: TaskEngine, - ) -> None: - with pytest.raises(TaskNotFoundError): - await engine.delete_task( - "task-nonexistent", - requested_by="alice", - ) - - -# ── Start when already running ──────────────────────────────── - - -@pytest.mark.unit -class TestStartAlreadyRunning: - """Starting an already-running engine raises RuntimeError.""" - - async def test_double_start_raises( - self, - engine: TaskEngine, - ) -> None: - # engine fixture already called start() - with pytest.raises(RuntimeError, match="already running"): - engine.start() - - -# ── Stop idempotency ───────────────────────────────────────── - - -@pytest.mark.unit -class TestStopIdempotency: - """Stopping an already-stopped engine is a no-op.""" - - async def test_stop_when_not_running( - self, - persistence: FakePersistence, - ) -> None: - eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] - # Never started — stop should be safe - await eng.stop(timeout=1.0) - - async def test_double_stop( - self, - engine: TaskEngine, - ) -> None: - await engine.stop(timeout=2.0) - # Second stop is a no-op - await engine.stop(timeout=1.0) - - -# ── Submit when not running ─────────────────────────────────── - - -@pytest.mark.unit -class TestSubmitWhenNotRunning: - """Submitting to a stopped engine raises TaskEngineNotRunningError.""" - - async def test_submit_after_stop( - self, - engine: TaskEngine, - ) -> None: - from ai_company.engine.errors import TaskEngineNotRunningError - - await engine.stop(timeout=2.0) - mutation = CreateTaskMutation( - request_id="req-late", - requested_by="alice", - task_data=_make_create_data(), - ) - with pytest.raises(TaskEngineNotRunningError): - await engine.submit(mutation) - - -# ── is_running property ────────────────────────────────────── - - -@pytest.mark.unit -class TestIsRunningProperty: - """is_running reflects engine lifecycle.""" - - async def test_running_after_start( - self, - engine: TaskEngine, - ) -> None: - assert engine.is_running is True - - async def test_not_running_after_stop( - self, - engine: TaskEngine, - ) -> None: - await engine.stop(timeout=2.0) - assert engine.is_running is False - - async def test_not_running_before_start( - self, - persistence: FakePersistence, - ) -> None: - eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] - assert eng.is_running is False diff --git a/tests/unit/engine/test_task_engine_integration.py b/tests/unit/engine/test_task_engine_integration.py index df1d887173..c9c8e1494d 100644 --- a/tests/unit/engine/test_task_engine_integration.py +++ b/tests/unit/engine/test_task_engine_integration.py @@ -19,7 +19,7 @@ FailingMessageBus, FakeMessageBus, FakePersistence, - _make_create_data, + make_create_data, ) # ── Snapshot publishing ─────────────────────────────────────── @@ -35,7 +35,7 @@ async def test_snapshot_published_on_create( message_bus: FakeMessageBus, ) -> None: await engine_with_bus.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) # Yield to event loop so the processing loop completes snapshot publication @@ -56,7 +56,7 @@ async def test_snapshot_publish_failure_does_not_affect_mutation( eng.start() try: task = await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) assert task.id.startswith("task-") @@ -80,7 +80,7 @@ async def test_no_snapshot_when_disabled( eng.start() try: await eng.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) await asyncio.sleep(0) @@ -99,7 +99,7 @@ async def test_pending_mutations_drained_on_stop( # Submit concurrently without awaiting — items enter the queue create_tasks = [ asyncio.create_task( - eng.create_task(_make_create_data(), requested_by="alice") + eng.create_task(make_create_data(), requested_by="alice") ) for _ in range(5) ] @@ -132,7 +132,7 @@ async def test_concurrent_submits( tasks = await asyncio.gather( *( engine.create_task( - _make_create_data(title=f"Task {i}"), + make_create_data(title=f"Task {i}"), requested_by="alice", ) for i in range(10) @@ -169,7 +169,7 @@ async def test_queue_full_raises( mutation1 = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) eng._queue.put_nowait(_MutationEnvelope(mutation=mutation1)) @@ -177,7 +177,7 @@ async def test_queue_full_raises( mutation2 = CreateTaskMutation( request_id="req-2", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) with pytest.raises(TaskEngineQueueFullError, match="queue is full"): await eng.submit(mutation2) @@ -199,7 +199,7 @@ async def test_version_increments( mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) r1 = await engine.submit(mutation) assert r1.version == 1 @@ -218,7 +218,7 @@ async def test_version_conflict( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) # version is 1 after create; expected_version=99 should fail @@ -238,7 +238,7 @@ async def test_version_reset_on_delete( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) delete = DeleteTaskMutation( @@ -254,7 +254,7 @@ async def test_transition_version_conflict( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) mutation = TransitionTaskMutation( @@ -300,14 +300,14 @@ async def slow_save(task: object) -> None: # Submit a task — it'll block in slow_save, holding the processing loop blocked_task = asyncio.create_task( - eng.create_task(_make_create_data(), requested_by="alice") + eng.create_task(make_create_data(), requested_by="alice") ) # Queue a second task directly so it's definitely waiting mutation2 = CreateTaskMutation( request_id="req-queued", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) envelope = _MutationEnvelope(mutation=mutation2) # Wait until slow_save is entered before queuing the second task diff --git a/tests/unit/engine/test_task_engine_lifecycle.py b/tests/unit/engine/test_task_engine_lifecycle.py index 818b2fa2d7..e5ee152b45 100644 --- a/tests/unit/engine/test_task_engine_lifecycle.py +++ b/tests/unit/engine/test_task_engine_lifecycle.py @@ -6,7 +6,7 @@ from ai_company.engine.task_engine import TaskEngine from ai_company.engine.task_engine_config import TaskEngineConfig from ai_company.engine.task_engine_models import CreateTaskMutation -from tests.unit.engine.task_engine_helpers import FakePersistence, _make_create_data +from tests.unit.engine.task_engine_helpers import FakePersistence, make_create_data # ── Lifecycle tests ─────────────────────────────────────────── @@ -23,7 +23,7 @@ async def test_start_sets_running( assert eng.is_running is False eng.start() assert eng.is_running is True - await eng.stop(timeout=2.0) # type: ignore[unreachable] + await eng.stop(timeout=2.0) assert eng.is_running is False async def test_double_start_raises( @@ -72,7 +72,7 @@ async def test_submit_raises( mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) with pytest.raises(TaskEngineNotRunningError): await eng.submit(mutation) @@ -107,3 +107,34 @@ def test_frozen(self) -> None: cfg = TaskEngineConfig() with pytest.raises(ValidationError): cfg.max_queue_size = 999 # type: ignore[misc] + + @pytest.mark.parametrize( + ("field", "value"), + [ + ("max_queue_size", -1), + ("drain_timeout_seconds", 0), + ("drain_timeout_seconds", -1.0), + ("drain_timeout_seconds", 301), + ], + ids=[ + "negative_queue_size", + "zero_drain_timeout", + "negative_drain_timeout", + "drain_timeout_above_max", + ], + ) + def test_rejects_out_of_range(self, field: str, value: object) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError): + TaskEngineConfig(**{field: value}) + + def test_zero_queue_size_allowed(self) -> None: + """Zero means unbounded — should be accepted.""" + cfg = TaskEngineConfig(max_queue_size=0) + assert cfg.max_queue_size == 0 + + def test_drain_timeout_upper_boundary(self) -> None: + """Exactly 300 should be accepted.""" + cfg = TaskEngineConfig(drain_timeout_seconds=300.0) + assert cfg.drain_timeout_seconds == 300.0 diff --git a/tests/unit/engine/test_task_engine_mutations.py b/tests/unit/engine/test_task_engine_mutations.py index 4cbb3474fe..ec92f362cc 100644 --- a/tests/unit/engine/test_task_engine_mutations.py +++ b/tests/unit/engine/test_task_engine_mutations.py @@ -21,7 +21,7 @@ from tests.unit.engine.task_engine_helpers import ( FakePersistence, FakeTaskRepository, - _make_create_data, + make_create_data, ) if TYPE_CHECKING: @@ -40,7 +40,7 @@ async def test_create_task( persistence: FakePersistence, ) -> None: task = await engine.create_task( - _make_create_data(title="My Task"), + make_create_data(title="My Task"), requested_by="alice", ) assert task.title == "My Task" @@ -58,7 +58,7 @@ async def test_create_returns_version_1( mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) result = await engine.submit(mutation) assert result.success is True @@ -69,7 +69,7 @@ async def test_create_with_assignee( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(assigned_to=None), + make_create_data(assigned_to=None), requested_by="alice", ) assert task.assigned_to is None @@ -87,7 +87,7 @@ async def test_update_fields( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(title="Original"), + make_create_data(title="Original"), requested_by="alice", ) updated = await engine.update_task( @@ -103,7 +103,7 @@ async def test_update_empty_no_op( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) result = await engine.update_task( @@ -137,7 +137,7 @@ async def test_valid_transition( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) assigned, _ = await engine.transition_task( @@ -155,7 +155,7 @@ async def test_invalid_transition( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) with pytest.raises(TaskMutationError): @@ -193,7 +193,7 @@ async def test_delete_task( persistence: FakePersistence, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) deleted = await engine.delete_task(task.id, requested_by="alice") @@ -225,7 +225,7 @@ async def test_cancel_assigned_task( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) assigned, _ = await engine.transition_task( @@ -248,7 +248,7 @@ async def test_cancel_from_created_fails( ) -> None: """CREATED -> CANCELLED is not a valid transition.""" task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) with pytest.raises(TaskMutationError): @@ -271,7 +271,7 @@ async def test_get_task( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(title="Findme"), + make_create_data(title="Findme"), requested_by="alice", ) found = await engine.get_task(task.id) @@ -290,11 +290,11 @@ async def test_list_tasks( engine: TaskEngine, ) -> None: await engine.create_task( - _make_create_data(project="proj-a"), + make_create_data(project="proj-a"), requested_by="alice", ) await engine.create_task( - _make_create_data(project="proj-b"), + make_create_data(project="proj-b"), requested_by="alice", ) all_tasks = await engine.list_tasks() @@ -308,7 +308,7 @@ async def test_list_tasks_by_status( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) await engine.transition_task( @@ -358,7 +358,7 @@ async def test_create_has_no_previous_status( mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) result = await engine.submit(mutation) assert result.success is True @@ -369,7 +369,7 @@ async def test_transition_has_previous_status( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) mutation = TransitionTaskMutation( @@ -389,7 +389,7 @@ async def test_cancel_has_previous_status( engine: TaskEngine, ) -> None: task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) # First move to ASSIGNED so cancel is valid @@ -500,7 +500,7 @@ async def test_update_version_conflict_raises_typed( ) -> None: """Version conflict via convenience method raises TaskVersionConflictError.""" task = await engine.create_task( - _make_create_data(), + make_create_data(), requested_by="alice", ) with pytest.raises(TaskVersionConflictError, match="conflict"): @@ -541,7 +541,7 @@ async def save(self, task: Task) -> None: mutation = CreateTaskMutation( request_id="req-1", requested_by="alice", - task_data=_make_create_data(), + task_data=make_create_data(), ) result = await eng.submit(mutation) assert result.success is False From 327681ea645148088dade1f6d188f2eb0fba9e69 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 21:29:09 +0100 Subject: [PATCH 11/14] fix: address 14 PR review items and add 15 tests for TaskEngine coverage Local agents (6 agents) + external reviewers (Copilot, Gemini, Greptile, CodeRabbit) identified 14 findings. All implemented: - Add error handling to list_tasks/get_task read endpoints (CRITICAL) - Add error handling to get_task read endpoint (CRITICAL) - Add dedicated event constants for reads, list cap, and futures failure - Merge duplicate except blocks in transition_task - Replace generic API_TASK_TRANSITION_FAILED with API_TASK_MUTATION_FAILED - Add created_by mismatch audit event - Narrow mutation_type to specific Literal per mutation class - Add error_code consistency checks to TaskMutationResult validator - Expand _map_task_engine_errors docstring with mapping table - Replace TASK_ENGINE_DRAIN_TIMEOUT with TASK_ENGINE_FUTURES_FAILED - Add read-through error wrapping tests (4 tests) - Add list_tasks safety cap test - Remove misleading computed_field reference in docstring - Add CreateTaskData max-length boundary tests (4 tests) CodeRabbit follow-up (3 additional items): - Assert error_code on version conflict integration tests - Assert sanitized 503 messages don't leak internal errors - PEP 758 except syntax: not applicable with `as` clause (ruff limitation) --- src/ai_company/api/controllers/tasks.py | 58 +++++---- src/ai_company/engine/task_engine.py | 17 +-- src/ai_company/engine/task_engine_models.py | 19 +-- src/ai_company/observability/events/api.py | 2 + .../observability/events/task_engine.py | 3 + .../unit/api/controllers/test_task_helpers.py | 3 + .../engine/test_task_engine_convenience.py | 2 + .../engine/test_task_engine_integration.py | 2 + tests/unit/engine/test_task_engine_models.py | 45 +++++++ .../unit/engine/test_task_engine_mutations.py | 111 ++++++++++++++++++ 10 files changed, 226 insertions(+), 36 deletions(-) diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 8296d65dd4..54559a8a4d 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -35,8 +35,9 @@ from ai_company.observability.events.api import ( API_AUTH_FALLBACK, API_RESOURCE_NOT_FOUND, + API_TASK_CREATED_BY_MISMATCH, API_TASK_DELETED, - API_TASK_TRANSITION_FAILED, + API_TASK_MUTATION_FAILED, API_TASK_UPDATED, ) from ai_company.observability.events.task import ( @@ -72,6 +73,22 @@ def _map_task_engine_errors( """Map a task-engine exception to the appropriate API error. Returns the API exception to raise (caller must ``raise`` it). + + Mapping: + TaskNotFoundError -> 404 NotFoundError + TaskEngineNotRunningError -> 503 ServiceUnavailableError + TaskEngineQueueFullError -> 503 ServiceUnavailableError + TaskInternalError -> 503 ServiceUnavailableError + TaskVersionConflictError -> 409 ConflictError + TaskMutationError -> 422 ApiValidationError + Other -> 503 ServiceUnavailableError + + Args: + exc: The engine exception to map. + task_id: Optional task identifier for log context. + + Returns: + The API exception to raise. """ if isinstance(exc, TaskNotFoundError): if task_id is not None: @@ -83,7 +100,7 @@ def _map_task_engine_errors( return NotFoundError(str(exc)) if isinstance(exc, TaskEngineNotRunningError | TaskEngineQueueFullError): logger.error( - API_TASK_TRANSITION_FAILED, + API_TASK_MUTATION_FAILED, resource="task", task_id=task_id, error=str(exc), @@ -92,7 +109,7 @@ def _map_task_engine_errors( return ServiceUnavailableError("Service temporarily unavailable") if isinstance(exc, TaskInternalError): logger.error( - API_TASK_TRANSITION_FAILED, + API_TASK_MUTATION_FAILED, resource="task", task_id=task_id, error=str(exc), @@ -101,7 +118,7 @@ def _map_task_engine_errors( return ServiceUnavailableError("Internal server error") if isinstance(exc, TaskVersionConflictError): logger.warning( - API_TASK_TRANSITION_FAILED, + API_TASK_MUTATION_FAILED, resource="task", task_id=task_id, error=str(exc), @@ -110,7 +127,7 @@ def _map_task_engine_errors( return ConflictError(str(exc)) if isinstance(exc, TaskMutationError): logger.warning( - API_TASK_TRANSITION_FAILED, + API_TASK_MUTATION_FAILED, resource="task", task_id=task_id, error=str(exc), @@ -119,7 +136,7 @@ def _map_task_engine_errors( return ApiValidationError(str(exc)) # Unknown error type — log and wrap to prevent leaking internals logger.error( - API_TASK_TRANSITION_FAILED, + API_TASK_MUTATION_FAILED, resource="task", task_id=task_id, error=str(exc), @@ -159,11 +176,14 @@ async def list_tasks( # noqa: PLR0913 Paginated task list. """ app_state: AppState = state.app_state - tasks = await app_state.task_engine.list_tasks( - status=status, - assigned_to=assigned_to, - project=project, - ) + try: + tasks = await app_state.task_engine.list_tasks( + status=status, + assigned_to=assigned_to, + project=project, + ) + except (TaskInternalError, TaskEngineNotRunningError) as exc: + raise _map_task_engine_errors(exc) from exc page, meta = paginate(tasks, offset=offset, limit=limit) return PaginatedResponse(data=page, pagination=meta) @@ -186,7 +206,10 @@ async def get_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - task = await app_state.task_engine.get_task(task_id) + try: + task = await app_state.task_engine.get_task(task_id) + except (TaskInternalError, TaskEngineNotRunningError) as exc: + raise _map_task_engine_errors(exc, task_id=task_id) from exc if task is None: msg = f"Task {task_id!r} not found" logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) @@ -223,7 +246,7 @@ async def create_task( ) if data.created_by != requester: logger.info( - API_TASK_UPDATED, + API_TASK_CREATED_BY_MISMATCH, note="created_by differs from authenticated requester", created_by=data.created_by, requester=requester, @@ -333,15 +356,10 @@ async def transition_task( TaskEngineQueueFullError, TaskNotFoundError, TaskInternalError, + TaskVersionConflictError, + TaskMutationError, ) as exc: raise _map_task_engine_errors(exc, task_id=task_id) from exc - except TaskMutationError as exc: - logger.warning( - API_TASK_TRANSITION_FAILED, - task_id=task_id, - error=str(exc), - ) - raise _map_task_engine_errors(exc, task_id=task_id) from exc logger.info( TASK_STATUS_CHANGED, task_id=task_id, diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index 400f4dc845..5854efefc5 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -47,11 +47,14 @@ TASK_ENGINE_DRAIN_COMPLETE, TASK_ENGINE_DRAIN_START, TASK_ENGINE_DRAIN_TIMEOUT, + TASK_ENGINE_FUTURES_FAILED, + TASK_ENGINE_LIST_CAPPED, TASK_ENGINE_LOOP_ERROR, TASK_ENGINE_MUTATION_FAILED, TASK_ENGINE_MUTATION_RECEIVED, TASK_ENGINE_NOT_RUNNING, TASK_ENGINE_QUEUE_FULL, + TASK_ENGINE_READ_FAILED, TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED, TASK_ENGINE_SNAPSHOT_PUBLISHED, TASK_ENGINE_STARTED, @@ -209,7 +212,7 @@ def _fail_remaining_futures( failed_count += 1 if failed_count: logger.warning( - TASK_ENGINE_DRAIN_TIMEOUT, + TASK_ENGINE_FUTURES_FAILED, failed_futures=failed_count, note="Resolved remaining futures with shutdown failure", ) @@ -507,7 +510,7 @@ async def get_task(self, task_id: str) -> Task | None: except Exception as exc: msg = f"Failed to read task: {exc}" logger.exception( - TASK_ENGINE_MUTATION_FAILED, + TASK_ENGINE_READ_FAILED, error=msg, task_id=task_id, ) @@ -544,17 +547,15 @@ async def list_tasks( except Exception as exc: msg = f"Failed to list tasks: {exc}" logger.exception( - TASK_ENGINE_MUTATION_FAILED, + TASK_ENGINE_READ_FAILED, error=msg, ) raise TaskInternalError(msg) from exc if len(tasks) > self._MAX_LIST_RESULTS: logger.warning( - TASK_ENGINE_MUTATION_FAILED, - error=( - f"list_tasks returned {len(tasks)} results, " - f"capping at {self._MAX_LIST_RESULTS}" - ), + TASK_ENGINE_LIST_CAPPED, + returned=len(tasks), + cap=self._MAX_LIST_RESULTS, ) return tasks[: self._MAX_LIST_RESULTS] return tasks diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py index 97de2c690f..3e2d8cd7f1 100644 --- a/src/ai_company/engine/task_engine_models.py +++ b/src/ai_company/engine/task_engine_models.py @@ -30,9 +30,6 @@ _VALID_TASK_FIELDS: frozenset[str] = frozenset(Task.model_fields) """Field names accepted by ``model_fields`` on :class:`Task`. -Pydantic's ``model_fields`` excludes any ``@computed_field`` -properties by design. - Used to reject unknown keys in :class:`UpdateTaskMutation` and :class:`TransitionTaskMutation` validators. """ @@ -107,7 +104,7 @@ class CreateTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: MutationType = "create" + mutation_type: Literal["create"] = "create" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_data: CreateTaskData = Field(description="Task creation payload") @@ -144,7 +141,7 @@ class UpdateTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: MutationType = "update" + mutation_type: Literal["update"] = "update" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_id: NotBlankStr = Field(description="Target task identifier") @@ -196,7 +193,7 @@ class TransitionTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: MutationType = "transition" + mutation_type: Literal["transition"] = "transition" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_id: NotBlankStr = Field(description="Target task identifier") @@ -242,7 +239,7 @@ class DeleteTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: MutationType = "delete" + mutation_type: Literal["delete"] = "delete" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_id: NotBlankStr = Field(description="Target task identifier") @@ -261,7 +258,7 @@ class CancelTaskMutation(BaseModel): model_config = ConfigDict(frozen=True) - mutation_type: MutationType = "cancel" + mutation_type: Literal["cancel"] = "cancel" request_id: NotBlankStr = Field(description="Unique request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") task_id: NotBlankStr = Field(description="Target task identifier") @@ -320,6 +317,12 @@ def _check_consistency(self) -> Self: if not self.success and self.error is None: msg = "Failed result must carry an error description" raise ValueError(msg) + if self.success and self.error_code is not None: + msg = "Successful result must not carry an error_code" + raise ValueError(msg) + if not self.success and self.error_code is None: + msg = "Failed result must carry an error_code" + raise ValueError(msg) return self diff --git a/src/ai_company/observability/events/api.py b/src/ai_company/observability/events/api.py index 688bc83a4d..7a585c0c17 100644 --- a/src/ai_company/observability/events/api.py +++ b/src/ai_company/observability/events/api.py @@ -36,4 +36,6 @@ API_AUTH_SETUP_COMPLETE: Final[str] = "api.auth.setup_complete" API_AUTH_PASSWORD_CHANGED: Final[str] = "api.auth.password_changed" # noqa: S105 API_TASK_TRANSITION_FAILED: Final[str] = "api.task.transition_failed" +API_TASK_MUTATION_FAILED: Final[str] = "api.task.mutation_failed" +API_TASK_CREATED_BY_MISMATCH: Final[str] = "api.task.created_by_mismatch" API_AUTH_FALLBACK: Final[str] = "api.auth.fallback" diff --git a/src/ai_company/observability/events/task_engine.py b/src/ai_company/observability/events/task_engine.py index a97427a64e..a7f3521d2a 100644 --- a/src/ai_company/observability/events/task_engine.py +++ b/src/ai_company/observability/events/task_engine.py @@ -17,3 +17,6 @@ TASK_ENGINE_NOT_RUNNING: Final[str] = "task_engine.not_running" TASK_ENGINE_VERSION_CONFLICT: Final[str] = "task_engine.version.conflict" TASK_ENGINE_LOOP_ERROR: Final[str] = "task_engine.loop.error" +TASK_ENGINE_READ_FAILED: Final[str] = "task_engine.read.failed" +TASK_ENGINE_LIST_CAPPED: Final[str] = "task_engine.list.capped" +TASK_ENGINE_FUTURES_FAILED: Final[str] = "task_engine.futures.failed" diff --git a/tests/unit/api/controllers/test_task_helpers.py b/tests/unit/api/controllers/test_task_helpers.py index f82c3e7841..fe42a6944b 100644 --- a/tests/unit/api/controllers/test_task_helpers.py +++ b/tests/unit/api/controllers/test_task_helpers.py @@ -73,16 +73,19 @@ def test_not_running_maps_to_service_unavailable(self) -> None: exc = TaskEngineNotRunningError("not running") result = _map_task_engine_errors(exc) assert isinstance(result, ServiceUnavailableError) + assert "not running" not in str(result) def test_queue_full_maps_to_service_unavailable(self) -> None: exc = TaskEngineQueueFullError("queue full") result = _map_task_engine_errors(exc) assert isinstance(result, ServiceUnavailableError) + assert "queue full" not in str(result) def test_internal_error_maps_to_service_unavailable(self) -> None: exc = TaskInternalError("internal fault") result = _map_task_engine_errors(exc) assert isinstance(result, ServiceUnavailableError) + assert "internal fault" not in str(result) def test_version_conflict_maps_to_conflict_error(self) -> None: exc = TaskVersionConflictError("version mismatch") diff --git a/tests/unit/engine/test_task_engine_convenience.py b/tests/unit/engine/test_task_engine_convenience.py index 76f71dc4bf..4ae2479965 100644 --- a/tests/unit/engine/test_task_engine_convenience.py +++ b/tests/unit/engine/test_task_engine_convenience.py @@ -74,6 +74,7 @@ def test_none_code_falls_through(self) -> None: request_id="r", success=False, error="generic", + error_code="validation", ) with pytest.raises(TaskMutationError, match="generic"): TaskEngine._raise_typed_error(result) @@ -83,6 +84,7 @@ def test_missing_error_uses_default_message(self) -> None: request_id="r", success=False, error="Mutation failed", + error_code="validation", ) with pytest.raises(TaskMutationError, match="Mutation failed"): TaskEngine._raise_typed_error(result) diff --git a/tests/unit/engine/test_task_engine_integration.py b/tests/unit/engine/test_task_engine_integration.py index c9c8e1494d..799cdb9d37 100644 --- a/tests/unit/engine/test_task_engine_integration.py +++ b/tests/unit/engine/test_task_engine_integration.py @@ -231,6 +231,7 @@ async def test_version_conflict( ) result = await engine.submit(update) assert result.success is False + assert result.error_code == "version_conflict" assert "conflict" in (result.error or "").lower() async def test_version_reset_on_delete( @@ -268,6 +269,7 @@ async def test_transition_version_conflict( ) result = await engine.submit(mutation) assert result.success is False + assert result.error_code == "version_conflict" assert "conflict" in (result.error or "").lower() diff --git a/tests/unit/engine/test_task_engine_models.py b/tests/unit/engine/test_task_engine_models.py index 1de77df13b..7f58ef5758 100644 --- a/tests/unit/engine/test_task_engine_models.py +++ b/tests/unit/engine/test_task_engine_models.py @@ -70,6 +70,49 @@ def test_negative_budget_rejected(self) -> None: budget_limit=-1.0, ) + def test_title_at_max_length(self) -> None: + data = CreateTaskData( + title="x" * 256, + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + assert len(data.title) == 256 + + def test_title_exceeds_max_length(self) -> None: + with pytest.raises(ValidationError, match="String should have at most 256"): + CreateTaskData( + title="x" * 257, + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + + def test_description_at_max_length(self) -> None: + data = CreateTaskData( + title="Task", + description="d" * 4096, + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + assert len(data.description) == 4096 + + def test_description_exceeds_max_length(self) -> None: + with pytest.raises( + ValidationError, + match="String should have at most 4096", + ): + CreateTaskData( + title="Task", + description="d" * 4097, + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + def test_frozen(self) -> None: data = CreateTaskData( title="Task", @@ -242,9 +285,11 @@ def test_failure_result(self) -> None: request_id="req-1", success=False, error="Not found", + error_code="not_found", ) assert result.success is False assert result.error == "Not found" + assert result.error_code == "not_found" assert result.version == 0 def test_frozen(self) -> None: diff --git a/tests/unit/engine/test_task_engine_mutations.py b/tests/unit/engine/test_task_engine_mutations.py index ec92f362cc..4e68d66974 100644 --- a/tests/unit/engine/test_task_engine_mutations.py +++ b/tests/unit/engine/test_task_engine_mutations.py @@ -325,6 +325,117 @@ async def test_list_tasks_by_status( assert len(assigned) == 1 +# ── Read-through error wrapping ──────────────────────────────── + + +@pytest.mark.unit +class TestReadThroughErrorWrapping: + """Persistence errors in read-through methods raise TaskInternalError.""" + + async def test_get_task_wraps_persistence_error( + self, + persistence: FakePersistence, + ) -> None: + from ai_company.engine.errors import TaskInternalError + + async def exploding_get(task_id: str) -> None: + msg = "disk I/O" + raise OSError(msg) + + persistence.tasks.get = exploding_get # type: ignore[method-assign] + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + try: + with pytest.raises(TaskInternalError, match="Failed to read task"): + await eng.get_task("task-1") + finally: + await eng.stop(timeout=2.0) + + async def test_list_tasks_wraps_persistence_error( + self, + persistence: FakePersistence, + ) -> None: + from ai_company.engine.errors import TaskInternalError + + async def exploding_list(**kwargs: object) -> None: + msg = "connection refused" + raise ConnectionError(msg) + + persistence.tasks.list_tasks = exploding_list # type: ignore[assignment,method-assign] + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + try: + with pytest.raises(TaskInternalError, match="Failed to list tasks"): + await eng.list_tasks() + finally: + await eng.stop(timeout=2.0) + + async def test_get_task_lets_memory_error_propagate( + self, + persistence: FakePersistence, + ) -> None: + async def oom_get(task_id: str) -> None: + raise MemoryError + + persistence.tasks.get = oom_get # type: ignore[method-assign] + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + try: + with pytest.raises(MemoryError): + await eng.get_task("task-1") + finally: + await eng.stop(timeout=2.0) + + async def test_list_tasks_lets_memory_error_propagate( + self, + persistence: FakePersistence, + ) -> None: + async def oom_list(**kwargs: object) -> None: + raise MemoryError + + persistence.tasks.list_tasks = oom_list # type: ignore[assignment,method-assign] + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + try: + with pytest.raises(MemoryError): + await eng.list_tasks() + finally: + await eng.stop(timeout=2.0) + + +# ── list_tasks safety cap ───────────────────────────────────── + + +@pytest.mark.unit +class TestListTasksSafetyCap: + """list_tasks caps results at _MAX_LIST_RESULTS.""" + + async def test_results_capped_at_max( + self, + persistence: FakePersistence, + ) -> None: + """When persistence returns more than cap, result is truncated.""" + + original_list = persistence.tasks.list_tasks + + async def oversized_list(**kwargs: object) -> tuple[Task, ...]: + # Return a few real tasks, then monkey-patch to simulate > cap + result = await original_list(**kwargs) # type: ignore[arg-type] + # Create a list longer than cap by repeating + return result * 20_000 if result else result + + persistence.tasks.list_tasks = oversized_list # type: ignore[assignment,method-assign] + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + try: + # Create one task so the oversized list has data to repeat + await eng.create_task(make_create_data(), requested_by="alice") + tasks = await eng.list_tasks() + assert len(tasks) <= 10_000 + finally: + await eng.stop(timeout=2.0) + + # ── Cancel not found ───────────────────────────────────────── From 4d4f759610dc583b98f9c96737950ffd6f19c36b Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 22:28:29 +0100 Subject: [PATCH 12/14] fix: address 14 PR review items and add 15 tests for TaskEngine coverage - Fix lifecycle race condition: add asyncio.Lock coordinating submit/stop - Add asyncio.shield to protect processing task from wait_for cancellation - Wrap MappingProxyType on UpdateTaskMutation.updates and TransitionTaskMutation.overrides for true immutability - Add task_id field to TaskStateChanged event model - Wrap PydanticValidationError in create_task, delete_task, cancel_task - Add cancel endpoint (POST /tasks/{id}/cancel) with CancelTaskRequest DTO - Forward expected_version from UpdateTaskRequest/TransitionTaskRequest DTOs - Remove dead TaskEngineNotRunningError catch from list/get read-throughs - Sanitize 503 error messages to prevent leaking internals - Log created_by mismatch at WARNING instead of INFO - Add INFO log for successful mutations (TASK_ENGINE_MUTATION_APPLIED) - Deep-copy tasks in FakeTaskRepository for test isolation - Replace structlog.configure with capture_logs for test stability - Replace timing-dependent sleep(0.01) with sleep(0) in race test - Add 15 new tests covering defensive guards, processing loop recovery, snapshot publishing, cancel lifecycle, and validation wrapping --- src/ai_company/api/controllers/tasks.py | 57 +- src/ai_company/api/dto.py | 29 +- src/ai_company/engine/agent_engine.py | 3 + src/ai_company/engine/task_engine.py | 121 ++- src/ai_company/engine/task_engine_apply.py | 5 +- src/ai_company/engine/task_engine_models.py | 18 +- src/ai_company/observability/events/api.py | 1 + .../unit/api/controllers/test_task_helpers.py | 6 +- tests/unit/config/test_defaults.py | 4 +- tests/unit/engine/conftest.py | 2 + tests/unit/engine/task_engine_helpers.py | 12 +- .../engine/test_task_engine_convenience.py | 8 +- .../unit/engine/test_task_engine_coverage.py | 5 +- .../engine/test_task_engine_integration.py | 22 +- .../unit/engine/test_task_engine_lifecycle.py | 14 +- tests/unit/engine/test_task_engine_models.py | 7 + .../unit/engine/test_task_engine_mutations.py | 2 +- .../engine/test_task_engine_review_fixes.py | 738 ++++++++++++++++++ 18 files changed, 971 insertions(+), 83 deletions(-) create mode 100644 tests/unit/engine/test_task_engine_review_fixes.py diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 54559a8a4d..48717265ad 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -6,6 +6,7 @@ from ai_company.api.dto import ( ApiResponse, + CancelTaskRequest, CreateTaskRequest, PaginatedResponse, TransitionTaskRequest, @@ -35,6 +36,7 @@ from ai_company.observability.events.api import ( API_AUTH_FALLBACK, API_RESOURCE_NOT_FOUND, + API_TASK_CANCELLED, API_TASK_CREATED_BY_MISMATCH, API_TASK_DELETED, API_TASK_MUTATION_FAILED, @@ -182,7 +184,7 @@ async def list_tasks( # noqa: PLR0913 assigned_to=assigned_to, project=project, ) - except (TaskInternalError, TaskEngineNotRunningError) as exc: + except TaskInternalError as exc: raise _map_task_engine_errors(exc) from exc page, meta = paginate(tasks, offset=offset, limit=limit) return PaginatedResponse(data=page, pagination=meta) @@ -208,7 +210,7 @@ async def get_task( app_state: AppState = state.app_state try: task = await app_state.task_engine.get_task(task_id) - except (TaskInternalError, TaskEngineNotRunningError) as exc: + except TaskInternalError as exc: raise _map_task_engine_errors(exc, task_id=task_id) from exc if task is None: msg = f"Task {task_id!r} not found" @@ -245,7 +247,7 @@ async def create_task( budget_limit=data.budget_limit, ) if data.created_by != requester: - logger.info( + logger.warning( API_TASK_CREATED_BY_MISMATCH, note="created_by differs from authenticated requester", created_by=data.created_by, @@ -293,12 +295,16 @@ async def update_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - updates = data.model_dump(exclude_none=True) + updates = data.model_dump( + exclude_none=True, + exclude={"expected_version"}, + ) try: task = await app_state.task_engine.update_task( task_id, updates, requested_by=_extract_requester(state), + expected_version=data.expected_version, ) except PydanticValidationError as exc: raise ApiValidationError(str(exc)) from exc @@ -347,7 +353,8 @@ async def transition_task( data.target_status, requested_by=requester, reason=f"API transition to {data.target_status.value}", - **overrides, # type: ignore[arg-type] + expected_version=data.expected_version, + **overrides, ) except PydanticValidationError as exc: raise ApiValidationError(str(exc)) from exc @@ -402,3 +409,43 @@ async def delete_task( raise _map_task_engine_errors(exc, task_id=task_id) from exc logger.info(API_TASK_DELETED, task_id=task_id) return ApiResponse(data=None) + + @post("/{task_id:str}/cancel", guards=[require_write_access]) + async def cancel_task( + self, + state: State, + task_id: str, + data: CancelTaskRequest, + ) -> ApiResponse[Task]: + """Cancel a task. + + Args: + state: Application state. + task_id: Task identifier. + data: Cancellation payload with reason. + + Returns: + Cancelled task envelope. + + Raises: + NotFoundError: If the task is not found. + """ + app_state: AppState = state.app_state + try: + task = await app_state.task_engine.cancel_task( + task_id, + requested_by=_extract_requester(state), + reason=data.reason, + ) + except PydanticValidationError as exc: + raise ApiValidationError(str(exc)) from exc + except ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskNotFoundError, + TaskInternalError, + TaskMutationError, + ) as exc: + raise _map_task_engine_errors(exc, task_id=task_id) from exc + logger.info(API_TASK_CANCELLED, task_id=task_id) + return ApiResponse(data=task) diff --git a/src/ai_company/api/dto.py b/src/ai_company/api/dto.py index b57580067e..e0e27ecd9f 100644 --- a/src/ai_company/api/dto.py +++ b/src/ai_company/api/dto.py @@ -130,6 +130,7 @@ class UpdateTaskRequest(BaseModel): priority: New priority. assigned_to: New assignee. budget_limit: New budget limit. + expected_version: Optimistic concurrency guard. """ model_config = ConfigDict(frozen=True) @@ -139,6 +140,11 @@ class UpdateTaskRequest(BaseModel): priority: Priority | None = None assigned_to: NotBlankStr | None = None budget_limit: float | None = Field(default=None, ge=0.0) + expected_version: int | None = Field( + default=None, + ge=1, + description="Optimistic concurrency version guard", + ) class TransitionTaskRequest(BaseModel): @@ -147,12 +153,33 @@ class TransitionTaskRequest(BaseModel): Attributes: target_status: The desired target status. assigned_to: Optional assignee override for the transition. + expected_version: Optimistic concurrency guard. """ model_config = ConfigDict(frozen=True) - target_status: TaskStatus + target_status: TaskStatus = Field(description="Desired target status") assigned_to: NotBlankStr | None = None + expected_version: int | None = Field( + default=None, + ge=1, + description="Optimistic concurrency version guard", + ) + + +class CancelTaskRequest(BaseModel): + """Payload for cancelling a task. + + Attributes: + reason: Reason for cancellation. + """ + + model_config = ConfigDict(frozen=True) + + reason: NotBlankStr = Field( + max_length=4096, + description="Reason for cancellation", + ) # ── Approval request DTOs ────────────────────────────────────── diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index d76e3bbe67..e470afd57a 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -694,6 +694,9 @@ async def _report_to_task_engine( return try: + # Best-effort: discard return value intentionally — if the + # transition is rejected (e.g. parallel mutation moved the task), + # the exception handlers below log the failure. _ = await self._task_engine.transition_task( task_id, final_status, diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index 5854efefc5..ca0a598d21 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -50,6 +50,7 @@ TASK_ENGINE_FUTURES_FAILED, TASK_ENGINE_LIST_CAPPED, TASK_ENGINE_LOOP_ERROR, + TASK_ENGINE_MUTATION_APPLIED, TASK_ENGINE_MUTATION_FAILED, TASK_ENGINE_MUTATION_RECEIVED, TASK_ENGINE_NOT_RUNNING, @@ -115,6 +116,7 @@ def __init__( self._processing_task: asyncio.Task[None] | None = None self._in_flight: _MutationEnvelope | None = None self._running = False + self._lifecycle_lock = asyncio.Lock() logger.debug( TASK_ENGINE_CREATED, max_queue_size=self._config.max_queue_size, @@ -146,13 +148,17 @@ def start(self) -> None: async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 """Stop the engine and drain pending mutations. + Acquires ``_lifecycle_lock`` to prevent a race with ``submit()`` + where an envelope is enqueued after the processing loop exits. + Args: timeout: Seconds to wait for drain. Defaults to ``config.drain_timeout_seconds``. """ - if not self._running: - return - self._running = False + async with self._lifecycle_lock: + if not self._running: + return + self._running = False effective_timeout = ( timeout if timeout is not None else self._config.drain_timeout_seconds ) @@ -165,7 +171,7 @@ async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 ) try: await asyncio.wait_for( - self._processing_task, + asyncio.shield(self._processing_task), timeout=effective_timeout, ) logger.info(TASK_ENGINE_DRAIN_COMPLETE) @@ -237,6 +243,10 @@ def is_running(self) -> bool: async def submit(self, mutation: TaskMutation) -> TaskMutationResult: """Submit a mutation and await its result. + Acquires ``_lifecycle_lock`` to prevent a race between + ``submit()`` and ``stop()`` where an envelope could be enqueued + after the processing loop has already drained and exited. + Args: mutation: The mutation to apply. @@ -247,27 +257,28 @@ async def submit(self, mutation: TaskMutation) -> TaskMutationResult: TaskEngineNotRunningError: If the engine is not running. TaskEngineQueueFullError: If the queue is at capacity. """ - if not self._running: - logger.warning( - TASK_ENGINE_NOT_RUNNING, - mutation_type=mutation.mutation_type, - request_id=mutation.request_id, - ) - msg = "TaskEngine is not running" - raise TaskEngineNotRunningError(msg) + async with self._lifecycle_lock: + if not self._running: + logger.warning( + TASK_ENGINE_NOT_RUNNING, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + ) + msg = "TaskEngine is not running" + raise TaskEngineNotRunningError(msg) - envelope = _MutationEnvelope(mutation=mutation) - try: - self._queue.put_nowait(envelope) - except asyncio.QueueFull: - logger.warning( - TASK_ENGINE_QUEUE_FULL, - mutation_type=mutation.mutation_type, - request_id=mutation.request_id, - queue_size=self._queue.qsize(), - ) - msg = "TaskEngine queue is full" - raise TaskEngineQueueFullError(msg) from None + envelope = _MutationEnvelope(mutation=mutation) + try: + self._queue.put_nowait(envelope) + except asyncio.QueueFull: + logger.warning( + TASK_ENGINE_QUEUE_FULL, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + queue_size=self._queue.qsize(), + ) + msg = "TaskEngine queue is full" + raise TaskEngineQueueFullError(msg) from None return await envelope.future @@ -291,11 +302,14 @@ async def create_task( TaskEngineQueueFullError: If the queue is at capacity. TaskMutationError: If the mutation fails. """ - mutation = CreateTaskMutation( - request_id=uuid4().hex, - requested_by=requested_by, - task_data=data, - ) + try: + mutation = CreateTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_data=data, + ) + except PydanticValidationError as exc: + raise TaskMutationError(str(exc)) from exc result = await self.submit(mutation) if not result.success: self._raise_typed_error(result) @@ -422,11 +436,14 @@ async def delete_task( TaskNotFoundError: If the task is not found. TaskMutationError: If the mutation fails. """ - mutation = DeleteTaskMutation( - request_id=uuid4().hex, - requested_by=requested_by, - task_id=task_id, - ) + try: + mutation = DeleteTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + ) + except PydanticValidationError as exc: + raise TaskMutationError(str(exc)) from exc result = await self.submit(mutation) if not result.success: self._raise_typed_error(result) @@ -455,12 +472,15 @@ async def cancel_task( TaskNotFoundError: If the task is not found. TaskMutationError: If the mutation fails. """ - mutation = CancelTaskMutation( - request_id=uuid4().hex, - requested_by=requested_by, - task_id=task_id, - reason=reason, - ) + try: + mutation = CancelTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + reason=reason, + ) + except PydanticValidationError as exc: + raise TaskMutationError(str(exc)) from exc result = await self.submit(mutation) if not result.success: self._raise_typed_error(result) @@ -627,6 +647,19 @@ async def _process_one(self, envelope: _MutationEnvelope) -> None: ) if not envelope.future.done(): envelope.future.set_result(result) + if result.success: + task_id = getattr(mutation, "task_id", None) + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + task_id=task_id or (result.task.id if result.task else None), + version=result.version, + previous_status=( + result.previous_status.value if result.previous_status else None + ), + new_status=(result.task.status.value if result.task else None), + ) if result.success and self._config.publish_snapshots: await self._publish_snapshot(mutation, result) except MemoryError, RecursionError: @@ -674,11 +707,17 @@ async def _publish_snapshot( new_status = None reason: str | None = getattr(mutation, "reason", None) + task_id: str | None = getattr(mutation, "task_id", None) + # For create mutations, task_id comes from the result + if task_id is None and result.task is not None: + task_id = result.task.id + effective_task_id = task_id or "unknown" event = TaskStateChanged( mutation_type=mutation.mutation_type, request_id=mutation.request_id, requested_by=mutation.requested_by, + task_id=effective_task_id, task=result.task, previous_status=result.previous_status, new_status=new_status, @@ -686,8 +725,6 @@ async def _publish_snapshot( reason=reason, timestamp=datetime.now(UTC), ) - - task_id = getattr(mutation, "task_id", None) try: # Deferred to break circular import: # communication -> engine -> communication diff --git a/src/ai_company/engine/task_engine_apply.py b/src/ai_company/engine/task_engine_apply.py index c67a0d3a64..cc3f6220fb 100644 --- a/src/ai_company/engine/task_engine_apply.py +++ b/src/ai_company/engine/task_engine_apply.py @@ -6,7 +6,6 @@ lifecycle, queue management, and the public API. """ -import copy from typing import TYPE_CHECKING from uuid import uuid4 @@ -217,7 +216,9 @@ async def apply_update( ) merged = task.model_dump() - merged.update(copy.deepcopy(mutation.updates)) + # mutation.updates is already deep-copied + wrapped in MappingProxyType + # at construction time, so no second deep-copy needed here. + merged.update(mutation.updates) try: updated = Task.model_validate(merged) except PydanticValidationError as exc: diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py index 3e2d8cd7f1..2398af73cc 100644 --- a/src/ai_company/engine/task_engine_models.py +++ b/src/ai_company/engine/task_engine_models.py @@ -7,6 +7,7 @@ import copy from datetime import UTC, datetime +from types import MappingProxyType from typing import Final, Literal, Self from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator @@ -166,8 +167,12 @@ def _reject_immutable_fields(self) -> Self: def __init__(self, **data: object) -> None: super().__init__(**data) - # Deep-copy mutable dict at system boundary per coding guidelines. - object.__setattr__(self, "updates", copy.deepcopy(self.updates)) + # Deep-copy and wrap in MappingProxyType for full immutability. + object.__setattr__( + self, + "updates", + MappingProxyType(copy.deepcopy(dict(self.updates))), + ) _IMMUTABLE_OVERRIDE_FIELDS: frozenset[str] = _ALWAYS_IMMUTABLE_FIELDS | {"status"} @@ -223,8 +228,12 @@ def _reject_immutable_overrides(self) -> Self: def __init__(self, **data: object) -> None: super().__init__(**data) - # Deep-copy mutable dict at system boundary per coding guidelines. - object.__setattr__(self, "overrides", copy.deepcopy(self.overrides)) + # Deep-copy and wrap in MappingProxyType for full immutability. + object.__setattr__( + self, + "overrides", + MappingProxyType(copy.deepcopy(dict(self.overrides))), + ) class DeleteTaskMutation(BaseModel): @@ -351,6 +360,7 @@ class TaskStateChanged(BaseModel): ) request_id: NotBlankStr = Field(description="Originating request identifier") requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Task identifier (always present)") task: Task | None = Field( default=None, description="Task snapshot after mutation", diff --git a/src/ai_company/observability/events/api.py b/src/ai_company/observability/events/api.py index 7a585c0c17..ccabad9655 100644 --- a/src/ai_company/observability/events/api.py +++ b/src/ai_company/observability/events/api.py @@ -20,6 +20,7 @@ API_RESOURCE_NOT_FOUND: Final[str] = "api.resource.not_found" API_TASK_UPDATED: Final[str] = "api.task.updated" API_TASK_DELETED: Final[str] = "api.task.deleted" +API_TASK_CANCELLED: Final[str] = "api.task.cancelled" API_APPROVAL_CREATED: Final[str] = "api.approval.created" API_APPROVAL_APPROVED: Final[str] = "api.approval.approved" API_APPROVAL_REJECTED: Final[str] = "api.approval.rejected" diff --git a/tests/unit/api/controllers/test_task_helpers.py b/tests/unit/api/controllers/test_task_helpers.py index fe42a6944b..ba5ba9e815 100644 --- a/tests/unit/api/controllers/test_task_helpers.py +++ b/tests/unit/api/controllers/test_task_helpers.py @@ -73,19 +73,19 @@ def test_not_running_maps_to_service_unavailable(self) -> None: exc = TaskEngineNotRunningError("not running") result = _map_task_engine_errors(exc) assert isinstance(result, ServiceUnavailableError) - assert "not running" not in str(result) + assert str(result) == "Service temporarily unavailable" def test_queue_full_maps_to_service_unavailable(self) -> None: exc = TaskEngineQueueFullError("queue full") result = _map_task_engine_errors(exc) assert isinstance(result, ServiceUnavailableError) - assert "queue full" not in str(result) + assert str(result) == "Service temporarily unavailable" def test_internal_error_maps_to_service_unavailable(self) -> None: exc = TaskInternalError("internal fault") result = _map_task_engine_errors(exc) assert isinstance(result, ServiceUnavailableError) - assert "internal fault" not in str(result) + assert str(result) == "Internal server error" def test_version_conflict_maps_to_conflict_error(self) -> None: exc = TaskVersionConflictError("version mismatch") diff --git a/tests/unit/config/test_defaults.py b/tests/unit/config/test_defaults.py index 52663d0564..dd9f720fae 100644 --- a/tests/unit/config/test_defaults.py +++ b/tests/unit/config/test_defaults.py @@ -1,5 +1,7 @@ """Tests for config defaults.""" +from typing import Any + import pytest from ai_company.config.defaults import default_config_dict @@ -20,7 +22,7 @@ def test_required_keys_present(self) -> None: assert result["company_type"] == "custom" def test_constructs_valid_root_config(self) -> None: - data = default_config_dict() + data: dict[str, Any] = default_config_dict() # narrow for **unpacking cfg = RootConfig(**data) assert cfg.company_name == "SynthOrg" assert cfg.company_type.value == "custom" diff --git a/tests/unit/engine/conftest.py b/tests/unit/engine/conftest.py index b93566f8f2..5903cb55da 100644 --- a/tests/unit/engine/conftest.py +++ b/tests/unit/engine/conftest.py @@ -450,6 +450,7 @@ async def engine_with_bus( config: TaskEngineConfig, ) -> AsyncIterator[TaskEngine]: """Create and start a TaskEngine with a message bus.""" + await message_bus.start() eng = TaskEngine( persistence=persistence, # type: ignore[arg-type] message_bus=message_bus, # type: ignore[arg-type] @@ -458,3 +459,4 @@ async def engine_with_bus( eng.start() yield eng await eng.stop(timeout=2.0) + await message_bus.stop() diff --git a/tests/unit/engine/task_engine_helpers.py b/tests/unit/engine/task_engine_helpers.py index 330bc043f8..2a6767e165 100644 --- a/tests/unit/engine/task_engine_helpers.py +++ b/tests/unit/engine/task_engine_helpers.py @@ -1,5 +1,6 @@ """Shared fakes and helpers for TaskEngine tests.""" +import copy from typing import TYPE_CHECKING from ai_company.core.task import Task # noqa: TC001 @@ -13,16 +14,21 @@ class FakeTaskRepository: - """Minimal in-memory task repository for engine tests.""" + """Minimal in-memory task repository for engine tests. + + Deep-copies tasks on save/get to mirror real persistence + behaviour and prevent test isolation regressions. + """ def __init__(self) -> None: self._tasks: dict[str, Task] = {} async def save(self, task: Task) -> None: - self._tasks[task.id] = task + self._tasks[task.id] = copy.deepcopy(task) async def get(self, task_id: str) -> Task | None: - return self._tasks.get(task_id) + task = self._tasks.get(task_id) + return copy.deepcopy(task) if task is not None else None async def list_tasks( self, diff --git a/tests/unit/engine/test_task_engine_convenience.py b/tests/unit/engine/test_task_engine_convenience.py index 4ae2479965..54ba265e20 100644 --- a/tests/unit/engine/test_task_engine_convenience.py +++ b/tests/unit/engine/test_task_engine_convenience.py @@ -69,7 +69,8 @@ def test_validation_code_falls_through(self) -> None: with pytest.raises(TaskMutationError, match="bad data"): TaskEngine._raise_typed_error(result) - def test_none_code_falls_through(self) -> None: + def test_validation_code_raises_mutation_error(self) -> None: + """Validation error_code falls through to generic TaskMutationError.""" result = TaskMutationResult( request_id="r", success=False, @@ -79,11 +80,12 @@ def test_none_code_falls_through(self) -> None: with pytest.raises(TaskMutationError, match="generic"): TaskEngine._raise_typed_error(result) - def test_missing_error_uses_default_message(self) -> None: + def test_default_error_message_when_error_is_empty(self) -> None: + """Empty error string triggers the 'Mutation failed' default.""" result = TaskMutationResult( request_id="r", success=False, - error="Mutation failed", + error="", error_code="validation", ) with pytest.raises(TaskMutationError, match="Mutation failed"): diff --git a/tests/unit/engine/test_task_engine_coverage.py b/tests/unit/engine/test_task_engine_coverage.py index 3d5a6e75d8..4078bdda7b 100644 --- a/tests/unit/engine/test_task_engine_coverage.py +++ b/tests/unit/engine/test_task_engine_coverage.py @@ -53,7 +53,8 @@ async def slow_save(task: object) -> None: await asyncio.sleep(0.05) # The processing loop should be in _process_one with _in_flight set - assert eng._in_flight is not None + in_flight_before = eng._in_flight + assert in_flight_before is not None # Stop with very short timeout — triggers _fail_remaining_futures await eng.stop(timeout=0.05) @@ -62,7 +63,7 @@ async def slow_save(task: object) -> None: assert eng._in_flight is None # Release the block and clean up - block.set() # type: ignore[unreachable] + block.set() blocked.cancel() with contextlib.suppress(Exception, asyncio.CancelledError): await blocked diff --git a/tests/unit/engine/test_task_engine_integration.py b/tests/unit/engine/test_task_engine_integration.py index 799cdb9d37..fa8ab399df 100644 --- a/tests/unit/engine/test_task_engine_integration.py +++ b/tests/unit/engine/test_task_engine_integration.py @@ -96,22 +96,22 @@ async def test_pending_mutations_drained_on_stop( eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] eng.start() - # Submit concurrently without awaiting — items enter the queue - create_tasks = [ - asyncio.create_task( - eng.create_task(make_create_data(), requested_by="alice") - ) - for _ in range(5) - ] + # Submit concurrently using structured concurrency + async with asyncio.TaskGroup() as tg: + results = [ + tg.create_task( + eng.create_task(make_create_data(), requested_by="alice"), + ) + for _ in range(5) + ] - # Yield to the event loop so the tasks can enqueue their mutations before stop() + # Yield to let processing complete before stopping await asyncio.sleep(0) - # Stop while tasks may still be in flight — drain timeout is generous + # Stop — drain remaining if any await eng.stop(timeout=5.0) - # All futures resolved (drained during stop or completed before stop) - results = await asyncio.gather(*create_tasks) + # All futures resolved assert len(results) == 5 stored = await persistence.tasks.list_tasks() assert len(stored) == 5 diff --git a/tests/unit/engine/test_task_engine_lifecycle.py b/tests/unit/engine/test_task_engine_lifecycle.py index e5ee152b45..585a4c71a7 100644 --- a/tests/unit/engine/test_task_engine_lifecycle.py +++ b/tests/unit/engine/test_task_engine_lifecycle.py @@ -20,11 +20,12 @@ async def test_start_sets_running( persistence: FakePersistence, ) -> None: eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] - assert eng.is_running is False + initial = eng.is_running + assert not initial eng.start() - assert eng.is_running is True + assert eng.is_running await eng.stop(timeout=2.0) - assert eng.is_running is False + assert not eng.is_running async def test_double_start_raises( self, @@ -53,7 +54,7 @@ async def test_restart( eng.start() await eng.stop(timeout=2.0) eng.start() - assert eng.is_running is True + assert eng.is_running await eng.stop(timeout=2.0) @@ -124,10 +125,13 @@ def test_frozen(self) -> None: ], ) def test_rejects_out_of_range(self, field: str, value: object) -> None: + from typing import Any + from pydantic import ValidationError + kwargs: dict[str, Any] = {field: value} with pytest.raises(ValidationError): - TaskEngineConfig(**{field: value}) + TaskEngineConfig(**kwargs) def test_zero_queue_size_allowed(self) -> None: """Zero means unbounded — should be accepted.""" diff --git a/tests/unit/engine/test_task_engine_models.py b/tests/unit/engine/test_task_engine_models.py index 7f58ef5758..5c5931dd80 100644 --- a/tests/unit/engine/test_task_engine_models.py +++ b/tests/unit/engine/test_task_engine_models.py @@ -311,6 +311,7 @@ def test_construction(self) -> None: mutation_type="create", request_id="req-1", requested_by="alice", + task_id="task-1", new_status=TaskStatus.CREATED, version=1, ) @@ -324,6 +325,7 @@ def test_transition_event(self) -> None: mutation_type="transition", request_id="req-2", requested_by="bob", + task_id="task-2", previous_status=TaskStatus.CREATED, new_status=TaskStatus.ASSIGNED, version=2, @@ -336,6 +338,7 @@ def test_delete_event(self) -> None: mutation_type="delete", request_id="req-3", requested_by="charlie", + task_id="task-3", version=0, ) assert event.task is None @@ -347,6 +350,7 @@ def test_reason_field_populated(self) -> None: mutation_type="transition", request_id="req-1", requested_by="alice", + task_id="task-1", previous_status=TaskStatus.ASSIGNED, new_status=TaskStatus.IN_PROGRESS, version=2, @@ -359,6 +363,7 @@ def test_reason_field_default_none(self) -> None: mutation_type="create", request_id="req-1", requested_by="alice", + task_id="task-1", new_status=TaskStatus.CREATED, version=1, ) @@ -369,6 +374,7 @@ def test_cancel_event_has_reason(self) -> None: mutation_type="cancel", request_id="req-1", requested_by="alice", + task_id="task-1", previous_status=TaskStatus.ASSIGNED, new_status=TaskStatus.CANCELLED, version=3, @@ -381,6 +387,7 @@ def test_serialization_roundtrip(self) -> None: mutation_type="create", request_id="req-1", requested_by="alice", + task_id="task-1", new_status=TaskStatus.CREATED, version=1, ) diff --git a/tests/unit/engine/test_task_engine_mutations.py b/tests/unit/engine/test_task_engine_mutations.py index 4e68d66974..e0f1335b5e 100644 --- a/tests/unit/engine/test_task_engine_mutations.py +++ b/tests/unit/engine/test_task_engine_mutations.py @@ -424,7 +424,7 @@ async def oversized_list(**kwargs: object) -> tuple[Task, ...]: # Create a list longer than cap by repeating return result * 20_000 if result else result - persistence.tasks.list_tasks = oversized_list # type: ignore[assignment,method-assign] + persistence.tasks.list_tasks = oversized_list # type: ignore[method-assign] eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] eng.start() try: diff --git a/tests/unit/engine/test_task_engine_review_fixes.py b/tests/unit/engine/test_task_engine_review_fixes.py new file mode 100644 index 0000000000..d477f26f9b --- /dev/null +++ b/tests/unit/engine/test_task_engine_review_fixes.py @@ -0,0 +1,738 @@ +"""Tests for PR review fixes: race condition, immutability, validation wrapping. + +Covers fixes from Copilot, CodeRabbit, and Greptile review findings. +""" + +import asyncio +import contextlib +from types import MappingProxyType +from typing import Any + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.errors import ( + TaskEngineNotRunningError, + TaskMutationError, +) +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CreateTaskMutation, + TaskMutationResult, + TaskStateChanged, + TransitionTaskMutation, + UpdateTaskMutation, +) +from tests.unit.engine.task_engine_helpers import ( + FakeMessageBus, + FakePersistence, + make_create_data, +) + +# ── Race condition: stop/submit coordination ───────────────── + + +@pytest.mark.unit +class TestLifecycleLock: + """Lifecycle lock prevents submit-after-stop race condition.""" + + async def test_submit_rejected_after_stop( + self, + persistence: FakePersistence, + ) -> None: + """submit() raises after stop() sets _running=False.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + ) + eng.start() + await eng.stop(timeout=2.0) + mutation = CreateTaskMutation( + request_id="req-late", + requested_by="alice", + task_data=make_create_data(), + ) + with pytest.raises(TaskEngineNotRunningError): + await eng.submit(mutation) + + async def test_concurrent_stop_and_submit( + self, + persistence: FakePersistence, + ) -> None: + """Concurrent stop() and submit() resolve without hanging futures.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=TaskEngineConfig(max_queue_size=100), + ) + eng.start() + + # Create a task first so the engine is clearly working + task = await eng.create_task( + make_create_data(), + requested_by="alice", + ) + assert task is not None + + # Now race: stop and create in parallel + stop_task = asyncio.create_task(eng.stop(timeout=2.0)) + + # Yield once to let stop() begin — both outcomes are valid + await asyncio.sleep(0) + + # The create should either succeed (if enqueued before stop) + # or raise TaskEngineNotRunningError (if stop wins the lock first) + with contextlib.suppress(TaskEngineNotRunningError): + await eng.create_task( + make_create_data(), + requested_by="bob", + ) + + await stop_task + + async def test_stop_idempotent_under_lock( + self, + persistence: FakePersistence, + ) -> None: + """Two concurrent stop() calls don't deadlock or double-drain.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + ) + eng.start() + await asyncio.gather( + eng.stop(timeout=2.0), + eng.stop(timeout=2.0), + ) + assert not eng.is_running + + +# ── MappingProxyType immutability ───────────────────────────── + + +@pytest.mark.unit +class TestMutationDictImmutability: + """Mutation dicts are wrapped in MappingProxyType after construction.""" + + def test_update_mutation_updates_is_mapping_proxy(self) -> None: + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"title": "New title"}, + ) + # Runtime type differs from annotation (dict -> MappingProxyType via __init__) + assert type(mutation.updates) is MappingProxyType # type: ignore[comparison-overlap,unreachable] + + def test_update_mutation_updates_is_immutable(self) -> None: + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"title": "New title"}, + ) + with pytest.raises(TypeError): + mutation.updates["hacked"] = "value" + + def test_transition_mutation_overrides_is_mapping_proxy(self) -> None: + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + ) + # Runtime type differs from annotation (dict -> MappingProxyType via __init__) + assert type(mutation.overrides) is MappingProxyType # type: ignore[comparison-overlap,unreachable] + + def test_transition_mutation_overrides_is_immutable(self) -> None: + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + ) + with pytest.raises(TypeError): + mutation.overrides["hacked"] = "value" + + def test_update_mutation_deep_copies_input(self) -> None: + """Original dict is not affected by mutation construction.""" + original: dict[str, object] = {"title": "Original"} + mutation = UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates=original, + ) + # Modifying original shouldn't affect mutation + original["title"] = "Modified" + assert mutation.updates["title"] == "Original" + + def test_transition_mutation_deep_copies_input(self) -> None: + original: dict[str, object] = {"assigned_to": "alice"} + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides=original, + ) + original["assigned_to"] = "hacker" + assert mutation.overrides["assigned_to"] == "alice" + + +# ── TaskStateChanged.task_id ───────────────────────────────── + + +@pytest.mark.unit +class TestTaskStateChangedTaskId: + """TaskStateChanged always carries task_id.""" + + def test_task_id_required(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="task_id"): + TaskStateChanged( # type: ignore[call-arg] + mutation_type="create", + request_id="req-1", + requested_by="alice", + new_status=TaskStatus.CREATED, + version=1, + ) + + def test_task_id_on_create(self) -> None: + event = TaskStateChanged( + mutation_type="create", + request_id="req-1", + requested_by="alice", + task_id="task-new", + new_status=TaskStatus.CREATED, + version=1, + ) + assert event.task_id == "task-new" + + def test_task_id_on_delete(self) -> None: + event = TaskStateChanged( + mutation_type="delete", + request_id="req-1", + requested_by="alice", + task_id="task-deleted", + version=0, + ) + assert event.task_id == "task-deleted" + assert event.task is None + + def test_task_id_in_serialization(self) -> None: + event = TaskStateChanged( + mutation_type="create", + request_id="req-1", + requested_by="alice", + task_id="task-1", + new_status=TaskStatus.CREATED, + version=1, + ) + data = event.model_dump() + assert data["task_id"] == "task-1" + restored = TaskStateChanged.model_validate(data) + assert restored.task_id == "task-1" + + +# ── PydanticValidationError wrapping in create/delete/cancel ── + + +@pytest.mark.unit +class TestPydanticValidationWrapping: + """create_task, delete_task, cancel_task wrap PydanticValidationError.""" + + async def test_create_task_wraps_validation_error( + self, + engine: TaskEngine, + ) -> None: + """Blank requested_by triggers validation, wrapped as TaskMutationError.""" + with pytest.raises(TaskMutationError): + await engine.create_task( + make_create_data(), + requested_by=" ", + ) + + async def test_delete_task_wraps_validation_error( + self, + engine: TaskEngine, + ) -> None: + """Blank task_id triggers validation, wrapped as TaskMutationError.""" + with pytest.raises(TaskMutationError): + await engine.delete_task( + " ", + requested_by="alice", + ) + + async def test_cancel_task_wraps_validation_error( + self, + engine: TaskEngine, + ) -> None: + """Blank reason triggers validation, wrapped as TaskMutationError.""" + with pytest.raises(TaskMutationError): + await engine.cancel_task( + "task-1", + requested_by="alice", + reason=" ", + ) + + +# ── Snapshot publishing with task_id ───────────────────────── + + +@pytest.mark.unit +class TestSnapshotPublishingTaskId: + """Snapshot events include task_id.""" + + async def test_create_snapshot_includes_task_id( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + ) -> None: + """Create mutation publishes snapshot with task_id from result.""" + await message_bus.start() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=TaskEngineConfig(publish_snapshots=True), + ) + eng.start() + task = await eng.create_task( + make_create_data(), + requested_by="alice", + ) + await eng.stop(timeout=2.0) + await message_bus.stop() + + assert len(message_bus.published) >= 1 + msg: Any = message_bus.published[0] + event = TaskStateChanged.model_validate_json(msg.content) + assert event.task_id == task.id + + async def test_delete_snapshot_includes_task_id( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + ) -> None: + """Delete mutation publishes snapshot with task_id.""" + await message_bus.start() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=TaskEngineConfig(publish_snapshots=True), + ) + eng.start() + task = await eng.create_task( + make_create_data(), + requested_by="alice", + ) + await eng.delete_task(task.id, requested_by="alice") + await eng.stop(timeout=2.0) + await message_bus.stop() + + # Second message is the delete event + assert len(message_bus.published) >= 2 + msg: Any = message_bus.published[1] + event = TaskStateChanged.model_validate_json(msg.content) + assert event.task_id == task.id + assert event.task is None + + +# ── INFO logging for successful mutations ──────────────────── + + +@pytest.mark.unit +class TestMutationAppliedLogging: + """Successful mutations are logged at INFO level.""" + + async def test_create_task_logs_applied( + self, + engine: TaskEngine, + ) -> None: + """create_task logs TASK_ENGINE_MUTATION_APPLIED at INFO.""" + import structlog.testing + + with structlog.testing.capture_logs() as captured: + await engine.create_task( + make_create_data(), + requested_by="alice", + ) + + applied = [ + e for e in captured if e.get("event") == "task_engine.mutation.applied" + ] + assert len(applied) >= 1 + assert applied[0]["mutation_type"] == "create" + + +# ── FakeTaskRepository deep-copy isolation ─────────────────── + + +@pytest.mark.unit +class TestFakeTaskRepositoryIsolation: + """FakeTaskRepository deep-copies tasks to prevent test isolation leaks.""" + + async def test_save_deep_copies( + self, + persistence: FakePersistence, + engine: TaskEngine, + ) -> None: + """Two reads of the same task return distinct objects.""" + task = await engine.create_task( + make_create_data(), + requested_by="alice", + ) + read1 = await engine.get_task(task.id) + read2 = await engine.get_task(task.id) + assert read1 is not None + assert read2 is not None + assert read1 == read2 + assert read1 is not read2 + + +# ── TaskMutationResult consistency validation ──────────────── + + +@pytest.mark.unit +class TestTaskMutationResultConsistency: + """TaskMutationResult validates success/error/error_code consistency.""" + + def test_success_with_error_raises(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="error"): + TaskMutationResult( + request_id="r", + success=True, + error="should not be here", + ) + + def test_failure_without_error_raises(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="error"): + TaskMutationResult( + request_id="r", + success=False, + ) + + def test_success_with_error_code_raises(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="error_code"): + TaskMutationResult( + request_id="r", + success=True, + error_code="internal", + ) + + def test_failure_without_error_code_raises(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="error_code"): + TaskMutationResult( + request_id="r", + success=False, + error="something broke", + ) + + def test_valid_success_result(self) -> None: + result = TaskMutationResult( + request_id="r", + success=True, + ) + assert result.error is None + assert result.error_code is None + + def test_valid_failure_result(self) -> None: + result = TaskMutationResult( + request_id="r", + success=False, + error="broken", + error_code="internal", + ) + assert result.error == "broken" + assert result.error_code == "internal" + + +# ── _map_task_engine_errors 503 message sanitization ───────── + + +@pytest.mark.unit +class TestErrorMessageSanitization: + """503 responses don't leak internal error details.""" + + def test_not_running_sanitizes_message(self) -> None: + from ai_company.api.controllers.tasks import _map_task_engine_errors + + exc = TaskEngineNotRunningError("internal detail about engine state") + result = _map_task_engine_errors(exc) + assert "internal detail" not in str(result) + assert "temporarily unavailable" in str(result).lower() + + def test_queue_full_sanitizes_message(self) -> None: + from ai_company.api.controllers.tasks import _map_task_engine_errors + from ai_company.engine.errors import TaskEngineQueueFullError + + exc = TaskEngineQueueFullError("queue has 1000 items") + result = _map_task_engine_errors(exc) + assert "1000" not in str(result) + + +# ── Defensive guard: task is None after success ────────────── + + +@pytest.mark.unit +class TestConvenienceMethodTaskNoneGuard: + """Convenience methods raise TaskInternalError when task is None after success.""" + + async def test_create_task_none_guard( + self, + engine: TaskEngine, + ) -> None: + """create_task raises TaskInternalError if result.task is None.""" + from unittest.mock import AsyncMock + + from ai_company.engine.errors import TaskInternalError + + bogus = TaskMutationResult(request_id="r", success=True) + engine.submit = AsyncMock(return_value=bogus) # type: ignore[method-assign] + with pytest.raises( + TaskInternalError, match="create succeeded but task is None" + ): + await engine.create_task(make_create_data(), requested_by="alice") + + async def test_update_task_none_guard( + self, + engine: TaskEngine, + ) -> None: + """update_task raises TaskInternalError if result.task is None.""" + from unittest.mock import AsyncMock + + from ai_company.engine.errors import TaskInternalError + + bogus = TaskMutationResult(request_id="r", success=True) + engine.submit = AsyncMock(return_value=bogus) # type: ignore[method-assign] + with pytest.raises( + TaskInternalError, match="update succeeded but task is None" + ): + await engine.update_task("task-1", {"title": "X"}, requested_by="alice") + + async def test_transition_task_none_guard( + self, + engine: TaskEngine, + ) -> None: + """transition_task raises TaskInternalError if result.task is None.""" + from unittest.mock import AsyncMock + + from ai_company.engine.errors import TaskInternalError + + bogus = TaskMutationResult(request_id="r", success=True) + engine.submit = AsyncMock(return_value=bogus) # type: ignore[method-assign] + with pytest.raises( + TaskInternalError, match="transition succeeded but task is None" + ): + await engine.transition_task( + "task-1", + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + ) + + async def test_cancel_task_none_guard( + self, + engine: TaskEngine, + ) -> None: + """cancel_task raises TaskInternalError if result.task is None.""" + from unittest.mock import AsyncMock + + from ai_company.engine.errors import TaskInternalError + + bogus = TaskMutationResult(request_id="r", success=True) + engine.submit = AsyncMock(return_value=bogus) # type: ignore[method-assign] + with pytest.raises( + TaskInternalError, match="cancel succeeded but task is None" + ): + await engine.cancel_task("task-1", requested_by="alice", reason="Test") + + +# ── Processing loop unhandled-exception recovery ──────────── + + +@pytest.mark.unit +class TestProcessingLoopExceptionRecovery: + """Processing loop catches unhandled exceptions and returns internal error.""" + + async def test_unhandled_exception_returns_internal_error( + self, + persistence: FakePersistence, + ) -> None: + """Non-MemoryError exception in dispatch returns error result.""" + + async def exploding_save(task: object) -> None: + msg = "Unexpected DB crash" + raise RuntimeError(msg) + + persistence.tasks.save = exploding_save # type: ignore[method-assign] + + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + ) + eng.start() + try: + mutation = CreateTaskMutation( + request_id="req-boom", + requested_by="alice", + task_data=make_create_data(), + ) + result = await eng.submit(mutation) + assert result.success is False + assert result.error_code == "internal" + finally: + await eng.stop(timeout=2.0) + + async def test_engine_recovers_after_exception( + self, + persistence: FakePersistence, + ) -> None: + """Engine continues processing after a failed mutation.""" + call_count = 0 + original_save = persistence.tasks.save + + async def fail_once_save(task: object) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + msg = "First save fails" + raise RuntimeError(msg) + await original_save(task) # type: ignore[arg-type] + + persistence.tasks.save = fail_once_save # type: ignore[method-assign] + + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + ) + eng.start() + try: + # First create fails + m1 = CreateTaskMutation( + request_id="req-fail", + requested_by="alice", + task_data=make_create_data(), + ) + r1 = await eng.submit(m1) + assert r1.success is False + + # Second create succeeds — engine recovered + m2 = CreateTaskMutation( + request_id="req-ok", + requested_by="alice", + task_data=make_create_data(), + ) + r2 = await eng.submit(m2) + assert r2.success is True + assert r2.task is not None + finally: + await eng.stop(timeout=2.0) + + +# ── Snapshot new_status=None for non-delete (result.task is None) ── + + +@pytest.mark.unit +class TestSnapshotNewStatusNone: + """Snapshot publishes new_status=None when result.task is None (non-delete).""" + + async def test_snapshot_with_no_task_sets_new_status_none( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + ) -> None: + """When result.task is None but mutation is not delete, new_status is None.""" + await message_bus.start() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=TaskEngineConfig(publish_snapshots=True), + ) + eng.start() + # Create then delete — delete snapshot has task=None and new_status=None + task = await eng.create_task(make_create_data(), requested_by="alice") + await eng.delete_task(task.id, requested_by="alice") + await eng.stop(timeout=2.0) + await message_bus.stop() + + # The delete event should have new_status=None + assert len(message_bus.published) >= 2 + delete_msg: Any = message_bus.published[1] + event = TaskStateChanged.model_validate_json(delete_msg.content) + assert event.new_status is None + assert event.task is None + + +# ── Cancel task full lifecycle ─────────────────────────────── + + +@pytest.mark.unit +class TestCancelTaskLifecycle: + """Full lifecycle tests for cancel_task convenience method.""" + + async def test_cancel_returns_cancelled_task( + self, + engine: TaskEngine, + ) -> None: + """cancel_task returns the task in CANCELLED status.""" + task = await engine.create_task(make_create_data(), requested_by="alice") + # Must transition to ASSIGNED first (CREATED -> CANCELLED is invalid) + assigned, _ = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + cancelled = await engine.cancel_task( + assigned.id, requested_by="alice", reason="No longer needed" + ) + assert cancelled.status == TaskStatus.CANCELLED + + async def test_cancel_nonexistent_raises_not_found( + self, + engine: TaskEngine, + ) -> None: + """cancel_task for missing task raises TaskNotFoundError.""" + from ai_company.engine.errors import TaskNotFoundError + + with pytest.raises(TaskNotFoundError): + await engine.cancel_task( + "task-nonexistent", + requested_by="alice", + reason="Cleanup", + ) + + +# ── Transition task wraps blank requested_by ───────────────── + + +@pytest.mark.unit +class TestTransitionTaskValidation: + """transition_task wraps PydanticValidationError for blank fields.""" + + async def test_blank_reason_wraps_validation( + self, + engine: TaskEngine, + ) -> None: + """Blank reason is caught and wrapped as TaskMutationError.""" + task = await engine.create_task(make_create_data(), requested_by="alice") + with pytest.raises(TaskMutationError): + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason=" ", + assigned_to="bob", + ) From f1b28ae0cce8c2e67706605958f1f4ed7764cb35 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 22:41:20 +0100 Subject: [PATCH 13/14] fix: resolve 5 PR review findings (runtime bugs, future leaks, consistency) - Fix isinstance() with union type that raises TypeError at runtime - Resolve futures before re-raising MemoryError/RecursionError to prevent hung callers - Drain remaining futures on abnormal processing task exit in stop() - Add TaskVersionConflictError to update_task except block for consistency - Remove duplicate TASK_ENGINE_MUTATION_APPLIED log (already emitted in apply_*) --- src/ai_company/api/controllers/tasks.py | 3 ++- src/ai_company/engine/task_engine.py | 24 ++++++++---------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 48717265ad..038376c5e3 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -100,7 +100,7 @@ def _map_task_engine_errors( id=task_id, ) return NotFoundError(str(exc)) - if isinstance(exc, TaskEngineNotRunningError | TaskEngineQueueFullError): + if isinstance(exc, (TaskEngineNotRunningError, TaskEngineQueueFullError)): logger.error( API_TASK_MUTATION_FAILED, resource="task", @@ -312,6 +312,7 @@ async def update_task( TaskEngineNotRunningError, TaskEngineQueueFullError, TaskNotFoundError, + TaskVersionConflictError, TaskInternalError, TaskMutationError, ) as exc: diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index ca0a598d21..ebee71ee3c 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -50,7 +50,6 @@ TASK_ENGINE_FUTURES_FAILED, TASK_ENGINE_LIST_CAPPED, TASK_ENGINE_LOOP_ERROR, - TASK_ENGINE_MUTATION_APPLIED, TASK_ENGINE_MUTATION_FAILED, TASK_ENGINE_MUTATION_RECEIVED, TASK_ENGINE_NOT_RUNNING, @@ -187,7 +186,11 @@ async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 with contextlib.suppress(asyncio.CancelledError): await self._processing_task self._fail_remaining_futures(saved_in_flight) - self._processing_task = None + except BaseException: + self._fail_remaining_futures(self._in_flight) + raise + finally: + self._processing_task = None logger.info(TASK_ENGINE_STOPPED) @@ -613,7 +616,9 @@ async def _processing_loop(self) -> None: continue try: await self._process_one(envelope) - except MemoryError, RecursionError: + except (MemoryError, RecursionError) as exc: + if not envelope.future.done(): + envelope.future.set_exception(exc) raise except Exception: logger.exception( @@ -647,19 +652,6 @@ async def _process_one(self, envelope: _MutationEnvelope) -> None: ) if not envelope.future.done(): envelope.future.set_result(result) - if result.success: - task_id = getattr(mutation, "task_id", None) - logger.info( - TASK_ENGINE_MUTATION_APPLIED, - mutation_type=mutation.mutation_type, - request_id=mutation.request_id, - task_id=task_id or (result.task.id if result.task else None), - version=result.version, - previous_status=( - result.previous_status.value if result.previous_status else None - ), - new_status=(result.task.status.value if result.task else None), - ) if result.success and self._config.publish_snapshots: await self._publish_snapshot(mutation, result) except MemoryError, RecursionError: From d23952ada1008748747175ead5b91cab9f0bd827 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Thu, 12 Mar 2026 22:55:36 +0100 Subject: [PATCH 14/14] fix: list_tasks pagination total, dead PydanticValidationError handlers, restore mutation log - list_tasks now returns (tasks, total) so pagination.total reflects true count even when safety cap truncates results - paginate() accepts optional total override for capped result sets - Remove dead except PydanticValidationError handlers in create_task, transition_task, cancel_task controllers (engine wraps as TaskMutationError) - Restore TASK_ENGINE_MUTATION_APPLIED INFO log in _process_one for state transition visibility (version, previous/new status) - Resolve futures before re-raising MemoryError/RecursionError in _processing_loop to prevent hung callers - Drain remaining futures on abnormal processing task exit in stop() --- src/ai_company/api/controllers/tasks.py | 18 ++++------ src/ai_company/api/pagination.py | 7 +++- src/ai_company/engine/task_engine.py | 34 +++++++++++++++---- .../unit/engine/test_task_engine_mutations.py | 13 ++++--- 4 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index 038376c5e3..65c387c2bb 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -2,7 +2,6 @@ from litestar import Controller, delete, get, patch, post from litestar.datastructures import State # noqa: TC002 -from pydantic import ValidationError as PydanticValidationError from ai_company.api.dto import ( ApiResponse, @@ -179,14 +178,19 @@ async def list_tasks( # noqa: PLR0913 """ app_state: AppState = state.app_state try: - tasks = await app_state.task_engine.list_tasks( + tasks, total = await app_state.task_engine.list_tasks( status=status, assigned_to=assigned_to, project=project, ) except TaskInternalError as exc: raise _map_task_engine_errors(exc) from exc - page, meta = paginate(tasks, offset=offset, limit=limit) + page, meta = paginate( + tasks, + offset=offset, + limit=limit, + total=total, + ) return PaginatedResponse(data=page, pagination=meta) @get("/{task_id:str}") @@ -258,8 +262,6 @@ async def create_task( task_data, requested_by=requester, ) - except PydanticValidationError as exc: - raise ApiValidationError(str(exc)) from exc except ( TaskEngineNotRunningError, TaskEngineQueueFullError, @@ -306,8 +308,6 @@ async def update_task( requested_by=_extract_requester(state), expected_version=data.expected_version, ) - except PydanticValidationError as exc: - raise ApiValidationError(str(exc)) from exc except ( TaskEngineNotRunningError, TaskEngineQueueFullError, @@ -357,8 +357,6 @@ async def transition_task( expected_version=data.expected_version, **overrides, ) - except PydanticValidationError as exc: - raise ApiValidationError(str(exc)) from exc except ( TaskEngineNotRunningError, TaskEngineQueueFullError, @@ -438,8 +436,6 @@ async def cancel_task( requested_by=_extract_requester(state), reason=data.reason, ) - except PydanticValidationError as exc: - raise ApiValidationError(str(exc)) from exc except ( TaskEngineNotRunningError, TaskEngineQueueFullError, diff --git a/src/ai_company/api/pagination.py b/src/ai_company/api/pagination.py index c98b1637c5..4c05f298d9 100644 --- a/src/ai_company/api/pagination.py +++ b/src/ai_company/api/pagination.py @@ -28,6 +28,7 @@ def paginate[T]( *, offset: int, limit: int, + total: int | None = None, ) -> tuple[tuple[T, ...], PaginationMeta]: """Slice a tuple and produce pagination metadata. @@ -38,15 +39,19 @@ def paginate[T]( items: Full collection to paginate. offset: Zero-based starting index. limit: Maximum items to return. + total: True total count when *items* has been truncated + upstream (e.g. by a safety cap). Defaults to + ``len(items)``. Returns: A tuple of (page_items, pagination_meta). """ + effective_total = total if total is not None else len(items) offset = max(0, min(offset, len(items))) limit = max(1, min(limit, MAX_LIMIT)) page = items[offset : offset + limit] meta = PaginationMeta( - total=len(items), + total=effective_total, offset=offset, limit=limit, ) diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py index ebee71ee3c..ae4cc7d050 100644 --- a/src/ai_company/engine/task_engine.py +++ b/src/ai_company/engine/task_engine.py @@ -50,6 +50,7 @@ TASK_ENGINE_FUTURES_FAILED, TASK_ENGINE_LIST_CAPPED, TASK_ENGINE_LOOP_ERROR, + TASK_ENGINE_MUTATION_APPLIED, TASK_ENGINE_MUTATION_FAILED, TASK_ENGINE_MUTATION_RECEIVED, TASK_ENGINE_NOT_RUNNING, @@ -545,16 +546,23 @@ async def list_tasks( status: TaskStatus | None = None, assigned_to: str | None = None, project: str | None = None, - ) -> tuple[Task, ...]: + ) -> tuple[tuple[Task, ...], int]: """List tasks directly from persistence (bypass queue). + Returns a tuple of ``(tasks, total)`` where *total* is the true + count before any safety cap is applied. When the result set + exceeds ``_MAX_LIST_RESULTS``, the returned tuple is truncated + but *total* reflects the real cardinality so pagination metadata + stays accurate. + Args: status: Filter by status. assigned_to: Filter by assignee. project: Filter by project. Returns: - Matching tasks as a tuple. + ``(tasks, total)`` — *tasks* may be capped at + ``_MAX_LIST_RESULTS``; *total* is the true count. Raises: TaskInternalError: If the persistence backend fails. @@ -574,14 +582,15 @@ async def list_tasks( error=msg, ) raise TaskInternalError(msg) from exc - if len(tasks) > self._MAX_LIST_RESULTS: + total = len(tasks) + if total > self._MAX_LIST_RESULTS: logger.warning( TASK_ENGINE_LIST_CAPPED, - returned=len(tasks), + actual_total=total, cap=self._MAX_LIST_RESULTS, ) - return tasks[: self._MAX_LIST_RESULTS] - return tasks + return tasks[: self._MAX_LIST_RESULTS], total + return tasks, total # -- Background processing --------------------------------------------- @@ -652,6 +661,19 @@ async def _process_one(self, envelope: _MutationEnvelope) -> None: ) if not envelope.future.done(): envelope.future.set_result(result) + if result.success: + task_id = getattr(mutation, "task_id", None) + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + task_id=task_id or (result.task.id if result.task else None), + version=result.version, + previous_status=( + result.previous_status.value if result.previous_status else None + ), + new_status=(result.task.status.value if result.task else None), + ) if result.success and self._config.publish_snapshots: await self._publish_snapshot(mutation, result) except MemoryError, RecursionError: diff --git a/tests/unit/engine/test_task_engine_mutations.py b/tests/unit/engine/test_task_engine_mutations.py index e0f1335b5e..a9c1df9a8d 100644 --- a/tests/unit/engine/test_task_engine_mutations.py +++ b/tests/unit/engine/test_task_engine_mutations.py @@ -297,11 +297,13 @@ async def test_list_tasks( make_create_data(project="proj-b"), requested_by="alice", ) - all_tasks = await engine.list_tasks() + all_tasks, all_total = await engine.list_tasks() assert len(all_tasks) == 2 + assert all_total == 2 - filtered = await engine.list_tasks(project="proj-a") + filtered, filtered_total = await engine.list_tasks(project="proj-a") assert len(filtered) == 1 + assert filtered_total == 1 async def test_list_tasks_by_status( self, @@ -319,8 +321,8 @@ async def test_list_tasks_by_status( assigned_to="bob", ) - created = await engine.list_tasks(status=TaskStatus.CREATED) - assigned = await engine.list_tasks(status=TaskStatus.ASSIGNED) + created, _ = await engine.list_tasks(status=TaskStatus.CREATED) + assigned, _ = await engine.list_tasks(status=TaskStatus.ASSIGNED) assert len(created) == 0 assert len(assigned) == 1 @@ -430,8 +432,9 @@ async def oversized_list(**kwargs: object) -> tuple[Task, ...]: try: # Create one task so the oversized list has data to repeat await eng.create_task(make_create_data(), requested_by="alice") - tasks = await eng.list_tasks() + tasks, total = await eng.list_tasks() assert len(tasks) <= 10_000 + assert total == 20_000 finally: await eng.stop(timeout=2.0)