diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index 277fe3baf6..249c0a8634 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -25,5 +25,6 @@ jobs: allow-licenses: >- MIT, Apache-2.0, BSD-2-Clause, BSD-3-Clause, ISC, MPL-2.0, PSF-2.0, Unlicense, 0BSD, - CC0-1.0, Python-2.0 + CC0-1.0, Python-2.0, + LicenseRef-scancode-free-unknown # aiosqlite 0.21.0 — MIT per classifiers, scancode misdetects comment-summary-in-pr: always diff --git a/CLAUDE.md b/CLAUDE.md index cd321ebf3c..438a4a6b9c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -51,6 +51,7 @@ src/ai_company/ core/ # Shared domain models and base classes engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, and workspace isolation memory/ # Persistent agent memory (Mem0 initial, custom stack future — ADR-001) + persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (§7.5) observability/ # Structured logging, correlation tracking, log sinks providers/ # LLM provider abstraction (LiteLLM adapter) security/ # SecOps agent, approval gates, audit diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 146c986579..b72e443f4f 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -12,7 +12,7 @@ 4. [Company Structure](#4-company-structure) 5. [Communication Architecture](#5-communication-architecture) — 5.6 Conflict Resolution, 5.7 Meeting Protocol 6. [Task & Workflow Engine](#6-task--workflow-engine) — 6.5 Execution Loop, 6.6 Crash Recovery, **6.7 Graceful Shutdown**, **6.8 Workspace Isolation**, **6.9 Task Decomposability & Coordination Topology** -7. [Memory & Persistence](#7-memory--persistence) — 7.4 Shared Org Memory (Research Directions) +7. [Memory & Persistence](#7-memory--persistence) — 7.4 Shared Org Memory (Research Directions), **7.5 Operational Data Persistence** 8. [HR & Workforce Management](#8-hr--workforce-management) 9. [Model Provider Layer](#9-model-provider-layer) 10. [Cost & Budget Management](#10-cost--budget-management) @@ -81,7 +81,7 @@ The MVP validates the core hypothesis: **a single agent can complete a real task > **Implementation snapshot (2026-03-08):** > - **Done:** M0–M4 (tooling, config/core, providers, single-agent engine, multi-agent orchestration). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). -> - **In progress:** M5 — memory layer implementation, persistence, budget enforcement. +> - **In progress:** M5 — memory layer implementation, budget enforcement. Persistence backend (§7.5) completed. > - **Not started (mostly placeholders):** M6 API/CLI surface, M7 security + approval system. ### 1.5 Configuration Philosophy @@ -1332,6 +1332,143 @@ org_memory: > **Extensibility:** All backends implement the `OrgMemoryBackend` protocol (`query(context) → list[OrgFact]`, `write(fact, author)`, `list_policies()`). The MVP ships with Backend 1; Backends 2 and 3 are research directions that may be explored if the default approach proves insufficient. The selected memory layer backend Mem0 (ADR-001) provides optional graph memory via Neo4j/FalkorDB, which could reduce implementation effort for Backends 2-3. > **Write access control:** Core policies are human-only. ADRs and procedures can be written by senior+ agents. All writes are versioned and auditable. This prevents agents from corrupting shared organizational knowledge while allowing senior agents to document decisions. +### 7.5 Operational Data Persistence + +Agent memory (§7.1–7.4) is handled by the `MemoryBackend` protocol (Mem0 initial, custom stack future — ADR-001). **Operational data** — tasks, cost records, messages, audit logs — is a separate concern managed by a pluggable `PersistenceBackend` protocol. Application code depends only on repository protocols; the storage engine is an implementation detail swappable via config. + +```text +┌──────────────────────────────────────────────────────────────────┐ +│ Application Code │ +│ engine/ budget/ communication/ security/ │ +│ │ │ │ │ │ +│ ▼ ▼ ▼ ▼ │ +│ ┌──────┐ ┌──────┐ ┌──────────┐ ┌──────────┐ │ +│ │ Task │ │ Cost │ │ Message │ │ Audit │ ← Repository │ +│ │ Repo │ │ Repo │ │ Repo │ │ Repo │ Protocols │ +│ └──┬───┘ └──┬───┘ └────┬─────┘ └────┬─────┘ │ +│ └────────┴──────────┴────────────┘ │ +│ │ │ +│ ┌───────────────────┴───────────────────────────────────────┐ │ +│ │ PersistenceBackend (protocol) │ │ +│ │ connect() · disconnect() · health_check() · migrate() │ │ +│ └───────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ┌───────────────────┴───────────────────────────────────────┐ │ +│ │ SQLitePersistenceBackend (initial) │ │ +│ │ PostgresPersistenceBackend (future) │ │ +│ │ MariaDBPersistenceBackend (future) │ │ +│ └───────────────────────────────────────────────────────────┘ │ +└──────────────────────────────────────────────────────────────────┘ +``` + +#### Protocol Design + +```python +@runtime_checkable +class PersistenceBackend(Protocol): + """Lifecycle management for operational data storage.""" + + async def connect(self) -> None: ... + async def disconnect(self) -> None: ... + async def health_check(self) -> bool: ... + async def migrate(self) -> None: ... + + @property + def is_connected(self) -> bool: ... + @property + def backend_name(self) -> str: ... + + @property + def tasks(self) -> TaskRepository: ... + @property + def cost_records(self) -> CostRecordRepository: ... + @property + def messages(self) -> MessageRepository: ... +``` + +Each entity type has its own repository protocol: + +```python +@runtime_checkable +class TaskRepository(Protocol): + """CRUD + query interface for Task persistence.""" + + async def save(self, task: Task) -> None: ... + async def get(self, task_id: str) -> Task | None: ... + async def list_tasks(self, *, status: TaskStatus | None = None, assigned_to: str | None = None, project: str | None = None) -> tuple[Task, ...]: ... + async def delete(self, task_id: str) -> bool: ... + +@runtime_checkable +class CostRecordRepository(Protocol): + """CRUD + aggregation interface for CostRecord persistence.""" + + async def save(self, record: CostRecord) -> None: ... + async def query(self, *, agent_id: str | None = None, task_id: str | None = None) -> tuple[CostRecord, ...]: ... + async def aggregate(self, *, agent_id: str | None = None) -> float: ... + +@runtime_checkable +class MessageRepository(Protocol): + """CRUD + query interface for Message persistence.""" + + async def save(self, message: Message) -> None: ... + async def get_history(self, channel: str, *, limit: int | None = None) -> tuple[Message, ...]: ... +``` + +#### Configuration + +```yaml +persistence: + backend: "sqlite" # sqlite, postgresql, mariadb (future) + sqlite: + path: "/data/ai-company.db" # database file path (mounted volume in Docker) + wal_mode: true # WAL for concurrent read performance + journal_size_limit: 67108864 # 64 MB WAL journal limit + # postgresql: # future + # url: "postgresql://user:pass@host:5432/ai_company" + # pool_size: 10 + # mariadb: # future + # url: "mariadb://user:pass@host:3306/ai_company" + # pool_size: 10 +``` + +#### Entities Persisted + +| Entity | Source Module | Repository | Key Queries | +|--------|-------------|------------|-------------| +| `Task` | `core/task.py` | `TaskRepository` | by status, by assignee, by project | +| `CostRecord` | `budget/cost_record.py` | `CostRecordRepository` | by agent, by task, aggregations | +| `Message` | `communication/message.py` | `MessageRepository` | by channel | +| Audit entries (planned — M7) | `security/` | `AuditRepository` (planned) | by agent, by action type, time range | + +#### Migration Strategy + +- Migrations run programmatically at startup via `PersistenceBackend.migrate()` +- Initial migration creates all tables +- Versioned migrations implemented per-backend (e.g. `persistence/sqlite/migrations.py` for SQLite) +- SQLite uses `user_version` pragma for version tracking; PostgreSQL/MariaDB use a migrations table + +#### Key Principles + +- **App code never imports a concrete backend** — only repository protocols +- **Adding a new backend** requires implementing `PersistenceBackend` + all repository protocols — no changes to consumers +- **Same entity models everywhere** — repositories accept and return the existing frozen Pydantic models (Task, CostRecord, Message), no ORM models or data transfer objects +- **Async throughout** — all repository methods are async, matching the project's concurrency model + +#### Multi-Tenancy + +Each company gets its own database. The `PersistenceConfig` embedded in a company's `RootConfig` specifies the backend type and connection details (e.g. a unique SQLite file path or PostgreSQL database URL). The `create_backend(config)` factory returns an isolated `PersistenceBackend` instance per company — no shared state, no cross-company data leakage. + +```python +# One database per company — configured in each company's YAML +company_a_backend = create_backend(company_a_config.persistence) +company_b_backend = create_backend(company_b_config.persistence) +# Each backend has independent lifecycle: connect → migrate → use → disconnect +``` + +#### Future: Runtime Backend Switching + +Runtime backend switching (e.g. migrating a company from SQLite to PostgreSQL during operation) is a planned future capability. The protocol-based design already supports this — the engine would disconnect the current backend, connect a new one with different config, and migrate. Implementation details (data migration tooling, zero-downtime switchover, connection draining) are deferred to the PostgreSQL backend milestone. + --- ## 8. HR & Workforce Management @@ -2339,7 +2476,7 @@ Run: ai-company start acme-corp | **Agent Memory** | Mem0 (Qdrant + SQLite) → custom (Neo4j + Qdrant) | Mem0 in-process as initial backend behind pluggable `MemoryBackend` protocol ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Qdrant embedded + SQLite for persistence. Custom stack (Neo4j + Qdrant external) as future upgrade. Config-driven backend selection | | **Message Bus** | Internal (async queues) → Redis | Start with Python asyncio queues, upgrade to Redis for multi-process/distributed | | **Task Queue** | Internal → Celery/Redis | Start simple, scale with Celery when needed | -| **Database** | SQLite → PostgreSQL | Start lightweight, migrate to Postgres for production/multi-user | +| **Database** | SQLite (aiosqlite) → PostgreSQL / MariaDB | Pluggable `PersistenceBackend` protocol (§7.5). SQLite ships first via aiosqlite async driver. PostgreSQL, MariaDB as future backends — swap via config, no app code changes | | **Web UI** | Vue 3 + Vite | Modern, fast, good ecosystem. Simpler than React for dashboards | | **Real-time** | WebSocket (FastAPI native) | Real-time agent activity, task updates, chat feed | | **Containerization** | Docker + Docker Compose | Isolated code execution, reproducible environments | @@ -2492,6 +2629,18 @@ ai-company/ │ │ ├── retrieval.py # Memory retrieval & ranking (M5) │ │ ├── consolidation.py # Memory compression over time (M5) │ │ └── shared.py # Shared knowledge base (M5) +│ ├── persistence/ # Operational data persistence (§7.5) +│ │ ├── __init__.py # Package exports +│ │ ├── protocol.py # PersistenceBackend protocol (M5) +│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository (M5); AuditRepository planned (M7) +│ │ ├── config.py # PersistenceConfig model (M5) +│ │ ├── errors.py # Persistence error hierarchy (M5) +│ │ ├── factory.py # create_backend() factory (M5) +│ │ └── sqlite/ # SQLite backend (M5, initial) +│ │ ├── __init__.py # Package exports +│ │ ├── backend.py # SQLitePersistenceBackend +│ │ ├── repositories.py # SQLite repository implementations +│ │ └── migrations.py # Schema migrations (user_version pragma) │ ├── observability/ # Structured logging & correlation │ │ ├── __init__.py # get_logger() entry point │ │ ├── _logger.py # Logger configuration @@ -2512,6 +2661,7 @@ ai-company/ │ │ │ ├── git.py # GIT_* constants │ │ │ ├── meeting.py # MEETING_* constants │ │ │ ├── parallel.py # PARALLEL_* constants +│ │ │ ├── persistence.py # PERSISTENCE_* constants │ │ │ ├── personality.py # PERSONALITY_* constants │ │ │ ├── prompt.py # PROMPT_* constants │ │ │ ├── provider.py # PROVIDER_* constants @@ -2650,6 +2800,7 @@ ai-company/ | Config | YAML + Pydantic | JSON, TOML, Python dicts | Human-friendly, strict validation, good IDE support | | CLI | Typer | Click, argparse, Fire | Built on Click, auto-completion, type hints | | Web UI | Vue 3 | React, Svelte, HTMX | Simpler than React for dashboards, good with FastAPI | +| Persistence | Pluggable protocol + repository protocols | ORM (SQLAlchemy), raw SQL, hybrid | Same frozen Pydantic models in and out (no DTOs), async throughout, backend-swappable via config. Repository protocols decouple app code from storage engine. See §7.5 | | Sandboxing | Layered: subprocess + Docker | Docker-only, subprocess-only, WASM | Risk-proportionate: fast subprocess for file/git, Docker isolation for code execution. Pluggable `SandboxBackend` protocol enables K8s migration later | ### 15.5 Engineering Conventions diff --git a/README.md b/README.md index 7fe72ce0f2..3570a76166 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Multi-Agent Core (M4)** - Message bus, delegation with loop prevention, conflict resolution, meeting protocols - **Task Intelligence (M4)** - Task decomposition, routing, assignment strategies, workspace isolation via git worktrees - **Templates** - Built-in templates, inheritance/merge, rendering, personality presets +- **Persistence Layer (M5)** - Pluggable `PersistenceBackend` protocol with SQLite backend (aiosqlite), repository protocols, schema migrations ### Not implemented yet (planned milestones) @@ -43,7 +44,7 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Mem0** for agent memory (initial backend; custom stack future — see [ADR-001](docs/decisions/ADR-001-memory-layer.md)) - **MCP** for tool integration (planned) - **Vue 3** for web dashboard (planned) -- **SQLite** → PostgreSQL for data persistence (planned) +- **SQLite** (aiosqlite) → PostgreSQL for operational data persistence ## System Requirements diff --git a/pyproject.toml b/pyproject.toml index ff9c2e9721..c1c4507bc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "aiosqlite==0.21.0", "jinja2==3.1.6", "jsonschema==4.26.0", "litellm==1.82.0", @@ -155,6 +156,10 @@ ignore_missing_imports = true module = "jsonschema.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "aiosqlite.*" +ignore_missing_imports = true + [tool.pydantic-mypy] init_forbid_extra = true init_typed = true diff --git a/src/ai_company/config/defaults.py b/src/ai_company/config/defaults.py index 0cced78a43..04973883b1 100644 --- a/src/ai_company/config/defaults.py +++ b/src/ai_company/config/defaults.py @@ -30,4 +30,5 @@ def default_config_dict() -> dict[str, Any]: "escalation_paths": [], "coordination_metrics": {}, "task_assignment": {}, + "persistence": {}, } diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index 2c90dea800..9288a60df4 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -20,6 +20,7 @@ from ai_company.observability import get_logger from ai_company.observability.config import LogConfig # noqa: TC001 from ai_company.observability.events.config import CONFIG_VALIDATION_FAILED +from ai_company.persistence.config import PersistenceConfig logger = get_logger(__name__) @@ -458,6 +459,7 @@ class RootConfig(BaseModel): escalation_paths: Cross-department escalation paths. coordination_metrics: Coordination metrics configuration. task_assignment: Task assignment configuration. + persistence: Persistence backend configuration. """ model_config = ConfigDict(frozen=True) @@ -525,6 +527,10 @@ class RootConfig(BaseModel): default_factory=TaskAssignmentConfig, description="Task assignment configuration", ) + persistence: PersistenceConfig = Field( + default_factory=PersistenceConfig, + description="Persistence backend configuration", + ) @model_validator(mode="after") def _validate_unique_agent_names(self) -> Self: diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py new file mode 100644 index 0000000000..598e57ec3e --- /dev/null +++ b/src/ai_company/observability/events/persistence.py @@ -0,0 +1,63 @@ +"""Persistence event constants for structured logging. + +Constants follow the ``persistence..`` naming convention +and are passed as the first argument to ``logger.info()``/``logger.debug()`` +calls in the persistence layer. +""" + +from typing import Final + +PERSISTENCE_BACKEND_CONNECTING: Final[str] = "persistence.backend.connecting" +PERSISTENCE_BACKEND_CONNECTED: Final[str] = "persistence.backend.connected" +PERSISTENCE_BACKEND_CONNECTION_FAILED: Final[str] = ( + "persistence.backend.connection_failed" +) +PERSISTENCE_BACKEND_ALREADY_CONNECTED: Final[str] = ( + "persistence.backend.already_connected" +) +PERSISTENCE_BACKEND_DISCONNECTING: Final[str] = "persistence.backend.disconnecting" +PERSISTENCE_BACKEND_DISCONNECTED: Final[str] = "persistence.backend.disconnected" +PERSISTENCE_BACKEND_DISCONNECT_ERROR: Final[str] = ( + "persistence.backend.disconnect_error" +) +PERSISTENCE_BACKEND_HEALTH_CHECK: Final[str] = "persistence.backend.health_check" +PERSISTENCE_BACKEND_CREATED: Final[str] = "persistence.backend.created" +PERSISTENCE_BACKEND_UNKNOWN: Final[str] = "persistence.backend.unknown" +PERSISTENCE_BACKEND_WAL_MODE_FAILED: Final[str] = "persistence.backend.wal_mode_failed" +PERSISTENCE_BACKEND_NOT_CONNECTED: Final[str] = "persistence.backend.not_connected" + +PERSISTENCE_MIGRATION_STARTED: Final[str] = "persistence.migration.started" +PERSISTENCE_MIGRATION_COMPLETED: Final[str] = "persistence.migration.completed" +PERSISTENCE_MIGRATION_SKIPPED: Final[str] = "persistence.migration.skipped" +PERSISTENCE_MIGRATION_FAILED: Final[str] = "persistence.migration.failed" + +PERSISTENCE_TASK_SAVED: Final[str] = "persistence.task.saved" +PERSISTENCE_TASK_SAVE_FAILED: Final[str] = "persistence.task.save_failed" +PERSISTENCE_TASK_FETCHED: Final[str] = "persistence.task.fetched" +PERSISTENCE_TASK_FETCH_FAILED: Final[str] = "persistence.task.fetch_failed" +PERSISTENCE_TASK_LISTED: Final[str] = "persistence.task.listed" +PERSISTENCE_TASK_LIST_FAILED: Final[str] = "persistence.task.list_failed" +PERSISTENCE_TASK_DELETED: Final[str] = "persistence.task.deleted" +PERSISTENCE_TASK_DELETE_FAILED: Final[str] = "persistence.task.delete_failed" + +PERSISTENCE_COST_RECORD_SAVED: Final[str] = "persistence.cost_record.saved" +PERSISTENCE_COST_RECORD_SAVE_FAILED: Final[str] = "persistence.cost_record.save_failed" +PERSISTENCE_COST_RECORD_QUERIED: Final[str] = "persistence.cost_record.queried" +PERSISTENCE_COST_RECORD_QUERY_FAILED: Final[str] = ( + "persistence.cost_record.query_failed" +) +PERSISTENCE_COST_RECORD_AGGREGATED: Final[str] = "persistence.cost_record.aggregated" +PERSISTENCE_COST_RECORD_AGGREGATE_FAILED: Final[str] = ( + "persistence.cost_record.aggregate_failed" +) + +PERSISTENCE_TASK_DESERIALIZE_FAILED: Final[str] = "persistence.task.deserialize_failed" + +PERSISTENCE_MESSAGE_SAVED: Final[str] = "persistence.message.saved" +PERSISTENCE_MESSAGE_SAVE_FAILED: Final[str] = "persistence.message.save_failed" +PERSISTENCE_MESSAGE_DUPLICATE: Final[str] = "persistence.message.duplicate" +PERSISTENCE_MESSAGE_HISTORY_FETCHED: Final[str] = "persistence.message.history_fetched" +PERSISTENCE_MESSAGE_HISTORY_FAILED: Final[str] = "persistence.message.history_failed" +PERSISTENCE_MESSAGE_DESERIALIZE_FAILED: Final[str] = ( + "persistence.message.deserialize_failed" +) diff --git a/src/ai_company/persistence/__init__.py b/src/ai_company/persistence/__init__.py new file mode 100644 index 0000000000..60a844bdf2 --- /dev/null +++ b/src/ai_company/persistence/__init__.py @@ -0,0 +1,39 @@ +"""Pluggable persistence layer for operational data (DESIGN_SPEC §7.5). + +Re-exports the protocol, repository protocols, config models, factory, +and error hierarchy so consumers can import from ``ai_company.persistence`` +directly. +""" + +from ai_company.persistence.config import PersistenceConfig, SQLiteConfig +from ai_company.persistence.errors import ( + DuplicateRecordError, + MigrationError, + PersistenceConnectionError, + PersistenceError, + QueryError, + RecordNotFoundError, +) +from ai_company.persistence.factory import create_backend +from ai_company.persistence.protocol import PersistenceBackend +from ai_company.persistence.repositories import ( + CostRecordRepository, + MessageRepository, + TaskRepository, +) + +__all__ = [ + "CostRecordRepository", + "DuplicateRecordError", + "MessageRepository", + "MigrationError", + "PersistenceBackend", + "PersistenceConfig", + "PersistenceConnectionError", + "PersistenceError", + "QueryError", + "RecordNotFoundError", + "SQLiteConfig", + "TaskRepository", + "create_backend", +] diff --git a/src/ai_company/persistence/config.py b/src/ai_company/persistence/config.py new file mode 100644 index 0000000000..61fc6544ae --- /dev/null +++ b/src/ai_company/persistence/config.py @@ -0,0 +1,109 @@ +"""Persistence configuration models. + +Frozen Pydantic models for persistence backend selection and +backend-specific settings. +""" + +from pathlib import PurePosixPath, PureWindowsPath +from typing import ClassVar, Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.config import CONFIG_VALIDATION_FAILED + +logger = get_logger(__name__) + + +class SQLiteConfig(BaseModel): + """SQLite-specific persistence configuration. + + Attributes: + path: Database file path. Use ``":memory:"`` for in-memory + databases (useful for testing). + wal_mode: Whether to enable WAL journal mode for concurrent + read performance. + journal_size_limit: Maximum WAL journal size in bytes + (default 64 MB). + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + path: NotBlankStr = Field( + default="ai-company.db", + description="Database file path", + ) + wal_mode: bool = Field( + default=True, + description="Enable WAL journal mode", + ) + journal_size_limit: int = Field( + default=67_108_864, + ge=0, + description="Maximum WAL journal size in bytes", + ) + + @model_validator(mode="after") + def _reject_traversal(self) -> Self: + """Reject parent-directory traversal to prevent path escapes. + + The special ``:memory:`` identifier is passed through unchanged. + Paths containing ``..`` components are rejected to prevent + path-traversal attacks in multi-tenant configs. Absolute paths + are allowed for operational flexibility. + """ + if self.path == ":memory:": + return self + parts = PureWindowsPath(self.path).parts + PurePosixPath(self.path).parts + if ".." in parts: + msg = "Database path must not contain parent-directory traversal (..)" + logger.warning( + CONFIG_VALIDATION_FAILED, + field="path", + value=self.path, + reason=msg, + ) + raise ValueError(msg) + return self + + +class PersistenceConfig(BaseModel): + """Top-level persistence configuration. + + Attributes: + backend: Backend name — currently only ``"sqlite"`` is + implemented. + sqlite: SQLite-specific settings (used when + ``backend="sqlite"``). + """ + + model_config = ConfigDict(frozen=True) + + _VALID_BACKENDS: ClassVar[frozenset[str]] = frozenset({"sqlite"}) + + backend: NotBlankStr = Field( + default="sqlite", + description="Persistence backend name", + ) + sqlite: SQLiteConfig = Field( + default_factory=SQLiteConfig, + description="SQLite-specific settings", + ) + + @model_validator(mode="after") + def _validate_backend_name(self) -> Self: + """Ensure backend is a known persistence backend.""" + if self.backend not in self._VALID_BACKENDS: + msg = ( + f"Unknown persistence backend {self.backend!r}. " + f"Valid backends: {sorted(self._VALID_BACKENDS)}" + ) + logger.warning( + CONFIG_VALIDATION_FAILED, + field="backend", + value=self.backend, + reason=msg, + ) + raise ValueError(msg) + return self diff --git a/src/ai_company/persistence/errors.py b/src/ai_company/persistence/errors.py new file mode 100644 index 0000000000..27b675364a --- /dev/null +++ b/src/ai_company/persistence/errors.py @@ -0,0 +1,34 @@ +"""Persistence error hierarchy. + +All persistence-related errors inherit from ``PersistenceError`` so +callers can catch the entire family with a single except clause. +""" + + +class PersistenceError(Exception): + """Base exception for all persistence operations.""" + + +class PersistenceConnectionError(PersistenceError): + """Raised when a backend connection cannot be established or is lost.""" + + +class MigrationError(PersistenceError): + """Raised when a database migration fails.""" + + +class RecordNotFoundError(PersistenceError): + """Raised when a requested record does not exist. + + Currently unused — ``TaskRepository.get()`` returns ``None`` + on miss, and other repositories use collection-returning queries. + Reserved for future strict-fetch methods (e.g. ``get_or_raise``). + """ + + +class DuplicateRecordError(PersistenceError): + """Raised when inserting a record that already exists.""" + + +class QueryError(PersistenceError): + """Raised when a query fails due to invalid parameters or backend issues.""" diff --git a/src/ai_company/persistence/factory.py b/src/ai_company/persistence/factory.py new file mode 100644 index 0000000000..791751f892 --- /dev/null +++ b/src/ai_company/persistence/factory.py @@ -0,0 +1,61 @@ +"""Factory for creating persistence backends from configuration. + +Each company gets its own ``PersistenceBackend`` instance, which maps +to its own database. This enables multi-tenancy: one database per +company, selectable via the ``PersistenceConfig`` embedded in each +company's ``RootConfig``. +""" + +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_BACKEND_CREATED, + PERSISTENCE_BACKEND_UNKNOWN, +) +from ai_company.persistence.config import PersistenceConfig # noqa: TC001 +from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend + +logger = get_logger(__name__) + + +def create_backend(config: PersistenceConfig) -> PersistenceBackend: + """Create a persistence backend from configuration. + + Factory function that maps ``config.backend`` to the correct + concrete backend class. Each call returns a new, disconnected + backend instance — the caller is responsible for calling + ``connect()`` and ``migrate()``. + + Args: + config: Persistence configuration (includes backend selection + and backend-specific settings). + + Returns: + A new, disconnected backend instance. + + Raises: + PersistenceConnectionError: If the backend name is not + recognized. + + Example:: + + config = PersistenceConfig( + backend="sqlite", + sqlite=SQLiteConfig(path="data/company-a.db"), + ) + backend = create_backend(config) + await backend.connect() + await backend.migrate() + """ + if config.backend == "sqlite": + backend: PersistenceBackend = SQLitePersistenceBackend(config.sqlite) + logger.debug( + PERSISTENCE_BACKEND_CREATED, + backend="sqlite", + path=config.sqlite.path, + ) + return backend + msg = f"Unknown persistence backend: {config.backend!r}" + logger.error(PERSISTENCE_BACKEND_UNKNOWN, backend=config.backend) + raise PersistenceConnectionError(msg) diff --git a/src/ai_company/persistence/protocol.py b/src/ai_company/persistence/protocol.py new file mode 100644 index 0000000000..28ac619b84 --- /dev/null +++ b/src/ai_company/persistence/protocol.py @@ -0,0 +1,87 @@ +"""PersistenceBackend protocol — lifecycle + repository access. + +Application code depends on this protocol for storage lifecycle +management. Repository protocols provide entity-level access. +""" + +from typing import Protocol, runtime_checkable + +from ai_company.persistence.repositories import ( + CostRecordRepository, # noqa: TC001 + MessageRepository, # noqa: TC001 + TaskRepository, # noqa: TC001 +) + + +@runtime_checkable +class PersistenceBackend(Protocol): + """Lifecycle management for operational data storage. + + Concrete backends implement this protocol to provide connection + management, health monitoring, schema migrations, and access to + entity-specific repositories. + + Attributes: + is_connected: Whether the backend has an active connection. + backend_name: Human-readable backend identifier. + tasks: Repository for Task persistence. + cost_records: Repository for CostRecord persistence. + messages: Repository for Message persistence. + """ + + async def connect(self) -> None: + """Establish connection to the storage backend. + + Raises: + PersistenceConnectionError: If the connection cannot be + established. + """ + ... + + async def disconnect(self) -> None: + """Close the storage backend connection. + + Safe to call even if not connected. + """ + ... + + async def health_check(self) -> bool: + """Check whether the backend is healthy and responsive. + + Returns: + ``True`` if the backend is reachable and operational. + """ + ... + + async def migrate(self) -> None: + """Run pending schema migrations. + + Raises: + MigrationError: If a migration fails. + """ + ... + + @property + def is_connected(self) -> bool: + """Whether the backend has an active connection.""" + ... + + @property + def backend_name(self) -> str: + """Human-readable backend identifier (e.g. ``"sqlite"``).""" + ... + + @property + def tasks(self) -> TaskRepository: + """Repository for Task persistence.""" + ... + + @property + def cost_records(self) -> CostRecordRepository: + """Repository for CostRecord persistence.""" + ... + + @property + def messages(self) -> MessageRepository: + """Repository for Message persistence.""" + ... diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py new file mode 100644 index 0000000000..18a3092f76 --- /dev/null +++ b/src/ai_company/persistence/repositories.py @@ -0,0 +1,172 @@ +"""Repository protocols for operational data persistence. + +Each entity type has its own protocol so that application code depends +only on abstract interfaces, never on a concrete backend. +""" + +from typing import Protocol, runtime_checkable + +from ai_company.budget.cost_record import CostRecord # noqa: TC001 +from ai_company.communication.message import Message # noqa: TC001 +from ai_company.core.enums import TaskStatus # noqa: TC001 +from ai_company.core.task import Task # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +@runtime_checkable +class TaskRepository(Protocol): + """CRUD + query interface for Task persistence.""" + + async def save(self, task: Task) -> None: + """Persist a task (insert or update). + + Args: + task: The task to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, task_id: NotBlankStr) -> Task | None: + """Retrieve a task by its ID. + + Args: + task_id: The task identifier. + + Returns: + The task, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: NotBlankStr | None = None, + project: NotBlankStr | None = None, + ) -> tuple[Task, ...]: + """List tasks with optional filters. + + Args: + status: Filter by task status. + assigned_to: Filter by assignee agent ID. + project: Filter by project ID. + + Returns: + Matching tasks as a tuple. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, task_id: NotBlankStr) -> bool: + """Delete a task by ID. + + Args: + task_id: The task identifier. + + Returns: + ``True`` if the task was deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + +@runtime_checkable +class CostRecordRepository(Protocol): + """Append-only persistence + query/aggregation for CostRecord.""" + + async def save(self, record: CostRecord) -> None: + """Persist a cost record (append-only). + + Args: + record: The cost record to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def query( + self, + *, + agent_id: NotBlankStr | None = None, + task_id: NotBlankStr | None = None, + ) -> tuple[CostRecord, ...]: + """Query cost records with optional filters. + + Args: + agent_id: Filter by agent identifier. + task_id: Filter by task identifier. + + Returns: + Matching cost records as a tuple. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def aggregate( + self, + *, + agent_id: NotBlankStr | None = None, + task_id: NotBlankStr | None = None, + ) -> float: + """Sum total cost_usd, optionally filtered by agent and/or task. + + Args: + agent_id: Filter by agent identifier. + task_id: Filter by task identifier. + + Returns: + Total cost in USD. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + +@runtime_checkable +class MessageRepository(Protocol): + """Write + history query interface for Message persistence.""" + + async def save(self, message: Message) -> None: + """Persist a message. + + Args: + message: The message to persist. + + Raises: + DuplicateRecordError: If a message with the same ID exists. + PersistenceError: If the operation fails. + """ + ... + + async def get_history( + self, + channel: NotBlankStr, + *, + limit: int | None = None, + ) -> tuple[Message, ...]: + """Retrieve message history for a channel. + + Args: + channel: Channel name to query. + limit: Maximum number of messages to return (newest first). + + Returns: + Messages ordered by timestamp descending. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/ai_company/persistence/sqlite/__init__.py b/src/ai_company/persistence/sqlite/__init__.py new file mode 100644 index 0000000000..8453cef33f --- /dev/null +++ b/src/ai_company/persistence/sqlite/__init__.py @@ -0,0 +1,21 @@ +"""SQLite persistence backend (DESIGN_SPEC §7.5 — initial backend).""" + +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend +from ai_company.persistence.sqlite.migrations import ( + SCHEMA_VERSION, + run_migrations, +) +from ai_company.persistence.sqlite.repositories import ( + SQLiteCostRecordRepository, + SQLiteMessageRepository, + SQLiteTaskRepository, +) + +__all__ = [ + "SCHEMA_VERSION", + "SQLiteCostRecordRepository", + "SQLiteMessageRepository", + "SQLitePersistenceBackend", + "SQLiteTaskRepository", + "run_migrations", +] diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py new file mode 100644 index 0000000000..a5c7a28ccb --- /dev/null +++ b/src/ai_company/persistence/sqlite/backend.py @@ -0,0 +1,221 @@ +"""SQLite persistence backend implementation.""" + +import asyncio +import sqlite3 +from typing import TYPE_CHECKING + +import aiosqlite + +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_BACKEND_ALREADY_CONNECTED, + PERSISTENCE_BACKEND_CONNECTED, + PERSISTENCE_BACKEND_CONNECTING, + PERSISTENCE_BACKEND_CONNECTION_FAILED, + PERSISTENCE_BACKEND_DISCONNECT_ERROR, + PERSISTENCE_BACKEND_DISCONNECTED, + PERSISTENCE_BACKEND_DISCONNECTING, + PERSISTENCE_BACKEND_HEALTH_CHECK, + PERSISTENCE_BACKEND_NOT_CONNECTED, + PERSISTENCE_BACKEND_WAL_MODE_FAILED, +) +from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.sqlite.migrations import run_migrations +from ai_company.persistence.sqlite.repositories import ( + SQLiteCostRecordRepository, + SQLiteMessageRepository, + SQLiteTaskRepository, +) + +if TYPE_CHECKING: + from ai_company.persistence.config import SQLiteConfig + +logger = get_logger(__name__) + + +class SQLitePersistenceBackend: + """SQLite implementation of the PersistenceBackend protocol. + + Uses a single ``aiosqlite.Connection`` with WAL mode enabled by + default for file-based databases (in-memory databases do not + support WAL). Configurable via ``SQLiteConfig.wal_mode``. + + Args: + config: SQLite-specific configuration. + """ + + def __init__(self, config: SQLiteConfig) -> None: + self._config = config + self._lifecycle_lock = asyncio.Lock() + self._db: aiosqlite.Connection | None = None + self._tasks: SQLiteTaskRepository | None = None + self._cost_records: SQLiteCostRecordRepository | None = None + self._messages: SQLiteMessageRepository | None = None + + def _clear_state(self) -> None: + """Reset connection and repository references to ``None``.""" + self._db = None + self._tasks = None + self._cost_records = None + self._messages = None + + async def connect(self) -> None: + """Open the SQLite database and configure WAL mode.""" + async with self._lifecycle_lock: + if self._db is not None: + logger.debug(PERSISTENCE_BACKEND_ALREADY_CONNECTED) + return + + logger.info(PERSISTENCE_BACKEND_CONNECTING, path=self._config.path) + try: + self._db = await aiosqlite.connect(self._config.path) + self._db.row_factory = aiosqlite.Row + + if self._config.wal_mode: + cursor = await self._db.execute("PRAGMA journal_mode=WAL") + row = await cursor.fetchone() + actual_mode = row[0] if row else "unknown" + if actual_mode != "wal" and self._config.path != ":memory:": + logger.warning( + PERSISTENCE_BACKEND_WAL_MODE_FAILED, + requested="wal", + actual=actual_mode, + ) + # PRAGMA does not support parameterized queries; + # journal_size_limit is validated as int >= 0 by Pydantic. + limit = int(self._config.journal_size_limit) + await self._db.execute(f"PRAGMA journal_size_limit={limit}") + + self._tasks = SQLiteTaskRepository(self._db) + self._cost_records = SQLiteCostRecordRepository(self._db) + self._messages = SQLiteMessageRepository(self._db) + except (sqlite3.Error, OSError) as exc: + logger.exception( + PERSISTENCE_BACKEND_CONNECTION_FAILED, + path=self._config.path, + error=str(exc), + ) + if self._db is not None: + try: + await self._db.close() + except (sqlite3.Error, OSError) as cleanup_exc: + logger.warning( + PERSISTENCE_BACKEND_DISCONNECT_ERROR, + path=self._config.path, + error=str(cleanup_exc), + error_type=type(cleanup_exc).__name__, + context="cleanup_after_connect_failure", + ) + self._clear_state() + msg = "Failed to connect to persistence backend" + raise PersistenceConnectionError(msg) from exc + + logger.info(PERSISTENCE_BACKEND_CONNECTED, path=self._config.path) + + async def disconnect(self) -> None: + """Close the database connection.""" + async with self._lifecycle_lock: + if self._db is None: + return + + logger.info(PERSISTENCE_BACKEND_DISCONNECTING, path=self._config.path) + try: + await self._db.close() + logger.info( + PERSISTENCE_BACKEND_DISCONNECTED, + path=self._config.path, + ) + except (sqlite3.Error, OSError) as exc: + logger.warning( + PERSISTENCE_BACKEND_DISCONNECT_ERROR, + path=self._config.path, + error=str(exc), + error_type=type(exc).__name__, + ) + finally: + self._clear_state() + + async def health_check(self) -> bool: + """Check database connectivity.""" + if self._db is None: + return False + try: + cursor = await self._db.execute("SELECT 1") + row = await cursor.fetchone() + healthy = row is not None + except (sqlite3.Error, aiosqlite.Error) as exc: + logger.warning( + PERSISTENCE_BACKEND_HEALTH_CHECK, + healthy=False, + error=str(exc), + error_type=type(exc).__name__, + ) + return False + logger.debug(PERSISTENCE_BACKEND_HEALTH_CHECK, healthy=healthy) + return healthy + + async def migrate(self) -> None: + """Run pending schema migrations. + + Raises: + PersistenceConnectionError: If not connected. + MigrationError: If migration fails. + """ + if self._db is None: + msg = "Cannot migrate: not connected" + logger.warning(PERSISTENCE_BACKEND_NOT_CONNECTED, error=msg) + raise PersistenceConnectionError(msg) + await run_migrations(self._db) + + @property + def is_connected(self) -> bool: + """Whether the backend has an active connection.""" + return self._db is not None + + @property + def backend_name(self) -> str: + """Human-readable backend identifier.""" + return "sqlite" + + def _require_connected[T](self, repo: T | None, name: str) -> T: + """Return *repo* or raise if the backend is not connected. + + Args: + repo: Repository instance (``None`` when disconnected). + name: Repository name for the error message. + + Raises: + PersistenceConnectionError: If *repo* is ``None``. + """ + if repo is None: + msg = f"Not connected — call connect() before accessing {name}" + logger.warning(PERSISTENCE_BACKEND_NOT_CONNECTED, error=msg) + raise PersistenceConnectionError(msg) + return repo + + @property + def tasks(self) -> SQLiteTaskRepository: + """Repository for Task persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._tasks, "tasks") + + @property + def cost_records(self) -> SQLiteCostRecordRepository: + """Repository for CostRecord persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._cost_records, "cost_records") + + @property + def messages(self) -> SQLiteMessageRepository: + """Repository for Message persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._messages, "messages") diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py new file mode 100644 index 0000000000..6bf6372d32 --- /dev/null +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -0,0 +1,204 @@ +"""SQLite schema migrations using the user_version pragma. + +Each migration is a function that receives a connection and applies +DDL statements. ``run_migrations`` checks the current version and +runs only the migrations that haven't been applied yet. +""" + +import sqlite3 +from collections.abc import Callable, Coroutine, Sequence +from typing import Any + +import aiosqlite + +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_MIGRATION_COMPLETED, + PERSISTENCE_MIGRATION_FAILED, + PERSISTENCE_MIGRATION_SKIPPED, + PERSISTENCE_MIGRATION_STARTED, +) +from ai_company.persistence.errors import MigrationError + +logger = get_logger(__name__) + +# Current schema version — bump when adding new migrations. +SCHEMA_VERSION = 1 + +_V1_STATEMENTS: Sequence[str] = ( + # ── Tasks ───────────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS tasks ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + type TEXT NOT NULL, + priority TEXT NOT NULL DEFAULT 'medium', + project TEXT NOT NULL, + created_by TEXT NOT NULL, + assigned_to TEXT, + status TEXT NOT NULL DEFAULT 'created', + estimated_complexity TEXT NOT NULL DEFAULT 'medium', + budget_limit REAL NOT NULL DEFAULT 0.0, + deadline TEXT, + max_retries INTEGER NOT NULL DEFAULT 1, + parent_task_id TEXT, + task_structure TEXT, + coordination_topology TEXT NOT NULL DEFAULT 'auto', + reviewers TEXT NOT NULL DEFAULT '[]', + dependencies TEXT NOT NULL DEFAULT '[]', + artifacts_expected TEXT NOT NULL DEFAULT '[]', + acceptance_criteria TEXT NOT NULL DEFAULT '[]', + delegation_chain TEXT NOT NULL DEFAULT '[]' +)""", + "CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)", + "CREATE INDEX IF NOT EXISTS idx_tasks_assigned_to ON tasks(assigned_to)", + "CREATE INDEX IF NOT EXISTS idx_tasks_project ON tasks(project)", + # ── Cost records ────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS cost_records ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT NOT NULL, + task_id TEXT NOT NULL, + provider TEXT NOT NULL, + model TEXT NOT NULL, + input_tokens INTEGER NOT NULL, + output_tokens INTEGER NOT NULL, + cost_usd REAL NOT NULL, + timestamp TEXT NOT NULL, + call_category TEXT +)""", + "CREATE INDEX IF NOT EXISTS idx_cost_records_agent_id ON cost_records(agent_id)", + "CREATE INDEX IF NOT EXISTS idx_cost_records_task_id ON cost_records(task_id)", + # ── Messages ────────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + timestamp TEXT NOT NULL, + sender TEXT NOT NULL, + "to" TEXT NOT NULL, + type TEXT NOT NULL, + priority TEXT NOT NULL DEFAULT 'normal', + channel TEXT NOT NULL, + content TEXT NOT NULL, + attachments TEXT NOT NULL DEFAULT '[]', + metadata TEXT NOT NULL DEFAULT '{}' +)""", + "CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel)", + "CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp)", +) + +_MigrateFn = Callable[[aiosqlite.Connection], Coroutine[Any, Any, None]] + + +async def get_user_version(db: aiosqlite.Connection) -> int: + """Read the current schema version from the SQLite user_version pragma.""" + cursor = await db.execute("PRAGMA user_version") + row = await cursor.fetchone() + return int(row[0]) if row else 0 + + +async def set_user_version(db: aiosqlite.Connection, version: int) -> None: + """Set the schema version via the SQLite user_version pragma. + + Args: + db: An open aiosqlite connection. + version: Non-negative integer schema version. + + Raises: + MigrationError: If *version* is not a valid non-negative integer. + """ + if not isinstance(version, int) or version < 0: + msg = f"Schema version must be a non-negative integer, got {version!r}" + logger.error( + PERSISTENCE_MIGRATION_FAILED, + error=msg, + version=version, + ) + raise MigrationError(msg) + # PRAGMA does not support parameterized queries; version is validated above. + await db.execute(f"PRAGMA user_version = {version}") + + +async def _apply_v1(db: aiosqlite.Connection) -> None: + """Apply schema version 1: create tasks, cost_records, messages.""" + for stmt in _V1_STATEMENTS: + await db.execute(stmt) + + +# Ordered list of (target_version, migration_function) pairs. Each migration +# is applied when the current schema version is below its target version. +_MIGRATIONS: list[tuple[int, _MigrateFn]] = [ + (1, _apply_v1), +] + + +async def run_migrations(db: aiosqlite.Connection) -> None: + """Run pending migrations up to ``SCHEMA_VERSION``. + + .. note:: + + SQLite implicitly commits before each DDL statement, so + multi-statement migrations are **not** fully atomic. All DDL + uses ``IF NOT EXISTS`` guards so that a partial failure + (e.g. disk full after creating some tables) can be recovered + by re-running the migration. + + Args: + db: An open aiosqlite connection. + + Raises: + MigrationError: If any migration step fails. + """ + try: + current = await get_user_version(db) + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to read current schema version" + logger.exception(PERSISTENCE_MIGRATION_FAILED, error=str(exc)) + raise MigrationError(msg) from exc + + if current >= SCHEMA_VERSION: + logger.debug( + PERSISTENCE_MIGRATION_SKIPPED, + current_version=current, + target_version=SCHEMA_VERSION, + ) + return + + logger.info( + PERSISTENCE_MIGRATION_STARTED, + current_version=current, + target_version=SCHEMA_VERSION, + ) + + try: + for target_version, migrate_fn in _MIGRATIONS: + if current < target_version: + await migrate_fn(db) + current = target_version + + await set_user_version(db, SCHEMA_VERSION) + await db.commit() + except (sqlite3.Error, aiosqlite.Error, MigrationError) as exc: + try: + await db.rollback() + except (sqlite3.Error, aiosqlite.Error) as rollback_exc: + logger.error( # noqa: TRY400 + PERSISTENCE_MIGRATION_FAILED, + error=f"Rollback also failed: {rollback_exc}", + original_error=str(exc), + ) + if isinstance(exc, MigrationError): + raise + msg = f"Migration to version {SCHEMA_VERSION} failed" + logger.exception( + PERSISTENCE_MIGRATION_FAILED, + target_version=SCHEMA_VERSION, + error=str(exc), + ) + raise MigrationError(msg) from exc + + logger.info( + PERSISTENCE_MIGRATION_COMPLETED, + version=SCHEMA_VERSION, + ) diff --git a/src/ai_company/persistence/sqlite/repositories.py b/src/ai_company/persistence/sqlite/repositories.py new file mode 100644 index 0000000000..d2dbfb342a --- /dev/null +++ b/src/ai_company/persistence/sqlite/repositories.py @@ -0,0 +1,497 @@ +"""SQLite repository implementations for Task, CostRecord, and Message.""" + +import json +import sqlite3 + +import aiosqlite +from pydantic import BaseModel, ValidationError + +from ai_company.budget.cost_record import CostRecord +from ai_company.communication.message import Message +from ai_company.core.enums import TaskStatus # noqa: TC001 +from ai_company.core.task import Task +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_COST_RECORD_AGGREGATE_FAILED, + PERSISTENCE_COST_RECORD_AGGREGATED, + PERSISTENCE_COST_RECORD_QUERIED, + PERSISTENCE_COST_RECORD_QUERY_FAILED, + PERSISTENCE_COST_RECORD_SAVE_FAILED, + PERSISTENCE_COST_RECORD_SAVED, + PERSISTENCE_MESSAGE_DESERIALIZE_FAILED, + PERSISTENCE_MESSAGE_DUPLICATE, + PERSISTENCE_MESSAGE_HISTORY_FAILED, + PERSISTENCE_MESSAGE_HISTORY_FETCHED, + PERSISTENCE_MESSAGE_SAVE_FAILED, + PERSISTENCE_MESSAGE_SAVED, + PERSISTENCE_TASK_DELETE_FAILED, + PERSISTENCE_TASK_DELETED, + PERSISTENCE_TASK_DESERIALIZE_FAILED, + PERSISTENCE_TASK_FETCH_FAILED, + PERSISTENCE_TASK_FETCHED, + PERSISTENCE_TASK_LIST_FAILED, + PERSISTENCE_TASK_LISTED, + PERSISTENCE_TASK_SAVE_FAILED, + PERSISTENCE_TASK_SAVED, +) +from ai_company.persistence.errors import DuplicateRecordError, QueryError + +logger = get_logger(__name__) + + +def _json_list(items: tuple[object, ...]) -> str: + """Serialize a tuple of Pydantic models or scalars to a JSON array. + + Items must be JSON-serializable or Pydantic models. + Non-serializable items will raise ``TypeError``. + """ + return json.dumps( + [ + item.model_dump(mode="json") if isinstance(item, BaseModel) else item + for item in items + ] + ) + + +class SQLiteTaskRepository: + """SQLite implementation of the TaskRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, task: Task) -> None: + """Persist a task (upsert semantics).""" + try: + params = task.model_dump(mode="json") + # Tuple fields must be stored as JSON strings. + params["reviewers"] = _json_list(task.reviewers) + params["dependencies"] = _json_list(task.dependencies) + params["artifacts_expected"] = _json_list(task.artifacts_expected) + params["acceptance_criteria"] = _json_list( + task.acceptance_criteria, + ) + params["delegation_chain"] = _json_list(task.delegation_chain) + + await self._db.execute( + """\ +INSERT INTO tasks ( + id, title, description, type, priority, project, created_by, + assigned_to, status, estimated_complexity, budget_limit, deadline, + max_retries, parent_task_id, task_structure, coordination_topology, + reviewers, dependencies, artifacts_expected, acceptance_criteria, + delegation_chain +) VALUES ( + :id, :title, :description, :type, :priority, :project, :created_by, + :assigned_to, :status, :estimated_complexity, :budget_limit, :deadline, + :max_retries, :parent_task_id, :task_structure, :coordination_topology, + :reviewers, :dependencies, :artifacts_expected, :acceptance_criteria, + :delegation_chain +) +ON CONFLICT(id) DO UPDATE SET + title=excluded.title, + description=excluded.description, + type=excluded.type, + priority=excluded.priority, + project=excluded.project, + created_by=excluded.created_by, + assigned_to=excluded.assigned_to, + status=excluded.status, + estimated_complexity=excluded.estimated_complexity, + budget_limit=excluded.budget_limit, + deadline=excluded.deadline, + max_retries=excluded.max_retries, + parent_task_id=excluded.parent_task_id, + task_structure=excluded.task_structure, + coordination_topology=excluded.coordination_topology, + reviewers=excluded.reviewers, + dependencies=excluded.dependencies, + artifacts_expected=excluded.artifacts_expected, + acceptance_criteria=excluded.acceptance_criteria, + delegation_chain=excluded.delegation_chain +""", + params, + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save task {task.id!r}" + logger.exception( + PERSISTENCE_TASK_SAVE_FAILED, task_id=task.id, error=str(exc) + ) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_TASK_SAVED, task_id=task.id) + + #: Fields stored as JSON strings that need deserialization. + _JSON_FIELDS: tuple[str, ...] = ( + "reviewers", + "dependencies", + "artifacts_expected", + "acceptance_criteria", + "delegation_chain", + ) + + def _row_to_task(self, row: aiosqlite.Row) -> Task: + """Reconstruct a Task from a database row.""" + try: + data = dict(row) + for field in self._JSON_FIELDS: + data[field] = json.loads(data[field]) + return Task.model_validate(data) + except ( + json.JSONDecodeError, + ValidationError, + KeyError, + TypeError, + ) as exc: + task_id = row["id"] if row else "unknown" + msg = f"Failed to deserialize task {task_id!r}" + logger.exception( + PERSISTENCE_TASK_DESERIALIZE_FAILED, + task_id=task_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + _TASK_COLUMNS = """\ +id, title, description, type, priority, project, created_by, + assigned_to, status, estimated_complexity, budget_limit, deadline, + max_retries, parent_task_id, task_structure, coordination_topology, + reviewers, dependencies, artifacts_expected, acceptance_criteria, + delegation_chain""" + + async def get(self, task_id: str) -> Task | None: + """Retrieve a task by its ID.""" + try: + cursor = await self._db.execute( + f"SELECT {self._TASK_COLUMNS} FROM tasks WHERE id = ?", # noqa: S608 + (task_id,), + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch task {task_id!r}" + logger.exception( + PERSISTENCE_TASK_FETCH_FAILED, + task_id=task_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + logger.debug(PERSISTENCE_TASK_FETCHED, task_id=task_id, found=False) + return None + logger.debug(PERSISTENCE_TASK_FETCHED, task_id=task_id, found=True) + return self._row_to_task(row) + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + """List tasks with optional filters.""" + clauses: list[str] = [] + params: list[str] = [] + if status is not None: + clauses.append("status = ?") + params.append(status.value) + if assigned_to is not None: + clauses.append("assigned_to = ?") + params.append(assigned_to) + if project is not None: + clauses.append("project = ?") + params.append(project) + + query = f"SELECT {self._TASK_COLUMNS} FROM tasks" # noqa: S608 + if clauses: + query += " WHERE " + " AND ".join(clauses) + + try: + cursor = await self._db.execute(query, params) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to list tasks" + logger.exception(PERSISTENCE_TASK_LIST_FAILED, error=str(exc)) + raise QueryError(msg) from exc + tasks = tuple(self._row_to_task(row) for row in rows) + logger.debug(PERSISTENCE_TASK_LISTED, count=len(tasks)) + return tasks + + async def delete(self, task_id: str) -> bool: + """Delete a task by ID.""" + try: + cursor = await self._db.execute( + "DELETE FROM tasks WHERE id = ?", (task_id,) + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete task {task_id!r}" + logger.exception( + PERSISTENCE_TASK_DELETE_FAILED, + task_id=task_id, + error=str(exc), + ) + raise QueryError(msg) from exc + deleted = cursor.rowcount > 0 + logger.debug(PERSISTENCE_TASK_DELETED, task_id=task_id, deleted=deleted) + return deleted + + +class SQLiteCostRecordRepository: + """SQLite implementation of the CostRecordRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, record: CostRecord) -> None: + """Persist a cost record (append-only).""" + try: + data = record.model_dump(mode="json") + await self._db.execute( + """\ +INSERT INTO cost_records ( + agent_id, task_id, provider, model, input_tokens, + output_tokens, cost_usd, timestamp, call_category +) VALUES ( + :agent_id, :task_id, :provider, :model, :input_tokens, + :output_tokens, :cost_usd, :timestamp, :call_category +)""", + data, + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save cost record for agent {record.agent_id!r}" + logger.exception( + PERSISTENCE_COST_RECORD_SAVE_FAILED, + agent_id=record.agent_id, + task_id=record.task_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug( + PERSISTENCE_COST_RECORD_SAVED, + agent_id=record.agent_id, + task_id=record.task_id, + ) + + async def query( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> tuple[CostRecord, ...]: + """Query cost records with optional filters.""" + clauses: list[str] = [] + params: list[str] = [] + if agent_id is not None: + clauses.append("agent_id = ?") + params.append(agent_id) + if task_id is not None: + clauses.append("task_id = ?") + params.append(task_id) + + sql = """\ +SELECT agent_id, task_id, provider, model, input_tokens, + output_tokens, cost_usd, timestamp, call_category +FROM cost_records""" + if clauses: + sql += " WHERE " + " AND ".join(clauses) + + try: + cursor = await self._db.execute(sql, params) + rows = await cursor.fetchall() + records = tuple(CostRecord.model_validate(dict(row)) for row in rows) + except ( + sqlite3.Error, + aiosqlite.Error, + json.JSONDecodeError, + ValidationError, + ) as exc: + msg = "Failed to query cost records" + logger.exception(PERSISTENCE_COST_RECORD_QUERY_FAILED, error=str(exc)) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_COST_RECORD_QUERIED, count=len(records)) + return records + + async def aggregate( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> float: + """Sum total cost_usd, optionally filtered by agent and/or task.""" + try: + sql = "SELECT COALESCE(SUM(cost_usd), 0.0) FROM cost_records" + conditions: list[str] = [] + params: list[str] = [] + if agent_id is not None: + conditions.append("agent_id = ?") + params.append(agent_id) + if task_id is not None: + conditions.append("task_id = ?") + params.append(task_id) + if conditions: + sql += " WHERE " + " AND ".join(conditions) + cursor = await self._db.execute(sql, tuple(params)) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to aggregate cost records" + logger.exception( + PERSISTENCE_COST_RECORD_AGGREGATE_FAILED, + agent_id=agent_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + msg = "aggregate query returned no rows" + logger.error( + PERSISTENCE_COST_RECORD_AGGREGATE_FAILED, + agent_id=agent_id, + error=msg, + ) + raise QueryError(msg) + total = float(row[0]) + logger.debug( + PERSISTENCE_COST_RECORD_AGGREGATED, + agent_id=agent_id, + total_usd=total, + ) + return total + + +class SQLiteMessageRepository: + """SQLite implementation of the MessageRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, message: Message) -> None: + """Persist a message.""" + data = message.model_dump(mode="json") + msg_id = str(message.id) + + try: + await self._db.execute( + """\ +INSERT INTO messages ( + id, timestamp, sender, "to", type, priority, + channel, content, attachments, metadata +) VALUES ( + :id, :timestamp, :sender, :to, :type, :priority, + :channel, :content, :attachments, :metadata +)""", + { + "id": msg_id, + "timestamp": data["timestamp"], + "sender": data["sender"], + "to": data["to"], + "type": data["type"], + "priority": data["priority"], + "channel": data["channel"], + "content": data["content"], + "attachments": json.dumps(data["attachments"]), + "metadata": json.dumps(data["metadata"]), + }, + ) + await self._db.commit() + except sqlite3.IntegrityError as exc: + error_text = str(exc) + is_duplicate_id = ( + "UNIQUE constraint failed: messages.id" in error_text + or "PRIMARY KEY" in error_text + ) + if is_duplicate_id: + err_msg = f"Message {msg_id} already exists" + logger.warning(PERSISTENCE_MESSAGE_DUPLICATE, message_id=msg_id) + raise DuplicateRecordError(err_msg) from exc + # Other integrity errors (NOT NULL, different UNIQUE). + msg = f"Failed to save message {msg_id!r}" + logger.exception( + PERSISTENCE_MESSAGE_SAVE_FAILED, + message_id=msg_id, + error=error_text, + ) + raise QueryError(msg) from exc + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save message {msg_id!r}" + logger.exception( + PERSISTENCE_MESSAGE_SAVE_FAILED, + message_id=msg_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_MESSAGE_SAVED, message_id=msg_id) + + def _row_to_message(self, row: aiosqlite.Row) -> Message: + """Reconstruct a Message from a database row.""" + try: + data = dict(row) + # Map DB column "sender" to Message's "from" alias. + data["from"] = data.pop("sender") + data["attachments"] = json.loads(data["attachments"]) + data["metadata"] = json.loads(data["metadata"]) + return Message.model_validate(data) + except ( + json.JSONDecodeError, + ValidationError, + KeyError, + TypeError, + ) as exc: + msg_id = row["id"] if row else "unknown" + msg = f"Failed to deserialize message {msg_id!r}" + logger.exception( + PERSISTENCE_MESSAGE_DESERIALIZE_FAILED, + message_id=msg_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + async def get_history( + self, + channel: str, + *, + limit: int | None = None, + ) -> tuple[Message, ...]: + """Retrieve message history for a channel, newest first.""" + if limit is not None and limit < 1: + msg = f"limit must be a positive integer, got {limit}" + raise QueryError(msg) + sql = """\ +SELECT id, timestamp, sender, "to", type, priority, + channel, content, attachments, metadata +FROM messages +WHERE channel = ? +ORDER BY timestamp DESC""" + params: list[object] = [channel] + if limit is not None: + sql += " LIMIT ?" + params.append(limit) + + try: + cursor = await self._db.execute(sql, params) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch message history for channel {channel!r}" + logger.exception( + PERSISTENCE_MESSAGE_HISTORY_FAILED, + channel=channel, + error=str(exc), + ) + raise QueryError(msg) from exc + messages = tuple(self._row_to_message(row) for row in rows) + logger.debug( + PERSISTENCE_MESSAGE_HISTORY_FETCHED, + channel=channel, + count=len(messages), + ) + return messages diff --git a/tests/integration/persistence/__init__.py b/tests/integration/persistence/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/persistence/conftest.py b/tests/integration/persistence/conftest.py new file mode 100644 index 0000000000..b041e1ae39 --- /dev/null +++ b/tests/integration/persistence/conftest.py @@ -0,0 +1,28 @@ +"""Fixtures for persistence integration tests.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.persistence.config import SQLiteConfig +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + +@pytest.fixture +def db_path(tmp_path: Path) -> str: + """Return a temporary on-disk database path.""" + return str(tmp_path / "test.db") + + +@pytest.fixture +async def on_disk_backend(db_path: str) -> AsyncGenerator[SQLitePersistenceBackend]: + """Connected + migrated on-disk SQLite backend.""" + backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend.connect() + await backend.migrate() + yield backend + await backend.disconnect() diff --git a/tests/integration/persistence/test_sqlite_integration.py b/tests/integration/persistence/test_sqlite_integration.py new file mode 100644 index 0000000000..ab859f53ee --- /dev/null +++ b/tests/integration/persistence/test_sqlite_integration.py @@ -0,0 +1,117 @@ +"""Integration tests for SQLite persistence (on-disk).""" + +from pathlib import Path + +import aiosqlite +import pytest + +from ai_company.persistence.config import SQLiteConfig +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend +from tests.unit.persistence.conftest import make_message, make_task + +pytestmark = [pytest.mark.integration, pytest.mark.timeout(30)] + + +class TestSQLiteOnDisk: + async def test_wal_mode_enabled(self, db_path: str) -> None: + """WAL journal mode is enabled for the on-disk SQLite database.""" + backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend.connect() + await backend.migrate() + + # Write some data to force WAL file creation + task = make_task() + await backend.tasks.save(task) + + assert Path(db_path).exists() # noqa: ASYNC240 + await backend.disconnect() + + # Verify WAL mode by querying the journal_mode pragma + async with aiosqlite.connect(db_path) as db: + cursor = await db.execute("PRAGMA journal_mode") + row = await cursor.fetchone() + assert row is not None + assert row[0] == "wal" + + async def test_data_persists_across_reconnect(self, db_path: str) -> None: + """Data written before disconnect is readable after reconnect.""" + backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend.connect() + await backend.migrate() + + task = make_task(task_id="persist-test") + await backend.tasks.save(task) + await backend.disconnect() + + # Reconnect and verify data + backend2 = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend2.connect() + await backend2.migrate() + + result = await backend2.tasks.get("persist-test") + assert result is not None + assert result.id == "persist-test" + await backend2.disconnect() + + async def test_multiple_entity_types_persist( + self, + on_disk_backend: SQLitePersistenceBackend, + ) -> None: + """Tasks, cost records, and messages all persist together.""" + from datetime import UTC, datetime + + from ai_company.budget.cost_record import CostRecord + + backend = on_disk_backend + + # Save task + await backend.tasks.save(make_task(task_id="multi-t1")) + + # Save cost record + record = CostRecord( + agent_id="alice", + task_id="multi-t1", + provider="test-provider", + model="test-model-001", + input_tokens=500, + output_tokens=200, + cost_usd=0.03, + timestamp=datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + ) + await backend.cost_records.save(record) + + # Save message + await backend.messages.save(make_message(channel="test-channel")) + + # Verify all persist + tasks = await backend.tasks.list_tasks() + assert len(tasks) == 1 + + records = await backend.cost_records.query() + assert len(records) == 1 + + history = await backend.messages.get_history("test-channel") + assert len(history) == 1 + + async def test_concurrent_reads(self, db_path: str) -> None: + """Multiple connections can read concurrently with WAL mode.""" + import asyncio + + # Set up data + backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend.connect() + await backend.migrate() + for i in range(10): + await backend.tasks.save(make_task(task_id=f"conc-{i}")) + await backend.disconnect() + + async def read_all() -> int: + b = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await b.connect() + tasks = await b.tasks.list_tasks() + await b.disconnect() + return len(tasks) + + # Run multiple readers concurrently + results = await asyncio.gather(read_all(), read_all(), read_all()) + assert all(r == 10 for r in results) diff --git a/tests/unit/config/conftest.py b/tests/unit/config/conftest.py index f92d4903c3..753a8ac6e3 100644 --- a/tests/unit/config/conftest.py +++ b/tests/unit/config/conftest.py @@ -20,6 +20,7 @@ TaskAssignmentConfig, ) from ai_company.core.company import CompanyConfig +from ai_company.persistence.config import PersistenceConfig if TYPE_CHECKING: from pathlib import Path @@ -73,6 +74,7 @@ class RootConfigFactory(ModelFactory[RootConfig]): logging = None coordination_metrics = CoordinationMetricsConfig() task_assignment = TaskAssignmentConfig() + persistence = PersistenceConfig() # ── Sample YAML strings ────────────────────────────────────────── diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index b23649214c..e777a72057 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -327,6 +327,23 @@ def test_defaults_applied(self) -> None: assert cfg.communication.default_pattern.value == "hybrid" assert cfg.routing.strategy == "cost_aware" + def test_persistence_defaults(self) -> None: + cfg = RootConfig(company_name="X") + assert cfg.persistence.backend == "sqlite" + assert cfg.persistence.sqlite.path == "ai-company.db" + assert cfg.persistence.sqlite.wal_mode is True + + def test_persistence_custom_path(self) -> None: + from ai_company.persistence.config import PersistenceConfig, SQLiteConfig + + cfg = RootConfig( + company_name="X", + persistence=PersistenceConfig( + sqlite=SQLiteConfig(path="data/company-a.db"), + ), + ) + assert cfg.persistence.sqlite.path == "data/company-a.db" + def test_missing_company_name_rejected(self) -> None: with pytest.raises(ValidationError): RootConfig() # type: ignore[call-arg] diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 97b3a67d68..7afb95d653 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -188,6 +188,7 @@ def test_all_domain_modules_discovered(self) -> None: "task_routing", "template", "tool", + "persistence", "workspace", } discovered = {info.name for info in pkgutil.iter_modules(events.__path__)} @@ -365,3 +366,51 @@ def test_workspace_events_exist(self) -> None: WORKSPACE_SORT_WORKSPACES_APPENDED == "workspace.sort.workspaces.appended" ) assert WORKSPACE_GROUP_SETUP_FAILED == "workspace.group.setup.failed" + + @pytest.mark.parametrize( + ("constant_name", "expected"), + [ + ("PERSISTENCE_BACKEND_CONNECTING", "persistence.backend.connecting"), + ("PERSISTENCE_BACKEND_CONNECTED", "persistence.backend.connected"), + ("PERSISTENCE_BACKEND_DISCONNECTING", "persistence.backend.disconnecting"), + ("PERSISTENCE_BACKEND_DISCONNECTED", "persistence.backend.disconnected"), + ("PERSISTENCE_BACKEND_HEALTH_CHECK", "persistence.backend.health_check"), + ( + "PERSISTENCE_BACKEND_NOT_CONNECTED", + "persistence.backend.not_connected", + ), + ("PERSISTENCE_MIGRATION_STARTED", "persistence.migration.started"), + ("PERSISTENCE_MIGRATION_COMPLETED", "persistence.migration.completed"), + ("PERSISTENCE_MIGRATION_SKIPPED", "persistence.migration.skipped"), + ("PERSISTENCE_TASK_SAVED", "persistence.task.saved"), + ("PERSISTENCE_TASK_FETCHED", "persistence.task.fetched"), + ("PERSISTENCE_TASK_LISTED", "persistence.task.listed"), + ("PERSISTENCE_TASK_DELETED", "persistence.task.deleted"), + ( + "PERSISTENCE_TASK_DESERIALIZE_FAILED", + "persistence.task.deserialize_failed", + ), + ("PERSISTENCE_COST_RECORD_SAVED", "persistence.cost_record.saved"), + ( + "PERSISTENCE_COST_RECORD_QUERIED", + "persistence.cost_record.queried", + ), + ( + "PERSISTENCE_COST_RECORD_AGGREGATED", + "persistence.cost_record.aggregated", + ), + ("PERSISTENCE_MESSAGE_SAVED", "persistence.message.saved"), + ( + "PERSISTENCE_MESSAGE_HISTORY_FETCHED", + "persistence.message.history_fetched", + ), + ( + "PERSISTENCE_MESSAGE_DESERIALIZE_FAILED", + "persistence.message.deserialize_failed", + ), + ], + ) + def test_persistence_events_exist(self, constant_name: str, expected: str) -> None: + from ai_company.observability.events import persistence as mod + + assert getattr(mod, constant_name) == expected diff --git a/tests/unit/persistence/__init__.py b/tests/unit/persistence/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/persistence/conftest.py b/tests/unit/persistence/conftest.py new file mode 100644 index 0000000000..38e0dc8b96 --- /dev/null +++ b/tests/unit/persistence/conftest.py @@ -0,0 +1,84 @@ +"""Shared fixtures and helpers for persistence unit tests.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from ai_company.communication.enums import MessagePriority, MessageType +from ai_company.communication.message import Message, MessageMetadata +from ai_company.core.enums import Priority, TaskStatus, TaskType +from ai_company.core.task import Task + +if TYPE_CHECKING: + from uuid import UUID + + +_REQUIRES_ASSIGNEE = frozenset( + { + TaskStatus.ASSIGNED, + TaskStatus.IN_PROGRESS, + TaskStatus.IN_REVIEW, + TaskStatus.COMPLETED, + } +) + + +def make_task( # noqa: PLR0913 + *, + task_id: str = "task-001", + title: str = "Test task", + description: str = "A test task for persistence", + task_type: TaskType = TaskType.DEVELOPMENT, + priority: Priority = Priority.MEDIUM, + project: str = "test-project", + created_by: str = "alice", + assigned_to: str | None = None, + status: TaskStatus = TaskStatus.CREATED, +) -> Task: + """Build a Task with sensible defaults for persistence tests. + + Automatically fills ``assigned_to`` with ``"alice"`` when the + status requires an assignee and none was provided. + """ + effective_assigned_to = assigned_to + if effective_assigned_to is None and status in _REQUIRES_ASSIGNEE: + effective_assigned_to = "alice" + return Task( + id=task_id, + title=title, + description=description, + type=task_type, + priority=priority, + project=project, + created_by=created_by, + assigned_to=effective_assigned_to, + status=status, + ) + + +def make_message( # noqa: PLR0913 + *, + msg_id: UUID | None = None, + sender: str = "alice", + to: str = "bob", + channel: str = "general", + content: str = "Hello, world!", + msg_type: MessageType = MessageType.TASK_UPDATE, + priority: MessagePriority = MessagePriority.NORMAL, + timestamp: datetime | None = None, + metadata: MessageMetadata | None = None, +) -> Message: + """Build a Message with sensible defaults for persistence tests.""" + kwargs: dict[str, object] = { + "from": sender, + "to": to, + "channel": channel, + "content": content, + "type": msg_type, + "priority": priority, + "timestamp": timestamp or datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + } + if msg_id is not None: + kwargs["id"] = msg_id + if metadata is not None: + kwargs["metadata"] = metadata + return Message.model_validate(kwargs) diff --git a/tests/unit/persistence/sqlite/__init__.py b/tests/unit/persistence/sqlite/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/persistence/sqlite/conftest.py b/tests/unit/persistence/sqlite/conftest.py new file mode 100644 index 0000000000..8d08d27d83 --- /dev/null +++ b/tests/unit/persistence/sqlite/conftest.py @@ -0,0 +1,34 @@ +"""Fixtures for SQLite persistence unit tests.""" + +from typing import TYPE_CHECKING + +import aiosqlite +import pytest + +from ai_company.persistence.sqlite.migrations import run_migrations + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + +@pytest.fixture +async def memory_db() -> AsyncGenerator[aiosqlite.Connection]: + """Raw in-memory SQLite connection (no migrations).""" + db = await aiosqlite.connect(":memory:") + try: + db.row_factory = aiosqlite.Row + yield db + finally: + await db.close() + + +@pytest.fixture +async def migrated_db() -> AsyncGenerator[aiosqlite.Connection]: + """In-memory SQLite connection with v1 schema applied.""" + db = await aiosqlite.connect(":memory:") + try: + db.row_factory = aiosqlite.Row + await run_migrations(db) + yield db + finally: + await db.close() diff --git a/tests/unit/persistence/sqlite/test_backend.py b/tests/unit/persistence/sqlite/test_backend.py new file mode 100644 index 0000000000..ddd2339d83 --- /dev/null +++ b/tests/unit/persistence/sqlite/test_backend.py @@ -0,0 +1,164 @@ +"""Tests for SQLitePersistenceBackend.""" + +import sqlite3 + +import aiosqlite +import pytest + +from ai_company.persistence.config import SQLiteConfig +from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.protocol import PersistenceBackend +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend +from ai_company.persistence.sqlite.repositories import ( + SQLiteCostRecordRepository, + SQLiteMessageRepository, + SQLiteTaskRepository, +) + + +@pytest.mark.unit +class TestSQLitePersistenceBackend: + async def test_connect_and_disconnect(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + assert backend.is_connected is False + + await backend.connect() + assert backend.is_connected is True + + await backend.disconnect() # type: ignore[unreachable] + assert backend.is_connected is False + + async def test_connect_idempotent(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + await backend.connect() # should not raise + assert backend.is_connected is True + await backend.disconnect() + + async def test_disconnect_when_not_connected(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.disconnect() # should not raise + + async def test_health_check_connected(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + assert await backend.health_check() is True + await backend.disconnect() + + async def test_health_check_disconnected(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + assert await backend.health_check() is False + + async def test_backend_name(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + assert backend.backend_name == "sqlite" + + async def test_migrate_creates_tables(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + await backend.migrate() + + # Verify tables exist by accessing repos + assert isinstance(backend.tasks, SQLiteTaskRepository) + assert isinstance(backend.cost_records, SQLiteCostRecordRepository) + assert isinstance(backend.messages, SQLiteMessageRepository) + await backend.disconnect() + + async def test_migrate_without_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="not connected"): + await backend.migrate() + + async def test_tasks_before_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="Not connected"): + _ = backend.tasks + + async def test_cost_records_before_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="Not connected"): + _ = backend.cost_records + + async def test_messages_before_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="Not connected"): + _ = backend.messages + + async def test_wal_mode_enabled(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:", wal_mode=True)) + await backend.connect() + # WAL mode on :memory: may show as "memory" not "wal", + # but the PRAGMA succeeds without error + assert backend.is_connected is True + await backend.disconnect() + + async def test_wal_mode_disabled(self) -> None: + backend = SQLitePersistenceBackend( + SQLiteConfig(path=":memory:", wal_mode=False) + ) + await backend.connect() + assert backend.is_connected is True + await backend.disconnect() + + async def test_connect_failure_raises_connection_error(self) -> None: + """Non-existent path raises PersistenceConnectionError.""" + config = SQLiteConfig(path="/nonexistent/deeply/nested/dir/test.db") + backend = SQLitePersistenceBackend(config) + with pytest.raises(PersistenceConnectionError): + await backend.connect() + assert backend.is_connected is False + + async def test_health_check_returns_false_on_error(self) -> None: + """health_check returns False (not raises) when the db errors.""" + from unittest.mock import AsyncMock + + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + # Patch execute to simulate a database error + assert backend._db is not None + backend._db.execute = AsyncMock( # type: ignore[method-assign] + side_effect=sqlite3.OperationalError("disk I/O error") + ) + result = await backend.health_check() + assert result is False + await backend.disconnect() + + async def test_disconnect_cleans_up_on_close_error(self) -> None: + """If db.close() raises, state is still cleared.""" + from unittest.mock import AsyncMock + + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + assert backend._db is not None + backend._db.close = AsyncMock( # type: ignore[method-assign] + side_effect=sqlite3.OperationalError("close failed") + ) + await backend.disconnect() + assert backend.is_connected is False + + async def test_connect_pragma_failure_cleans_up(self) -> None: + """If PRAGMA fails after connect, backend cleans up and raises.""" + from unittest.mock import AsyncMock, patch + + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:", wal_mode=True)) + + real_connect = aiosqlite.connect + + async def patched_connect(*args: object, **kwargs: object) -> object: + conn = await real_connect(":memory:") + conn.execute = AsyncMock( # type: ignore[method-assign] + side_effect=sqlite3.OperationalError("PRAGMA failed") + ) + return conn + + with ( + patch("aiosqlite.connect", side_effect=patched_connect), + pytest.raises(PersistenceConnectionError), + ): + await backend.connect() + + assert backend.is_connected is False + + async def test_protocol_compliance(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + assert isinstance(backend, PersistenceBackend) diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py new file mode 100644 index 0000000000..35f61d0b0d --- /dev/null +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -0,0 +1,108 @@ +"""Tests for SQLite schema migrations.""" + +import sqlite3 +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, patch + +import pytest + +from ai_company.persistence.errors import MigrationError +from ai_company.persistence.sqlite.migrations import ( + SCHEMA_VERSION, + get_user_version, + run_migrations, + set_user_version, +) + +if TYPE_CHECKING: + import aiosqlite + + +@pytest.mark.unit +class TestUserVersion: + async def test_default_version_is_zero( + self, memory_db: aiosqlite.Connection + ) -> None: + assert await get_user_version(memory_db) == 0 + + async def test_set_and_get_version(self, memory_db: aiosqlite.Connection) -> None: + await set_user_version(memory_db, 42) + assert await get_user_version(memory_db) == 42 + + async def test_set_negative_version_raises( + self, memory_db: aiosqlite.Connection + ) -> None: + with pytest.raises(MigrationError, match="non-negative integer"): + await set_user_version(memory_db, -1) + + async def test_set_non_int_version_raises( + self, memory_db: aiosqlite.Connection + ) -> None: + with pytest.raises(MigrationError, match="non-negative integer"): + await set_user_version(memory_db, 2.5) # type: ignore[arg-type] + + +@pytest.mark.unit +class TestRunMigrations: + async def test_creates_tables(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [row[0] for row in await cursor.fetchall()] + assert "tasks" in tables + assert "cost_records" in tables + assert "messages" in tables + + async def test_sets_version(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + assert await get_user_version(memory_db) == SCHEMA_VERSION + + async def test_idempotent(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + await run_migrations(memory_db) + assert await get_user_version(memory_db) == SCHEMA_VERSION + + async def test_creates_indexes(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='index' " + "AND name LIKE 'idx_%' ORDER BY name" + ) + indexes = {row[0] for row in await cursor.fetchall()} + expected = { + "idx_tasks_status", + "idx_tasks_assigned_to", + "idx_tasks_project", + "idx_cost_records_agent_id", + "idx_cost_records_task_id", + "idx_messages_channel", + "idx_messages_timestamp", + } + assert expected.issubset(indexes) + + async def test_skips_when_already_at_version( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Running migrations on an already-migrated db is a no-op.""" + version_before = await get_user_version(migrated_db) + await run_migrations(migrated_db) + assert await get_user_version(migrated_db) == version_before + + async def test_migration_failure_raises_migration_error( + self, memory_db: aiosqlite.Connection + ) -> None: + """A failing migration step wraps the error as MigrationError.""" + failing_fn = AsyncMock( + side_effect=sqlite3.OperationalError("simulated migration failure") + ) + with ( + patch( + "ai_company.persistence.sqlite.migrations._MIGRATIONS", + [(1, failing_fn)], + ), + pytest.raises(MigrationError, match="Migration to version"), + ): + await run_migrations(memory_db) diff --git a/tests/unit/persistence/sqlite/test_repositories.py b/tests/unit/persistence/sqlite/test_repositories.py new file mode 100644 index 0000000000..9485d24ebb --- /dev/null +++ b/tests/unit/persistence/sqlite/test_repositories.py @@ -0,0 +1,554 @@ +"""Tests for SQLite repository implementations.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +import pytest + +if TYPE_CHECKING: + import aiosqlite + +from ai_company.budget.cost_record import CostRecord +from ai_company.communication.message import ( + Message, + MessageMetadata, +) +from ai_company.core.enums import ( + ArtifactType, + Complexity, + CoordinationTopology, + Priority, + TaskStatus, + TaskStructure, + TaskType, +) +from ai_company.core.task import AcceptanceCriterion, Task +from ai_company.persistence.errors import DuplicateRecordError +from ai_company.persistence.sqlite.repositories import ( + SQLiteCostRecordRepository, + SQLiteMessageRepository, + SQLiteTaskRepository, +) +from tests.unit.persistence.conftest import make_message, make_task + +# ── TaskRepository ─────────────────────────────────────────────── + + +@pytest.mark.unit +class TestSQLiteTaskRepository: + async def test_save_and_get(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteTaskRepository(migrated_db) + task = make_task() + await repo.save(task) + + result = await repo.get("task-001") + assert result is not None + assert result.id == task.id + assert result.title == task.title + assert result.type == task.type + assert result.status == TaskStatus.CREATED + + async def test_get_returns_none_for_missing( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + assert await repo.get("nonexistent") is None + + async def test_save_upsert_updates_existing( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + task = make_task() + await repo.save(task) + + updated = task.with_transition(TaskStatus.ASSIGNED, assigned_to="bob") + await repo.save(updated) + + result = await repo.get("task-001") + assert result is not None + assert result.status == TaskStatus.ASSIGNED + assert result.assigned_to == "bob" + + async def test_list_all(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteTaskRepository(migrated_db) + await repo.save(make_task(task_id="t1")) + await repo.save(make_task(task_id="t2")) + + tasks = await repo.list_tasks() + assert len(tasks) == 2 + + async def test_list_filter_by_status( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + await repo.save(make_task(task_id="t1")) + t2 = make_task(task_id="t2").with_transition( + TaskStatus.ASSIGNED, assigned_to="bob" + ) + await repo.save(t2) + + created = await repo.list_tasks(status=TaskStatus.CREATED) + assert len(created) == 1 + assert created[0].id == "t1" + + async def test_list_filter_by_assigned_to( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + t = make_task().with_transition(TaskStatus.ASSIGNED, assigned_to="bob") + await repo.save(t) + + result = await repo.list_tasks(assigned_to="bob") + assert len(result) == 1 + assert result[0].assigned_to == "bob" + + async def test_list_filter_by_project( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + await repo.save(make_task(task_id="t1", project="proj-a")) + await repo.save(make_task(task_id="t2", project="proj-b")) + + result = await repo.list_tasks(project="proj-a") + assert len(result) == 1 + assert result[0].project == "proj-a" + + async def test_delete_existing(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteTaskRepository(migrated_db) + await repo.save(make_task()) + assert await repo.delete("task-001") is True + assert await repo.get("task-001") is None + + async def test_delete_nonexistent(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteTaskRepository(migrated_db) + assert await repo.delete("nonexistent") is False + + async def test_list_with_combined_filters( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Multiple filters combine with AND logic.""" + repo = SQLiteTaskRepository(migrated_db) + t1 = make_task(task_id="t1", project="proj-a") + t2 = make_task(task_id="t2", project="proj-a").with_transition( + TaskStatus.ASSIGNED, assigned_to="bob" + ) + t3 = make_task(task_id="t3", project="proj-b").with_transition( + TaskStatus.ASSIGNED, assigned_to="bob" + ) + await repo.save(t1) + await repo.save(t2) + await repo.save(t3) + + result = await repo.list_tasks(status=TaskStatus.ASSIGNED, project="proj-a") + assert len(result) == 1 + assert result[0].id == "t2" + + async def test_round_trip_with_nested_models( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Verify complex nested fields survive serialization.""" + from ai_company.core.artifact import ExpectedArtifact + + task = Task( + id="task-complex", + title="Complex task", + description="Task with nested models", + type=TaskType.DEVELOPMENT, + priority=Priority.HIGH, + project="test-project", + created_by="alice", + estimated_complexity=Complexity.COMPLEX, + budget_limit=50.0, + deadline="2026-12-31T23:59:59", + max_retries=3, + task_structure=TaskStructure.PARALLEL, + coordination_topology=CoordinationTopology.CENTRALIZED, + reviewers=("reviewer-1", "reviewer-2"), + dependencies=("dep-1", "dep-2"), + artifacts_expected=( + ExpectedArtifact(type=ArtifactType.CODE, path="src/main.py"), + ExpectedArtifact(type=ArtifactType.TESTS, path="tests/"), + ), + acceptance_criteria=( + AcceptanceCriterion(description="Tests pass"), + AcceptanceCriterion(description="Code reviewed", met=True), + ), + delegation_chain=("manager", "lead"), + ) + repo = SQLiteTaskRepository(migrated_db) + await repo.save(task) + + result = await repo.get("task-complex") + assert result is not None + assert result.reviewers == ("reviewer-1", "reviewer-2") + assert result.dependencies == ("dep-1", "dep-2") + assert len(result.artifacts_expected) == 2 + assert result.artifacts_expected[0].type == ArtifactType.CODE + assert result.artifacts_expected[0].path == "src/main.py" + assert len(result.acceptance_criteria) == 2 + assert result.acceptance_criteria[1].met is True + assert result.delegation_chain == ("manager", "lead") + assert result.task_structure == TaskStructure.PARALLEL + assert result.coordination_topology == CoordinationTopology.CENTRALIZED + assert result.budget_limit == 50.0 + assert result.deadline == "2026-12-31T23:59:59" + assert result.max_retries == 3 + + +# ── CostRecordRepository ──────────────────────────────────────── + + +@pytest.mark.unit +class TestSQLiteCostRecordRepository: + def _make_record( + self, + *, + agent_id: str = "alice", + task_id: str = "task-001", + cost_usd: float = 0.05, + ) -> CostRecord: + return CostRecord( + agent_id=agent_id, + task_id=task_id, + provider="test-provider", + model="test-model-001", + input_tokens=1000, + output_tokens=500, + cost_usd=cost_usd, + timestamp=datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + ) + + async def test_save_and_query(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + record = self._make_record() + await repo.save(record) + + results = await repo.query() + assert len(results) == 1 + assert results[0].agent_id == "alice" + assert results[0].cost_usd == 0.05 + + async def test_query_by_agent(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(agent_id="alice")) + await repo.save(self._make_record(agent_id="bob")) + + results = await repo.query(agent_id="alice") + assert len(results) == 1 + assert results[0].agent_id == "alice" + + async def test_query_by_task(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(task_id="t1")) + await repo.save(self._make_record(task_id="t2")) + + results = await repo.query(task_id="t1") + assert len(results) == 1 + assert results[0].task_id == "t1" + + async def test_aggregate_all(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(cost_usd=0.10)) + await repo.save(self._make_record(cost_usd=0.20)) + + total = await repo.aggregate() + assert abs(total - 0.30) < 1e-9 + + async def test_aggregate_by_agent(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(agent_id="alice", cost_usd=0.10)) + await repo.save(self._make_record(agent_id="bob", cost_usd=0.20)) + + total = await repo.aggregate(agent_id="alice") + assert abs(total - 0.10) < 1e-9 + + async def test_aggregate_by_task(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(task_id="t1", cost_usd=0.10)) + await repo.save(self._make_record(task_id="t2", cost_usd=0.20)) + + total = await repo.aggregate(task_id="t1") + assert abs(total - 0.10) < 1e-9 + + async def test_aggregate_by_agent_and_task( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save( + self._make_record(agent_id="alice", task_id="t1", cost_usd=0.10) + ) + await repo.save( + self._make_record(agent_id="alice", task_id="t2", cost_usd=0.20) + ) + await repo.save(self._make_record(agent_id="bob", task_id="t1", cost_usd=0.30)) + + total = await repo.aggregate(agent_id="alice", task_id="t1") + assert abs(total - 0.10) < 1e-9 + + async def test_aggregate_empty(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + total = await repo.aggregate() + assert total == 0.0 + + async def test_query_with_combined_filters( + self, migrated_db: aiosqlite.Connection + ) -> None: + """agent_id + task_id filters combine correctly.""" + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(agent_id="alice", task_id="t1")) + await repo.save(self._make_record(agent_id="alice", task_id="t2")) + await repo.save(self._make_record(agent_id="bob", task_id="t1")) + + results = await repo.query(agent_id="alice", task_id="t1") + assert len(results) == 1 + assert results[0].agent_id == "alice" + assert results[0].task_id == "t1" + + async def test_round_trip_with_call_category( + self, migrated_db: aiosqlite.Connection + ) -> None: + from ai_company.budget.call_category import LLMCallCategory + + record = CostRecord( + agent_id="alice", + task_id="task-001", + provider="test-provider", + model="test-model-001", + input_tokens=1000, + output_tokens=500, + cost_usd=0.05, + timestamp=datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + call_category=LLMCallCategory.PRODUCTIVE, + ) + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(record) + + results = await repo.query() + assert results[0].call_category == LLMCallCategory.PRODUCTIVE + + async def test_round_trip_null_call_category( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record()) + + results = await repo.query() + assert results[0].call_category is None + + +# ── MessageRepository ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestSQLiteMessageRepository: + async def test_save_and_get_history( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteMessageRepository(migrated_db) + msg = make_message() + await repo.save(msg) + + history = await repo.get_history("general") + assert len(history) == 1 + assert history[0].content == "Hello, world!" + + async def test_history_ordered_newest_first( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteMessageRepository(migrated_db) + msg1 = make_message( + timestamp=datetime(2026, 3, 1, 10, 0, 0, tzinfo=UTC), + content="first", + ) + msg2 = make_message( + timestamp=datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + content="second", + ) + await repo.save(msg1) + await repo.save(msg2) + + history = await repo.get_history("general") + assert len(history) == 2 + assert history[0].content == "second" + assert history[1].content == "first" + + async def test_history_with_limit(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteMessageRepository(migrated_db) + for i in range(5): + await repo.save( + make_message( + timestamp=datetime(2026, 3, 1, i, 0, 0, tzinfo=UTC), + content=f"msg-{i}", + ) + ) + + history = await repo.get_history("general", limit=2) + assert len(history) == 2 + + async def test_history_filters_by_channel( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteMessageRepository(migrated_db) + await repo.save(make_message(channel="general")) + await repo.save(make_message(channel="engineering")) + + general = await repo.get_history("general") + assert len(general) == 1 + assert general[0].channel == "general" + + async def test_duplicate_message_rejected( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteMessageRepository(migrated_db) + fixed_id = uuid4() + msg = make_message(msg_id=fixed_id) + await repo.save(msg) + + with pytest.raises(DuplicateRecordError, match="already exists"): + await repo.save(make_message(msg_id=fixed_id)) + + async def test_round_trip_alias_from_field( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Verify the sender/'from' alias round-trips correctly.""" + repo = SQLiteMessageRepository(migrated_db) + msg = make_message(sender="charlie") + await repo.save(msg) + + history = await repo.get_history("general") + assert history[0].sender == "charlie" + + async def test_round_trip_with_attachments_and_metadata( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Verify nested JSON fields round-trip correctly.""" + from ai_company.communication.enums import AttachmentType + + msg = make_message( + metadata=MessageMetadata( + task_id="task-001", + project_id="proj-a", + tokens_used=100, + cost_usd=0.01, + extra=(("key1", "val1"),), + ), + ) + # Add attachments via model_validate to bypass frozen + msg_data = msg.model_dump(mode="json") + msg_data["attachments"] = [ + {"type": AttachmentType.FILE, "ref": "/path/to/file"}, + ] + msg_with_attach = Message.model_validate(msg_data) + + repo = SQLiteMessageRepository(migrated_db) + await repo.save(msg_with_attach) + + history = await repo.get_history("general") + result = history[0] + assert len(result.attachments) == 1 + assert result.attachments[0].type == AttachmentType.FILE + assert result.attachments[0].ref == "/path/to/file" + assert result.metadata.task_id == "task-001" + assert result.metadata.extra == (("key1", "val1"),) + + async def test_round_trip_uuid_id(self, migrated_db: aiosqlite.Connection) -> None: + """Verify UUID id survives round-trip.""" + repo = SQLiteMessageRepository(migrated_db) + msg = make_message() + original_id = msg.id + await repo.save(msg) + + history = await repo.get_history("general") + assert history[0].id == original_id + + async def test_get_history_invalid_limit( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Negative or zero limit raises QueryError.""" + from ai_company.persistence.errors import QueryError + + repo = SQLiteMessageRepository(migrated_db) + with pytest.raises(QueryError, match="positive integer"): + await repo.get_history("general", limit=0) + with pytest.raises(QueryError, match="positive integer"): + await repo.get_history("general", limit=-1) + + +@pytest.mark.unit +class TestSQLiteRepoProtocolCompliance: + """Verify SQLite repositories satisfy their protocol interfaces.""" + + async def test_task_repo_implements_protocol( + self, migrated_db: aiosqlite.Connection + ) -> None: + from ai_company.persistence.repositories import TaskRepository + + repo = SQLiteTaskRepository(migrated_db) + assert isinstance(repo, TaskRepository) + + async def test_cost_record_repo_implements_protocol( + self, migrated_db: aiosqlite.Connection + ) -> None: + from ai_company.persistence.repositories import CostRecordRepository + + repo = SQLiteCostRecordRepository(migrated_db) + assert isinstance(repo, CostRecordRepository) + + async def test_message_repo_implements_protocol( + self, migrated_db: aiosqlite.Connection + ) -> None: + from ai_company.persistence.repositories import MessageRepository + + repo = SQLiteMessageRepository(migrated_db) + assert isinstance(repo, MessageRepository) + + +@pytest.mark.unit +class TestDeserializationFailures: + """Test deserialization error paths with corrupt data.""" + + async def test_row_to_task_corrupt_json( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Corrupt JSON in a tuple field raises QueryError.""" + from ai_company.persistence.errors import QueryError + + await migrated_db.execute( + """\ +INSERT INTO tasks ( + id, title, description, type, priority, project, + created_by, status, reviewers +) VALUES ( + 'corrupt-1', 'Test', 'Test', 'development', 'medium', + 'proj', 'alice', 'created', '{BAD JSON}' +)""" + ) + await migrated_db.commit() + + repo = SQLiteTaskRepository(migrated_db) + with pytest.raises(QueryError, match="deserialize task"): + await repo.get("corrupt-1") + + async def test_row_to_message_corrupt_json( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Corrupt JSON in attachments raises QueryError.""" + from ai_company.persistence.errors import QueryError + + await migrated_db.execute( + """\ +INSERT INTO messages ( + id, timestamp, sender, "to", type, priority, + channel, content, attachments, metadata +) VALUES ( + 'corrupt-msg', '2026-01-01T00:00:00+00:00', 'alice', + 'bob', 'task_update', 'normal', 'general', + 'hello', '{BAD}', '{}' +)""" + ) + await migrated_db.commit() + + repo = SQLiteMessageRepository(migrated_db) + with pytest.raises(QueryError, match="deserialize message"): + await repo.get_history("general") diff --git a/tests/unit/persistence/test_config.py b/tests/unit/persistence/test_config.py new file mode 100644 index 0000000000..15d2630097 --- /dev/null +++ b/tests/unit/persistence/test_config.py @@ -0,0 +1,86 @@ +"""Tests for persistence configuration models.""" + +import pytest +from pydantic import ValidationError + +from ai_company.persistence.config import PersistenceConfig, SQLiteConfig + + +@pytest.mark.unit +class TestSQLiteConfig: + def test_defaults(self) -> None: + cfg = SQLiteConfig() + assert cfg.path == "ai-company.db" + assert cfg.wal_mode is True + assert cfg.journal_size_limit == 67_108_864 + + def test_custom_values(self) -> None: + cfg = SQLiteConfig( + path="/data/test.db", + wal_mode=False, + journal_size_limit=1024, + ) + assert cfg.path == "/data/test.db" + assert cfg.wal_mode is False + assert cfg.journal_size_limit == 1024 + + def test_memory_path(self) -> None: + cfg = SQLiteConfig(path=":memory:") + assert cfg.path == ":memory:" + + def test_frozen(self) -> None: + cfg = SQLiteConfig() + with pytest.raises(ValidationError): + cfg.path = "other.db" # type: ignore[misc] + + def test_blank_path_rejected(self) -> None: + with pytest.raises(ValidationError): + SQLiteConfig(path="") + + def test_whitespace_path_rejected(self) -> None: + with pytest.raises(ValidationError, match="whitespace-only"): + SQLiteConfig(path=" ") + + def test_negative_journal_size_rejected(self) -> None: + with pytest.raises(ValidationError): + SQLiteConfig(journal_size_limit=-1) + + def test_traversal_rejected(self) -> None: + with pytest.raises(ValidationError, match="traversal"): + SQLiteConfig(path="../escape/test.db") + + def test_embedded_traversal_rejected(self) -> None: + with pytest.raises(ValidationError, match="traversal"): + SQLiteConfig(path="data/../../../etc/test.db") + + +@pytest.mark.unit +class TestPersistenceConfig: + def test_defaults(self) -> None: + cfg = PersistenceConfig() + assert cfg.backend == "sqlite" + assert isinstance(cfg.sqlite, SQLiteConfig) + + def test_sqlite_backend_valid(self) -> None: + cfg = PersistenceConfig(backend="sqlite") + assert cfg.backend == "sqlite" + + def test_unknown_backend_rejected(self) -> None: + with pytest.raises(ValidationError, match="Unknown persistence backend"): + PersistenceConfig(backend="postgres") + + def test_blank_backend_rejected(self) -> None: + with pytest.raises(ValidationError): + PersistenceConfig(backend="") + + def test_frozen(self) -> None: + cfg = PersistenceConfig() + with pytest.raises(ValidationError): + cfg.backend = "other" # type: ignore[misc] + + def test_custom_sqlite_config(self) -> None: + cfg = PersistenceConfig( + sqlite=SQLiteConfig(path="data/test.db", wal_mode=False), + ) + assert cfg.sqlite.path == "data/test.db" + assert cfg.sqlite.wal_mode is False diff --git a/tests/unit/persistence/test_errors.py b/tests/unit/persistence/test_errors.py new file mode 100644 index 0000000000..841fc08cc9 --- /dev/null +++ b/tests/unit/persistence/test_errors.py @@ -0,0 +1,52 @@ +"""Tests for persistence error hierarchy.""" + +import pytest + +from ai_company.persistence.errors import ( + DuplicateRecordError, + MigrationError, + PersistenceConnectionError, + PersistenceError, + QueryError, + RecordNotFoundError, +) + +_SUBCLASSES = [ + PersistenceConnectionError, + MigrationError, + RecordNotFoundError, + DuplicateRecordError, + QueryError, +] + + +@pytest.mark.unit +class TestPersistenceErrorHierarchy: + def test_base_is_exception(self) -> None: + assert issubclass(PersistenceError, Exception) + + @pytest.mark.parametrize("cls", _SUBCLASSES) + def test_subclass_inherits_from_base( + self, + cls: type[PersistenceError], + ) -> None: + """All error subclasses inherit from PersistenceError.""" + assert issubclass(cls, PersistenceError) + + @pytest.mark.parametrize("cls", _SUBCLASSES) + def test_catch_all_with_base( + self, + cls: type[PersistenceError], + ) -> None: + """All subclasses are caught by except PersistenceError.""" + msg = "test" + with pytest.raises(PersistenceError): + raise cls(msg) + + def test_error_message_preserved(self) -> None: + err = PersistenceConnectionError("db down") + assert str(err) == "db down" + + def test_does_not_shadow_builtin(self) -> None: + """Our error is NOT the builtin ConnectionError.""" + assert PersistenceConnectionError is not ConnectionError # type: ignore[comparison-overlap] diff --git a/tests/unit/persistence/test_factory.py b/tests/unit/persistence/test_factory.py new file mode 100644 index 0000000000..4d4a61db13 --- /dev/null +++ b/tests/unit/persistence/test_factory.py @@ -0,0 +1,77 @@ +"""Tests for persistence backend factory.""" + +import pytest + +from ai_company.persistence.config import PersistenceConfig, SQLiteConfig +from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.factory import create_backend +from ai_company.persistence.protocol import PersistenceBackend +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend + + +@pytest.mark.unit +class TestCreateBackend: + def test_creates_sqlite_backend(self) -> None: + config = PersistenceConfig( + backend="sqlite", + sqlite=SQLiteConfig(path=":memory:"), + ) + backend = create_backend(config) + assert isinstance(backend, SQLitePersistenceBackend) + assert backend.backend_name == "sqlite" + assert backend.is_connected is False + + def test_returns_protocol_type(self) -> None: + config = PersistenceConfig( + sqlite=SQLiteConfig(path=":memory:"), + ) + backend = create_backend(config) + assert isinstance(backend, PersistenceBackend) + + def test_passes_sqlite_config(self) -> None: + config = PersistenceConfig( + sqlite=SQLiteConfig(path="data/company.db", wal_mode=False), + ) + backend = create_backend(config) + assert isinstance(backend, SQLitePersistenceBackend) + + def test_unknown_backend_raises(self) -> None: + """Bypass validation via model_copy to test the factory guard.""" + config = PersistenceConfig() + bad_config = config.model_copy(update={"backend": "postgres"}) + with pytest.raises( + PersistenceConnectionError, + match="Unknown persistence backend", + ): + create_backend(bad_config) + + async def test_multi_tenancy_separate_databases(self) -> None: + """Each company config creates an isolated backend instance.""" + config_a = PersistenceConfig( + sqlite=SQLiteConfig(path=":memory:"), + ) + config_b = PersistenceConfig( + sqlite=SQLiteConfig(path=":memory:"), + ) + + backend_a = create_backend(config_a) + backend_b = create_backend(config_b) + + # They are separate instances + assert backend_a is not backend_b + + # Each can connect and operate independently + await backend_a.connect() + await backend_a.migrate() + await backend_b.connect() + await backend_b.migrate() + + # Verify isolation — data in one doesn't affect the other + from tests.unit.persistence.conftest import make_task + + await backend_a.tasks.save(make_task(task_id="company-a-task")) + assert await backend_a.tasks.get("company-a-task") is not None + assert await backend_b.tasks.get("company-a-task") is None + + await backend_a.disconnect() + await backend_b.disconnect() diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py new file mode 100644 index 0000000000..ec40d19452 --- /dev/null +++ b/tests/unit/persistence/test_protocol.py @@ -0,0 +1,116 @@ +"""Tests for persistence protocol compliance.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.persistence.protocol import PersistenceBackend +from ai_company.persistence.repositories import ( + CostRecordRepository, + MessageRepository, + TaskRepository, +) + +if TYPE_CHECKING: + from ai_company.budget.cost_record import CostRecord + from ai_company.communication.message import Message + from ai_company.core.enums import TaskStatus + from ai_company.core.task import Task + + +class _FakeTaskRepository: + async def save(self, task: Task) -> None: + pass + + async def get(self, task_id: str) -> Task | None: + return None + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + return () + + async def delete(self, task_id: str) -> bool: + return False + + +class _FakeCostRecordRepository: + async def save(self, record: CostRecord) -> None: + pass + + async def query( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> tuple[CostRecord, ...]: + return () + + async def aggregate(self, *, agent_id: str | None = None) -> float: + return 0.0 + + +class _FakeMessageRepository: + async def save(self, message: Message) -> None: + pass + + async def get_history( + self, + channel: str, + *, + limit: int | None = None, + ) -> tuple[Message, ...]: + return () + + +class _FakeBackend: + async def connect(self) -> None: + pass + + async def disconnect(self) -> None: + pass + + async def health_check(self) -> bool: + return True + + async def migrate(self) -> None: + pass + + @property + def is_connected(self) -> bool: + return True + + @property + def backend_name(self) -> str: + return "fake" + + @property + def tasks(self) -> _FakeTaskRepository: + return _FakeTaskRepository() + + @property + def cost_records(self) -> _FakeCostRecordRepository: + return _FakeCostRecordRepository() + + @property + def messages(self) -> _FakeMessageRepository: + return _FakeMessageRepository() + + +@pytest.mark.unit +class TestProtocolCompliance: + def test_fake_backend_is_persistence_backend(self) -> None: + assert isinstance(_FakeBackend(), PersistenceBackend) + + def test_fake_task_repo_is_task_repository(self) -> None: + assert isinstance(_FakeTaskRepository(), TaskRepository) + + def test_fake_cost_repo_is_cost_record_repository(self) -> None: + assert isinstance(_FakeCostRecordRepository(), CostRecordRepository) + + def test_fake_message_repo_is_message_repository(self) -> None: + assert isinstance(_FakeMessageRepository(), MessageRepository) diff --git a/uv.lock b/uv.lock index 71ba180f77..ea7e2c73c6 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,7 @@ requires-python = ">=3.14" name = "ai-company" source = { editable = "." } dependencies = [ + { name = "aiosqlite" }, { name = "jinja2" }, { name = "jsonschema" }, { name = "litellm" }, @@ -44,6 +45,7 @@ test = [ [package.metadata] requires-dist = [ + { name = "aiosqlite", specifier = "==0.21.0" }, { name = "jinja2", specifier = "==3.1.6" }, { name = "jsonschema", specifier = "==4.26.0" }, { name = "litellm", specifier = "==1.82.0" }, @@ -152,6 +154,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4"