Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ src/ai_company/
communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol
config/ # YAML company config loading and validation
core/ # Shared domain models, base classes, and resilience config (RetryConfig, RateLimiterConfig)
engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation
engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination error classification, and prompt policy validation
hr/ # HR engine: hiring, firing, onboarding, offboarding, agent registry, performance tracking (task metrics, collaboration scoring, trend detection), promotion/demotion (criteria evaluation, approval strategies, model mapping)
memory/ # Persistent agent memory (Mem0 initial, custom stack future — see Decision Log), retrieval pipeline (ranking, injection, context formatting, non-inferable filtering), shared org memory (org/), consolidation/archival (consolidation/)
persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (see Memory & Persistence design page)
Expand Down Expand Up @@ -127,7 +127,7 @@ src/ai_company/
- **Every module** with business logic MUST have: `from ai_company.observability import get_logger` then `logger = get_logger(__name__)`
- **Never** use `import logging` / `logging.getLogger()` / `print()` in application code
- **Variable name**: always `logger` (not `_logger`, not `log`)
- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`). Import directly: `from ai_company.observability.events.<domain> import EVENT_CONSTANT`
- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`, `TASK_ENGINE_STARTED` from `events.task_engine`). Import directly: `from ai_company.observability.events.<domain> import EVENT_CONSTANT`
- **Structured kwargs**: always `logger.info(EVENT, key=value)` — never `logger.info("msg %s", val)`
- **All error paths** must log at WARNING or ERROR with context before raising
- **All state transitions** must log at INFO
Expand Down
2 changes: 1 addition & 1 deletion docs/architecture/tech-stack.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ These conventions are used throughout the codebase. For full details on each, se
| **Cost tiers and quota tracking** | Adopted | Configurable `CostTierDefinition` with merge/override semantics. `QuotaTracker` enforces per-provider request/token quotas with window-based rotation. |
| **Shared org memory** | Adopted | `OrgMemoryBackend` protocol with `HybridPromptRetrievalBackend`. Seniority-based write access control. Core policies in system prompts; extended facts retrieved on demand. |
| **Memory consolidation** | Adopted | `ConsolidationStrategy` protocol with deduplication + summarization. `RetentionEnforcer` for age-based cleanup. `ArchivalStore` for cold storage. |
| **State coordination** | Planned | Centralized single-writer `TaskEngine` with `asyncio.Queue`. Agents submit requests; engine applies `model_copy(update=...)` sequentially and publishes snapshots. |
| **State coordination** | Adopted | Centralized single-writer `TaskEngine` with `asyncio.Queue`. Agents submit requests; engine applies `model_copy(update=...)` sequentially and publishes snapshots. |
| **Workspace isolation** | Adopted | Pluggable `WorkspaceIsolationStrategy` protocol. Default: git worktrees with sequential merge on completion. |
| **Graceful shutdown** | Adopted | Pluggable `ShutdownStrategy` protocol with cooperative 30-second timeout. Force-cancel after timeout with `INTERRUPTED` status. |
| **Template inheritance** | Adopted | `extends` field triggers parent resolution at render time with deep merge by field type. Circular chain detection included. |
Expand Down
59 changes: 58 additions & 1 deletion docs/design/engine.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,63 @@ exceptions on failure; scoring-based strategies return

---

## TaskEngine — Centralized State Coordination

All task state mutations flow through a single-writer `TaskEngine` that owns the
authoritative task state. This eliminates race conditions when multiple agents
attempt concurrent transitions on the same task.

### Architecture

```text
Agent / API ──submit()──▶ asyncio.Queue ──▶ _processing_loop ──▶ Persistence
├──▶ Version tracking (optimistic concurrency)
└──▶ Snapshot publishing (MessageBus)
```

