diff --git a/CLAUDE.md b/CLAUDE.md index 8a4df2f256..3a4a6eed15 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -92,7 +92,7 @@ src/ai_company/ communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol config/ # YAML company config loading and validation core/ # Shared domain models, base classes, and resilience config (RetryConfig, RateLimiterConfig) - engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation + engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation hr/ # HR engine: hiring, firing, onboarding, offboarding, agent registry, performance tracking (task metrics, collaboration scoring, trend detection), promotion/demotion (criteria evaluation, approval strategies, model mapping) memory/ # Persistent agent memory (Mem0 initial, custom stack future — see Decision Log), retrieval pipeline (ranking, injection, context formatting, non-inferable filtering), shared org memory (org/), consolidation/archival (consolidation/) persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (see Memory & Persistence design page) @@ -127,7 +127,7 @@ src/ai_company/ - **Every module** with business logic MUST have: `from ai_company.observability import get_logger` then `logger = get_logger(__name__)` - **Never** use `import logging` / `logging.getLogger()` / `print()` in application code - **Variable name**: always `logger` (not `_logger`, not `log`) -- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` +- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`, `TASK_ENGINE_STARTED` from `events.task_engine`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` - **Structured kwargs**: always `logger.info(EVENT, key=value)` — never `logger.info("msg %s", val)` - **All error paths** must log at WARNING or ERROR with context before raising - **All state transitions** must log at INFO diff --git a/docs/architecture/tech-stack.md b/docs/architecture/tech-stack.md index 4580351e7c..4b867d4798 100644 --- a/docs/architecture/tech-stack.md +++ b/docs/architecture/tech-stack.md @@ -119,7 +119,7 @@ These conventions are used throughout the codebase. For full details on each, se | **Cost tiers and quota tracking** | Adopted | Configurable `CostTierDefinition` with merge/override semantics. `QuotaTracker` enforces per-provider request/token quotas with window-based rotation. | | **Shared org memory** | Adopted | `OrgMemoryBackend` protocol with `HybridPromptRetrievalBackend`. Seniority-based write access control. Core policies in system prompts; extended facts retrieved on demand. | | **Memory consolidation** | Adopted | `ConsolidationStrategy` protocol with deduplication + summarization. `RetentionEnforcer` for age-based cleanup. `ArchivalStore` for cold storage. | -| **State coordination** | Planned | Centralized single-writer `TaskEngine` with `asyncio.Queue`. Agents submit requests; engine applies `model_copy(update=...)` sequentially and publishes snapshots. | +| **State coordination** | Adopted | Centralized single-writer `TaskEngine` with `asyncio.Queue`. Agents submit requests; engine applies `model_copy(update=...)` sequentially and publishes snapshots. | | **Workspace isolation** | Adopted | Pluggable `WorkspaceIsolationStrategy` protocol. Default: git worktrees with sequential merge on completion. | | **Graceful shutdown** | Adopted | Pluggable `ShutdownStrategy` protocol with cooperative 30-second timeout. Force-cancel after timeout with `INTERRUPTED` status. | | **Template inheritance** | Adopted | `extends` field triggers parent resolution at render time with deep merge by field type. Circular chain detection included. | diff --git a/docs/design/engine.md b/docs/design/engine.md index 62eafca175..6e40d84d5d 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -166,6 +166,63 @@ exceptions on failure; scoring-based strategies return --- +## TaskEngine — Centralized State Coordination + +All task state mutations flow through a single-writer `TaskEngine` that owns the +authoritative task state. This eliminates race conditions when multiple agents +attempt concurrent transitions on the same task. + +### Architecture + +```text +Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop ──▶ Persistence + │ + ├──▶ Version tracking (optimistic concurrency) + └──▶ Snapshot publishing (MessageBus) +``` + +- **Single writer**: A background `asyncio.Task` consumes `TaskMutation` + requests sequentially from an `asyncio.Queue`. +- **Immutable updates**: Each mutation calls `model_copy(update=...)` on + frozen `Task` models — the original is never mutated. +- **Optimistic concurrency**: In-memory version counters per task. + Callers can pass `expected_version` to detect stale writes; on mismatch + the engine returns a failed `TaskMutationResult` with + `error_code="version_conflict"`. Convenience methods raise + `TaskVersionConflictError`. +- **Read-through**: `get_task()` and `list_tasks()` bypass the queue and + read directly from persistence — safe because TaskEngine is the sole writer. +- **Snapshot publishing**: On success, a `TaskStateChanged` event is published + to the message bus for downstream consumers (WebSocket bridge, audit, etc.). + +### Mutation Types + +| Mutation | Description | +|----------|-------------| +| `CreateTaskMutation` | Generates a unique ID, persists, and returns the new task. | +| `UpdateTaskMutation` | Applies field updates with immutable-field rejection (`id`, `status`, `created_by`) and re-validates via `model_validate`. | +| `TransitionTaskMutation` | Validates status transition via `Task.with_transition()`, supports field overrides. | +| `DeleteTaskMutation` | Removes from persistence and clears version tracking. | +| `CancelTaskMutation` | Shortcut for transition to `CANCELLED`. | + +### Error Handling + +- **Typed errors**: `TaskNotFoundError` and `TaskVersionConflictError` provide + precise failure classification — API controllers catch these directly instead + of parsing error strings. +- **Error sanitization**: Internal exception details (SQL paths, stack traces) + are replaced with a generic message before reaching callers. +- **Queue full**: `TaskEngineQueueFullError` signals backpressure when the + queue is at capacity. + +### Lifecycle + +- **start()**: Spawns the background processing task. +- **stop()**: Sets `_running = False`, drains the queue within a configurable + timeout, then cancels. Abandoned futures receive a failure result. + +--- + ## Agent Execution Loop The agent execution loop defines how an agent processes a task from start to @@ -346,7 +403,7 @@ async run( alone when no enforcer is configured. 8. **Delegate to loop** -- calls `ExecutionLoop.execute()` with context, provider, tool invoker, budget checker, and completion config. If - `timeout_seconds` is set, wraps the call in `asyncio.wait_for`; on expiry + `timeout_seconds` is set, wraps the call in `asyncio.wait`; on expiry the run returns with `TerminationReason.ERROR` but cost recording and post-execution processing still occur. 9. **Record costs** -- records accumulated `TokenUsage` to `CostTracker` (if diff --git a/src/ai_company/api/app.py b/src/ai_company/api/app.py index 2d0eb7435d..13428569f8 100644 --- a/src/ai_company/api/app.py +++ b/src/ai_company/api/app.py @@ -35,6 +35,7 @@ from ai_company.communication.bus_protocol import MessageBus # noqa: TC001 from ai_company.config.schema import RootConfig from ai_company.core.approval import ApprovalItem # noqa: TC001 +from ai_company.engine.task_engine import TaskEngine # noqa: TC001 from ai_company.observability import get_logger from ai_company.observability.events.api import ( API_APP_SHUTDOWN, @@ -87,7 +88,9 @@ def _on_expire(item: ApprovalItem) -> None: event.model_dump_json(), channels=[CHANNEL_APPROVALS], ) - except RuntimeError, OSError: + except MemoryError, RecursionError: + raise + except Exception: logger.warning( API_APPROVAL_PUBLISH_FAILED, approval_id=item.id, @@ -102,6 +105,7 @@ def _build_lifecycle( persistence: PersistenceBackend | None, message_bus: MessageBus | None, bridge: MessageBusBridge | None, + task_engine: TaskEngine | None, app_state: AppState, ) -> tuple[ Sequence[Callable[[], Awaitable[None]]], @@ -115,23 +119,49 @@ def _build_lifecycle( async def on_startup() -> None: logger.info(API_APP_STARTUP, version=__version__) - await _safe_startup(persistence, message_bus, bridge, app_state) + await _safe_startup( + persistence, + message_bus, + bridge, + task_engine, + app_state, + ) async def on_shutdown() -> None: logger.info(API_APP_SHUTDOWN, version=__version__) - await _safe_shutdown(bridge, message_bus, persistence) + await _safe_shutdown(bridge, task_engine, message_bus, persistence) return [on_startup], [on_shutdown] -async def _cleanup_on_failure( +async def _cleanup_on_failure( # noqa: PLR0913 *, persistence: PersistenceBackend | None, started_persistence: bool, message_bus: MessageBus | None, started_bus: bool, + bridge: MessageBusBridge | None = None, + started_bridge: bool = False, + task_engine: TaskEngine | None = None, + started_task_engine: bool = False, ) -> None: - """Reverse cleanup of persistence and message bus on startup failure.""" + """Reverse cleanup on startup failure (task engine, bridge, bus, persistence).""" + if started_task_engine and task_engine is not None: + try: + await task_engine.stop() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Cleanup: failed to stop task engine", + ) + if started_bridge and bridge is not None: + try: + await bridge.stop() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Cleanup: failed to stop message bus bridge", + ) if started_bus and message_bus is not None: try: await message_bus.stop() @@ -196,15 +226,18 @@ async def _safe_startup( persistence: PersistenceBackend | None, message_bus: MessageBus | None, bridge: MessageBusBridge | None, + task_engine: TaskEngine | None, app_state: AppState, ) -> None: - """Connect persistence, resolve JWT secret, start message bus and bridge. + """Connect persistence, resolve JWT secret, start bus, bridge, task engine. Executes in order; on failure, cleans up already-started components in reverse order before re-raising. """ started_bus = False + started_bridge = False started_persistence = False + started_task_engine = False try: if persistence is not None: try: @@ -239,22 +272,38 @@ async def _safe_startup( error="Failed to start message bus bridge", ) raise + started_bridge = True + if task_engine is not None: + try: + task_engine.start() + except Exception: + logger.exception( + API_APP_STARTUP, + error="Failed to start task engine", + ) + raise + started_task_engine = True except Exception: await _cleanup_on_failure( persistence=persistence, started_persistence=started_persistence, message_bus=message_bus, started_bus=started_bus, + bridge=bridge, + started_bridge=started_bridge, + task_engine=task_engine, + started_task_engine=started_task_engine, ) raise async def _safe_shutdown( bridge: MessageBusBridge | None, + task_engine: TaskEngine | None, message_bus: MessageBus | None, persistence: PersistenceBackend | None, ) -> None: - """Stop bridge, message bus and disconnect persistence.""" + """Stop bridge, task engine, message bus and disconnect persistence.""" if bridge is not None: try: await bridge.stop() @@ -263,6 +312,14 @@ async def _safe_shutdown( API_APP_SHUTDOWN, error="Failed to stop message bus bridge", ) + if task_engine is not None: + try: + await task_engine.stop() + except Exception: + logger.exception( + API_APP_SHUTDOWN, + error="Failed to stop task engine", + ) if message_bus is not None: try: await message_bus.stop() @@ -289,6 +346,7 @@ def create_app( # noqa: PLR0913 cost_tracker: CostTracker | None = None, approval_store: ApprovalStore | None = None, auth_service: AuthService | None = None, + task_engine: TaskEngine | None = None, ) -> Litestar: """Create and configure the Litestar application. @@ -302,6 +360,7 @@ def create_app( # noqa: PLR0913 cost_tracker: Cost tracking service. approval_store: Approval queue store. auth_service: Pre-built auth service (for testing). + task_engine: Centralized task state engine. Returns: Configured Litestar application. @@ -330,6 +389,7 @@ def create_app( # noqa: PLR0913 cost_tracker=cost_tracker, approval_store=effective_approval_store, auth_service=auth_service, + task_engine=task_engine, startup_time=time.monotonic(), ) @@ -347,6 +407,7 @@ def create_app( # noqa: PLR0913 persistence, message_bus, bridge, + task_engine, app_state, ) diff --git a/src/ai_company/api/controllers/tasks.py b/src/ai_company/api/controllers/tasks.py index c309969585..88cc018876 100644 --- a/src/ai_company/api/controllers/tasks.py +++ b/src/ai_company/api/controllers/tasks.py @@ -1,6 +1,4 @@ -"""Task controller — full CRUD via TaskRepository.""" - -from uuid import uuid4 +"""Task controller — full CRUD via TaskEngine.""" from litestar import Controller, delete, get, patch, post from litestar.datastructures import State # noqa: TC002 @@ -12,16 +10,29 @@ TransitionTaskRequest, UpdateTaskRequest, ) -from ai_company.api.errors import ApiValidationError, NotFoundError +from ai_company.api.errors import ( + ApiValidationError, + NotFoundError, + ServiceUnavailableError, +) from ai_company.api.guards import require_read_access, require_write_access from ai_company.api.pagination import PaginationLimit, PaginationOffset, paginate from ai_company.api.state import AppState # noqa: TC001 from ai_company.core.enums import TaskStatus # noqa: TC001 -from ai_company.core.task import Task +from ai_company.core.task import Task # noqa: TC001 +from ai_company.engine.errors import ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskInternalError, + TaskMutationError, + TaskNotFoundError, +) +from ai_company.engine.task_engine_models import CreateTaskData from ai_company.observability import get_logger from ai_company.observability.events.api import ( API_RESOURCE_NOT_FOUND, API_TASK_DELETED, + API_TASK_TRANSITION_FAILED, API_TASK_UPDATED, ) from ai_company.observability.events.task import ( @@ -33,7 +44,7 @@ class TaskController(Controller): - """Full CRUD for tasks via ``TaskRepository``.""" + """Full CRUD for tasks via ``TaskEngine``.""" path = "/tasks" tags = ("tasks",) @@ -63,7 +74,7 @@ async def list_tasks( # noqa: PLR0913 Paginated task list. """ app_state: AppState = state.app_state - tasks = await app_state.persistence.tasks.list_tasks( + tasks = await app_state.task_engine.list_tasks( status=status, assigned_to=assigned_to, project=project, @@ -90,7 +101,7 @@ async def get_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - task = await app_state.persistence.tasks.get(task_id) + task = await app_state.task_engine.get_task(task_id) if task is None: msg = f"Task {task_id!r} not found" logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) @@ -113,9 +124,7 @@ async def create_task( Created task envelope. """ app_state: AppState = state.app_state - task_id = f"task-{uuid4().hex}" - task = Task( - id=task_id, + task_data = CreateTaskData( title=data.title, description=data.description, type=data.type, @@ -126,7 +135,19 @@ async def create_task( estimated_complexity=data.estimated_complexity, budget_limit=data.budget_limit, ) - await app_state.persistence.tasks.save(task) + try: + task = await app_state.task_engine.create_task( + task_data, + requested_by=data.created_by, + ) + except TaskEngineNotRunningError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskEngineQueueFullError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskInternalError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskMutationError as exc: + raise ApiValidationError(str(exc)) from exc logger.info( TASK_CREATED, task_id=task.id, @@ -155,17 +176,29 @@ async def update_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - task = await app_state.persistence.tasks.get(task_id) - if task is None: - msg = f"Task {task_id!r} not found" - logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) - raise NotFoundError(msg) - updates = data.model_dump(exclude_none=True) - if updates: - task = task.model_copy(update=updates) - await app_state.persistence.tasks.save(task) - logger.info(API_TASK_UPDATED, task_id=task_id, fields=list(updates)) + try: + task = await app_state.task_engine.update_task( + task_id, + updates, + requested_by="api", + ) + except TaskEngineNotRunningError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskEngineQueueFullError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskNotFoundError as exc: + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(str(exc)) from exc + except TaskInternalError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskMutationError as exc: + raise ApiValidationError(str(exc)) from exc + logger.info(API_TASK_UPDATED, task_id=task_id, fields=list(updates)) return ApiResponse(data=task) @post( @@ -192,33 +225,46 @@ async def transition_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - task = await app_state.persistence.tasks.get(task_id) - if task is None: - msg = f"Task {task_id!r} not found" - logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) - raise NotFoundError(msg) - - overrides: dict[str, object] = {} + transition_kwargs: dict[str, object] = { + "requested_by": "api", + "reason": f"API transition to {data.target_status.value}", + } if data.assigned_to is not None: - overrides["assigned_to"] = data.assigned_to - + transition_kwargs["assigned_to"] = data.assigned_to try: - new_task = task.with_transition(data.target_status, **overrides) - except ValueError as exc: + task, from_status = await app_state.task_engine.transition_task( + task_id, + data.target_status, + **transition_kwargs, # type: ignore[arg-type] + ) + except TaskEngineNotRunningError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskEngineQueueFullError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskNotFoundError as exc: logger.warning( - TASK_STATUS_CHANGED, + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(str(exc)) from exc + except TaskInternalError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskMutationError as exc: + error_str = str(exc) + logger.warning( + API_TASK_TRANSITION_FAILED, task_id=task_id, - error=str(exc), + error=error_str, ) - raise ApiValidationError(str(exc)) from exc - await app_state.persistence.tasks.save(new_task) + raise ApiValidationError(error_str) from exc logger.info( TASK_STATUS_CHANGED, task_id=task_id, - from_status=task.status.value, - to_status=new_task.status.value, + from_status=from_status.value if from_status else None, + to_status=task.status.value, ) - return ApiResponse(data=new_task) + return ApiResponse(data=task) @delete("/{task_id:str}", guards=[require_write_access], status_code=200) async def delete_task( @@ -239,10 +285,25 @@ async def delete_task( NotFoundError: If the task is not found. """ app_state: AppState = state.app_state - deleted = await app_state.persistence.tasks.delete(task_id) - if not deleted: - msg = f"Task {task_id!r} not found" - logger.warning(API_RESOURCE_NOT_FOUND, resource="task", id=task_id) - raise NotFoundError(msg) + try: + await app_state.task_engine.delete_task( + task_id, + requested_by="api", + ) + except TaskEngineNotRunningError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskEngineQueueFullError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskNotFoundError as exc: + logger.warning( + API_RESOURCE_NOT_FOUND, + resource="task", + id=task_id, + ) + raise NotFoundError(str(exc)) from exc + except TaskInternalError as exc: + raise ServiceUnavailableError(str(exc)) from exc + except TaskMutationError as exc: + raise ApiValidationError(str(exc)) from exc logger.info(API_TASK_DELETED, task_id=task_id) return ApiResponse(data=None) diff --git a/src/ai_company/api/state.py b/src/ai_company/api/state.py index 2ffd80dab2..9a8ebe847c 100644 --- a/src/ai_company/api/state.py +++ b/src/ai_company/api/state.py @@ -11,6 +11,7 @@ from ai_company.budget.tracker import CostTracker # noqa: TC001 from ai_company.communication.bus_protocol import MessageBus # noqa: TC001 from ai_company.config.schema import RootConfig # noqa: TC001 +from ai_company.engine.task_engine import TaskEngine # noqa: TC001 from ai_company.observability import get_logger from ai_company.observability.events.api import API_APP_STARTUP, API_SERVICE_UNAVAILABLE from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 @@ -22,8 +23,8 @@ class AppState: """Typed application state container. Service fields (``persistence``, ``message_bus``, ``cost_tracker``, - ``auth_service``) accept ``None`` at construction time for dev/test - mode. Property + ``auth_service``, ``task_engine``) accept ``None`` at construction + time for dev/test mode. Property accessors raise ``ServiceUnavailableError`` (HTTP 503) when the service is not configured, producing a clear error instead of an opaque ``AttributeError``. @@ -39,6 +40,7 @@ class AppState: "_cost_tracker", "_message_bus", "_persistence", + "_task_engine", "approval_store", "config", "startup_time", @@ -53,6 +55,7 @@ def __init__( # noqa: PLR0913 message_bus: MessageBus | None = None, cost_tracker: CostTracker | None = None, auth_service: AuthService | None = None, + task_engine: TaskEngine | None = None, startup_time: float = 0.0, ) -> None: self.config = config @@ -61,6 +64,7 @@ def __init__( # noqa: PLR0913 self._message_bus = message_bus self._cost_tracker = cost_tracker self._auth_service = auth_service + self._task_engine = task_engine self.startup_time = startup_time def _require_service[T](self, service: T | None, name: str) -> T: @@ -99,6 +103,34 @@ def auth_service(self) -> AuthService: """Return auth service or raise 503.""" return self._require_service(self._auth_service, "auth_service") + @property + def task_engine(self) -> TaskEngine: + """Return task engine or raise 503.""" + return self._require_service(self._task_engine, "task_engine") + + @property + def has_task_engine(self) -> bool: + """Check whether the task engine is already configured.""" + return self._task_engine is not None + + def set_task_engine(self, engine: TaskEngine) -> None: + """Set the task engine (deferred initialisation). + + Supports late binding when the task engine is created after + ``AppState`` construction. + + Args: + engine: Fully configured task engine. + + Raises: + RuntimeError: If the task engine was already configured. + """ + if self._task_engine is not None: + msg = "Task engine already configured" + logger.error(API_APP_STARTUP, error=msg) + raise RuntimeError(msg) + self._task_engine = engine + @property def has_auth_service(self) -> bool: """Check whether the auth service is already configured.""" diff --git a/src/ai_company/config/defaults.py b/src/ai_company/config/defaults.py index 02145cd16e..579d1fecbd 100644 --- a/src/ai_company/config/defaults.py +++ b/src/ai_company/config/defaults.py @@ -40,4 +40,5 @@ def default_config_dict() -> dict[str, Any]: "security": {}, "trust": {}, "promotion": {}, + "task_engine": {}, } diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index 39beb30bc2..aa6a3e37aa 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -21,6 +21,7 @@ from ai_company.core.resilience_config import RateLimiterConfig, RetryConfig from ai_company.core.role import CustomRole # noqa: TC001 from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.engine.task_engine_config import TaskEngineConfig from ai_company.hr.promotion.config import PromotionConfig from ai_company.memory.config import CompanyMemoryConfig from ai_company.memory.org.config import OrgMemoryConfig @@ -414,6 +415,7 @@ class RootConfig(BaseModel): security: Security subsystem configuration. trust: Progressive trust configuration. promotion: Promotion/demotion configuration. + task_engine: Task engine configuration. """ model_config = ConfigDict(frozen=True) @@ -521,6 +523,10 @@ class RootConfig(BaseModel): default_factory=PromotionConfig, description="Promotion/demotion configuration", ) + task_engine: TaskEngineConfig = Field( + default_factory=TaskEngineConfig, + description="Task engine configuration", + ) @model_validator(mode="after") def _validate_unique_agent_names(self) -> Self: diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index d9d43940c8..c9bd49e2b7 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -1,8 +1,8 @@ """Agent execution engine. -Re-exports the public API for the agent orchestrator, run results, -system prompt construction, runtime execution state, execution loops, -and engine errors. +Re-exports the public API for the agent orchestrator, task engine, +run results, system prompt construction, runtime execution state, +execution loops, and engine errors. """ from ai_company.engine.agent_engine import AgentEngine @@ -67,7 +67,14 @@ PromptBuildError, ResourceConflictError, TaskAssignmentError, + TaskEngineError, + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskInternalError, + TaskMutationError, + TaskNotFoundError, TaskRoutingError, + TaskVersionConflictError, WorkspaceCleanupError, WorkspaceError, WorkspaceLimitError, @@ -128,6 +135,19 @@ ShutdownResult, ShutdownStrategy, ) +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskData, + CreateTaskMutation, + DeleteTaskMutation, + TaskMutation, + TaskMutationResult, + TaskStateChanged, + TransitionTaskMutation, + UpdateTaskMutation, +) from ai_company.engine.task_execution import StatusTransition, TaskExecution from ai_company.engine.workspace import ( MergeConflict, @@ -168,10 +188,13 @@ "AuctionAssignmentStrategy", "AutoTopologyConfig", "BudgetChecker", + "CancelTaskMutation", "ClassificationResult", "CleanupCallback", "CooperativeTimeoutStrategy", "CostOptimizedAssignmentStrategy", + "CreateTaskData", + "CreateTaskMutation", "DecompositionContext", "DecompositionCycleError", "DecompositionDepthError", @@ -181,6 +204,7 @@ "DecompositionService", "DecompositionStrategy", "DefaultTokenEstimator", + "DeleteTaskMutation", "DependencyGraph", "EngineError", "ErrorFinding", @@ -239,13 +263,27 @@ "TaskAssignmentService", "TaskAssignmentStrategy", "TaskCompletionMetrics", + "TaskEngine", + "TaskEngineConfig", + "TaskEngineError", + "TaskEngineNotRunningError", + "TaskEngineQueueFullError", "TaskExecution", + "TaskInternalError", + "TaskMutation", + "TaskMutationError", + "TaskMutationResult", + "TaskNotFoundError", "TaskRoutingError", "TaskRoutingService", + "TaskStateChanged", "TaskStructureClassifier", + "TaskVersionConflictError", "TerminationReason", "TopologySelector", + "TransitionTaskMutation", "TurnRecord", + "UpdateTaskMutation", "Workspace", "WorkspaceCleanupError", "WorkspaceError", diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 3915fe78e5..c4e3701411 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -19,7 +19,7 @@ from ai_company.engine.classification.pipeline import classify_execution_errors from ai_company.engine.context import DEFAULT_MAX_TURNS, AgentContext from ai_company.engine.cost_recording import record_execution_costs -from ai_company.engine.errors import ExecutionStateError +from ai_company.engine.errors import ExecutionStateError, TaskMutationError from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, @@ -84,6 +84,7 @@ ExecutionLoop, ShutdownChecker, ) + from ai_company.engine.task_engine import TaskEngine from ai_company.providers.models import CompletionConfig from ai_company.providers.protocol import CompletionProvider from ai_company.security.config import SecurityConfig @@ -98,6 +99,21 @@ _DEFAULT_RECOVERY_STRATEGY = FailAndReassignStrategy() """Module-level default instance for the recovery strategy.""" +_REPORTABLE_STATUSES: frozenset[TaskStatus] = frozenset( + { + TaskStatus.COMPLETED, + TaskStatus.FAILED, + TaskStatus.INTERRUPTED, + TaskStatus.CANCELLED, + } +) +"""Final execution outcomes that trigger a report to the centralized TaskEngine. + +Note: ``FAILED`` and ``INTERRUPTED`` are not strictly terminal in the task +lifecycle (they can be reassigned), but represent final outcomes of this +``AgentEngine`` run that should be reported. +""" + class AgentEngine: """Top-level orchestrator for agent execution. @@ -118,6 +134,10 @@ class AgentEngine: error_taxonomy_config: Post-execution error classification. budget_enforcer: Pre-flight checks, auto-downgrade, and enhanced in-flight budget checking. + security_config: Optional security subsystem configuration. + approval_store: Optional approval queue store. + task_engine: Optional centralized task engine for reporting + final execution status. """ def __init__( # noqa: PLR0913 @@ -133,6 +153,7 @@ def __init__( # noqa: PLR0913 budget_enforcer: BudgetEnforcer | None = None, security_config: SecurityConfig | None = None, approval_store: ApprovalStore | None = None, + task_engine: TaskEngine | None = None, ) -> None: self._provider = provider self._loop: ExecutionLoop = execution_loop or ReactLoop() @@ -154,6 +175,7 @@ def __init__( # noqa: PLR0913 self._cost_tracker = cost_tracker self._security_config = security_config self._approval_store = approval_store + self._task_engine = task_engine self._recovery_strategy = recovery_strategy self._shutdown_checker = shutdown_checker self._error_taxonomy_config = error_taxonomy_config @@ -336,7 +358,11 @@ async def _post_execution_pipeline( agent_id: str, task_id: str, ) -> ExecutionResult: - """Record costs, apply transitions, run recovery and classify.""" + """Post-execution: costs, transitions, TaskEngine, recovery, classify. + + Best-effort: classification and reporting failures are logged, + never fatal. + """ await record_execution_costs( execution_result, identity, @@ -349,6 +375,7 @@ async def _post_execution_pipeline( agent_id, task_id, ) + await self._report_to_task_engine(execution_result, agent_id, task_id) if execution_result.termination_reason == TerminationReason.ERROR: execution_result = await self._apply_recovery( execution_result, @@ -636,6 +663,60 @@ def _transition_to_interrupted( ) return execution_result + async def _report_to_task_engine( + self, + execution_result: ExecutionResult, + agent_id: str, + task_id: str, + ) -> None: + """Report final execution status to the centralized TaskEngine. + + Only reports final execution outcomes (COMPLETED, FAILED, + INTERRUPTED, CANCELLED); other statuses are silently skipped. + + Best-effort: failures are logged and swallowed. If no + ``TaskEngine`` is configured, this is a no-op. + """ + if self._task_engine is None: + return + ctx = execution_result.context + if ctx.task_execution is None: + return + + final_status = ctx.task_execution.status + if final_status not in _REPORTABLE_STATUSES: + return + + try: + _, _ = await self._task_engine.transition_task( + task_id, + final_status, + requested_by=agent_id, + reason=( + "AgentEngine execution ended: " + f"{execution_result.termination_reason.value}" + ), + ) + except MemoryError, RecursionError: + raise + except TaskMutationError: + logger.warning( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error="Failed to report final status to TaskEngine (mutation rejected)", + exc_info=True, + ) + except Exception: + logger.error( + EXECUTION_ENGINE_ERROR, + agent_id=agent_id, + task_id=task_id, + error="Unexpected error reporting to TaskEngine" + " -- state may be divergent", + exc_info=True, + ) + async def _apply_recovery( self, execution_result: ExecutionResult, diff --git a/src/ai_company/engine/errors.py b/src/ai_company/engine/errors.py index 7f5fdb7bce..a597b971a1 100644 --- a/src/ai_company/engine/errors.py +++ b/src/ai_company/engine/errors.py @@ -80,3 +80,37 @@ class WorkspaceCleanupError(WorkspaceError): class WorkspaceLimitError(WorkspaceError): """Raised when maximum concurrent workspaces reached.""" + + +class TaskEngineError(EngineError): + """Base exception for all task engine errors.""" + + +class TaskEngineNotRunningError(TaskEngineError): + """Raised when a mutation is submitted to a stopped task engine.""" + + +class TaskEngineQueueFullError(TaskEngineError): + """Raised when the task engine queue is at capacity.""" + + +class TaskMutationError(TaskEngineError): + """Raised when a task mutation fails (not found, validation, etc.).""" + + +class TaskNotFoundError(TaskMutationError): + """Raised when a task is not found during mutation.""" + + +class TaskVersionConflictError(TaskMutationError): + """Raised when optimistic concurrency version does not match.""" + + +class TaskInternalError(TaskMutationError): + """Raised when a task mutation fails due to an internal engine error. + + Unlike :class:`TaskMutationError` (which covers business-rule failures + such as validation or not-found), this signals an unexpected engine fault + that the caller cannot fix by changing the request. Maps to 5xx at the API + layer. + """ diff --git a/src/ai_company/engine/task_engine.py b/src/ai_company/engine/task_engine.py new file mode 100644 index 0000000000..eea3bdeb46 --- /dev/null +++ b/src/ai_company/engine/task_engine.py @@ -0,0 +1,907 @@ +"""Centralized single-writer task engine. + +Owns all task state mutations via an ``asyncio.Queue``. A single +background task consumes mutation requests sequentially, applies +``model_copy(update=...)`` on frozen ``Task`` models, persists the +result, and publishes snapshots to the message bus. + +Reads bypass the queue and go directly to persistence -- this is safe +because the TaskEngine is the only writer. +""" + +import asyncio +import contextlib +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Never +from uuid import uuid4 + +from ai_company.core.enums import TaskStatus +from ai_company.core.task import Task +from ai_company.engine.errors import ( + TaskEngineNotRunningError, + TaskEngineQueueFullError, + TaskInternalError, + TaskMutationError, + TaskNotFoundError, + TaskVersionConflictError, +) +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskData, + CreateTaskMutation, + DeleteTaskMutation, + TaskMutation, + TaskMutationResult, + TaskStateChanged, + TransitionTaskMutation, + UpdateTaskMutation, +) +from ai_company.observability import get_logger +from ai_company.observability.events.task_engine import ( + TASK_ENGINE_CREATED, + TASK_ENGINE_DRAIN_COMPLETE, + TASK_ENGINE_DRAIN_START, + TASK_ENGINE_DRAIN_TIMEOUT, + TASK_ENGINE_LOOP_ERROR, + TASK_ENGINE_MUTATION_APPLIED, + TASK_ENGINE_MUTATION_FAILED, + TASK_ENGINE_MUTATION_RECEIVED, + TASK_ENGINE_NOT_RUNNING, + TASK_ENGINE_QUEUE_FULL, + TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED, + TASK_ENGINE_SNAPSHOT_PUBLISHED, + TASK_ENGINE_STARTED, + TASK_ENGINE_STOPPED, + TASK_ENGINE_VERSION_CONFLICT, +) + +if TYPE_CHECKING: + from ai_company.communication.bus_protocol import MessageBus + from ai_company.persistence.protocol import PersistenceBackend + +logger = get_logger(__name__) + + +@dataclass +class _MutationEnvelope: + """Pairs a mutation request with its response future. + + Note: must be instantiated within a running event loop (the + ``future`` default factory calls ``asyncio.get_running_loop()``). + """ + + mutation: TaskMutation + future: asyncio.Future[TaskMutationResult] = field( + default_factory=lambda: asyncio.get_running_loop().create_future(), + ) + + +class TaskEngine: + """Centralized single-writer for all task state mutations. + + Uses an actor-like pattern: a single background ``asyncio.Task`` + consumes ``TaskMutation`` requests from an ``asyncio.Queue``, + applies each mutation sequentially, persists the result, and + publishes state-change snapshots. + + Args: + persistence: Backend for task storage. + message_bus: Optional bus for snapshot publication. + config: Engine configuration. + """ + + def __init__( + self, + *, + persistence: PersistenceBackend, + message_bus: MessageBus | None = None, + config: TaskEngineConfig | None = None, + ) -> None: + self._persistence = persistence + self._message_bus = message_bus + self._config = config or TaskEngineConfig() + self._queue: asyncio.Queue[_MutationEnvelope] = asyncio.Queue( + maxsize=self._config.max_queue_size, + ) + self._versions: dict[str, int] = {} + self._processing_task: asyncio.Task[None] | None = None + self._running = False + logger.debug( + TASK_ENGINE_CREATED, + max_queue_size=self._config.max_queue_size, + publish_snapshots=self._config.publish_snapshots, + ) + + # -- Lifecycle --------------------------------------------------------- + + def start(self) -> None: + """Spawn the background processing loop. + + Raises: + RuntimeError: If already running. + """ + if self._running: + msg = "TaskEngine is already running" + logger.warning(TASK_ENGINE_STARTED, error=msg) + raise RuntimeError(msg) + self._running = True + self._processing_task = asyncio.create_task( + self._processing_loop(), + name="task-engine-loop", + ) + logger.info( + TASK_ENGINE_STARTED, + max_queue_size=self._config.max_queue_size, + ) + + async def stop(self, *, timeout: float | None = None) -> None: # noqa: ASYNC109 + """Stop the engine and drain pending mutations. + + Args: + timeout: Seconds to wait for drain. Defaults to + ``config.drain_timeout_seconds``. + """ + if not self._running: + return + self._running = False + effective_timeout = ( + timeout if timeout is not None else self._config.drain_timeout_seconds + ) + + if self._processing_task is not None: + logger.info( + TASK_ENGINE_DRAIN_START, + pending=self._queue.qsize(), + timeout_seconds=effective_timeout, + ) + try: + await asyncio.wait_for( + self._processing_task, + timeout=effective_timeout, + ) + logger.info(TASK_ENGINE_DRAIN_COMPLETE) + except TimeoutError: + logger.warning( + TASK_ENGINE_DRAIN_TIMEOUT, + remaining=self._queue.qsize(), + ) + self._processing_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._processing_task + self._fail_remaining_futures() + self._processing_task = None + + logger.info(TASK_ENGINE_STOPPED) + + def _fail_remaining_futures(self) -> None: + """Fail all remaining enqueued futures after drain timeout.""" + while not self._queue.empty(): + with contextlib.suppress(asyncio.QueueEmpty): + envelope = self._queue.get_nowait() + if not envelope.future.done(): + envelope.future.set_result( + TaskMutationResult( + request_id=envelope.mutation.request_id, + success=False, + error="TaskEngine shut down before processing", + error_code="internal", + ), + ) + + @property + def is_running(self) -> bool: + """Whether the engine is accepting mutations.""" + return self._running + + # -- Submit & convenience methods -------------------------------------- + + async def submit(self, mutation: TaskMutation) -> TaskMutationResult: + """Submit a mutation and await its result. + + Args: + mutation: The mutation to apply. + + Returns: + Result of the mutation. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. + """ + if not self._running: + logger.warning( + TASK_ENGINE_NOT_RUNNING, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + ) + msg = "TaskEngine is not running" + raise TaskEngineNotRunningError(msg) + + envelope = _MutationEnvelope(mutation=mutation) + try: + self._queue.put_nowait(envelope) + except asyncio.QueueFull: + logger.warning( + TASK_ENGINE_QUEUE_FULL, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + queue_size=self._queue.qsize(), + ) + msg = "TaskEngine queue is full" + raise TaskEngineQueueFullError(msg) from None + + return await envelope.future + + async def create_task( + self, + data: CreateTaskData, + *, + requested_by: str, + ) -> Task: + """Convenience: create a task and return the created Task. + + Args: + data: Task creation data. + requested_by: Identity of the requester. + + Returns: + The created task. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. + TaskMutationError: If the mutation fails. + """ + mutation = CreateTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_data=data, + ) + result = await self.submit(mutation) + if not result.success: + raise TaskMutationError(result.error or "Create failed") + if result.task is None: + msg = "Internal error: create succeeded but task is None" + raise TaskMutationError(msg) + return result.task + + async def update_task( + self, + task_id: str, + updates: dict[str, object], + *, + requested_by: str, + expected_version: int | None = None, + ) -> Task: + """Convenience: update task fields and return the updated Task. + + Args: + task_id: Target task identifier. + updates: Field-value pairs to apply. + requested_by: Identity of the requester. + expected_version: Optional optimistic concurrency version. + + Returns: + The updated task. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. + TaskNotFoundError: If the task is not found. + TaskVersionConflictError: If ``expected_version`` doesn't match. + TaskMutationError: If the mutation fails. + """ + mutation = UpdateTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + updates=updates, + expected_version=expected_version, + ) + result = await self.submit(mutation) + if not result.success: + self._raise_typed_error(result) + if result.task is None: + msg = "Internal error: update succeeded but task is None" + raise TaskMutationError(msg) + return result.task + + async def transition_task( + self, + task_id: str, + target_status: TaskStatus, + *, + requested_by: str, + reason: str = "", + expected_version: int | None = None, + **overrides: object, + ) -> tuple[Task, TaskStatus | None]: + """Convenience: transition task status and return the updated Task. + + Args: + task_id: Target task identifier. + target_status: Desired target status. + requested_by: Identity of the requester. + reason: Reason for the transition. + expected_version: Optional optimistic concurrency version. + **overrides: Additional field overrides for the transition. + + Returns: + Tuple of (transitioned task, status before the transition). + The second element is ``None`` when the previous status is unknown. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. + TaskNotFoundError: If the task is not found. + TaskVersionConflictError: If ``expected_version`` doesn't match. + TaskMutationError: If the mutation fails. + """ + effective_reason = reason or f"Transition to {target_status.value}" + mutation = TransitionTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + target_status=target_status, + reason=effective_reason, + overrides=dict(overrides), + expected_version=expected_version, + ) + result = await self.submit(mutation) + if not result.success: + self._raise_typed_error(result) + if result.task is None: + msg = "Internal error: transition succeeded but task is None" + raise TaskMutationError(msg) + return result.task, result.previous_status + + async def delete_task( + self, + task_id: str, + *, + requested_by: str, + ) -> bool: + """Convenience: delete a task and return success. + + Args: + task_id: Target task identifier. + requested_by: Identity of the requester. + + Returns: + ``True`` if the task was deleted. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. + TaskNotFoundError: If the task is not found. + TaskMutationError: If the mutation fails. + """ + mutation = DeleteTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + ) + result = await self.submit(mutation) + if not result.success: + self._raise_typed_error(result) + return True + + async def cancel_task( + self, + task_id: str, + *, + requested_by: str, + reason: str, + ) -> Task: + """Convenience: cancel a task and return the cancelled Task. + + Args: + task_id: Target task identifier. + requested_by: Identity of the requester. + reason: Reason for cancellation. + + Returns: + The cancelled task. + + Raises: + TaskEngineNotRunningError: If the engine is not running. + TaskEngineQueueFullError: If the queue is at capacity. + TaskNotFoundError: If the task is not found. + TaskMutationError: If the mutation fails. + """ + mutation = CancelTaskMutation( + request_id=uuid4().hex, + requested_by=requested_by, + task_id=task_id, + reason=reason, + ) + result = await self.submit(mutation) + if not result.success: + self._raise_typed_error(result) + if result.task is None: + msg = "Internal error: cancel succeeded but task is None" + raise TaskMutationError(msg) + return result.task + + @staticmethod + def _raise_typed_error(result: TaskMutationResult) -> Never: + """Raise a typed error from a failed mutation result.""" + error = result.error or "Mutation failed" + match result.error_code: + case "not_found": + raise TaskNotFoundError(error) + case "version_conflict": + raise TaskVersionConflictError(error) + case "internal": + raise TaskInternalError(error) + case _: + raise TaskMutationError(error) + + # -- Read-through (bypass queue) --------------------------------------- + + async def get_task(self, task_id: str) -> Task | None: + """Read a task directly from persistence (bypass queue). + + Args: + task_id: Task identifier. + + Returns: + The task, or ``None`` if not found. + """ + return await self._persistence.tasks.get(task_id) + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + """List tasks directly from persistence (bypass queue). + + Args: + status: Filter by status. + assigned_to: Filter by assignee. + project: Filter by project. + + Returns: + Matching tasks as a tuple. + """ + return await self._persistence.tasks.list_tasks( + status=status, + assigned_to=assigned_to, + project=project, + ) + + # -- Background processing --------------------------------------------- + + async def _processing_loop(self) -> None: + """Background loop: dequeue and process mutations sequentially.""" + while self._running or not self._queue.empty(): + try: + envelope = await asyncio.wait_for( + self._queue.get(), + timeout=0.5, + ) + except TimeoutError: + continue + try: + await self._process_one(envelope) + except Exception: + logger.exception( + TASK_ENGINE_LOOP_ERROR, + error="Unhandled exception in processing loop", + ) + if not envelope.future.done(): + envelope.future.set_result( + TaskMutationResult( + request_id=envelope.mutation.request_id, + success=False, + error="Internal error in processing loop", + error_code="internal", + ), + ) + + async def _process_one(self, envelope: _MutationEnvelope) -> None: + """Process a single mutation envelope.""" + mutation = envelope.mutation + logger.debug( + TASK_ENGINE_MUTATION_RECEIVED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + ) + try: + result = await self._apply_mutation(mutation) + if not envelope.future.done(): + envelope.future.set_result(result) + if result.success and self._config.publish_snapshots: + await self._publish_snapshot(mutation, result) + except Exception as exc: + internal_msg = f"{type(exc).__name__}: {exc}" + logger.exception( + TASK_ENGINE_MUTATION_FAILED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + error=internal_msg, + ) + if not envelope.future.done(): + envelope.future.set_result( + TaskMutationResult( + request_id=mutation.request_id, + success=False, + error="Internal error processing mutation", + error_code="internal", + ), + ) + + async def _apply_mutation(self, mutation: TaskMutation) -> TaskMutationResult: + """Dispatch and apply a mutation by type. + + Raises: + TypeError: If the mutation type is unrecognised. + """ + match mutation: + case CreateTaskMutation(): + return await self._apply_create(mutation) + case UpdateTaskMutation(): + return await self._apply_update(mutation) + case TransitionTaskMutation(): + return await self._apply_transition(mutation) + case DeleteTaskMutation(): + return await self._apply_delete(mutation) + case CancelTaskMutation(): + return await self._apply_cancel(mutation) + case _: + msg = f"Unknown mutation type: {type(mutation).__name__}" # type: ignore[unreachable] + raise TypeError(msg) + + def _not_found_result( + self, + mutation_type: str, + request_id: str, + task_id: str, + ) -> TaskMutationResult: + """Build a failure result for a missing task and log it.""" + error = f"Task {task_id!r} not found" + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type=mutation_type, + request_id=request_id, + task_id=task_id, + error=error, + ) + return TaskMutationResult( + request_id=request_id, + success=False, + error=error, + error_code="not_found", + ) + + async def _apply_create( + self, + mutation: CreateTaskMutation, + ) -> TaskMutationResult: + """Create a new task.""" + data = mutation.task_data + task_id = f"task-{uuid4().hex}" + + task = Task( + id=task_id, + title=data.title, + description=data.description, + type=data.type, + priority=data.priority, + project=data.project, + created_by=data.created_by, + assigned_to=data.assigned_to, + estimated_complexity=data.estimated_complexity, + budget_limit=data.budget_limit, + ) + await self._persistence.tasks.save(task) + self._versions[task_id] = 1 + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="create", + request_id=mutation.request_id, + task_id=task_id, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=task, + version=1, + ) + + async def _apply_update( + self, + mutation: UpdateTaskMutation, + ) -> TaskMutationResult: + """Update task fields.""" + task = await self._persistence.tasks.get(mutation.task_id) + if task is None: + return self._not_found_result( + "update", + mutation.request_id, + mutation.task_id, + ) + + try: + self._check_version(mutation.task_id, mutation.expected_version) + except TaskVersionConflictError as exc: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + error_code="version_conflict", + ) + + if not mutation.updates: + version = self._versions.get(mutation.task_id, 0) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=task, + version=version, + previous_status=task.status, + ) + + merged = task.model_dump() + merged.update(mutation.updates) + updated = Task.model_validate(merged) + await self._persistence.tasks.save(updated) + version = self._bump_version(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="update", + request_id=mutation.request_id, + task_id=mutation.task_id, + fields=list(mutation.updates), + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + previous_status=task.status, + ) + + async def _apply_transition( + self, + mutation: TransitionTaskMutation, + ) -> TaskMutationResult: + """Perform a task status transition.""" + task = await self._persistence.tasks.get(mutation.task_id) + if task is None: + return self._not_found_result( + "transition", + mutation.request_id, + mutation.task_id, + ) + + try: + self._check_version(mutation.task_id, mutation.expected_version) + except TaskVersionConflictError as exc: + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + error_code="version_conflict", + ) + + previous_status = task.status + + try: + updated = task.with_transition( + mutation.target_status, + **mutation.overrides, + ) + except ValueError as exc: + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type="transition", + request_id=mutation.request_id, + task_id=mutation.task_id, + error=str(exc), + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + error_code="validation", + ) + + await self._persistence.tasks.save(updated) + version = self._bump_version(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="transition", + request_id=mutation.request_id, + task_id=mutation.task_id, + from_status=previous_status.value, + to_status=mutation.target_status.value, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + previous_status=previous_status, + ) + + async def _apply_delete( + self, + mutation: DeleteTaskMutation, + ) -> TaskMutationResult: + """Delete a task.""" + deleted = await self._persistence.tasks.delete(mutation.task_id) + if not deleted: + return self._not_found_result( + "delete", + mutation.request_id, + mutation.task_id, + ) + + self._versions.pop(mutation.task_id, None) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="delete", + request_id=mutation.request_id, + task_id=mutation.task_id, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + version=0, + ) + + async def _apply_cancel( + self, + mutation: CancelTaskMutation, + ) -> TaskMutationResult: + """Cancel a task (shortcut for transition to CANCELLED).""" + task = await self._persistence.tasks.get(mutation.task_id) + if task is None: + return self._not_found_result( + "cancel", + mutation.request_id, + mutation.task_id, + ) + + previous_status = task.status + try: + updated = task.with_transition(TaskStatus.CANCELLED) + except ValueError as exc: + logger.warning( + TASK_ENGINE_MUTATION_FAILED, + mutation_type="cancel", + request_id=mutation.request_id, + task_id=mutation.task_id, + error=str(exc), + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=False, + error=str(exc), + error_code="validation", + ) + + await self._persistence.tasks.save(updated) + version = self._bump_version(mutation.task_id) + + logger.info( + TASK_ENGINE_MUTATION_APPLIED, + mutation_type="cancel", + request_id=mutation.request_id, + task_id=mutation.task_id, + from_status=previous_status.value, + to_status=TaskStatus.CANCELLED.value, + ) + return TaskMutationResult( + request_id=mutation.request_id, + success=True, + task=updated, + version=version, + previous_status=previous_status, + ) + + # -- Snapshot publishing ----------------------------------------------- + + async def _publish_snapshot( + self, + mutation: TaskMutation, + result: TaskMutationResult, + ) -> None: + """Publish a TaskStateChanged event to the message bus. + + Best-effort: failures are logged and swallowed. + """ + if self._message_bus is None: + return + + if isinstance(mutation, DeleteTaskMutation): + new_status = None + elif result.task is not None: + new_status = result.task.status + else: + new_status = None + + event = TaskStateChanged( + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + requested_by=mutation.requested_by, + task=result.task, + previous_status=result.previous_status, + new_status=new_status, + version=result.version, + timestamp=datetime.now(UTC), + ) + + try: + # Deferred to break circular import: + # communication -> engine -> communication + from ai_company.communication.enums import MessageType # noqa: PLC0415 + from ai_company.communication.message import Message # noqa: PLC0415 + + msg = Message( + timestamp=datetime.now(UTC), + sender="task-engine", + to="task_engine", + type=MessageType.TASK_UPDATE, + channel="task_engine", + content=event.model_dump_json(), + ) + await self._message_bus.publish(msg) + logger.debug( + TASK_ENGINE_SNAPSHOT_PUBLISHED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.warning( + TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED, + mutation_type=mutation.mutation_type, + request_id=mutation.request_id, + exc_info=True, + ) + + # -- Version tracking -------------------------------------------------- + + def _bump_version(self, task_id: str) -> int: + """Increment and return the version counter for a task.""" + version = self._versions.get(task_id, 0) + 1 + self._versions[task_id] = version + return version + + def _check_version( + self, + task_id: str, + expected_version: int | None, + ) -> None: + """Check optimistic concurrency version if provided. + + Raises: + TaskVersionConflictError: If versions don't match. + """ + if expected_version is None: + return + current = self._versions.get(task_id, 0) + if current != expected_version: + msg = ( + f"Version conflict for task {task_id!r}: " + f"expected {expected_version}, current {current}" + ) + logger.warning( + TASK_ENGINE_VERSION_CONFLICT, + task_id=task_id, + expected_version=expected_version, + current_version=current, + ) + raise TaskVersionConflictError(msg) diff --git a/src/ai_company/engine/task_engine_config.py b/src/ai_company/engine/task_engine_config.py new file mode 100644 index 0000000000..e0e4e5e4da --- /dev/null +++ b/src/ai_company/engine/task_engine_config.py @@ -0,0 +1,37 @@ +"""Task engine configuration model.""" + +from pydantic import BaseModel, ConfigDict, Field + + +class TaskEngineConfig(BaseModel): + """Configuration for the centralized task engine. + + Controls queue sizing, drain behaviour on shutdown, and whether + state-change snapshots are published to the message bus. + + Attributes: + max_queue_size: Maximum pending mutations before backpressure + is applied. ``0`` means unbounded. + drain_timeout_seconds: Seconds to wait for pending mutations + to drain during ``stop()``. + publish_snapshots: Whether to publish ``TaskStateChanged`` + events to the message bus after each mutation. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + max_queue_size: int = Field( + default=1000, + ge=0, + description="Maximum pending mutations (0 = unbounded)", + ) + drain_timeout_seconds: float = Field( + default=10.0, + gt=0, + le=300, + description="Seconds to wait for drain during stop()", + ) + publish_snapshots: bool = Field( + default=True, + description="Publish TaskStateChanged to message bus", + ) diff --git a/src/ai_company/engine/task_engine_models.py b/src/ai_company/engine/task_engine_models.py new file mode 100644 index 0000000000..f977ebc8df --- /dev/null +++ b/src/ai_company/engine/task_engine_models.py @@ -0,0 +1,318 @@ +"""Task engine request, response, and event models. + +All mutation requests are frozen Pydantic models, discriminated by a +``mutation_type`` literal. Each request carries a ``request_id`` and +``requested_by`` field for tracing and auditing. +""" + +from datetime import UTC, datetime +from typing import Literal, Self + +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator + +from ai_company.core.enums import Complexity, Priority, TaskStatus, TaskType +from ai_company.core.task import Task # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 + +# ── Mutation data ───────────────────────────────────────────── + + +class CreateTaskData(BaseModel): + """Data required to create a new task (server-generated fields excluded). + + Mirrors :class:`~ai_company.api.dto.CreateTaskRequest` but lives in + the engine layer so it has no dependency on the API (field parity is + maintained by convention, not enforced). + + Attributes: + title: Short task title. + description: Detailed task description. + type: Task work type. + priority: Task priority level. + project: Project ID. + created_by: Agent name of the creator. + assigned_to: Optional assignee agent ID. + estimated_complexity: Complexity estimate. + budget_limit: Maximum USD spend. + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + title: NotBlankStr = Field(description="Short task title") + description: NotBlankStr = Field(description="Detailed task description") + type: TaskType = Field(description="Task work type") + priority: Priority = Field(default=Priority.MEDIUM, description="Task priority") + project: NotBlankStr = Field(description="Project ID") + created_by: NotBlankStr = Field(description="Agent name of the creator") + assigned_to: NotBlankStr | None = Field( + default=None, + description="Assignee agent ID", + ) + estimated_complexity: Complexity = Field( + default=Complexity.MEDIUM, + description="Complexity estimate", + ) + budget_limit: float = Field( + default=0.0, + ge=0.0, + description="Maximum USD spend", + ) + + +# ── Mutation requests ───────────────────────────────────────── + + +class CreateTaskMutation(BaseModel): + """Request to create a new task. + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_data: Task creation payload. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["create"] = "create" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_data: CreateTaskData = Field(description="Task creation payload") + + +_IMMUTABLE_TASK_FIELDS: frozenset[str] = frozenset( + { + "id", + "status", + "created_by", + } +) +"""Fields that must not be modified via :class:`UpdateTaskMutation`. + +``status`` must go through :class:`TransitionTaskMutation` (which +validates the state machine); ``id`` and ``created_by`` are identity +fields set at creation time. +""" + + +class UpdateTaskMutation(BaseModel): + """Request to update task fields. + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_id: Target task identifier. + updates: Field-value pairs to apply (immutable fields rejected). + expected_version: Optional optimistic concurrency version. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["update"] = "update" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Target task identifier") + updates: dict[str, object] = Field(description="Field-value pairs to apply") + expected_version: int | None = Field( + default=None, + ge=1, + description="Optional optimistic concurrency version", + ) + + @model_validator(mode="after") + def _reject_immutable_fields(self) -> Self: + forbidden = set(self.updates) & _IMMUTABLE_TASK_FIELDS + if forbidden: + msg = f"Cannot update immutable fields: {sorted(forbidden)}" + raise ValueError(msg) + return self + + +_IMMUTABLE_OVERRIDE_FIELDS: frozenset[str] = frozenset( + { + "id", + "created_by", + "status", + } +) +"""Fields that must not be overridden during a transition.""" + + +class TransitionTaskMutation(BaseModel): + """Request to perform a task status transition. + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_id: Target task identifier. + target_status: Desired target status. + reason: Reason for the transition. + overrides: Additional field overrides (immutable fields rejected). + expected_version: Optional optimistic concurrency version. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["transition"] = "transition" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Target task identifier") + target_status: TaskStatus = Field(description="Desired target status") + reason: NotBlankStr = Field(description="Reason for the transition") + overrides: dict[str, object] = Field( + default_factory=dict, + description="Additional field overrides", + ) + expected_version: int | None = Field( + default=None, + ge=1, + description="Optional optimistic concurrency version", + ) + + @model_validator(mode="after") + def _reject_immutable_overrides(self) -> Self: + forbidden = set(self.overrides) & _IMMUTABLE_OVERRIDE_FIELDS + if forbidden: + msg = f"Cannot override immutable fields: {sorted(forbidden)}" + raise ValueError(msg) + return self + + +class DeleteTaskMutation(BaseModel): + """Request to delete a task. + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_id: Target task identifier. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["delete"] = "delete" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Target task identifier") + + +class CancelTaskMutation(BaseModel): + """Request to cancel a task (shortcut for transition to CANCELLED). + + Attributes: + mutation_type: Discriminator literal. + request_id: Unique request identifier for tracing. + requested_by: Identity of the requester. + task_id: Target task identifier. + reason: Reason for cancellation. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["cancel"] = "cancel" + request_id: NotBlankStr = Field(description="Unique request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task_id: NotBlankStr = Field(description="Target task identifier") + reason: NotBlankStr = Field(description="Reason for cancellation") + + +TaskMutation = ( + CreateTaskMutation + | UpdateTaskMutation + | TransitionTaskMutation + | DeleteTaskMutation + | CancelTaskMutation +) +"""Union of all task mutation request types.""" + + +# ── Mutation result ─────────────────────────────────────────── + + +class TaskMutationResult(BaseModel): + """Result of a processed task mutation. + + Attributes: + request_id: Echoed request identifier. + success: Whether the mutation succeeded. + task: The task after mutation (``None`` on delete or failure). + version: Current version counter for the task. + previous_status: Status before the mutation (``None`` on create + or failure). + error: Error description (``None`` on success). + error_code: Machine-readable error classification for reliable + dispatch (``None`` on success). + """ + + model_config = ConfigDict(frozen=True) + + request_id: NotBlankStr = Field(description="Echoed request identifier") + success: bool = Field(description="Whether the mutation succeeded") + task: Task | None = Field(default=None, description="Task after mutation") + version: int = Field(default=0, ge=0, description="Version counter") + previous_status: TaskStatus | None = Field( + default=None, + description="Status before mutation", + ) + error: str | None = Field(default=None, description="Error description") + error_code: ( + Literal["not_found", "version_conflict", "validation", "internal"] | None + ) = Field( + default=None, + description="Machine-readable error classification", + ) + + @model_validator(mode="after") + def _check_consistency(self) -> Self: + if self.success and self.error is not None: + msg = "Successful result must not carry an error" + raise ValueError(msg) + if not self.success and self.error is None: + msg = "Failed result must carry an error description" + raise ValueError(msg) + return self + + +# ── State-change event ──────────────────────────────────────── + + +class TaskStateChanged(BaseModel): + """Event published to the message bus after each successful mutation. + + Attributes: + mutation_type: Type of mutation that triggered the event. + request_id: Originating request identifier. + requested_by: Identity of the requester. + task: Task snapshot after mutation (``None`` on delete). + previous_status: Status before the mutation (``None`` on create). + new_status: Status after the mutation (``None`` on delete). + version: Version counter after mutation. + timestamp: When the mutation was applied. + """ + + model_config = ConfigDict(frozen=True) + + mutation_type: Literal["create", "update", "transition", "delete", "cancel"] = ( + Field(description="Mutation type that triggered event") + ) + request_id: NotBlankStr = Field(description="Originating request identifier") + requested_by: NotBlankStr = Field(description="Identity of the requester") + task: Task | None = Field( + default=None, + description="Task snapshot after mutation", + ) + previous_status: TaskStatus | None = Field( + default=None, + description="Status before mutation", + ) + new_status: TaskStatus | None = Field( + default=None, + description="Status after mutation", + ) + version: int = Field(ge=0, description="Version counter after mutation") + timestamp: AwareDatetime = Field( + default_factory=lambda: datetime.now(UTC), + description="When the mutation was applied", + ) diff --git a/src/ai_company/observability/events/api.py b/src/ai_company/observability/events/api.py index 8890cd7538..67aecd48cb 100644 --- a/src/ai_company/observability/events/api.py +++ b/src/ai_company/observability/events/api.py @@ -35,3 +35,4 @@ API_AUTH_TOKEN_ISSUED: Final[str] = "api.auth.token_issued" # noqa: S105 API_AUTH_SETUP_COMPLETE: Final[str] = "api.auth.setup_complete" API_AUTH_PASSWORD_CHANGED: Final[str] = "api.auth.password_changed" # noqa: S105 +API_TASK_TRANSITION_FAILED: Final[str] = "api.task.transition_failed" diff --git a/src/ai_company/observability/events/task_engine.py b/src/ai_company/observability/events/task_engine.py new file mode 100644 index 0000000000..a97427a64e --- /dev/null +++ b/src/ai_company/observability/events/task_engine.py @@ -0,0 +1,19 @@ +"""Task engine event constants.""" + +from typing import Final + +TASK_ENGINE_CREATED: Final[str] = "task_engine.created" +TASK_ENGINE_STARTED: Final[str] = "task_engine.started" +TASK_ENGINE_STOPPED: Final[str] = "task_engine.stopped" +TASK_ENGINE_MUTATION_RECEIVED: Final[str] = "task_engine.mutation.received" +TASK_ENGINE_MUTATION_APPLIED: Final[str] = "task_engine.mutation.applied" +TASK_ENGINE_MUTATION_FAILED: Final[str] = "task_engine.mutation.failed" +TASK_ENGINE_SNAPSHOT_PUBLISHED: Final[str] = "task_engine.snapshot.published" +TASK_ENGINE_SNAPSHOT_PUBLISH_FAILED: Final[str] = "task_engine.snapshot.publish_failed" +TASK_ENGINE_QUEUE_FULL: Final[str] = "task_engine.queue.full" +TASK_ENGINE_DRAIN_START: Final[str] = "task_engine.drain.start" +TASK_ENGINE_DRAIN_COMPLETE: Final[str] = "task_engine.drain.complete" +TASK_ENGINE_DRAIN_TIMEOUT: Final[str] = "task_engine.drain.timeout" +TASK_ENGINE_NOT_RUNNING: Final[str] = "task_engine.not_running" +TASK_ENGINE_VERSION_CONFLICT: Final[str] = "task_engine.version.conflict" +TASK_ENGINE_LOOP_ERROR: Final[str] = "task_engine.loop.error" diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index 32baeee0ce..aa7befec2b 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -2,6 +2,7 @@ import asyncio import uuid +from collections.abc import Generator # noqa: TC003 from datetime import UTC, datetime, timedelta from typing import Any @@ -26,6 +27,7 @@ TaskStatus, ) from ai_company.core.task import Task +from ai_company.engine.task_engine import TaskEngine from ai_company.persistence.errors import DuplicateRecordError, QueryError from ai_company.security.models import AuditEntry, AuditVerdictStr # noqa: TC001 from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 @@ -604,6 +606,16 @@ def root_config() -> RootConfig: return RootConfig(company_name="test-company") +@pytest.fixture +def fake_task_engine( + fake_persistence: FakePersistenceBackend, +) -> TaskEngine: + """TaskEngine backed by the shared fake persistence.""" + return TaskEngine( + persistence=fake_persistence, + ) + + @pytest.fixture def test_client( # noqa: PLR0913 fake_persistence: FakePersistenceBackend, @@ -612,7 +624,8 @@ def test_client( # noqa: PLR0913 approval_store: ApprovalStore, root_config: RootConfig, auth_service: AuthService, -) -> TestClient[Any]: + fake_task_engine: TaskEngine, +) -> Generator[TestClient[Any]]: # Pre-seed users for each role so JWT sub claims resolve _seed_test_users(fake_persistence, auth_service) @@ -623,11 +636,12 @@ def test_client( # noqa: PLR0913 cost_tracker=cost_tracker, approval_store=approval_store, auth_service=auth_service, + task_engine=fake_task_engine, ) - client = TestClient(app) - # Default: CEO token (most tests need write access) - client.headers.update(make_auth_headers("ceo")) - return client + with TestClient(app) as client: + # Default: CEO token (most tests need write access) + client.headers.update(make_auth_headers("ceo")) + yield client def _seed_test_users( diff --git a/tests/unit/api/test_app.py b/tests/unit/api/test_app.py index be172178fb..4a10fadbec 100644 --- a/tests/unit/api/test_app.py +++ b/tests/unit/api/test_app.py @@ -63,7 +63,7 @@ async def failing_start() -> None: bus.start = failing_start # type: ignore[method-assign] with pytest.raises(RuntimeError, match="bus boom"): - await _safe_startup(persistence, bus, None, app_state) + await _safe_startup(persistence, bus, None, None, app_state) # Persistence should have been disconnected during cleanup assert not persistence.is_connected @@ -81,4 +81,50 @@ async def failing_disconnect() -> None: persistence.disconnect = failing_disconnect # type: ignore[method-assign] # Should not raise even when disconnect fails - await _safe_shutdown(None, None, persistence) + await _safe_shutdown(None, None, None, persistence) + + async def test_task_engine_failure_cleans_up( + self, + root_config: Any, + ) -> None: + """Task engine start fails → persistence + bus cleaned up.""" + from unittest.mock import MagicMock + + from ai_company.api.app import _safe_startup + from ai_company.api.approval_store import ApprovalStore + from ai_company.api.state import AppState + from tests.unit.api.conftest import ( + FakeMessageBus, + FakePersistenceBackend, + ) + + persistence = FakePersistenceBackend() + bus = FakeMessageBus() + mock_te = MagicMock() + mock_te.start = MagicMock(side_effect=RuntimeError("engine boom")) + mock_te.stop = MagicMock() + + app_state = AppState( + config=root_config, + approval_store=ApprovalStore(), + persistence=persistence, + ) + + with pytest.raises(RuntimeError, match="engine boom"): + await _safe_startup(persistence, bus, None, mock_te, app_state) + + # Persistence and bus should be cleaned up + assert not persistence.is_connected + assert not bus.is_running + + async def test_shutdown_task_engine_failure_does_not_propagate(self) -> None: + """Task engine stop failure during shutdown is logged, not raised.""" + from unittest.mock import AsyncMock, MagicMock + + from ai_company.api.app import _safe_shutdown + + mock_te = MagicMock() + mock_te.stop = AsyncMock(side_effect=RuntimeError("stop boom")) + + # Should not raise even when task engine stop fails + await _safe_shutdown(None, mock_te, None, None) diff --git a/tests/unit/api/test_state.py b/tests/unit/api/test_state.py index 792907cead..404a13d51a 100644 --- a/tests/unit/api/test_state.py +++ b/tests/unit/api/test_state.py @@ -90,3 +90,47 @@ def test_set_auth_service_twice_raises(self) -> None: state = _make_state(auth_service=svc) with pytest.raises(RuntimeError, match="already configured"): state.set_auth_service(svc) + + +@pytest.mark.unit +class TestAppStateTaskEngine: + """Tests for task_engine property, has_task_engine, set_task_engine.""" + + def test_task_engine_raises_when_none(self) -> None: + state = _make_state(task_engine=None) + with pytest.raises(ServiceUnavailableError): + _ = state.task_engine + + def test_task_engine_returns_when_set(self) -> None: + from unittest.mock import MagicMock + + engine = MagicMock() + state = _make_state(task_engine=engine) + assert state.task_engine is engine + + def test_has_task_engine_false_when_none(self) -> None: + state = _make_state(task_engine=None) + assert state.has_task_engine is False + + def test_has_task_engine_true_when_set(self) -> None: + from unittest.mock import MagicMock + + engine = MagicMock() + state = _make_state(task_engine=engine) + assert state.has_task_engine is True + + def test_set_task_engine_succeeds_once(self) -> None: + from unittest.mock import MagicMock + + engine = MagicMock() + state = _make_state() + state.set_task_engine(engine) + assert state.task_engine is engine + + def test_set_task_engine_twice_raises(self) -> None: + from unittest.mock import MagicMock + + engine = MagicMock() + state = _make_state(task_engine=engine) + with pytest.raises(RuntimeError, match="already configured"): + state.set_task_engine(engine) diff --git a/tests/unit/engine/conftest.py b/tests/unit/engine/conftest.py index 3901cd71f4..4b6d6d1e99 100644 --- a/tests/unit/engine/conftest.py +++ b/tests/unit/engine/conftest.py @@ -27,6 +27,8 @@ from ai_company.core.role import Authority, Role from ai_company.core.task import AcceptanceCriterion, Task from ai_company.engine.context import AgentContext +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_config import TaskEngineConfig from ai_company.engine.task_execution import TaskExecution from ai_company.providers.capabilities import ModelCapabilities from ai_company.providers.enums import FinishReason @@ -38,9 +40,10 @@ TokenUsage, ToolDefinition, ) +from tests.unit.engine.task_engine_helpers import FakeMessageBus, FakePersistence if TYPE_CHECKING: - from collections.abc import AsyncIterator + from collections.abc import AsyncGenerator, AsyncIterator from ai_company.core.enums import ConflictEscalation from ai_company.engine.workspace.models import ( @@ -402,3 +405,56 @@ def make_assignment_task(**overrides: object) -> Task: } defaults.update(overrides) return Task(**defaults) # type: ignore[arg-type] + + +# ── TaskEngine fixtures ─────────────────────────────────────── + + +@pytest.fixture +def persistence() -> FakePersistence: + """Provide a fresh FakePersistence instance.""" + return FakePersistence() + + +@pytest.fixture +def message_bus() -> FakeMessageBus: + """Provide a fresh FakeMessageBus instance.""" + return FakeMessageBus() + + +@pytest.fixture +def config() -> TaskEngineConfig: + """Provide a TaskEngineConfig with a sensible queue size.""" + return TaskEngineConfig(max_queue_size=100) + + +@pytest.fixture +async def engine( + persistence: FakePersistence, + config: TaskEngineConfig, +) -> AsyncGenerator[TaskEngine]: + """Create and start a TaskEngine, stop on teardown.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + yield eng + await eng.stop(timeout=2.0) + + +@pytest.fixture +async def engine_with_bus( + persistence: FakePersistence, + message_bus: FakeMessageBus, + config: TaskEngineConfig, +) -> AsyncGenerator[TaskEngine]: + """Create and start a TaskEngine with a message bus.""" + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + yield eng + await eng.stop(timeout=2.0) diff --git a/tests/unit/engine/task_engine_helpers.py b/tests/unit/engine/task_engine_helpers.py new file mode 100644 index 0000000000..6837d0be44 --- /dev/null +++ b/tests/unit/engine/task_engine_helpers.py @@ -0,0 +1,102 @@ +"""Shared fakes and helpers for TaskEngine tests.""" + +from typing import TYPE_CHECKING + +from ai_company.core.task import Task # noqa: TC001 +from ai_company.engine.task_engine_models import CreateTaskData + +if TYPE_CHECKING: + from ai_company.core.enums import TaskStatus + + +# ── Fakes ───────────────────────────────────────────────────── + + +class FakeTaskRepository: + """Minimal in-memory task repository for engine tests.""" + + def __init__(self) -> None: + self._tasks: dict[str, Task] = {} + + async def save(self, task: Task) -> None: + self._tasks[task.id] = task + + async def get(self, task_id: str) -> Task | None: + return self._tasks.get(task_id) + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + result = list(self._tasks.values()) + if status is not None: + result = [t for t in result if t.status == status] + if assigned_to is not None: + result = [t for t in result if t.assigned_to == assigned_to] + if project is not None: + result = [t for t in result if t.project == project] + return tuple(result) + + async def delete(self, task_id: str) -> bool: + return self._tasks.pop(task_id, None) is not None + + +class FakePersistence: + """Minimal fake persistence backend with only a task repository.""" + + def __init__(self) -> None: + self._tasks = FakeTaskRepository() + + @property + def tasks(self) -> FakeTaskRepository: + return self._tasks + + +class FakeMessageBus: + """Minimal fake message bus that records published messages.""" + + def __init__(self) -> None: + self.published: list[object] = [] + self._running = False + + async def start(self) -> None: + self._running = True + + async def stop(self) -> None: + self._running = False + + @property + def is_running(self) -> bool: + return self._running + + async def publish(self, message: object) -> None: + self.published.append(message) + + +class FailingMessageBus(FakeMessageBus): + """Message bus that always fails on publish.""" + + async def publish(self, message: object) -> None: + msg = "Publish failed" + raise RuntimeError(msg) + + +# ── Helpers ──────────────────────────────────────────────────── + + +def _make_create_data(**overrides: object) -> CreateTaskData: + """Build a CreateTaskData with sensible defaults.""" + from ai_company.core.enums import TaskType + + defaults: dict[str, object] = { + "title": "Test task", + "description": "A test task", + "type": TaskType.DEVELOPMENT, + "project": "test-project", + "created_by": "test-agent", + } + defaults.update(overrides) + return CreateTaskData(**defaults) # type: ignore[arg-type] diff --git a/tests/unit/engine/test_agent_engine.py b/tests/unit/engine/test_agent_engine.py index 6ff75b0839..75da927abc 100644 --- a/tests/unit/engine/test_agent_engine.py +++ b/tests/unit/engine/test_agent_engine.py @@ -14,7 +14,7 @@ from ai_company.core.task import Task from ai_company.engine.agent_engine import AgentEngine from ai_company.engine.context import AgentContext -from ai_company.engine.errors import ExecutionStateError +from ai_company.engine.errors import ExecutionStateError, TaskMutationError from ai_company.engine.loop_protocol import ( ExecutionResult, TerminationReason, @@ -931,3 +931,187 @@ async def test_prompt_token_ratio_warning( # noqa: PLR0913 assert "prompt_token_ratio" in warning_events[0] else: assert len(warning_events) == 0 + + +@pytest.mark.unit +class TestReportToTaskEngine: + """Tests for _report_to_task_engine interaction.""" + + async def test_no_task_engine_is_noop( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Without task_engine, run() succeeds and no reporting occurs.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + engine = AgentEngine(provider=provider, task_engine=None) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.is_success is True + + async def test_nonterminal_status_skipped( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Non-terminal task status does not trigger TaskEngine call.""" + # Build a mock loop that returns IN_PROGRESS (non-terminal) + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition( + TaskStatus.IN_PROGRESS, + reason="Engine starting execution", + ) + mock_result = ExecutionResult( + context=ctx, + termination_reason=TerminationReason.MAX_TURNS, + turns=( + TurnRecord( + turn_number=1, + input_tokens=10, + output_tokens=5, + cost_usd=0.001, + finish_reason=FinishReason.STOP, + ), + ), + ) + mock_loop = MagicMock() + mock_loop.execute = AsyncMock(return_value=mock_result) + mock_loop.get_loop_type = MagicMock(return_value="react") + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock() + + provider = mock_provider_factory([]) + engine = AgentEngine( + provider=provider, + execution_loop=mock_loop, + task_engine=mock_te, + recovery_strategy=None, + ) + + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + mock_te.transition_task.assert_not_awaited() + + async def test_terminal_status_reported( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """COMPLETED status is reported to TaskEngine.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock() + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + assert result.is_success is True + mock_te.transition_task.assert_awaited_once() + call_args = mock_te.transition_task.call_args + assert call_args.args[1] == TaskStatus.COMPLETED + + async def test_mutation_error_swallowed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """TaskMutationError from TaskEngine is logged and swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock( + side_effect=TaskMutationError("rejected"), + ) + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + # Run still succeeds despite task engine failure + assert result.is_success is True + + async def test_unexpected_error_swallowed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """Unexpected Exception from TaskEngine is logged and swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock( + side_effect=RuntimeError("connection lost"), + ) + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + result = await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) + + # Run still succeeds despite task engine failure + assert result.is_success is True + + async def test_memory_error_propagates( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + """MemoryError from TaskEngine is re-raised, not swallowed.""" + response = _make_completion_response() + provider = mock_provider_factory([response]) + + mock_te = MagicMock() + mock_te.transition_task = AsyncMock( + side_effect=MemoryError("out of memory"), + ) + + engine = AgentEngine( + provider=provider, + task_engine=mock_te, + ) + + with pytest.raises(MemoryError, match="out of memory"): + await engine.run( + identity=sample_agent_with_personality, + task=sample_task_with_criteria, + ) diff --git a/tests/unit/engine/test_task_engine_integration.py b/tests/unit/engine/test_task_engine_integration.py new file mode 100644 index 0000000000..2298a9de1b --- /dev/null +++ b/tests/unit/engine/test_task_engine_integration.py @@ -0,0 +1,326 @@ +"""Integration tests for TaskEngine: publishing, ordering, queue, versioning, drain.""" + +import asyncio +import contextlib + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.errors import TaskEngineQueueFullError +from ai_company.engine.task_engine import TaskEngine, _MutationEnvelope +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import ( + CreateTaskMutation, + DeleteTaskMutation, + TransitionTaskMutation, + UpdateTaskMutation, +) +from tests.unit.engine.task_engine_helpers import ( + FailingMessageBus, + FakeMessageBus, + FakePersistence, + _make_create_data, +) + +# ── Snapshot publishing ─────────────────────────────────────── + + +@pytest.mark.unit +class TestSnapshotPublishing: + """Tests for event publishing to the message bus.""" + + async def test_snapshot_published_on_create( + self, + engine_with_bus: TaskEngine, + message_bus: FakeMessageBus, + ) -> None: + await engine_with_bus.create_task( + _make_create_data(), + requested_by="alice", + ) + # Yield to event loop so the processing loop completes snapshot publication + await asyncio.sleep(0) + assert len(message_bus.published) == 1 + + async def test_snapshot_publish_failure_does_not_affect_mutation( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + failing_bus = FailingMessageBus() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=failing_bus, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + task = await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + assert task.id.startswith("task-") + + stored = await persistence.tasks.get(task.id) + assert stored is not None + finally: + await eng.stop(timeout=2.0) + + async def test_no_snapshot_when_disabled( + self, + persistence: FakePersistence, + message_bus: FakeMessageBus, + ) -> None: + no_snap_config = TaskEngineConfig(publish_snapshots=False) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + message_bus=message_bus, # type: ignore[arg-type] + config=no_snap_config, + ) + eng.start() + try: + await eng.create_task( + _make_create_data(), + requested_by="alice", + ) + await asyncio.sleep(0) + assert len(message_bus.published) == 0 + finally: + await eng.stop(timeout=2.0) + + async def test_pending_mutations_drained_on_stop( + self, + persistence: FakePersistence, + ) -> None: + """Tasks submitted before stop() are processed during drain.""" + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + + # Submit concurrently without awaiting — items enter the queue + create_tasks = [ + asyncio.create_task( + eng.create_task(_make_create_data(), requested_by="alice") + ) + for _ in range(5) + ] + + # Yield to the event loop so the tasks can enqueue their mutations before stop() + await asyncio.sleep(0) + + # Stop while tasks may still be in flight — drain timeout is generous + await eng.stop(timeout=5.0) + + # All futures resolved (drained during stop or completed before stop) + results = await asyncio.gather(*create_tasks) + assert len(results) == 5 + stored = await persistence.tasks.list_tasks() + assert len(stored) == 5 + + +# ── Sequential ordering ────────────────────────────────────── + + +@pytest.mark.unit +class TestSequentialOrdering: + """Tests that mutations are processed sequentially.""" + + async def test_concurrent_submits( + self, + engine: TaskEngine, + ) -> None: + """Multiple concurrent creates all succeed without interleaving.""" + tasks = await asyncio.gather( + *( + engine.create_task( + _make_create_data(title=f"Task {i}"), + requested_by="alice", + ) + for i in range(10) + ), + ) + assert len(tasks) == 10 + ids = {t.id for t in tasks} + assert len(ids) == 10 # all unique + + +# ── Queue backpressure ──────────────────────────────────────── + + +@pytest.mark.unit +class TestQueueFull: + """Tests for queue full backpressure.""" + + async def test_queue_full_raises( + self, + persistence: FakePersistence, + ) -> None: + tiny_config = TaskEngineConfig(max_queue_size=1) + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=tiny_config, + ) + # Start the engine but pause the processing loop + eng._running = True + + # First submit fills the queue + mutation1 = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + eng._queue.put_nowait(_MutationEnvelope(mutation=mutation1)) + + # Second submit should fail because queue is full + mutation2 = CreateTaskMutation( + request_id="req-2", + requested_by="alice", + task_data=_make_create_data(), + ) + with pytest.raises(TaskEngineQueueFullError, match="queue is full"): + await eng.submit(mutation2) + + eng._running = False + + +# ── Version tracking ────────────────────────────────────────── + + +@pytest.mark.unit +class TestVersionTracking: + """Tests for the in-memory version counter.""" + + async def test_version_increments( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + r1 = await engine.submit(mutation) + assert r1.version == 1 + + update = UpdateTaskMutation( + request_id="req-2", + requested_by="alice", + task_id=r1.task.id, # type: ignore[union-attr] + updates={"title": "Updated"}, + ) + r2 = await engine.submit(update) + assert r2.version == 2 + + async def test_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # version is 1 after create; expected_version=99 should fail + update = UpdateTaskMutation( + request_id="req-2", + requested_by="alice", + task_id=task.id, + updates={"title": "X"}, + expected_version=99, + ) + result = await engine.submit(update) + assert result.success is False + assert "conflict" in (result.error or "").lower() + + async def test_version_reset_on_delete( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + delete = DeleteTaskMutation( + request_id="req-3", + requested_by="alice", + task_id=task.id, + ) + result = await engine.submit(delete) + assert result.version == 0 + + async def test_transition_version_conflict( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + expected_version=99, + ) + result = await engine.submit(mutation) + assert result.success is False + assert "conflict" in (result.error or "").lower() + + +# ── Drain timeout ───────────────────────────────────────────── + + +@pytest.mark.unit +class TestDrainTimeout: + """Verify drain-timeout cleanup resolves outstanding futures.""" + + async def test_drain_timeout_resolves_pending_futures( + self, + persistence: FakePersistence, + ) -> None: + """Futures still in queue are failed when stop() times out.""" + # Block the processing loop with a slow save + block = asyncio.Event() + original_save = persistence.tasks.save + + async def slow_save(task: object) -> None: + await block.wait() + await original_save(task) # type: ignore[arg-type] + + persistence.tasks.save = slow_save # type: ignore[method-assign] + + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + + # Submit a task — it'll block in slow_save, holding the processing loop + blocked_task = asyncio.create_task( + eng.create_task(_make_create_data(), requested_by="alice") + ) + + # Queue a second task directly so it's definitely waiting + mutation2 = CreateTaskMutation( + request_id="req-queued", + requested_by="alice", + task_data=_make_create_data(), + ) + envelope = _MutationEnvelope(mutation=mutation2) + # Give the engine a tick to start processing the first task + await asyncio.sleep(0.05) + eng._queue.put_nowait(envelope) + + # Stop with a very short timeout — loop is blocked, so timeout fires + await eng.stop(timeout=0.05) + + # The queued envelope (not yet processed) must be failed + assert envelope.future.done() + result = envelope.future.result() + assert result.success is False + assert result.error_code == "internal" + + # Release the block so slow_save can finish, then cancel the blocked task + # (its future was never set because the processing loop was cancelled) + block.set() + blocked_task.cancel() + with contextlib.suppress(Exception, asyncio.CancelledError): + await blocked_task diff --git a/tests/unit/engine/test_task_engine_lifecycle.py b/tests/unit/engine/test_task_engine_lifecycle.py new file mode 100644 index 0000000000..818b2fa2d7 --- /dev/null +++ b/tests/unit/engine/test_task_engine_lifecycle.py @@ -0,0 +1,109 @@ +"""Lifecycle and config tests for TaskEngine.""" + +import pytest + +from ai_company.engine.errors import TaskEngineNotRunningError +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_config import TaskEngineConfig +from ai_company.engine.task_engine_models import CreateTaskMutation +from tests.unit.engine.task_engine_helpers import FakePersistence, _make_create_data + +# ── Lifecycle tests ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestTaskEngineLifecycle: + """Tests for start/stop lifecycle.""" + + async def test_start_sets_running( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + assert eng.is_running is False + eng.start() + assert eng.is_running is True + await eng.stop(timeout=2.0) # type: ignore[unreachable] + assert eng.is_running is False + + async def test_double_start_raises( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + with pytest.raises(RuntimeError, match="already running"): + eng.start() + await eng.stop(timeout=2.0) + + async def test_stop_idempotent( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + await eng.stop(timeout=2.0) + await eng.stop(timeout=2.0) # no error + + async def test_restart( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + eng.start() + await eng.stop(timeout=2.0) + eng.start() + assert eng.is_running is True + await eng.stop(timeout=2.0) + + +# ── Submit to stopped engine ────────────────────────────────── + + +@pytest.mark.unit +class TestSubmitToStoppedEngine: + """Submitting to a stopped engine raises TaskEngineNotRunningError.""" + + async def test_submit_raises( + self, + persistence: FakePersistence, + ) -> None: + eng = TaskEngine(persistence=persistence) # type: ignore[arg-type] + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + with pytest.raises(TaskEngineNotRunningError): + await eng.submit(mutation) + + +# ── TaskEngineConfig ────────────────────────────────────────── + + +@pytest.mark.unit +class TestTaskEngineConfig: + """Tests for TaskEngineConfig model.""" + + def test_defaults(self) -> None: + cfg = TaskEngineConfig() + assert cfg.max_queue_size == 1000 + assert cfg.drain_timeout_seconds == 10.0 + assert cfg.publish_snapshots is True + + def test_custom_values(self) -> None: + cfg = TaskEngineConfig( + max_queue_size=500, + drain_timeout_seconds=5.0, + publish_snapshots=False, + ) + assert cfg.max_queue_size == 500 + assert cfg.drain_timeout_seconds == 5.0 + assert cfg.publish_snapshots is False + + def test_frozen(self) -> None: + from pydantic import ValidationError + + cfg = TaskEngineConfig() + with pytest.raises(ValidationError): + cfg.max_queue_size = 999 # type: ignore[misc] diff --git a/tests/unit/engine/test_task_engine_models.py b/tests/unit/engine/test_task_engine_models.py new file mode 100644 index 0000000000..ee50199fcc --- /dev/null +++ b/tests/unit/engine/test_task_engine_models.py @@ -0,0 +1,312 @@ +"""Tests for task engine request, response, and event models.""" + +import pytest +from pydantic import ValidationError + +from ai_company.core.enums import Complexity, Priority, TaskStatus, TaskType +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskData, + CreateTaskMutation, + DeleteTaskMutation, + TaskMutationResult, + TaskStateChanged, + TransitionTaskMutation, + UpdateTaskMutation, +) + + +@pytest.mark.unit +class TestCreateTaskData: + """Tests for CreateTaskData model.""" + + def test_minimal_construction(self) -> None: + data = CreateTaskData( + title="Fix bug", + description="Fix the login bug", + type=TaskType.DEVELOPMENT, + project="proj-1", + created_by="alice", + ) + assert data.title == "Fix bug" + assert data.priority == Priority.MEDIUM + assert data.estimated_complexity == Complexity.MEDIUM + assert data.budget_limit == 0.0 + assert data.assigned_to is None + + def test_full_construction(self) -> None: + data = CreateTaskData( + title="Implement feature", + description="Add new dashboard", + type=TaskType.DEVELOPMENT, + priority=Priority.HIGH, + project="proj-2", + created_by="bob", + assigned_to="charlie", + estimated_complexity=Complexity.COMPLEX, + budget_limit=5.0, + ) + assert data.assigned_to == "charlie" + assert data.budget_limit == 5.0 + + def test_blank_title_rejected(self) -> None: + with pytest.raises(ValueError, match="must not be whitespace"): + CreateTaskData( + title=" ", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + + def test_negative_budget_rejected(self) -> None: + with pytest.raises(ValueError, match="greater than or equal to 0"): + CreateTaskData( + title="Task", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + budget_limit=-1.0, + ) + + def test_frozen(self) -> None: + data = CreateTaskData( + title="Task", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + with pytest.raises(ValidationError): + data.title = "changed" # type: ignore[misc] + + +@pytest.mark.unit +class TestCreateTaskMutation: + """Tests for CreateTaskMutation model.""" + + def test_construction(self) -> None: + data = CreateTaskData( + title="Task", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=data, + ) + assert mutation.mutation_type == "create" + assert mutation.request_id == "req-1" + assert mutation.requested_by == "alice" + + def test_mutation_type_literal(self) -> None: + data = CreateTaskData( + title="Task", + description="desc", + type=TaskType.DEVELOPMENT, + project="proj", + created_by="alice", + ) + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=data, + ) + assert mutation.mutation_type == "create" + + +@pytest.mark.unit +class TestUpdateTaskMutation: + """Tests for UpdateTaskMutation model.""" + + def test_construction(self) -> None: + mutation = UpdateTaskMutation( + request_id="req-2", + requested_by="bob", + task_id="task-123", + updates={"title": "New title"}, + ) + assert mutation.mutation_type == "update" + assert mutation.task_id == "task-123" + assert mutation.updates == {"title": "New title"} + assert mutation.expected_version is None + + def test_with_expected_version(self) -> None: + mutation = UpdateTaskMutation( + request_id="req-2", + requested_by="bob", + task_id="task-123", + updates={}, + expected_version=3, + ) + assert mutation.expected_version == 3 + + def test_empty_updates(self) -> None: + mutation = UpdateTaskMutation( + request_id="req-2", + requested_by="bob", + task_id="task-123", + updates={}, + ) + assert mutation.updates == {} + + def test_expected_version_must_be_positive(self) -> None: + with pytest.raises(ValueError, match="greater than or equal to 1"): + UpdateTaskMutation( + request_id="req-2", + requested_by="bob", + task_id="task-123", + updates={}, + expected_version=0, + ) + + +@pytest.mark.unit +class TestTransitionTaskMutation: + """Tests for TransitionTaskMutation model.""" + + def test_construction(self) -> None: + mutation = TransitionTaskMutation( + request_id="req-3", + requested_by="charlie", + task_id="task-456", + target_status=TaskStatus.IN_PROGRESS, + reason="Starting work", + ) + assert mutation.mutation_type == "transition" + assert mutation.target_status == TaskStatus.IN_PROGRESS + assert mutation.reason == "Starting work" + assert mutation.overrides == {} + + def test_with_overrides(self) -> None: + mutation = TransitionTaskMutation( + request_id="req-3", + requested_by="charlie", + task_id="task-456", + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "dave"}, + ) + assert mutation.overrides == {"assigned_to": "dave"} + + +@pytest.mark.unit +class TestDeleteTaskMutation: + """Tests for DeleteTaskMutation model.""" + + def test_construction(self) -> None: + mutation = DeleteTaskMutation( + request_id="req-4", + requested_by="alice", + task_id="task-789", + ) + assert mutation.mutation_type == "delete" + assert mutation.task_id == "task-789" + + +@pytest.mark.unit +class TestCancelTaskMutation: + """Tests for CancelTaskMutation model.""" + + def test_construction(self) -> None: + mutation = CancelTaskMutation( + request_id="req-5", + requested_by="bob", + task_id="task-abc", + reason="No longer needed", + ) + assert mutation.mutation_type == "cancel" + assert mutation.reason == "No longer needed" + + +@pytest.mark.unit +class TestTaskMutationResult: + """Tests for TaskMutationResult model.""" + + def test_success_result(self) -> None: + result = TaskMutationResult( + request_id="req-1", + success=True, + version=1, + ) + assert result.success is True + assert result.task is None + assert result.error is None + + def test_failure_result(self) -> None: + result = TaskMutationResult( + request_id="req-1", + success=False, + error="Not found", + ) + assert result.success is False + assert result.error == "Not found" + assert result.version == 0 + + def test_frozen(self) -> None: + result = TaskMutationResult( + request_id="req-1", + success=True, + version=1, + ) + with pytest.raises(ValidationError): + result.success = False # type: ignore[misc] + + +@pytest.mark.unit +class TestTaskStateChanged: + """Tests for TaskStateChanged event model.""" + + def test_construction(self) -> None: + event = TaskStateChanged( + mutation_type="create", + request_id="req-1", + requested_by="alice", + new_status=TaskStatus.CREATED, + version=1, + ) + assert event.mutation_type == "create" + assert event.previous_status is None + assert event.new_status == TaskStatus.CREATED + assert event.timestamp is not None + + def test_transition_event(self) -> None: + event = TaskStateChanged( + mutation_type="transition", + request_id="req-2", + requested_by="bob", + previous_status=TaskStatus.CREATED, + new_status=TaskStatus.ASSIGNED, + version=2, + ) + assert event.previous_status == TaskStatus.CREATED + assert event.new_status == TaskStatus.ASSIGNED + + def test_delete_event(self) -> None: + event = TaskStateChanged( + mutation_type="delete", + request_id="req-3", + requested_by="charlie", + version=0, + ) + assert event.task is None + assert event.previous_status is None + assert event.new_status is None + + def test_serialization_roundtrip(self) -> None: + event = TaskStateChanged( + mutation_type="create", + request_id="req-1", + requested_by="alice", + new_status=TaskStatus.CREATED, + version=1, + ) + json_str = event.model_dump_json() + restored = TaskStateChanged.model_validate_json(json_str) + assert restored.mutation_type == event.mutation_type + assert restored.request_id == event.request_id + assert restored.version == event.version diff --git a/tests/unit/engine/test_task_engine_mutations.py b/tests/unit/engine/test_task_engine_mutations.py new file mode 100644 index 0000000000..4cbb3474fe --- /dev/null +++ b/tests/unit/engine/test_task_engine_mutations.py @@ -0,0 +1,583 @@ +"""CRUD mutation, typed error, and consistency tests for TaskEngine.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.core.task import Task # noqa: TC001 +from ai_company.engine.errors import ( + TaskMutationError, + TaskNotFoundError, + TaskVersionConflictError, +) +from ai_company.engine.task_engine import TaskEngine +from ai_company.engine.task_engine_models import ( + CancelTaskMutation, + CreateTaskMutation, + TransitionTaskMutation, + UpdateTaskMutation, +) +from tests.unit.engine.task_engine_helpers import ( + FakePersistence, + FakeTaskRepository, + _make_create_data, +) + +if TYPE_CHECKING: + from ai_company.engine.task_engine_config import TaskEngineConfig + +# ── Create mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestCreateTask: + """Tests for task creation via TaskEngine.""" + + async def test_create_task( + self, + engine: TaskEngine, + persistence: FakePersistence, + ) -> None: + task = await engine.create_task( + _make_create_data(title="My Task"), + requested_by="alice", + ) + assert task.title == "My Task" + assert task.id.startswith("task-") + assert task.status == TaskStatus.CREATED + + stored = await persistence.tasks.get(task.id) + assert stored is not None + assert stored.title == "My Task" + + async def test_create_returns_version_1( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.version == 1 + + async def test_create_with_assignee( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(assigned_to=None), + requested_by="alice", + ) + assert task.assigned_to is None + + +# ── Update mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestUpdateTask: + """Tests for task update via TaskEngine.""" + + async def test_update_fields( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(title="Original"), + requested_by="alice", + ) + updated = await engine.update_task( + task.id, + {"title": "Updated"}, + requested_by="alice", + ) + assert updated.title == "Updated" + assert updated.id == task.id + + async def test_update_empty_no_op( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + result = await engine.update_task( + task.id, + {}, + requested_by="alice", + ) + assert result.title == task.title + + async def test_update_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.update_task( + "task-nonexistent", + {"title": "X"}, + requested_by="alice", + ) + + +# ── Transition mutation ─────────────────────────────────────── + + +@pytest.mark.unit +class TestTransitionTask: + """Tests for task status transitions via TaskEngine.""" + + async def test_valid_transition( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + assigned, _ = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + assert assigned.status == TaskStatus.ASSIGNED + assert assigned.assigned_to == "bob" + + async def test_invalid_transition( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskMutationError): + await engine.transition_task( + task.id, + TaskStatus.COMPLETED, + requested_by="alice", + reason="Skip to done", + assigned_to="bob", + ) + + async def test_transition_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.transition_task( + "task-nonexistent", + TaskStatus.ASSIGNED, + requested_by="alice", + reason="test", + ) + + +# ── Delete mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestDeleteTask: + """Tests for task deletion via TaskEngine.""" + + async def test_delete_task( + self, + engine: TaskEngine, + persistence: FakePersistence, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + deleted = await engine.delete_task(task.id, requested_by="alice") + assert deleted is True + + stored = await persistence.tasks.get(task.id) + assert stored is None + + async def test_delete_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskMutationError, match="not found"): + await engine.delete_task( + "task-nonexistent", + requested_by="alice", + ) + + +# ── Cancel mutation ─────────────────────────────────────────── + + +@pytest.mark.unit +class TestCancelTask: + """Tests for task cancellation via TaskEngine.""" + + async def test_cancel_assigned_task( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + assigned, _ = await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + cancelled = await engine.cancel_task( + assigned.id, + requested_by="alice", + reason="No longer needed", + ) + assert cancelled.status == TaskStatus.CANCELLED + + async def test_cancel_from_created_fails( + self, + engine: TaskEngine, + ) -> None: + """CREATED -> CANCELLED is not a valid transition.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskMutationError): + await engine.cancel_task( + task.id, + requested_by="alice", + reason="Oops", + ) + + +# ── Read-through ────────────────────────────────────────────── + + +@pytest.mark.unit +class TestReadThrough: + """Tests for read-through methods that bypass the queue.""" + + async def test_get_task( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(title="Findme"), + requested_by="alice", + ) + found = await engine.get_task(task.id) + assert found is not None + assert found.title == "Findme" + + async def test_get_task_not_found( + self, + engine: TaskEngine, + ) -> None: + result = await engine.get_task("task-nonexistent") + assert result is None + + async def test_list_tasks( + self, + engine: TaskEngine, + ) -> None: + await engine.create_task( + _make_create_data(project="proj-a"), + requested_by="alice", + ) + await engine.create_task( + _make_create_data(project="proj-b"), + requested_by="alice", + ) + all_tasks = await engine.list_tasks() + assert len(all_tasks) == 2 + + filtered = await engine.list_tasks(project="proj-a") + assert len(filtered) == 1 + + async def test_list_tasks_by_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + + created = await engine.list_tasks(status=TaskStatus.CREATED) + assigned = await engine.list_tasks(status=TaskStatus.ASSIGNED) + assert len(created) == 0 + assert len(assigned) == 1 + + +# ── Cancel not found ───────────────────────────────────────── + + +@pytest.mark.unit +class TestCancelNotFound: + """Cancel mutation on a non-existent task.""" + + async def test_cancel_not_found( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError, match="not found"): + await engine.cancel_task( + "task-nonexistent", + requested_by="alice", + reason="test", + ) + + +# ── Previous status in results ──────────────────────────────── + + +@pytest.mark.unit +class TestPreviousStatus: + """Verify previous_status is populated in mutation results.""" + + async def test_create_has_no_previous_status( + self, + engine: TaskEngine, + ) -> None: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status is None + + async def test_transition_has_previous_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + mutation = TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + target_status=TaskStatus.ASSIGNED, + reason="Assigning", + overrides={"assigned_to": "bob"}, + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status == TaskStatus.CREATED + + async def test_cancel_has_previous_status( + self, + engine: TaskEngine, + ) -> None: + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + # First move to ASSIGNED so cancel is valid + await engine.transition_task( + task.id, + TaskStatus.ASSIGNED, + requested_by="alice", + reason="Assigning", + assigned_to="bob", + ) + mutation = CancelTaskMutation( + request_id="req-1", + requested_by="alice", + task_id=task.id, + reason="No longer needed", + ) + result = await engine.submit(mutation) + assert result.success is True + assert result.previous_status == TaskStatus.ASSIGNED + + +# ── Immutable field rejection ───────────────────────────────── + + +@pytest.mark.unit +class TestImmutableFieldRejection: + """UpdateTaskMutation and TransitionTaskMutation reject immutable fields.""" + + def test_update_rejects_status(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"status": "completed"}, + ) + + def test_update_rejects_id(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + UpdateTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + updates={"id": "new-id"}, + ) + + def test_transition_rejects_id_override(self) -> None: + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="immutable"): + TransitionTaskMutation( + request_id="req-1", + requested_by="alice", + task_id="task-1", + target_status=TaskStatus.ASSIGNED, + reason="test", + overrides={"id": "new-id"}, + ) + + +# ── Typed error propagation ────────────────────────────────── + + +@pytest.mark.unit +class TestTypedErrors: + """Convenience methods raise typed errors.""" + + async def test_update_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.update_task( + "task-nonexistent", + {"title": "X"}, + requested_by="alice", + ) + + async def test_delete_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.delete_task( + "task-nonexistent", + requested_by="alice", + ) + + async def test_transition_not_found_raises_typed( + self, + engine: TaskEngine, + ) -> None: + with pytest.raises(TaskNotFoundError): + await engine.transition_task( + "task-nonexistent", + TaskStatus.ASSIGNED, + requested_by="alice", + reason="test", + ) + + async def test_update_version_conflict_raises_typed( + self, + engine: TaskEngine, + ) -> None: + """Version conflict via convenience method raises TaskVersionConflictError.""" + task = await engine.create_task( + _make_create_data(), + requested_by="alice", + ) + with pytest.raises(TaskVersionConflictError, match="conflict"): + await engine.update_task( + task.id, + {"title": "changed"}, + requested_by="alice", + expected_version=99, + ) + + +# ── Error propagation ──────────────────────────────────────── + + +@pytest.mark.unit +class TestErrorPropagation: + """Tests for error propagation via futures.""" + + async def test_persistence_error_returns_failure( + self, + persistence: FakePersistence, + config: TaskEngineConfig, + ) -> None: + """Persistence errors during mutation are captured in the result.""" + + class FailingSaveRepo(FakeTaskRepository): + async def save(self, task: Task) -> None: + msg = "Disk full" + raise OSError(msg) + + persistence._tasks = FailingSaveRepo() + eng = TaskEngine( + persistence=persistence, # type: ignore[arg-type] + config=config, + ) + eng.start() + try: + mutation = CreateTaskMutation( + request_id="req-1", + requested_by="alice", + task_data=_make_create_data(), + ) + result = await eng.submit(mutation) + assert result.success is False + assert result.error == "Internal error processing mutation" + finally: + await eng.stop(timeout=2.0) + + +# ── TaskMutationResult consistency ──────────────────────────── + + +@pytest.mark.unit +class TestMutationResultConsistency: + """Verify _check_consistency validator on TaskMutationResult.""" + + def test_success_with_error_rejected(self) -> None: + """Successful result must not carry an error.""" + from pydantic import ValidationError + + from ai_company.engine.task_engine_models import TaskMutationResult + + with pytest.raises(ValidationError, match="error"): + TaskMutationResult( + request_id="r", + success=True, + error="oops", + ) + + def test_failure_without_error_rejected(self) -> None: + """Failed result must carry an error description.""" + from pydantic import ValidationError + + from ai_company.engine.task_engine_models import TaskMutationResult + + with pytest.raises(ValidationError, match="error"): + TaskMutationResult( + request_id="r", + success=False, + ) diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 9e51eb994d..9d4238b3dd 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -210,6 +210,7 @@ def test_all_domain_modules_discovered(self) -> None: "security", "task", "task_assignment", + "task_engine", "task_routing", "template", "timeout",