- **Single writer**: A background `asyncio.Task` consumes `TaskMutation`
requests sequentially from an `asyncio.Queue`.
- **Immutable updates**: Each mutation calls `model_copy(update=...)` on
frozen `Task` models — the original is never mutated.
- **Optimistic concurrency**: In-memory version counters per task.
Callers can pass `expected_version` to detect stale writes; on mismatch
the engine returns a failed `TaskMutationResult` with
`error_code="version_conflict"`. Convenience methods raise
`TaskVersionConflictError`.
- **Read-through**: `get_task()` and `list_tasks()` bypass the queue and
read directly from persistence — safe because TaskEngine is the sole writer.
- **Snapshot publishing**: On success, a `TaskStateChanged` event is published
to the message bus for downstream consumers (WebSocket bridge, audit, etc.).

### Mutation Types

| Mutation | Description |
|----------|-------------|
| `CreateTaskMutation` | Generates a unique ID, persists, and returns the new task. |
| `UpdateTaskMutation` | Applies field updates with immutable-field rejection (`id`, `status`, `created_by`) and re-validates via `model_validate`. |
| `TransitionTaskMutation` | Validates status transition via `Task.with_transition()`, supports field overrides. |
| `DeleteTaskMutation` | Removes from persistence and clears version tracking. |
| `CancelTaskMutation` | Shortcut for transition to `CANCELLED`. |

### Error Handling

- **Typed errors**: `TaskNotFoundError` and `TaskVersionConflictError` provide
precise failure classification — API controllers catch these directly instead
of parsing error strings.
- **Error sanitization**: Internal exception details (SQL paths, stack traces)
are replaced with a generic message before reaching callers.
- **Queue full**: `TaskEngineQueueFullError` signals backpressure when the
queue is at capacity.

### Lifecycle

- **start()**: Spawns the background processing task.
- **stop()**: Sets `_running = False`, drains the queue within a configurable
timeout, then cancels. Abandoned futures receive a failure result.

---

## Agent Execution Loop

The agent execution loop defines how an agent processes a task from start to
Expand Down Expand Up @@ -346,7 +403,7 @@ async run(
alone when no enforcer is configured.
8. **Delegate to loop** -- calls `ExecutionLoop.execute()` with context,
provider, tool invoker, budget checker, and completion config. If
`timeout_seconds` is set, wraps the call in `asyncio.wait_for`; on expiry
`timeout_seconds` is set, wraps the call in `asyncio.wait`; on expiry
the run returns with `TerminationReason.ERROR` but cost recording and
post-execution processing still occur.
9. **Record costs** -- records accumulated `TokenUsage` to `CostTracker` (if
Expand Down
75 changes: 68 additions & 7 deletions src/ai_company/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ai_company.communication.bus_protocol import MessageBus # noqa: TC001
from ai_company.config.schema import RootConfig
from ai_company.core.approval import ApprovalItem # noqa: TC001
from ai_company.engine.task_engine import TaskEngine # noqa: TC001
from ai_company.observability import get_logger
from ai_company.observability.events.api import (
API_APP_SHUTDOWN,
Expand Down Expand Up @@ -87,7 +88,9 @@ def _on_expire(item: ApprovalItem) -> None:
event.model_dump_json(),
channels=[CHANNEL_APPROVALS],
)
except RuntimeError, OSError:
except MemoryError, RecursionError:
raise
except Exception:
logger.warning(
API_APPROVAL_PUBLISH_FAILED,
approval_id=item.id,
Expand All @@ -102,6 +105,7 @@ def _build_lifecycle(
persistence: PersistenceBackend | None,
message_bus: MessageBus | None,
bridge: MessageBusBridge | None,
task_engine: TaskEngine | None,
app_state: AppState,
) -> tuple[
Sequence[Callable[[], Awaitable[None]]],
Expand All @@ -115,23 +119,49 @@ def _build_lifecycle(

async def on_startup() -> None:
logger.info(API_APP_STARTUP, version=__version__)
await _safe_startup(persistence, message_bus, bridge, app_state)
await _safe_startup(
persistence,
message_bus,
bridge,
task_engine,
app_state,
)

async def on_shutdown() -> None:
logger.info(API_APP_SHUTDOWN, version=__version__)
await _safe_shutdown(bridge, message_bus, persistence)
await _safe_shutdown(bridge, task_engine, message_bus, persistence)

return [on_startup], [on_shutdown]


async def _cleanup_on_failure(
async def _cleanup_on_failure( # noqa: PLR0913
*,
persistence: PersistenceBackend | None,
started_persistence: bool,
message_bus: MessageBus | None,
started_bus: bool,
bridge: MessageBusBridge | None = None,
started_bridge: bool = False,
task_engine: TaskEngine | None = None,
started_task_engine: bool = False,
) -> None:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""Reverse cleanup of persistence and message bus on startup failure."""
"""Reverse cleanup on startup failure (task engine, bridge, bus, persistence)."""
if started_task_engine and task_engine is not None:
try:
await task_engine.stop()
except Exception:
logger.exception(
API_APP_STARTUP,
error="Cleanup: failed to stop task engine",
)
if started_bridge and bridge is not None:
try:
await bridge.stop()
except Exception:
logger.exception(
API_APP_STARTUP,
error="Cleanup: failed to stop message bus bridge",
)
if started_bus and message_bus is not None:
try:
await message_bus.stop()
Expand Down Expand Up @@ -196,15 +226,18 @@ async def _safe_startup(
persistence: PersistenceBackend | None,
message_bus: MessageBus | None,
bridge: MessageBusBridge | None,
task_engine: TaskEngine | None,
app_state: AppState,
) -> None:
"""Connect persistence, resolve JWT secret, start message bus and bridge.
"""Connect persistence, resolve JWT secret, start bus, bridge, task engine.

Executes in order; on failure, cleans up already-started
components in reverse order before re-raising.
"""
started_bus = False
started_bridge = False
started_persistence = False
started_task_engine = False
try:
if persistence is not None:
try:
Expand Down Expand Up @@ -239,22 +272,38 @@ async def _safe_startup(
error="Failed to start message bus bridge",
)
raise
started_bridge = True
if task_engine is not None:
try:
task_engine.start()
except Exception:
logger.exception(
API_APP_STARTUP,
error="Failed to start task engine",
)
raise
started_task_engine = True
except Exception:
await _cleanup_on_failure(
persistence=persistence,
started_persistence=started_persistence,
message_bus=message_bus,
started_bus=started_bus,
bridge=bridge,
started_bridge=started_bridge,
task_engine=task_engine,
started_task_engine=started_task_engine,
)
raise


async def _safe_shutdown(
bridge: MessageBusBridge | None,
task_engine: TaskEngine | None,
message_bus: MessageBus | None,
persistence: PersistenceBackend | None,
) -> None:
"""Stop bridge, message bus and disconnect persistence."""
"""Stop bridge, task engine, message bus and disconnect persistence."""
if bridge is not None:
try:
await bridge.stop()
Expand All @@ -263,6 +312,14 @@ async def _safe_shutdown(
API_APP_SHUTDOWN,
error="Failed to stop message bus bridge",
)
if task_engine is not None:
try:
await task_engine.stop()
except Exception:
logger.exception(
API_APP_SHUTDOWN,
error="Failed to stop task engine",
)
if message_bus is not None:
try:
await message_bus.stop()
Expand All @@ -289,6 +346,7 @@ def create_app( # noqa: PLR0913
cost_tracker: CostTracker | None = None,
approval_store: ApprovalStore | None = None,
auth_service: AuthService | None = None,
task_engine: TaskEngine | None = None,
) -> Litestar:
"""Create and configure the Litestar application.

Expand All @@ -302,6 +360,7 @@ def create_app( # noqa: PLR0913
cost_tracker: Cost tracking service.
approval_store: Approval queue store.
auth_service: Pre-built auth service (for testing).
task_engine: Centralized task state engine.

Returns:
Configured Litestar application.
Expand Down Expand Up @@ -330,6 +389,7 @@ def create_app( # noqa: PLR0913
cost_tracker=cost_tracker,
approval_store=effective_approval_store,
auth_service=auth_service,
task_engine=task_engine,
startup_time=time.monotonic(),
)

Expand All @@ -347,6 +407,7 @@ def create_app( # noqa: PLR0913
persistence,
message_bus,
bridge,
task_engine,
app_state,
)

Expand Down
Loading
Loading