From e49eb865de8cb7a2797f3daa59eeb2345f4295dd Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Sun, 8 Mar 2026 23:28:26 +0100 Subject: [PATCH 1/5] =?UTF-8?q?docs:=20add=20=C2=A77.5=20Operational=20Dat?= =?UTF-8?q?a=20Persistence=20with=20pluggable=20PersistenceBackend=20proto?= =?UTF-8?q?col=20(#36)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds DESIGN_SPEC §7.5 defining the PersistenceBackend protocol for operational data (tasks, costs, messages, audit logs) — separate from agent memory (§7.1-7.4). SQLite ships as first backend; PostgreSQL and MariaDB planned as future backends swappable via config. Updates tech stack entry and project structure to include persistence/ module. --- CLAUDE.md | 1 + DESIGN_SPEC.md | 128 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 128 insertions(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index cd321ebf3c..438a4a6b9c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -51,6 +51,7 @@ src/ai_company/ core/ # Shared domain models and base classes engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, task lifecycle, recovery, shutdown, and workspace isolation memory/ # Persistent agent memory (Mem0 initial, custom stack future — ADR-001) + persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (§7.5) observability/ # Structured logging, correlation tracking, log sinks providers/ # LLM provider abstraction (LiteLLM adapter) security/ # SecOps agent, approval gates, audit diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 146c986579..890404fbde 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -1332,6 +1332,121 @@ org_memory: > **Extensibility:** All backends implement the `OrgMemoryBackend` protocol (`query(context) → list[OrgFact]`, `write(fact, author)`, `list_policies()`). The MVP ships with Backend 1; Backends 2 and 3 are research directions that may be explored if the default approach proves insufficient. The selected memory layer backend Mem0 (ADR-001) provides optional graph memory via Neo4j/FalkorDB, which could reduce implementation effort for Backends 2-3. > **Write access control:** Core policies are human-only. ADRs and procedures can be written by senior+ agents. All writes are versioned and auditable. This prevents agents from corrupting shared organizational knowledge while allowing senior agents to document decisions. +### 7.5 Operational Data Persistence + +Agent memory (§7.1–7.4) is handled by the `MemoryBackend` protocol (Mem0 initial, custom stack future — ADR-001). **Operational data** — tasks, cost records, messages, audit logs — is a separate concern managed by a pluggable `PersistenceBackend` protocol. Application code depends only on repository protocols; the storage engine is an implementation detail swappable via config. + +```text +┌──────────────────────────────────────────────────────────────────┐ +│ Application Code │ +│ engine/ budget/ communication/ security/ │ +│ │ │ │ │ │ +│ ▼ ▼ ▼ ▼ │ +│ ┌──────┐ ┌──────┐ ┌──────────┐ ┌──────────┐ │ +│ │ Task │ │ Cost │ │ Message │ │ Audit │ ← Repository │ +│ │ Repo │ │ Repo │ │ Repo │ │ Repo │ Protocols │ +│ └──┬───┘ └──┬───┘ └────┬─────┘ └────┬─────┘ │ +│ └────────┴──────────┴────────────┘ │ +│ │ │ +│ ┌───────────────────┴───────────────────────────────────────┐ │ +│ │ PersistenceBackend (protocol) │ │ +│ │ connect() · disconnect() · health_check() · migrate() │ │ +│ └───────────────────┬───────────────────────────────────────┘ │ +│ │ │ +│ ┌───────────────────┴───────────────────────────────────────┐ │ +│ │ SQLitePersistenceBackend (initial) │ │ +│ │ PostgresPersistenceBackend (future) │ │ +│ │ MariaDBPersistenceBackend (future) │ │ +│ └───────────────────────────────────────────────────────────┘ │ +└──────────────────────────────────────────────────────────────────┘ +``` + +#### Protocol Design + +```python +@runtime_checkable +class PersistenceBackend(Protocol): + """Lifecycle management for operational data storage.""" + + async def connect(self) -> None: ... + async def disconnect(self) -> None: ... + async def health_check(self) -> bool: ... + async def migrate(self) -> None: ... + + @property + def is_connected(self) -> bool: ... + @property + def backend_name(self) -> str: ... +``` + +Each entity type has its own repository protocol: + +```python +@runtime_checkable +class TaskRepository(Protocol): + """CRUD + query interface for Task persistence.""" + + async def save(self, task: Task) -> None: ... + async def get(self, task_id: str) -> Task | None: ... + async def list_tasks(self, *, status: TaskStatus | None = None, assigned_to: str | None = None) -> tuple[Task, ...]: ... + async def delete(self, task_id: str) -> bool: ... + +@runtime_checkable +class CostRecordRepository(Protocol): + """CRUD + aggregation interface for CostRecord persistence.""" + + async def save(self, record: CostRecord) -> None: ... + async def query(self, *, agent_id: str | None = None, task_id: str | None = None) -> tuple[CostRecord, ...]: ... + async def aggregate(self, *, agent_id: str | None = None) -> float: ... + +@runtime_checkable +class MessageRepository(Protocol): + """CRUD + query interface for Message persistence.""" + + async def save(self, message: Message) -> None: ... + async def get_history(self, channel: str, *, limit: int | None = None) -> tuple[Message, ...]: ... +``` + +#### Configuration + +```yaml +persistence: + backend: "sqlite" # sqlite, postgresql, mariadb (future) + sqlite: + path: "/data/ai-company.db" # database file path (mounted volume in Docker) + wal_mode: true # WAL for concurrent read performance + journal_size_limit: 67108864 # 64 MB WAL journal limit + # postgresql: # future + # url: "postgresql://user:pass@host:5432/ai_company" + # pool_size: 10 + # mariadb: # future + # url: "mariadb://user:pass@host:3306/ai_company" + # pool_size: 10 +``` + +#### Entities Persisted + +| Entity | Source Module | Repository | Key Queries | +|--------|-------------|------------|-------------| +| `Task` | `core/task.py` | `TaskRepository` | by status, by assignee, by project | +| `CostRecord` | `budget/cost_record.py` | `CostRecordRepository` | by agent, by task, aggregations | +| `Message` | `communication/message.py` | `MessageRepository` | by channel, by sender, time range | +| Audit entries | `security/` | `AuditRepository` | by agent, by action type, time range | + +#### Migration Strategy + +- Migrations run programmatically at startup via `PersistenceBackend.migrate()` +- Initial migration creates all tables +- Versioned migration scripts tracked in `persistence/migrations/` +- SQLite uses `user_version` pragma for version tracking; PostgreSQL/MariaDB use a migrations table + +#### Key Principles + +- **App code never imports a concrete backend** — only repository protocols +- **Adding a new backend** requires implementing `PersistenceBackend` + all repository protocols — no changes to consumers +- **Same entity models everywhere** — repositories accept and return the existing frozen Pydantic models (Task, CostRecord, Message), no ORM models or data transfer objects +- **Async throughout** — all repository methods are async, matching the project's concurrency model + --- ## 8. HR & Workforce Management @@ -2339,7 +2454,7 @@ Run: ai-company start acme-corp | **Agent Memory** | Mem0 (Qdrant + SQLite) → custom (Neo4j + Qdrant) | Mem0 in-process as initial backend behind pluggable `MemoryBackend` protocol ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Qdrant embedded + SQLite for persistence. Custom stack (Neo4j + Qdrant external) as future upgrade. Config-driven backend selection | | **Message Bus** | Internal (async queues) → Redis | Start with Python asyncio queues, upgrade to Redis for multi-process/distributed | | **Task Queue** | Internal → Celery/Redis | Start simple, scale with Celery when needed | -| **Database** | SQLite → PostgreSQL | Start lightweight, migrate to Postgres for production/multi-user | +| **Database** | SQLite → PostgreSQL / MariaDB | Pluggable `PersistenceBackend` protocol (§7.5). SQLite ships first. PostgreSQL, MariaDB as future backends — swap via config, no app code changes | | **Web UI** | Vue 3 + Vite | Modern, fast, good ecosystem. Simpler than React for dashboards | | **Real-time** | WebSocket (FastAPI native) | Real-time agent activity, task updates, chat feed | | **Containerization** | Docker + Docker Compose | Isolated code execution, reproducible environments | @@ -2492,6 +2607,17 @@ ai-company/ │ │ ├── retrieval.py # Memory retrieval & ranking (M5) │ │ ├── consolidation.py # Memory compression over time (M5) │ │ └── shared.py # Shared knowledge base (M5) +│ ├── persistence/ # Operational data persistence (§7.5) +│ │ ├── __init__.py # Package exports +│ │ ├── protocol.py # PersistenceBackend protocol (M5) +│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository, AuditRepository (M5) +│ │ ├── config.py # PersistenceConfig model (M5) +│ │ ├── errors.py # Persistence error hierarchy (M5) +│ │ └── sqlite/ # SQLite backend (M5, initial) +│ │ ├── __init__.py # Package exports +│ │ ├── backend.py # SQLitePersistenceBackend +│ │ ├── repositories.py # SQLite repository implementations +│ │ └── migrations.py # Schema migrations (user_version pragma) │ ├── observability/ # Structured logging & correlation │ │ ├── __init__.py # get_logger() entry point │ │ ├── _logger.py # Logger configuration From 6e2dd434ec41dcbd62fd537b6ddf12e4bc9a69c4 Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Mon, 9 Mar 2026 00:34:20 +0100 Subject: [PATCH 2/5] feat: add pluggable PersistenceBackend protocol with SQLite implementation (#36) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement §7.5 Operational Data Persistence: - PersistenceBackend protocol with connect/disconnect/health_check/migrate - Repository protocols: TaskRepository, CostRecordRepository, MessageRepository - SQLite backend via aiosqlite with WAL mode, schema migrations, TOCTOU-safe writes - Frozen Pydantic config models (PersistenceConfig, SQLiteConfig) - Error hierarchy: PersistenceError → ConnectionError, MigrationError, etc. - Factory pattern for backend creation - Structured logging with 30+ event constants - 96% test coverage (unit + integration) Pre-reviewed by 10 agents, 42 findings addressed. --- DESIGN_SPEC.md | 33 +- pyproject.toml | 5 + src/ai_company/config/defaults.py | 1 + src/ai_company/config/schema.py | 6 + .../observability/events/persistence.py | 57 +++ src/ai_company/persistence/__init__.py | 39 ++ src/ai_company/persistence/config.py | 98 ++++ src/ai_company/persistence/errors.py | 33 ++ src/ai_company/persistence/factory.py | 65 +++ src/ai_company/persistence/protocol.py | 88 ++++ src/ai_company/persistence/repositories.py | 166 +++++++ src/ai_company/persistence/sqlite/__init__.py | 21 + src/ai_company/persistence/sqlite/backend.py | 206 ++++++++ .../persistence/sqlite/migrations.py | 181 +++++++ .../persistence/sqlite/repositories.py | 450 ++++++++++++++++++ tests/integration/persistence/__init__.py | 0 tests/integration/persistence/conftest.py | 28 ++ .../persistence/test_sqlite_integration.py | 111 +++++ tests/unit/config/conftest.py | 2 + tests/unit/config/test_schema.py | 17 + tests/unit/observability/test_events.py | 44 ++ tests/unit/persistence/__init__.py | 0 tests/unit/persistence/conftest.py | 67 +++ tests/unit/persistence/sqlite/__init__.py | 0 tests/unit/persistence/sqlite/conftest.py | 30 ++ tests/unit/persistence/sqlite/test_backend.py | 127 +++++ .../persistence/sqlite/test_migrations.py | 96 ++++ .../persistence/sqlite/test_repositories.py | 440 +++++++++++++++++ tests/unit/persistence/test_config.py | 78 +++ tests/unit/persistence/test_errors.py | 57 +++ tests/unit/persistence/test_factory.py | 77 +++ tests/unit/persistence/test_protocol.py | 116 +++++ uv.lock | 14 + 33 files changed, 2749 insertions(+), 4 deletions(-) create mode 100644 src/ai_company/observability/events/persistence.py create mode 100644 src/ai_company/persistence/__init__.py create mode 100644 src/ai_company/persistence/config.py create mode 100644 src/ai_company/persistence/errors.py create mode 100644 src/ai_company/persistence/factory.py create mode 100644 src/ai_company/persistence/protocol.py create mode 100644 src/ai_company/persistence/repositories.py create mode 100644 src/ai_company/persistence/sqlite/__init__.py create mode 100644 src/ai_company/persistence/sqlite/backend.py create mode 100644 src/ai_company/persistence/sqlite/migrations.py create mode 100644 src/ai_company/persistence/sqlite/repositories.py create mode 100644 tests/integration/persistence/__init__.py create mode 100644 tests/integration/persistence/conftest.py create mode 100644 tests/integration/persistence/test_sqlite_integration.py create mode 100644 tests/unit/persistence/__init__.py create mode 100644 tests/unit/persistence/conftest.py create mode 100644 tests/unit/persistence/sqlite/__init__.py create mode 100644 tests/unit/persistence/sqlite/conftest.py create mode 100644 tests/unit/persistence/sqlite/test_backend.py create mode 100644 tests/unit/persistence/sqlite/test_migrations.py create mode 100644 tests/unit/persistence/sqlite/test_repositories.py create mode 100644 tests/unit/persistence/test_config.py create mode 100644 tests/unit/persistence/test_errors.py create mode 100644 tests/unit/persistence/test_factory.py create mode 100644 tests/unit/persistence/test_protocol.py diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 890404fbde..604290f442 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -12,7 +12,7 @@ 4. [Company Structure](#4-company-structure) 5. [Communication Architecture](#5-communication-architecture) — 5.6 Conflict Resolution, 5.7 Meeting Protocol 6. [Task & Workflow Engine](#6-task--workflow-engine) — 6.5 Execution Loop, 6.6 Crash Recovery, **6.7 Graceful Shutdown**, **6.8 Workspace Isolation**, **6.9 Task Decomposability & Coordination Topology** -7. [Memory & Persistence](#7-memory--persistence) — 7.4 Shared Org Memory (Research Directions) +7. [Memory & Persistence](#7-memory--persistence) — 7.4 Shared Org Memory (Research Directions), **7.5 Operational Data Persistence** 8. [HR & Workforce Management](#8-hr--workforce-management) 9. [Model Provider Layer](#9-model-provider-layer) 10. [Cost & Budget Management](#10-cost--budget-management) @@ -1377,6 +1377,13 @@ class PersistenceBackend(Protocol): def is_connected(self) -> bool: ... @property def backend_name(self) -> str: ... + + @property + def tasks(self) -> TaskRepository: ... + @property + def cost_records(self) -> CostRecordRepository: ... + @property + def messages(self) -> MessageRepository: ... ``` Each entity type has its own repository protocol: @@ -1388,7 +1395,7 @@ class TaskRepository(Protocol): async def save(self, task: Task) -> None: ... async def get(self, task_id: str) -> Task | None: ... - async def list_tasks(self, *, status: TaskStatus | None = None, assigned_to: str | None = None) -> tuple[Task, ...]: ... + async def list_tasks(self, *, status: TaskStatus | None = None, assigned_to: str | None = None, project: str | None = None) -> tuple[Task, ...]: ... async def delete(self, task_id: str) -> bool: ... @runtime_checkable @@ -1447,6 +1454,21 @@ persistence: - **Same entity models everywhere** — repositories accept and return the existing frozen Pydantic models (Task, CostRecord, Message), no ORM models or data transfer objects - **Async throughout** — all repository methods are async, matching the project's concurrency model +#### Multi-Tenancy + +Each company gets its own database. The `PersistenceConfig` embedded in a company's `RootConfig` specifies the backend type and connection details (e.g. a unique SQLite file path or PostgreSQL database URL). The `create_backend(config)` factory returns an isolated `PersistenceBackend` instance per company — no shared state, no cross-company data leakage. + +```python +# One database per company — configured in each company's YAML +company_a_backend = create_backend(company_a_config.persistence) +company_b_backend = create_backend(company_b_config.persistence) +# Each backend has independent lifecycle: connect → migrate → use → disconnect +``` + +#### Future: Runtime Backend Switching + +Runtime backend switching (e.g. migrating a company from SQLite to PostgreSQL during operation) is a planned future capability. The protocol-based design already supports this — the engine would disconnect the current backend, connect a new one with different config, and migrate. Implementation details (data migration tooling, zero-downtime switchover, connection draining) are deferred to the PostgreSQL backend milestone. + --- ## 8. HR & Workforce Management @@ -2454,7 +2476,7 @@ Run: ai-company start acme-corp | **Agent Memory** | Mem0 (Qdrant + SQLite) → custom (Neo4j + Qdrant) | Mem0 in-process as initial backend behind pluggable `MemoryBackend` protocol ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). Qdrant embedded + SQLite for persistence. Custom stack (Neo4j + Qdrant external) as future upgrade. Config-driven backend selection | | **Message Bus** | Internal (async queues) → Redis | Start with Python asyncio queues, upgrade to Redis for multi-process/distributed | | **Task Queue** | Internal → Celery/Redis | Start simple, scale with Celery when needed | -| **Database** | SQLite → PostgreSQL / MariaDB | Pluggable `PersistenceBackend` protocol (§7.5). SQLite ships first. PostgreSQL, MariaDB as future backends — swap via config, no app code changes | +| **Database** | SQLite (aiosqlite) → PostgreSQL / MariaDB | Pluggable `PersistenceBackend` protocol (§7.5). SQLite ships first via aiosqlite async driver. PostgreSQL, MariaDB as future backends — swap via config, no app code changes | | **Web UI** | Vue 3 + Vite | Modern, fast, good ecosystem. Simpler than React for dashboards | | **Real-time** | WebSocket (FastAPI native) | Real-time agent activity, task updates, chat feed | | **Containerization** | Docker + Docker Compose | Isolated code execution, reproducible environments | @@ -2610,9 +2632,10 @@ ai-company/ │ ├── persistence/ # Operational data persistence (§7.5) │ │ ├── __init__.py # Package exports │ │ ├── protocol.py # PersistenceBackend protocol (M5) -│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository, AuditRepository (M5) +│ │ ├── repositories.py # Repository protocols: TaskRepository, CostRecordRepository, MessageRepository (M5); AuditRepository planned (M7) │ │ ├── config.py # PersistenceConfig model (M5) │ │ ├── errors.py # Persistence error hierarchy (M5) +│ │ ├── factory.py # create_backend() factory (M5) │ │ └── sqlite/ # SQLite backend (M5, initial) │ │ ├── __init__.py # Package exports │ │ ├── backend.py # SQLitePersistenceBackend @@ -2638,6 +2661,7 @@ ai-company/ │ │ │ ├── git.py # GIT_* constants │ │ │ ├── meeting.py # MEETING_* constants │ │ │ ├── parallel.py # PARALLEL_* constants +│ │ │ ├── persistence.py # PERSISTENCE_* constants │ │ │ ├── personality.py # PERSONALITY_* constants │ │ │ ├── prompt.py # PROMPT_* constants │ │ │ ├── provider.py # PROVIDER_* constants @@ -2776,6 +2800,7 @@ ai-company/ | Config | YAML + Pydantic | JSON, TOML, Python dicts | Human-friendly, strict validation, good IDE support | | CLI | Typer | Click, argparse, Fire | Built on Click, auto-completion, type hints | | Web UI | Vue 3 | React, Svelte, HTMX | Simpler than React for dashboards, good with FastAPI | +| Persistence | Pluggable protocol + repository protocols | ORM (SQLAlchemy), raw SQL, hybrid | Same frozen Pydantic models in and out (no DTOs), async throughout, backend-swappable via config. Repository protocols decouple app code from storage engine. See §7.5 | | Sandboxing | Layered: subprocess + Docker | Docker-only, subprocess-only, WASM | Risk-proportionate: fast subprocess for file/git, Docker isolation for code execution. Pluggable `SandboxBackend` protocol enables K8s migration later | ### 15.5 Engineering Conventions diff --git a/pyproject.toml b/pyproject.toml index ff9c2e9721..c1c4507bc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "aiosqlite==0.21.0", "jinja2==3.1.6", "jsonschema==4.26.0", "litellm==1.82.0", @@ -155,6 +156,10 @@ ignore_missing_imports = true module = "jsonschema.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "aiosqlite.*" +ignore_missing_imports = true + [tool.pydantic-mypy] init_forbid_extra = true init_typed = true diff --git a/src/ai_company/config/defaults.py b/src/ai_company/config/defaults.py index 0cced78a43..04973883b1 100644 --- a/src/ai_company/config/defaults.py +++ b/src/ai_company/config/defaults.py @@ -30,4 +30,5 @@ def default_config_dict() -> dict[str, Any]: "escalation_paths": [], "coordination_metrics": {}, "task_assignment": {}, + "persistence": {}, } diff --git a/src/ai_company/config/schema.py b/src/ai_company/config/schema.py index 2c90dea800..9288a60df4 100644 --- a/src/ai_company/config/schema.py +++ b/src/ai_company/config/schema.py @@ -20,6 +20,7 @@ from ai_company.observability import get_logger from ai_company.observability.config import LogConfig # noqa: TC001 from ai_company.observability.events.config import CONFIG_VALIDATION_FAILED +from ai_company.persistence.config import PersistenceConfig logger = get_logger(__name__) @@ -458,6 +459,7 @@ class RootConfig(BaseModel): escalation_paths: Cross-department escalation paths. coordination_metrics: Coordination metrics configuration. task_assignment: Task assignment configuration. + persistence: Persistence backend configuration. """ model_config = ConfigDict(frozen=True) @@ -525,6 +527,10 @@ class RootConfig(BaseModel): default_factory=TaskAssignmentConfig, description="Task assignment configuration", ) + persistence: PersistenceConfig = Field( + default_factory=PersistenceConfig, + description="Persistence backend configuration", + ) @model_validator(mode="after") def _validate_unique_agent_names(self) -> Self: diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py new file mode 100644 index 0000000000..95c385fe05 --- /dev/null +++ b/src/ai_company/observability/events/persistence.py @@ -0,0 +1,57 @@ +"""Persistence event constants for structured logging. + +Constants follow the ``persistence..`` naming convention +and are passed as the first argument to ``logger.info()``/``logger.debug()`` +calls in the persistence layer. +""" + +from typing import Final + +PERSISTENCE_BACKEND_CONNECTING: Final[str] = "persistence.backend.connecting" +PERSISTENCE_BACKEND_CONNECTED: Final[str] = "persistence.backend.connected" +PERSISTENCE_BACKEND_CONNECTION_FAILED: Final[str] = ( + "persistence.backend.connection_failed" +) +PERSISTENCE_BACKEND_ALREADY_CONNECTED: Final[str] = ( + "persistence.backend.already_connected" +) +PERSISTENCE_BACKEND_DISCONNECTING: Final[str] = "persistence.backend.disconnecting" +PERSISTENCE_BACKEND_DISCONNECTED: Final[str] = "persistence.backend.disconnected" +PERSISTENCE_BACKEND_DISCONNECT_ERROR: Final[str] = ( + "persistence.backend.disconnect_error" +) +PERSISTENCE_BACKEND_HEALTH_CHECK: Final[str] = "persistence.backend.health_check" +PERSISTENCE_BACKEND_CREATED: Final[str] = "persistence.backend.created" +PERSISTENCE_BACKEND_UNKNOWN: Final[str] = "persistence.backend.unknown" +PERSISTENCE_BACKEND_WAL_MODE_FAILED: Final[str] = "persistence.backend.wal_mode_failed" + +PERSISTENCE_MIGRATION_STARTED: Final[str] = "persistence.migration.started" +PERSISTENCE_MIGRATION_COMPLETED: Final[str] = "persistence.migration.completed" +PERSISTENCE_MIGRATION_SKIPPED: Final[str] = "persistence.migration.skipped" +PERSISTENCE_MIGRATION_FAILED: Final[str] = "persistence.migration.failed" + +PERSISTENCE_TASK_SAVED: Final[str] = "persistence.task.saved" +PERSISTENCE_TASK_SAVE_FAILED: Final[str] = "persistence.task.save_failed" +PERSISTENCE_TASK_FETCHED: Final[str] = "persistence.task.fetched" +PERSISTENCE_TASK_FETCH_FAILED: Final[str] = "persistence.task.fetch_failed" +PERSISTENCE_TASK_LISTED: Final[str] = "persistence.task.listed" +PERSISTENCE_TASK_LIST_FAILED: Final[str] = "persistence.task.list_failed" +PERSISTENCE_TASK_DELETED: Final[str] = "persistence.task.deleted" +PERSISTENCE_TASK_DELETE_FAILED: Final[str] = "persistence.task.delete_failed" + +PERSISTENCE_COST_RECORD_SAVED: Final[str] = "persistence.cost_record.saved" +PERSISTENCE_COST_RECORD_SAVE_FAILED: Final[str] = "persistence.cost_record.save_failed" +PERSISTENCE_COST_RECORD_QUERIED: Final[str] = "persistence.cost_record.queried" +PERSISTENCE_COST_RECORD_QUERY_FAILED: Final[str] = ( + "persistence.cost_record.query_failed" +) +PERSISTENCE_COST_RECORD_AGGREGATED: Final[str] = "persistence.cost_record.aggregated" +PERSISTENCE_COST_RECORD_AGGREGATE_FAILED: Final[str] = ( + "persistence.cost_record.aggregate_failed" +) + +PERSISTENCE_MESSAGE_SAVED: Final[str] = "persistence.message.saved" +PERSISTENCE_MESSAGE_SAVE_FAILED: Final[str] = "persistence.message.save_failed" +PERSISTENCE_MESSAGE_DUPLICATE: Final[str] = "persistence.message.duplicate" +PERSISTENCE_MESSAGE_HISTORY_FETCHED: Final[str] = "persistence.message.history_fetched" +PERSISTENCE_MESSAGE_HISTORY_FAILED: Final[str] = "persistence.message.history_failed" diff --git a/src/ai_company/persistence/__init__.py b/src/ai_company/persistence/__init__.py new file mode 100644 index 0000000000..60a844bdf2 --- /dev/null +++ b/src/ai_company/persistence/__init__.py @@ -0,0 +1,39 @@ +"""Pluggable persistence layer for operational data (DESIGN_SPEC §7.5). + +Re-exports the protocol, repository protocols, config models, factory, +and error hierarchy so consumers can import from ``ai_company.persistence`` +directly. +""" + +from ai_company.persistence.config import PersistenceConfig, SQLiteConfig +from ai_company.persistence.errors import ( + DuplicateRecordError, + MigrationError, + PersistenceConnectionError, + PersistenceError, + QueryError, + RecordNotFoundError, +) +from ai_company.persistence.factory import create_backend +from ai_company.persistence.protocol import PersistenceBackend +from ai_company.persistence.repositories import ( + CostRecordRepository, + MessageRepository, + TaskRepository, +) + +__all__ = [ + "CostRecordRepository", + "DuplicateRecordError", + "MessageRepository", + "MigrationError", + "PersistenceBackend", + "PersistenceConfig", + "PersistenceConnectionError", + "PersistenceError", + "QueryError", + "RecordNotFoundError", + "SQLiteConfig", + "TaskRepository", + "create_backend", +] diff --git a/src/ai_company/persistence/config.py b/src/ai_company/persistence/config.py new file mode 100644 index 0000000000..bef48063fa --- /dev/null +++ b/src/ai_company/persistence/config.py @@ -0,0 +1,98 @@ +"""Persistence configuration models. + +Frozen Pydantic models for persistence backend selection and +backend-specific settings. +""" + +from pathlib import PurePosixPath +from typing import ClassVar, Self + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.config import CONFIG_VALIDATION_FAILED + +logger = get_logger(__name__) + + +class SQLiteConfig(BaseModel): + """SQLite-specific persistence configuration. + + Attributes: + path: Database file path. Use ``":memory:"`` for in-memory + databases (useful for testing). + wal_mode: Whether to enable WAL journal mode for concurrent + read performance. + journal_size_limit: Maximum WAL journal size in bytes + (default 64 MB). + """ + + model_config = ConfigDict(frozen=True, allow_inf_nan=False) + + path: NotBlankStr = Field( + default="ai-company.db", + description="Database file path", + ) + wal_mode: bool = Field( + default=True, + description="Enable WAL journal mode", + ) + journal_size_limit: int = Field( + default=67_108_864, + ge=0, + description="Maximum WAL journal size in bytes", + ) + + @model_validator(mode="after") + def _resolve_path(self) -> Self: + """Resolve relative paths to absolute to prevent traversal ambiguity. + + The special ``:memory:`` identifier is passed through unchanged. + """ + if self.path == ":memory:": + return self + resolved = str(PurePosixPath(self.path)) + object.__setattr__(self, "path", resolved) + return self + + +class PersistenceConfig(BaseModel): + """Top-level persistence configuration. + + Attributes: + backend: Backend name — currently only ``"sqlite"`` is + implemented. + sqlite: SQLite-specific settings (used when + ``backend="sqlite"``). + """ + + model_config = ConfigDict(frozen=True) + + _VALID_BACKENDS: ClassVar[frozenset[str]] = frozenset({"sqlite"}) + + backend: NotBlankStr = Field( + default="sqlite", + description="Persistence backend name", + ) + sqlite: SQLiteConfig = Field( + default_factory=SQLiteConfig, + description="SQLite-specific settings", + ) + + @model_validator(mode="after") + def _validate_backend_name(self) -> Self: + """Ensure backend is a known persistence backend.""" + if self.backend not in self._VALID_BACKENDS: + msg = ( + f"Unknown persistence backend {self.backend!r}. " + f"Valid backends: {sorted(self._VALID_BACKENDS)}" + ) + logger.warning( + CONFIG_VALIDATION_FAILED, + field="backend", + value=self.backend, + reason=msg, + ) + raise ValueError(msg) + return self diff --git a/src/ai_company/persistence/errors.py b/src/ai_company/persistence/errors.py new file mode 100644 index 0000000000..6ea2fdb800 --- /dev/null +++ b/src/ai_company/persistence/errors.py @@ -0,0 +1,33 @@ +"""Persistence error hierarchy. + +All persistence-related errors inherit from ``PersistenceError`` so +callers can catch the entire family with a single except clause. +""" + + +class PersistenceError(Exception): + """Base exception for all persistence operations.""" + + +class PersistenceConnectionError(PersistenceError): + """Raised when a backend connection cannot be established or is lost.""" + + +class MigrationError(PersistenceError): + """Raised when a database migration fails.""" + + +class RecordNotFoundError(PersistenceError): + """Raised when a requested record does not exist. + + Reserved for future strict-fetch methods (e.g. ``get_or_raise``). + Current repository protocols return ``None`` on miss instead. + """ + + +class DuplicateRecordError(PersistenceError): + """Raised when inserting a record that already exists.""" + + +class QueryError(PersistenceError): + """Raised when a query fails due to invalid parameters or backend issues.""" diff --git a/src/ai_company/persistence/factory.py b/src/ai_company/persistence/factory.py new file mode 100644 index 0000000000..da8aef3043 --- /dev/null +++ b/src/ai_company/persistence/factory.py @@ -0,0 +1,65 @@ +"""Factory for creating persistence backends from configuration. + +Each company gets its own ``PersistenceBackend`` instance, which maps +to its own database. This enables multi-tenancy: one database per +company, selectable via the ``PersistenceConfig`` embedded in each +company's ``RootConfig``. +""" + +from typing import TYPE_CHECKING + +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_BACKEND_CREATED, + PERSISTENCE_BACKEND_UNKNOWN, +) +from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend + +if TYPE_CHECKING: + from ai_company.persistence.config import PersistenceConfig + from ai_company.persistence.protocol import PersistenceBackend + +logger = get_logger(__name__) + + +def create_backend(config: PersistenceConfig) -> PersistenceBackend: + """Create a persistence backend from configuration. + + Factory function that maps ``config.backend`` to the correct + concrete backend class. Each call returns a new, disconnected + backend instance — the caller is responsible for calling + ``connect()`` and ``migrate()``. + + Args: + config: Persistence configuration (includes backend selection + and backend-specific settings). + + Returns: + A new, disconnected backend instance. + + Raises: + PersistenceConnectionError: If the backend name is not + recognized. + + Example:: + + config = PersistenceConfig( + backend="sqlite", + sqlite=SQLiteConfig(path="data/company-a.db"), + ) + backend = create_backend(config) + await backend.connect() + await backend.migrate() + """ + if config.backend == "sqlite": + backend: PersistenceBackend = SQLitePersistenceBackend(config.sqlite) + logger.debug( + PERSISTENCE_BACKEND_CREATED, + backend="sqlite", + path=config.sqlite.path, + ) + return backend + msg = f"Unknown persistence backend: {config.backend!r}" + logger.error(PERSISTENCE_BACKEND_UNKNOWN, backend=config.backend) + raise PersistenceConnectionError(msg) diff --git a/src/ai_company/persistence/protocol.py b/src/ai_company/persistence/protocol.py new file mode 100644 index 0000000000..869155b068 --- /dev/null +++ b/src/ai_company/persistence/protocol.py @@ -0,0 +1,88 @@ +"""PersistenceBackend protocol — lifecycle + repository access. + +Application code depends on this protocol for storage lifecycle +management. Repository protocols provide entity-level access. +""" + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from ai_company.persistence.repositories import ( + CostRecordRepository, + MessageRepository, + TaskRepository, + ) + + +@runtime_checkable +class PersistenceBackend(Protocol): + """Lifecycle management for operational data storage. + + Concrete backends implement this protocol to provide connection + management, health monitoring, schema migrations, and access to + entity-specific repositories. + + Attributes: + is_connected: Whether the backend has an active connection. + backend_name: Human-readable backend identifier. + tasks: Repository for Task persistence. + cost_records: Repository for CostRecord persistence. + messages: Repository for Message persistence. + """ + + async def connect(self) -> None: + """Establish connection to the storage backend. + + Raises: + PersistenceConnectionError: If the connection cannot be + established. + """ + ... + + async def disconnect(self) -> None: + """Close the storage backend connection. + + Safe to call even if not connected. + """ + ... + + async def health_check(self) -> bool: + """Check whether the backend is healthy and responsive. + + Returns: + ``True`` if the backend is reachable and operational. + """ + ... + + async def migrate(self) -> None: + """Run pending schema migrations. + + Raises: + MigrationError: If a migration fails. + """ + ... + + @property + def is_connected(self) -> bool: + """Whether the backend has an active connection.""" + ... + + @property + def backend_name(self) -> str: + """Human-readable backend identifier (e.g. ``"sqlite"``).""" + ... + + @property + def tasks(self) -> TaskRepository: + """Repository for Task persistence.""" + ... + + @property + def cost_records(self) -> CostRecordRepository: + """Repository for CostRecord persistence.""" + ... + + @property + def messages(self) -> MessageRepository: + """Repository for Message persistence.""" + ... diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py new file mode 100644 index 0000000000..1f4e6d4713 --- /dev/null +++ b/src/ai_company/persistence/repositories.py @@ -0,0 +1,166 @@ +"""Repository protocols for operational data persistence. + +Each entity type has its own protocol so that application code depends +only on abstract interfaces, never on a concrete backend. +""" + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from ai_company.budget.cost_record import CostRecord + from ai_company.communication.message import Message + from ai_company.core.enums import TaskStatus + from ai_company.core.task import Task + + +@runtime_checkable +class TaskRepository(Protocol): + """CRUD + query interface for Task persistence.""" + + async def save(self, task: Task) -> None: + """Persist a task (insert or update). + + Args: + task: The task to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, task_id: str) -> Task | None: + """Retrieve a task by its ID. + + Args: + task_id: The task identifier. + + Returns: + The task, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + """List tasks with optional filters. + + Args: + status: Filter by task status. + assigned_to: Filter by assignee agent ID. + project: Filter by project ID. + + Returns: + Matching tasks as a tuple. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, task_id: str) -> bool: + """Delete a task by ID. + + Args: + task_id: The task identifier. + + Returns: + ``True`` if the task was deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + +@runtime_checkable +class CostRecordRepository(Protocol): + """CRUD + aggregation interface for CostRecord persistence.""" + + async def save(self, record: CostRecord) -> None: + """Persist a cost record (append-only). + + Args: + record: The cost record to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def query( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> tuple[CostRecord, ...]: + """Query cost records with optional filters. + + Args: + agent_id: Filter by agent identifier. + task_id: Filter by task identifier. + + Returns: + Matching cost records as a tuple. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def aggregate(self, *, agent_id: str | None = None) -> float: + """Sum total cost_usd, optionally filtered by agent. + + Args: + agent_id: Filter by agent identifier. + + Returns: + Total cost in USD. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + +@runtime_checkable +class MessageRepository(Protocol): + """CRUD + query interface for Message persistence.""" + + async def save(self, message: Message) -> None: + """Persist a message. + + Args: + message: The message to persist. + + Raises: + DuplicateRecordError: If a message with the same ID exists. + PersistenceError: If the operation fails. + """ + ... + + async def get_history( + self, + channel: str, + *, + limit: int | None = None, + ) -> tuple[Message, ...]: + """Retrieve message history for a channel. + + Args: + channel: Channel name to query. + limit: Maximum number of messages to return (newest first). + + Returns: + Messages ordered by timestamp descending. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/ai_company/persistence/sqlite/__init__.py b/src/ai_company/persistence/sqlite/__init__.py new file mode 100644 index 0000000000..8453cef33f --- /dev/null +++ b/src/ai_company/persistence/sqlite/__init__.py @@ -0,0 +1,21 @@ +"""SQLite persistence backend (DESIGN_SPEC §7.5 — initial backend).""" + +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend +from ai_company.persistence.sqlite.migrations import ( + SCHEMA_VERSION, + run_migrations, +) +from ai_company.persistence.sqlite.repositories import ( + SQLiteCostRecordRepository, + SQLiteMessageRepository, + SQLiteTaskRepository, +) + +__all__ = [ + "SCHEMA_VERSION", + "SQLiteCostRecordRepository", + "SQLiteMessageRepository", + "SQLitePersistenceBackend", + "SQLiteTaskRepository", + "run_migrations", +] diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py new file mode 100644 index 0000000000..7b93db0b0f --- /dev/null +++ b/src/ai_company/persistence/sqlite/backend.py @@ -0,0 +1,206 @@ +"""SQLite persistence backend implementation.""" + +import contextlib +import sqlite3 +from typing import TYPE_CHECKING + +import aiosqlite + +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_BACKEND_ALREADY_CONNECTED, + PERSISTENCE_BACKEND_CONNECTED, + PERSISTENCE_BACKEND_CONNECTING, + PERSISTENCE_BACKEND_CONNECTION_FAILED, + PERSISTENCE_BACKEND_DISCONNECT_ERROR, + PERSISTENCE_BACKEND_DISCONNECTED, + PERSISTENCE_BACKEND_DISCONNECTING, + PERSISTENCE_BACKEND_HEALTH_CHECK, + PERSISTENCE_BACKEND_WAL_MODE_FAILED, +) +from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.sqlite.migrations import run_migrations +from ai_company.persistence.sqlite.repositories import ( + SQLiteCostRecordRepository, + SQLiteMessageRepository, + SQLiteTaskRepository, +) + +if TYPE_CHECKING: + from ai_company.persistence.config import SQLiteConfig + +logger = get_logger(__name__) + + +class SQLitePersistenceBackend: + """SQLite implementation of the PersistenceBackend protocol. + + Uses a single ``aiosqlite.Connection`` with WAL mode enabled by + default for concurrent read performance (configurable via + ``SQLiteConfig.wal_mode``). + + Args: + config: SQLite-specific configuration. + """ + + def __init__(self, config: SQLiteConfig) -> None: + self._config = config + self._db: aiosqlite.Connection | None = None + self._tasks: SQLiteTaskRepository | None = None + self._cost_records: SQLiteCostRecordRepository | None = None + self._messages: SQLiteMessageRepository | None = None + + def _clear_state(self) -> None: + """Reset connection and repository references to ``None``.""" + self._db = None + self._tasks = None + self._cost_records = None + self._messages = None + + async def connect(self) -> None: + """Open the SQLite database and configure WAL mode.""" + if self._db is not None: + logger.debug(PERSISTENCE_BACKEND_ALREADY_CONNECTED) + return + + logger.info(PERSISTENCE_BACKEND_CONNECTING, path=self._config.path) + try: + self._db = await aiosqlite.connect(self._config.path) + self._db.row_factory = aiosqlite.Row + + if self._config.wal_mode: + cursor = await self._db.execute("PRAGMA journal_mode=WAL") + row = await cursor.fetchone() + actual_mode = row[0] if row else "unknown" + if actual_mode != "wal" and self._config.path != ":memory:": + logger.warning( + PERSISTENCE_BACKEND_WAL_MODE_FAILED, + requested="wal", + actual=actual_mode, + ) + # PRAGMA does not support parameterized queries; + # journal_size_limit is validated as int >= 0 by Pydantic. + limit = int(self._config.journal_size_limit) + await self._db.execute(f"PRAGMA journal_size_limit={limit}") + + self._tasks = SQLiteTaskRepository(self._db) + self._cost_records = SQLiteCostRecordRepository(self._db) + self._messages = SQLiteMessageRepository(self._db) + except (sqlite3.Error, OSError) as exc: + logger.exception( + PERSISTENCE_BACKEND_CONNECTION_FAILED, + path=self._config.path, + error=str(exc), + ) + if self._db is not None: + with contextlib.suppress(sqlite3.Error, OSError): + await self._db.close() + self._clear_state() + msg = "Failed to connect to persistence backend" + raise PersistenceConnectionError(msg) from exc + + logger.info(PERSISTENCE_BACKEND_CONNECTED, path=self._config.path) + + async def disconnect(self) -> None: + """Close the database connection.""" + if self._db is None: + return + + logger.info(PERSISTENCE_BACKEND_DISCONNECTING, path=self._config.path) + try: + await self._db.close() + except (sqlite3.Error, OSError) as exc: + logger.warning( + PERSISTENCE_BACKEND_DISCONNECT_ERROR, + path=self._config.path, + error=str(exc), + error_type=type(exc).__name__, + ) + finally: + self._clear_state() + logger.info(PERSISTENCE_BACKEND_DISCONNECTED, path=self._config.path) + + async def health_check(self) -> bool: + """Check database connectivity.""" + if self._db is None: + return False + try: + cursor = await self._db.execute("SELECT 1") + row = await cursor.fetchone() + healthy = row is not None + except (sqlite3.Error, aiosqlite.Error) as exc: + logger.warning( + PERSISTENCE_BACKEND_HEALTH_CHECK, + healthy=False, + error=str(exc), + error_type=type(exc).__name__, + ) + return False + logger.debug(PERSISTENCE_BACKEND_HEALTH_CHECK, healthy=healthy) + return healthy + + async def migrate(self) -> None: + """Run pending schema migrations. + + Raises: + PersistenceConnectionError: If not connected. + MigrationError: If migration fails. + """ + if self._db is None: + msg = "Cannot migrate: not connected" + logger.warning(PERSISTENCE_BACKEND_CONNECTION_FAILED, error=msg) + raise PersistenceConnectionError(msg) + await run_migrations(self._db) + + @property + def is_connected(self) -> bool: + """Whether the backend has an active connection.""" + return self._db is not None + + @property + def backend_name(self) -> str: + """Human-readable backend identifier.""" + return "sqlite" + + def _require_connected[T](self, repo: T | None, name: str) -> T: + """Return *repo* or raise if the backend is not connected. + + Args: + repo: Repository instance (``None`` when disconnected). + name: Repository name for the error message. + + Raises: + PersistenceConnectionError: If *repo* is ``None``. + """ + if repo is None: + msg = f"Not connected — call connect() before accessing {name}" + logger.warning(PERSISTENCE_BACKEND_CONNECTION_FAILED, error=msg) + raise PersistenceConnectionError(msg) + return repo + + @property + def tasks(self) -> SQLiteTaskRepository: + """Repository for Task persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._tasks, "tasks") + + @property + def cost_records(self) -> SQLiteCostRecordRepository: + """Repository for CostRecord persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._cost_records, "cost_records") + + @property + def messages(self) -> SQLiteMessageRepository: + """Repository for Message persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._messages, "messages") diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py new file mode 100644 index 0000000000..302d5bcd35 --- /dev/null +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -0,0 +1,181 @@ +"""SQLite schema migrations using the user_version pragma. + +Each migration is a function that receives a connection and applies +DDL statements. ``run_migrations`` checks the current version and +runs only the migrations that haven't been applied yet. +""" + +import sqlite3 +from collections.abc import Callable, Coroutine, Sequence +from typing import Any + +import aiosqlite + +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_MIGRATION_COMPLETED, + PERSISTENCE_MIGRATION_FAILED, + PERSISTENCE_MIGRATION_SKIPPED, + PERSISTENCE_MIGRATION_STARTED, +) +from ai_company.persistence.errors import MigrationError + +logger = get_logger(__name__) + +# Current schema version — bump when adding new migrations. +SCHEMA_VERSION = 1 + +_V1_STATEMENTS: Sequence[str] = ( + # ── Tasks ───────────────────────────────────────────── + """\ +CREATE TABLE tasks ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + type TEXT NOT NULL, + priority TEXT NOT NULL DEFAULT 'medium', + project TEXT NOT NULL, + created_by TEXT NOT NULL, + assigned_to TEXT, + status TEXT NOT NULL DEFAULT 'created', + estimated_complexity TEXT NOT NULL DEFAULT 'medium', + budget_limit REAL NOT NULL DEFAULT 0.0, + deadline TEXT, + max_retries INTEGER NOT NULL DEFAULT 1, + parent_task_id TEXT, + task_structure TEXT, + coordination_topology TEXT NOT NULL DEFAULT 'auto', + reviewers TEXT NOT NULL DEFAULT '[]', + dependencies TEXT NOT NULL DEFAULT '[]', + artifacts_expected TEXT NOT NULL DEFAULT '[]', + acceptance_criteria TEXT NOT NULL DEFAULT '[]', + delegation_chain TEXT NOT NULL DEFAULT '[]' +)""", + "CREATE INDEX idx_tasks_status ON tasks(status)", + "CREATE INDEX idx_tasks_assigned_to ON tasks(assigned_to)", + "CREATE INDEX idx_tasks_project ON tasks(project)", + # ── Cost records ────────────────────────────────────── + """\ +CREATE TABLE cost_records ( + rowid INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT NOT NULL, + task_id TEXT NOT NULL, + provider TEXT NOT NULL, + model TEXT NOT NULL, + input_tokens INTEGER NOT NULL, + output_tokens INTEGER NOT NULL, + cost_usd REAL NOT NULL, + timestamp TEXT NOT NULL, + call_category TEXT +)""", + "CREATE INDEX idx_cost_records_agent_id ON cost_records(agent_id)", + "CREATE INDEX idx_cost_records_task_id ON cost_records(task_id)", + # ── Messages ────────────────────────────────────────── + """\ +CREATE TABLE messages ( + id TEXT PRIMARY KEY, + timestamp TEXT NOT NULL, + sender TEXT NOT NULL, + "to" TEXT NOT NULL, + type TEXT NOT NULL, + priority TEXT NOT NULL DEFAULT 'normal', + channel TEXT NOT NULL, + content TEXT NOT NULL, + attachments TEXT NOT NULL DEFAULT '[]', + metadata TEXT NOT NULL DEFAULT '{}' +)""", + "CREATE INDEX idx_messages_channel ON messages(channel)", + "CREATE INDEX idx_messages_timestamp ON messages(timestamp)", +) + +_MigrateFn = Callable[[aiosqlite.Connection], Coroutine[Any, Any, None]] + + +async def get_user_version(db: aiosqlite.Connection) -> int: + """Read the current schema version from the SQLite user_version pragma.""" + cursor = await db.execute("PRAGMA user_version") + row = await cursor.fetchone() + return int(row[0]) if row else 0 + + +async def set_user_version(db: aiosqlite.Connection, version: int) -> None: + """Set the schema version via the SQLite user_version pragma. + + Args: + db: An open aiosqlite connection. + version: Non-negative integer schema version. + + Raises: + MigrationError: If *version* is not a valid non-negative integer. + """ + if not isinstance(version, int) or version < 0: + msg = f"Schema version must be a non-negative integer, got {version!r}" + raise MigrationError(msg) + # PRAGMA does not support parameterized queries; version is validated above. + await db.execute(f"PRAGMA user_version = {version}") + + +async def _apply_v1(db: aiosqlite.Connection) -> None: + """Apply schema version 1: create tasks, cost_records, messages.""" + for stmt in _V1_STATEMENTS: + await db.execute(stmt) + + +# Ordered list of (target_version, migration_function) pairs. Each migration +# is applied when the current schema version is below its target version. +_MIGRATIONS: list[tuple[int, _MigrateFn]] = [ + (1, _apply_v1), +] + + +async def run_migrations(db: aiosqlite.Connection) -> None: + """Run pending migrations up to ``SCHEMA_VERSION``. + + Migrations are executed within an implicit transaction. On + failure, the transaction is explicitly rolled back and + ``MigrationError`` is raised. + + Args: + db: An open aiosqlite connection. + + Raises: + MigrationError: If any migration step fails. + """ + current = await get_user_version(db) + + if current >= SCHEMA_VERSION: + logger.debug( + PERSISTENCE_MIGRATION_SKIPPED, + current_version=current, + target_version=SCHEMA_VERSION, + ) + return + + logger.info( + PERSISTENCE_MIGRATION_STARTED, + current_version=current, + target_version=SCHEMA_VERSION, + ) + + try: + for target_version, migrate_fn in _MIGRATIONS: + if current < target_version: + await migrate_fn(db) + current = target_version + + await set_user_version(db, SCHEMA_VERSION) + await db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + await db.rollback() + msg = f"Migration to version {SCHEMA_VERSION} failed" + logger.exception( + PERSISTENCE_MIGRATION_FAILED, + target_version=SCHEMA_VERSION, + error=str(exc), + ) + raise MigrationError(msg) from exc + + logger.info( + PERSISTENCE_MIGRATION_COMPLETED, + version=SCHEMA_VERSION, + ) diff --git a/src/ai_company/persistence/sqlite/repositories.py b/src/ai_company/persistence/sqlite/repositories.py new file mode 100644 index 0000000000..65dee51c67 --- /dev/null +++ b/src/ai_company/persistence/sqlite/repositories.py @@ -0,0 +1,450 @@ +"""SQLite repository implementations for Task, CostRecord, and Message.""" + +import json +import sqlite3 +from typing import TYPE_CHECKING + +import aiosqlite +from pydantic import BaseModel, ValidationError + +from ai_company.budget.cost_record import CostRecord +from ai_company.communication.message import Message +from ai_company.core.task import Task +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_COST_RECORD_AGGREGATE_FAILED, + PERSISTENCE_COST_RECORD_AGGREGATED, + PERSISTENCE_COST_RECORD_QUERIED, + PERSISTENCE_COST_RECORD_QUERY_FAILED, + PERSISTENCE_COST_RECORD_SAVE_FAILED, + PERSISTENCE_COST_RECORD_SAVED, + PERSISTENCE_MESSAGE_DUPLICATE, + PERSISTENCE_MESSAGE_HISTORY_FAILED, + PERSISTENCE_MESSAGE_HISTORY_FETCHED, + PERSISTENCE_MESSAGE_SAVE_FAILED, + PERSISTENCE_MESSAGE_SAVED, + PERSISTENCE_TASK_DELETE_FAILED, + PERSISTENCE_TASK_DELETED, + PERSISTENCE_TASK_FETCH_FAILED, + PERSISTENCE_TASK_FETCHED, + PERSISTENCE_TASK_LIST_FAILED, + PERSISTENCE_TASK_LISTED, + PERSISTENCE_TASK_SAVE_FAILED, + PERSISTENCE_TASK_SAVED, +) +from ai_company.persistence.errors import DuplicateRecordError, QueryError + +if TYPE_CHECKING: + from ai_company.core.enums import TaskStatus + +logger = get_logger(__name__) + + +def _json_list(items: tuple[object, ...]) -> str: + """Serialize a tuple of Pydantic models or scalars to a JSON array. + + Items must be JSON-serializable or Pydantic models. + Non-serializable items will raise ``TypeError``. + """ + return json.dumps( + [ + item.model_dump(mode="json") if isinstance(item, BaseModel) else item + for item in items + ] + ) + + +class SQLiteTaskRepository: + """SQLite implementation of the TaskRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, task: Task) -> None: + """Persist a task (upsert semantics).""" + try: + params = task.model_dump(mode="json") + # Tuple fields must be stored as JSON strings. + params["reviewers"] = _json_list(task.reviewers) + params["dependencies"] = _json_list(task.dependencies) + params["artifacts_expected"] = _json_list(task.artifacts_expected) + params["acceptance_criteria"] = _json_list( + task.acceptance_criteria, + ) + params["delegation_chain"] = _json_list(task.delegation_chain) + + await self._db.execute( + """\ +INSERT INTO tasks ( + id, title, description, type, priority, project, created_by, + assigned_to, status, estimated_complexity, budget_limit, deadline, + max_retries, parent_task_id, task_structure, coordination_topology, + reviewers, dependencies, artifacts_expected, acceptance_criteria, + delegation_chain +) VALUES ( + :id, :title, :description, :type, :priority, :project, :created_by, + :assigned_to, :status, :estimated_complexity, :budget_limit, :deadline, + :max_retries, :parent_task_id, :task_structure, :coordination_topology, + :reviewers, :dependencies, :artifacts_expected, :acceptance_criteria, + :delegation_chain +) +ON CONFLICT(id) DO UPDATE SET + title=excluded.title, + description=excluded.description, + type=excluded.type, + priority=excluded.priority, + project=excluded.project, + created_by=excluded.created_by, + assigned_to=excluded.assigned_to, + status=excluded.status, + estimated_complexity=excluded.estimated_complexity, + budget_limit=excluded.budget_limit, + deadline=excluded.deadline, + max_retries=excluded.max_retries, + parent_task_id=excluded.parent_task_id, + task_structure=excluded.task_structure, + coordination_topology=excluded.coordination_topology, + reviewers=excluded.reviewers, + dependencies=excluded.dependencies, + artifacts_expected=excluded.artifacts_expected, + acceptance_criteria=excluded.acceptance_criteria, + delegation_chain=excluded.delegation_chain +""", + params, + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save task {task.id!r}" + logger.exception( + PERSISTENCE_TASK_SAVE_FAILED, task_id=task.id, error=str(exc) + ) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_TASK_SAVED, task_id=task.id) + + #: Fields stored as JSON strings that need deserialization. + _JSON_FIELDS: tuple[str, ...] = ( + "reviewers", + "dependencies", + "artifacts_expected", + "acceptance_criteria", + "delegation_chain", + ) + + def _row_to_task(self, row: aiosqlite.Row) -> Task: + """Reconstruct a Task from a database row.""" + try: + data = dict(row) + for field in self._JSON_FIELDS: + data[field] = json.loads(data[field]) + return Task.model_validate(data) + except (json.JSONDecodeError, ValidationError) as exc: + task_id = row["id"] if row else "unknown" + msg = f"Failed to deserialize task {task_id!r}" + raise QueryError(msg) from exc + + _TASK_COLUMNS = """\ +id, title, description, type, priority, project, created_by, + assigned_to, status, estimated_complexity, budget_limit, deadline, + max_retries, parent_task_id, task_structure, coordination_topology, + reviewers, dependencies, artifacts_expected, acceptance_criteria, + delegation_chain""" + + async def get(self, task_id: str) -> Task | None: + """Retrieve a task by its ID.""" + try: + cursor = await self._db.execute( + f"SELECT {self._TASK_COLUMNS} FROM tasks WHERE id = ?", # noqa: S608 + (task_id,), + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch task {task_id!r}" + logger.exception( + PERSISTENCE_TASK_FETCH_FAILED, + task_id=task_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + logger.debug(PERSISTENCE_TASK_FETCHED, task_id=task_id, found=False) + return None + logger.debug(PERSISTENCE_TASK_FETCHED, task_id=task_id, found=True) + return self._row_to_task(row) + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + """List tasks with optional filters.""" + clauses: list[str] = [] + params: list[str] = [] + if status is not None: + clauses.append("status = ?") + params.append(status.value) + if assigned_to is not None: + clauses.append("assigned_to = ?") + params.append(assigned_to) + if project is not None: + clauses.append("project = ?") + params.append(project) + + query = f"SELECT {self._TASK_COLUMNS} FROM tasks" # noqa: S608 + if clauses: + query += " WHERE " + " AND ".join(clauses) + + try: + cursor = await self._db.execute(query, params) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to list tasks" + logger.exception(PERSISTENCE_TASK_LIST_FAILED, error=str(exc)) + raise QueryError(msg) from exc + tasks = tuple(self._row_to_task(row) for row in rows) + logger.debug(PERSISTENCE_TASK_LISTED, count=len(tasks)) + return tasks + + async def delete(self, task_id: str) -> bool: + """Delete a task by ID.""" + try: + cursor = await self._db.execute( + "DELETE FROM tasks WHERE id = ?", (task_id,) + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete task {task_id!r}" + logger.exception( + PERSISTENCE_TASK_DELETE_FAILED, + task_id=task_id, + error=str(exc), + ) + raise QueryError(msg) from exc + deleted = cursor.rowcount > 0 + logger.debug(PERSISTENCE_TASK_DELETED, task_id=task_id, deleted=deleted) + return deleted + + +class SQLiteCostRecordRepository: + """SQLite implementation of the CostRecordRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, record: CostRecord) -> None: + """Persist a cost record (append-only).""" + try: + data = record.model_dump(mode="json") + await self._db.execute( + """\ +INSERT INTO cost_records ( + agent_id, task_id, provider, model, input_tokens, + output_tokens, cost_usd, timestamp, call_category +) VALUES ( + :agent_id, :task_id, :provider, :model, :input_tokens, + :output_tokens, :cost_usd, :timestamp, :call_category +)""", + data, + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save cost record for agent {record.agent_id!r}" + logger.exception( + PERSISTENCE_COST_RECORD_SAVE_FAILED, + agent_id=record.agent_id, + task_id=record.task_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug( + PERSISTENCE_COST_RECORD_SAVED, + agent_id=record.agent_id, + task_id=record.task_id, + ) + + async def query( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> tuple[CostRecord, ...]: + """Query cost records with optional filters.""" + clauses: list[str] = [] + params: list[str] = [] + if agent_id is not None: + clauses.append("agent_id = ?") + params.append(agent_id) + if task_id is not None: + clauses.append("task_id = ?") + params.append(task_id) + + sql = """\ +SELECT agent_id, task_id, provider, model, input_tokens, + output_tokens, cost_usd, timestamp, call_category +FROM cost_records""" + if clauses: + sql += " WHERE " + " AND ".join(clauses) + + try: + cursor = await self._db.execute(sql, params) + rows = await cursor.fetchall() + records = tuple(CostRecord.model_validate(dict(row)) for row in rows) + except ( + sqlite3.Error, + aiosqlite.Error, + json.JSONDecodeError, + ValidationError, + ) as exc: + msg = "Failed to query cost records" + logger.exception(PERSISTENCE_COST_RECORD_QUERY_FAILED, error=str(exc)) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_COST_RECORD_QUERIED, count=len(records)) + return records + + async def aggregate(self, *, agent_id: str | None = None) -> float: + """Sum total cost_usd, optionally filtered by agent.""" + try: + sql = "SELECT COALESCE(SUM(cost_usd), 0.0) FROM cost_records" + params: tuple[str, ...] = () + if agent_id is not None: + sql += " WHERE agent_id = ?" + params = (agent_id,) + cursor = await self._db.execute(sql, params) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to aggregate cost records" + logger.exception( + PERSISTENCE_COST_RECORD_AGGREGATE_FAILED, + agent_id=agent_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if row is None: + msg = "aggregate query returned no rows" + logger.error( + PERSISTENCE_COST_RECORD_AGGREGATE_FAILED, + agent_id=agent_id, + error=msg, + ) + raise QueryError(msg) + total = float(row[0]) + logger.debug( + PERSISTENCE_COST_RECORD_AGGREGATED, + agent_id=agent_id, + total_usd=total, + ) + return total + + +class SQLiteMessageRepository: + """SQLite implementation of the MessageRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, message: Message) -> None: + """Persist a message.""" + data = message.model_dump(mode="json") + msg_id = str(message.id) + + try: + await self._db.execute( + """\ +INSERT INTO messages ( + id, timestamp, sender, "to", type, priority, + channel, content, attachments, metadata +) VALUES ( + :id, :timestamp, :sender, :to, :type, :priority, + :channel, :content, :attachments, :metadata +)""", + { + "id": msg_id, + "timestamp": data["timestamp"], + "sender": data["sender"], + "to": data["to"], + "type": data["type"], + "priority": data["priority"], + "channel": data["channel"], + "content": data["content"], + "attachments": json.dumps(data["attachments"]), + "metadata": json.dumps(data["metadata"]), + }, + ) + await self._db.commit() + except sqlite3.IntegrityError as exc: + err_msg = f"Message {msg_id} already exists" + logger.warning(PERSISTENCE_MESSAGE_DUPLICATE, message_id=msg_id) + raise DuplicateRecordError(err_msg) from exc + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save message {msg_id!r}" + logger.exception( + PERSISTENCE_MESSAGE_SAVE_FAILED, + message_id=msg_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug(PERSISTENCE_MESSAGE_SAVED, message_id=msg_id) + + def _row_to_message(self, row: aiosqlite.Row) -> Message: + """Reconstruct a Message from a database row.""" + try: + data = dict(row) + # Map DB column "sender" to Message's "from" alias. + data["from"] = data.pop("sender") + data["attachments"] = json.loads(data["attachments"]) + data["metadata"] = json.loads(data["metadata"]) + return Message.model_validate(data) + except (json.JSONDecodeError, ValidationError) as exc: + msg_id = row["id"] if row else "unknown" + msg = f"Failed to deserialize message {msg_id!r}" + raise QueryError(msg) from exc + + async def get_history( + self, + channel: str, + *, + limit: int | None = None, + ) -> tuple[Message, ...]: + """Retrieve message history for a channel, newest first.""" + sql = """\ +SELECT id, timestamp, sender, "to", type, priority, + channel, content, attachments, metadata +FROM messages +WHERE channel = ? +ORDER BY timestamp DESC""" + params: list[object] = [channel] + if limit is not None: + sql += " LIMIT ?" + params.append(limit) + + try: + cursor = await self._db.execute(sql, params) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch message history for channel {channel!r}" + logger.exception( + PERSISTENCE_MESSAGE_HISTORY_FAILED, + channel=channel, + error=str(exc), + ) + raise QueryError(msg) from exc + messages = tuple(self._row_to_message(row) for row in rows) + logger.debug( + PERSISTENCE_MESSAGE_HISTORY_FETCHED, + channel=channel, + count=len(messages), + ) + return messages diff --git a/tests/integration/persistence/__init__.py b/tests/integration/persistence/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/persistence/conftest.py b/tests/integration/persistence/conftest.py new file mode 100644 index 0000000000..b041e1ae39 --- /dev/null +++ b/tests/integration/persistence/conftest.py @@ -0,0 +1,28 @@ +"""Fixtures for persistence integration tests.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.persistence.config import SQLiteConfig +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + from pathlib import Path + + +@pytest.fixture +def db_path(tmp_path: Path) -> str: + """Return a temporary on-disk database path.""" + return str(tmp_path / "test.db") + + +@pytest.fixture +async def on_disk_backend(db_path: str) -> AsyncGenerator[SQLitePersistenceBackend]: + """Connected + migrated on-disk SQLite backend.""" + backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend.connect() + await backend.migrate() + yield backend + await backend.disconnect() diff --git a/tests/integration/persistence/test_sqlite_integration.py b/tests/integration/persistence/test_sqlite_integration.py new file mode 100644 index 0000000000..131952f977 --- /dev/null +++ b/tests/integration/persistence/test_sqlite_integration.py @@ -0,0 +1,111 @@ +"""Integration tests for SQLite persistence (on-disk).""" + +from pathlib import Path + +import pytest + +from ai_company.persistence.config import SQLiteConfig +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend +from tests.unit.persistence.conftest import make_message, make_task + + +@pytest.mark.integration +class TestSQLiteOnDisk: + async def test_wal_file_created(self, db_path: str) -> None: + """WAL mode creates -wal and -shm files on disk.""" + backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend.connect() + await backend.migrate() + + # Write some data to force WAL file creation + task = make_task() + await backend.tasks.save(task) + + # WAL file may or may not exist depending on checkpoint behavior, + # but the db file should exist + assert Path(db_path).exists() # noqa: ASYNC240 + await backend.disconnect() + + async def test_data_persists_across_reconnect(self, db_path: str) -> None: + """Data written before disconnect is readable after reconnect.""" + backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend.connect() + await backend.migrate() + + task = make_task(task_id="persist-test") + await backend.tasks.save(task) + await backend.disconnect() + + # Reconnect and verify data + backend2 = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend2.connect() + await backend2.migrate() + + result = await backend2.tasks.get("persist-test") + assert result is not None + assert result.id == "persist-test" + await backend2.disconnect() + + async def test_multiple_entity_types_persist( + self, + on_disk_backend: SQLitePersistenceBackend, + ) -> None: + """Tasks, cost records, and messages all persist together.""" + from datetime import UTC, datetime + + from ai_company.budget.cost_record import CostRecord + + backend = on_disk_backend + + # Save task + await backend.tasks.save(make_task(task_id="multi-t1")) + + # Save cost record + record = CostRecord( + agent_id="alice", + task_id="multi-t1", + provider="test-provider", + model="test-model-001", + input_tokens=500, + output_tokens=200, + cost_usd=0.03, + timestamp=datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + ) + await backend.cost_records.save(record) + + # Save message + await backend.messages.save(make_message(channel="test-channel")) + + # Verify all persist + tasks = await backend.tasks.list_tasks() + assert len(tasks) == 1 + + records = await backend.cost_records.query() + assert len(records) == 1 + + history = await backend.messages.get_history("test-channel") + assert len(history) == 1 + + async def test_concurrent_reads(self, db_path: str) -> None: + """Multiple connections can read concurrently with WAL mode.""" + import asyncio + + # Set up data + backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await backend.connect() + await backend.migrate() + for i in range(10): + await backend.tasks.save(make_task(task_id=f"conc-{i}")) + await backend.disconnect() + + async def read_all() -> int: + b = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) + await b.connect() + await b.migrate() + tasks = await b.tasks.list_tasks() + await b.disconnect() + return len(tasks) + + # Run multiple readers concurrently + results = await asyncio.gather(read_all(), read_all(), read_all()) + assert all(r == 10 for r in results) diff --git a/tests/unit/config/conftest.py b/tests/unit/config/conftest.py index f92d4903c3..753a8ac6e3 100644 --- a/tests/unit/config/conftest.py +++ b/tests/unit/config/conftest.py @@ -20,6 +20,7 @@ TaskAssignmentConfig, ) from ai_company.core.company import CompanyConfig +from ai_company.persistence.config import PersistenceConfig if TYPE_CHECKING: from pathlib import Path @@ -73,6 +74,7 @@ class RootConfigFactory(ModelFactory[RootConfig]): logging = None coordination_metrics = CoordinationMetricsConfig() task_assignment = TaskAssignmentConfig() + persistence = PersistenceConfig() # ── Sample YAML strings ────────────────────────────────────────── diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index b23649214c..e777a72057 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -327,6 +327,23 @@ def test_defaults_applied(self) -> None: assert cfg.communication.default_pattern.value == "hybrid" assert cfg.routing.strategy == "cost_aware" + def test_persistence_defaults(self) -> None: + cfg = RootConfig(company_name="X") + assert cfg.persistence.backend == "sqlite" + assert cfg.persistence.sqlite.path == "ai-company.db" + assert cfg.persistence.sqlite.wal_mode is True + + def test_persistence_custom_path(self) -> None: + from ai_company.persistence.config import PersistenceConfig, SQLiteConfig + + cfg = RootConfig( + company_name="X", + persistence=PersistenceConfig( + sqlite=SQLiteConfig(path="data/company-a.db"), + ), + ) + assert cfg.persistence.sqlite.path == "data/company-a.db" + def test_missing_company_name_rejected(self) -> None: with pytest.raises(ValidationError): RootConfig() # type: ignore[call-arg] diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 97b3a67d68..2313129af1 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -188,6 +188,7 @@ def test_all_domain_modules_discovered(self) -> None: "task_routing", "template", "tool", + "persistence", "workspace", } discovered = {info.name for info in pkgutil.iter_modules(events.__path__)} @@ -365,3 +366,46 @@ def test_workspace_events_exist(self) -> None: WORKSPACE_SORT_WORKSPACES_APPENDED == "workspace.sort.workspaces.appended" ) assert WORKSPACE_GROUP_SETUP_FAILED == "workspace.group.setup.failed" + + def test_persistence_events_exist(self) -> None: + from ai_company.observability.events.persistence import ( + PERSISTENCE_BACKEND_CONNECTED, + PERSISTENCE_BACKEND_CONNECTING, + PERSISTENCE_BACKEND_DISCONNECTED, + PERSISTENCE_BACKEND_DISCONNECTING, + PERSISTENCE_BACKEND_HEALTH_CHECK, + PERSISTENCE_COST_RECORD_AGGREGATED, + PERSISTENCE_COST_RECORD_QUERIED, + PERSISTENCE_COST_RECORD_SAVED, + PERSISTENCE_MESSAGE_HISTORY_FETCHED, + PERSISTENCE_MESSAGE_SAVED, + PERSISTENCE_MIGRATION_COMPLETED, + PERSISTENCE_MIGRATION_SKIPPED, + PERSISTENCE_MIGRATION_STARTED, + PERSISTENCE_TASK_DELETED, + PERSISTENCE_TASK_FETCHED, + PERSISTENCE_TASK_LISTED, + PERSISTENCE_TASK_SAVED, + ) + + assert PERSISTENCE_BACKEND_CONNECTING == "persistence.backend.connecting" + assert PERSISTENCE_BACKEND_CONNECTED == "persistence.backend.connected" + assert PERSISTENCE_BACKEND_DISCONNECTING == "persistence.backend.disconnecting" + assert PERSISTENCE_BACKEND_DISCONNECTED == "persistence.backend.disconnected" + assert PERSISTENCE_BACKEND_HEALTH_CHECK == "persistence.backend.health_check" + assert PERSISTENCE_MIGRATION_STARTED == "persistence.migration.started" + assert PERSISTENCE_MIGRATION_COMPLETED == "persistence.migration.completed" + assert PERSISTENCE_MIGRATION_SKIPPED == "persistence.migration.skipped" + assert PERSISTENCE_TASK_SAVED == "persistence.task.saved" + assert PERSISTENCE_TASK_FETCHED == "persistence.task.fetched" + assert PERSISTENCE_TASK_LISTED == "persistence.task.listed" + assert PERSISTENCE_TASK_DELETED == "persistence.task.deleted" + assert PERSISTENCE_COST_RECORD_SAVED == "persistence.cost_record.saved" + assert PERSISTENCE_COST_RECORD_QUERIED == "persistence.cost_record.queried" + assert PERSISTENCE_COST_RECORD_AGGREGATED == ( + "persistence.cost_record.aggregated" + ) + assert PERSISTENCE_MESSAGE_SAVED == "persistence.message.saved" + assert PERSISTENCE_MESSAGE_HISTORY_FETCHED == ( + "persistence.message.history_fetched" + ) diff --git a/tests/unit/persistence/__init__.py b/tests/unit/persistence/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/persistence/conftest.py b/tests/unit/persistence/conftest.py new file mode 100644 index 0000000000..b17822f934 --- /dev/null +++ b/tests/unit/persistence/conftest.py @@ -0,0 +1,67 @@ +"""Shared fixtures and helpers for persistence unit tests.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +from ai_company.communication.enums import MessagePriority, MessageType +from ai_company.communication.message import Message, MessageMetadata +from ai_company.core.enums import Priority, TaskStatus, TaskType +from ai_company.core.task import Task + +if TYPE_CHECKING: + from uuid import UUID + + +def make_task( # noqa: PLR0913 + *, + task_id: str = "task-001", + title: str = "Test task", + description: str = "A test task for persistence", + task_type: TaskType = TaskType.DEVELOPMENT, + priority: Priority = Priority.MEDIUM, + project: str = "test-project", + created_by: str = "alice", + assigned_to: str | None = None, + status: TaskStatus = TaskStatus.CREATED, +) -> Task: + """Build a Task with sensible defaults for persistence tests.""" + return Task( + id=task_id, + title=title, + description=description, + type=task_type, + priority=priority, + project=project, + created_by=created_by, + assigned_to=assigned_to, + status=status, + ) + + +def make_message( # noqa: PLR0913 + *, + msg_id: UUID | None = None, + sender: str = "alice", + to: str = "bob", + channel: str = "general", + content: str = "Hello, world!", + msg_type: MessageType = MessageType.TASK_UPDATE, + priority: MessagePriority = MessagePriority.NORMAL, + timestamp: datetime | None = None, + metadata: MessageMetadata | None = None, +) -> Message: + """Build a Message with sensible defaults for persistence tests.""" + kwargs: dict[str, object] = { + "from": sender, + "to": to, + "channel": channel, + "content": content, + "type": msg_type, + "priority": priority, + "timestamp": timestamp or datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + } + if msg_id is not None: + kwargs["id"] = msg_id + if metadata is not None: + kwargs["metadata"] = metadata + return Message.model_validate(kwargs) diff --git a/tests/unit/persistence/sqlite/__init__.py b/tests/unit/persistence/sqlite/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/persistence/sqlite/conftest.py b/tests/unit/persistence/sqlite/conftest.py new file mode 100644 index 0000000000..0849460507 --- /dev/null +++ b/tests/unit/persistence/sqlite/conftest.py @@ -0,0 +1,30 @@ +"""Fixtures for SQLite persistence unit tests.""" + +from typing import TYPE_CHECKING + +import aiosqlite +import pytest + +from ai_company.persistence.sqlite.migrations import run_migrations + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + +@pytest.fixture +async def memory_db() -> AsyncGenerator[aiosqlite.Connection]: + """Raw in-memory SQLite connection (no migrations).""" + db = await aiosqlite.connect(":memory:") + db.row_factory = aiosqlite.Row + yield db + await db.close() + + +@pytest.fixture +async def migrated_db() -> AsyncGenerator[aiosqlite.Connection]: + """In-memory SQLite connection with v1 schema applied.""" + db = await aiosqlite.connect(":memory:") + db.row_factory = aiosqlite.Row + await run_migrations(db) + yield db + await db.close() diff --git a/tests/unit/persistence/sqlite/test_backend.py b/tests/unit/persistence/sqlite/test_backend.py new file mode 100644 index 0000000000..dd52c0f79b --- /dev/null +++ b/tests/unit/persistence/sqlite/test_backend.py @@ -0,0 +1,127 @@ +"""Tests for SQLitePersistenceBackend.""" + +import sqlite3 + +import pytest + +from ai_company.persistence.config import SQLiteConfig +from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.protocol import PersistenceBackend +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend +from ai_company.persistence.sqlite.repositories import ( + SQLiteCostRecordRepository, + SQLiteMessageRepository, + SQLiteTaskRepository, +) + + +@pytest.mark.unit +class TestSQLitePersistenceBackend: + async def test_connect_and_disconnect(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + assert backend.is_connected is False + + await backend.connect() + assert backend.is_connected is True + + await backend.disconnect() # type: ignore[unreachable] + assert backend.is_connected is False + + async def test_connect_idempotent(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + await backend.connect() # should not raise + assert backend.is_connected is True + await backend.disconnect() + + async def test_disconnect_when_not_connected(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.disconnect() # should not raise + + async def test_health_check_connected(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + assert await backend.health_check() is True + await backend.disconnect() + + async def test_health_check_disconnected(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + assert await backend.health_check() is False + + async def test_backend_name(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + assert backend.backend_name == "sqlite" + + async def test_migrate_creates_tables(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + await backend.migrate() + + # Verify tables exist by accessing repos + assert isinstance(backend.tasks, SQLiteTaskRepository) + assert isinstance(backend.cost_records, SQLiteCostRecordRepository) + assert isinstance(backend.messages, SQLiteMessageRepository) + await backend.disconnect() + + async def test_migrate_without_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="not connected"): + await backend.migrate() + + async def test_tasks_before_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="Not connected"): + _ = backend.tasks + + async def test_cost_records_before_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="Not connected"): + _ = backend.cost_records + + async def test_messages_before_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="Not connected"): + _ = backend.messages + + async def test_wal_mode_enabled(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:", wal_mode=True)) + await backend.connect() + # WAL mode on :memory: may show as "memory" not "wal", + # but the PRAGMA succeeds without error + assert backend.is_connected is True + await backend.disconnect() + + async def test_wal_mode_disabled(self) -> None: + backend = SQLitePersistenceBackend( + SQLiteConfig(path=":memory:", wal_mode=False) + ) + await backend.connect() + assert backend.is_connected is True + await backend.disconnect() + + async def test_connect_failure_raises_connection_error(self) -> None: + """Non-existent path raises PersistenceConnectionError.""" + config = SQLiteConfig(path="/nonexistent/deeply/nested/dir/test.db") + backend = SQLitePersistenceBackend(config) + with pytest.raises(PersistenceConnectionError): + await backend.connect() + assert backend.is_connected is False + + async def test_health_check_returns_false_on_error(self) -> None: + """health_check returns False (not raises) when the db errors.""" + from unittest.mock import AsyncMock + + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + # Patch execute to simulate a database error + assert backend._db is not None + backend._db.execute = AsyncMock( # type: ignore[method-assign] + side_effect=sqlite3.OperationalError("disk I/O error") + ) + result = await backend.health_check() + assert result is False + await backend.disconnect() + + async def test_protocol_compliance(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + assert isinstance(backend, PersistenceBackend) diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py new file mode 100644 index 0000000000..2a1dd95547 --- /dev/null +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -0,0 +1,96 @@ +"""Tests for SQLite schema migrations.""" + +import sqlite3 +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, patch + +import pytest + +from ai_company.persistence.errors import MigrationError +from ai_company.persistence.sqlite.migrations import ( + SCHEMA_VERSION, + get_user_version, + run_migrations, + set_user_version, +) + +if TYPE_CHECKING: + import aiosqlite + + +@pytest.mark.unit +class TestUserVersion: + async def test_default_version_is_zero( + self, memory_db: aiosqlite.Connection + ) -> None: + assert await get_user_version(memory_db) == 0 + + async def test_set_and_get_version(self, memory_db: aiosqlite.Connection) -> None: + await set_user_version(memory_db, 42) + assert await get_user_version(memory_db) == 42 + + +@pytest.mark.unit +class TestRunMigrations: + async def test_creates_tables(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [row[0] for row in await cursor.fetchall()] + assert "tasks" in tables + assert "cost_records" in tables + assert "messages" in tables + + async def test_sets_version(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + assert await get_user_version(memory_db) == SCHEMA_VERSION + + async def test_idempotent(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + await run_migrations(memory_db) + assert await get_user_version(memory_db) == SCHEMA_VERSION + + async def test_creates_indexes(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='index' " + "AND name LIKE 'idx_%' ORDER BY name" + ) + indexes = {row[0] for row in await cursor.fetchall()} + expected = { + "idx_tasks_status", + "idx_tasks_assigned_to", + "idx_tasks_project", + "idx_cost_records_agent_id", + "idx_cost_records_task_id", + "idx_messages_channel", + "idx_messages_timestamp", + } + assert expected.issubset(indexes) + + async def test_skips_when_already_at_version( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Running migrations on an already-migrated db is a no-op.""" + version_before = await get_user_version(migrated_db) + await run_migrations(migrated_db) + assert await get_user_version(migrated_db) == version_before + + async def test_migration_failure_raises_migration_error( + self, memory_db: aiosqlite.Connection + ) -> None: + """A failing migration step wraps the error as MigrationError.""" + failing_fn = AsyncMock( + side_effect=sqlite3.OperationalError("simulated migration failure") + ) + with ( + patch( + "ai_company.persistence.sqlite.migrations._MIGRATIONS", + [(1, failing_fn)], + ), + pytest.raises(MigrationError, match="Migration to version"), + ): + await run_migrations(memory_db) diff --git a/tests/unit/persistence/sqlite/test_repositories.py b/tests/unit/persistence/sqlite/test_repositories.py new file mode 100644 index 0000000000..a7dd8796b4 --- /dev/null +++ b/tests/unit/persistence/sqlite/test_repositories.py @@ -0,0 +1,440 @@ +"""Tests for SQLite repository implementations.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +import pytest + +if TYPE_CHECKING: + import aiosqlite + +from ai_company.budget.cost_record import CostRecord +from ai_company.communication.message import ( + Message, + MessageMetadata, +) +from ai_company.core.enums import ( + ArtifactType, + Complexity, + CoordinationTopology, + Priority, + TaskStatus, + TaskStructure, + TaskType, +) +from ai_company.core.task import AcceptanceCriterion, Task +from ai_company.persistence.errors import DuplicateRecordError +from ai_company.persistence.sqlite.repositories import ( + SQLiteCostRecordRepository, + SQLiteMessageRepository, + SQLiteTaskRepository, +) +from tests.unit.persistence.conftest import make_message, make_task + +# ── TaskRepository ─────────────────────────────────────────────── + + +@pytest.mark.unit +class TestSQLiteTaskRepository: + async def test_save_and_get(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteTaskRepository(migrated_db) + task = make_task() + await repo.save(task) + + result = await repo.get("task-001") + assert result is not None + assert result.id == task.id + assert result.title == task.title + assert result.type == task.type + assert result.status == TaskStatus.CREATED + + async def test_get_returns_none_for_missing( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + assert await repo.get("nonexistent") is None + + async def test_save_upsert_updates_existing( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + task = make_task() + await repo.save(task) + + updated = task.with_transition(TaskStatus.ASSIGNED, assigned_to="bob") + await repo.save(updated) + + result = await repo.get("task-001") + assert result is not None + assert result.status == TaskStatus.ASSIGNED + assert result.assigned_to == "bob" + + async def test_list_all(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteTaskRepository(migrated_db) + await repo.save(make_task(task_id="t1")) + await repo.save(make_task(task_id="t2")) + + tasks = await repo.list_tasks() + assert len(tasks) == 2 + + async def test_list_filter_by_status( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + await repo.save(make_task(task_id="t1")) + t2 = make_task(task_id="t2").with_transition( + TaskStatus.ASSIGNED, assigned_to="bob" + ) + await repo.save(t2) + + created = await repo.list_tasks(status=TaskStatus.CREATED) + assert len(created) == 1 + assert created[0].id == "t1" + + async def test_list_filter_by_assigned_to( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + t = make_task().with_transition(TaskStatus.ASSIGNED, assigned_to="bob") + await repo.save(t) + + result = await repo.list_tasks(assigned_to="bob") + assert len(result) == 1 + assert result[0].assigned_to == "bob" + + async def test_list_filter_by_project( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteTaskRepository(migrated_db) + await repo.save(make_task(task_id="t1", project="proj-a")) + await repo.save(make_task(task_id="t2", project="proj-b")) + + result = await repo.list_tasks(project="proj-a") + assert len(result) == 1 + assert result[0].project == "proj-a" + + async def test_delete_existing(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteTaskRepository(migrated_db) + await repo.save(make_task()) + assert await repo.delete("task-001") is True + assert await repo.get("task-001") is None + + async def test_delete_nonexistent(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteTaskRepository(migrated_db) + assert await repo.delete("nonexistent") is False + + async def test_list_with_combined_filters( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Multiple filters combine with AND logic.""" + repo = SQLiteTaskRepository(migrated_db) + t1 = make_task(task_id="t1", project="proj-a") + t2 = make_task(task_id="t2", project="proj-a").with_transition( + TaskStatus.ASSIGNED, assigned_to="bob" + ) + t3 = make_task(task_id="t3", project="proj-b").with_transition( + TaskStatus.ASSIGNED, assigned_to="bob" + ) + await repo.save(t1) + await repo.save(t2) + await repo.save(t3) + + result = await repo.list_tasks(status=TaskStatus.ASSIGNED, project="proj-a") + assert len(result) == 1 + assert result[0].id == "t2" + + async def test_round_trip_with_nested_models( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Verify complex nested fields survive serialization.""" + from ai_company.core.artifact import ExpectedArtifact + + task = Task( + id="task-complex", + title="Complex task", + description="Task with nested models", + type=TaskType.DEVELOPMENT, + priority=Priority.HIGH, + project="test-project", + created_by="alice", + estimated_complexity=Complexity.COMPLEX, + budget_limit=50.0, + deadline="2026-12-31T23:59:59", + max_retries=3, + task_structure=TaskStructure.PARALLEL, + coordination_topology=CoordinationTopology.CENTRALIZED, + reviewers=("reviewer-1", "reviewer-2"), + dependencies=("dep-1", "dep-2"), + artifacts_expected=( + ExpectedArtifact(type=ArtifactType.CODE, path="src/main.py"), + ExpectedArtifact(type=ArtifactType.TESTS, path="tests/"), + ), + acceptance_criteria=( + AcceptanceCriterion(description="Tests pass"), + AcceptanceCriterion(description="Code reviewed", met=True), + ), + delegation_chain=("manager", "lead"), + ) + repo = SQLiteTaskRepository(migrated_db) + await repo.save(task) + + result = await repo.get("task-complex") + assert result is not None + assert result.reviewers == ("reviewer-1", "reviewer-2") + assert result.dependencies == ("dep-1", "dep-2") + assert len(result.artifacts_expected) == 2 + assert result.artifacts_expected[0].type == ArtifactType.CODE + assert result.artifacts_expected[0].path == "src/main.py" + assert len(result.acceptance_criteria) == 2 + assert result.acceptance_criteria[1].met is True + assert result.delegation_chain == ("manager", "lead") + assert result.task_structure == TaskStructure.PARALLEL + assert result.coordination_topology == CoordinationTopology.CENTRALIZED + assert result.budget_limit == 50.0 + assert result.deadline == "2026-12-31T23:59:59" + assert result.max_retries == 3 + + +# ── CostRecordRepository ──────────────────────────────────────── + + +@pytest.mark.unit +class TestSQLiteCostRecordRepository: + def _make_record( + self, + *, + agent_id: str = "alice", + task_id: str = "task-001", + cost_usd: float = 0.05, + ) -> CostRecord: + return CostRecord( + agent_id=agent_id, + task_id=task_id, + provider="test-provider", + model="test-model-001", + input_tokens=1000, + output_tokens=500, + cost_usd=cost_usd, + timestamp=datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + ) + + async def test_save_and_query(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + record = self._make_record() + await repo.save(record) + + results = await repo.query() + assert len(results) == 1 + assert results[0].agent_id == "alice" + assert results[0].cost_usd == 0.05 + + async def test_query_by_agent(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(agent_id="alice")) + await repo.save(self._make_record(agent_id="bob")) + + results = await repo.query(agent_id="alice") + assert len(results) == 1 + assert results[0].agent_id == "alice" + + async def test_query_by_task(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(task_id="t1")) + await repo.save(self._make_record(task_id="t2")) + + results = await repo.query(task_id="t1") + assert len(results) == 1 + assert results[0].task_id == "t1" + + async def test_aggregate_all(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(cost_usd=0.10)) + await repo.save(self._make_record(cost_usd=0.20)) + + total = await repo.aggregate() + assert abs(total - 0.30) < 1e-9 + + async def test_aggregate_by_agent(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(agent_id="alice", cost_usd=0.10)) + await repo.save(self._make_record(agent_id="bob", cost_usd=0.20)) + + total = await repo.aggregate(agent_id="alice") + assert abs(total - 0.10) < 1e-9 + + async def test_aggregate_empty(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + total = await repo.aggregate() + assert total == 0.0 + + async def test_query_with_combined_filters( + self, migrated_db: aiosqlite.Connection + ) -> None: + """agent_id + task_id filters combine correctly.""" + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(agent_id="alice", task_id="t1")) + await repo.save(self._make_record(agent_id="alice", task_id="t2")) + await repo.save(self._make_record(agent_id="bob", task_id="t1")) + + results = await repo.query(agent_id="alice", task_id="t1") + assert len(results) == 1 + assert results[0].agent_id == "alice" + assert results[0].task_id == "t1" + + async def test_round_trip_with_call_category( + self, migrated_db: aiosqlite.Connection + ) -> None: + from ai_company.budget.call_category import LLMCallCategory + + record = CostRecord( + agent_id="alice", + task_id="task-001", + provider="test-provider", + model="test-model-001", + input_tokens=1000, + output_tokens=500, + cost_usd=0.05, + timestamp=datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + call_category=LLMCallCategory.PRODUCTIVE, + ) + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(record) + + results = await repo.query() + assert results[0].call_category == LLMCallCategory.PRODUCTIVE + + async def test_round_trip_null_call_category( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record()) + + results = await repo.query() + assert results[0].call_category is None + + +# ── MessageRepository ──────────────────────────────────────────── + + +@pytest.mark.unit +class TestSQLiteMessageRepository: + async def test_save_and_get_history( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteMessageRepository(migrated_db) + msg = make_message() + await repo.save(msg) + + history = await repo.get_history("general") + assert len(history) == 1 + assert history[0].content == "Hello, world!" + + async def test_history_ordered_newest_first( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteMessageRepository(migrated_db) + msg1 = make_message( + timestamp=datetime(2026, 3, 1, 10, 0, 0, tzinfo=UTC), + content="first", + ) + msg2 = make_message( + timestamp=datetime(2026, 3, 1, 12, 0, 0, tzinfo=UTC), + content="second", + ) + await repo.save(msg1) + await repo.save(msg2) + + history = await repo.get_history("general") + assert len(history) == 2 + assert history[0].content == "second" + assert history[1].content == "first" + + async def test_history_with_limit(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteMessageRepository(migrated_db) + for i in range(5): + await repo.save( + make_message( + timestamp=datetime(2026, 3, 1, i, 0, 0, tzinfo=UTC), + content=f"msg-{i}", + ) + ) + + history = await repo.get_history("general", limit=2) + assert len(history) == 2 + + async def test_history_filters_by_channel( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteMessageRepository(migrated_db) + await repo.save(make_message(channel="general")) + await repo.save(make_message(channel="engineering")) + + general = await repo.get_history("general") + assert len(general) == 1 + assert general[0].channel == "general" + + async def test_duplicate_message_rejected( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteMessageRepository(migrated_db) + fixed_id = uuid4() + msg = make_message(msg_id=fixed_id) + await repo.save(msg) + + with pytest.raises(DuplicateRecordError, match="already exists"): + await repo.save(make_message(msg_id=fixed_id)) + + async def test_round_trip_alias_from_field( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Verify the sender/'from' alias round-trips correctly.""" + repo = SQLiteMessageRepository(migrated_db) + msg = make_message(sender="charlie") + await repo.save(msg) + + history = await repo.get_history("general") + assert history[0].sender == "charlie" + + async def test_round_trip_with_attachments_and_metadata( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Verify nested JSON fields round-trip correctly.""" + from ai_company.communication.enums import AttachmentType + + msg = make_message( + metadata=MessageMetadata( + task_id="task-001", + project_id="proj-a", + tokens_used=100, + cost_usd=0.01, + extra=(("key1", "val1"),), + ), + ) + # Add attachments via model_validate to bypass frozen + msg_data = msg.model_dump(mode="json") + msg_data["attachments"] = [ + {"type": AttachmentType.FILE, "ref": "/path/to/file"}, + ] + msg_with_attach = Message.model_validate(msg_data) + + repo = SQLiteMessageRepository(migrated_db) + await repo.save(msg_with_attach) + + history = await repo.get_history("general") + result = history[0] + assert len(result.attachments) == 1 + assert result.attachments[0].type == AttachmentType.FILE + assert result.attachments[0].ref == "/path/to/file" + assert result.metadata.task_id == "task-001" + assert result.metadata.extra == (("key1", "val1"),) + + async def test_round_trip_uuid_id(self, migrated_db: aiosqlite.Connection) -> None: + """Verify UUID id survives round-trip.""" + repo = SQLiteMessageRepository(migrated_db) + msg = make_message() + original_id = msg.id + await repo.save(msg) + + history = await repo.get_history("general") + assert history[0].id == original_id diff --git a/tests/unit/persistence/test_config.py b/tests/unit/persistence/test_config.py new file mode 100644 index 0000000000..7fb552884a --- /dev/null +++ b/tests/unit/persistence/test_config.py @@ -0,0 +1,78 @@ +"""Tests for persistence configuration models.""" + +import pytest +from pydantic import ValidationError + +from ai_company.persistence.config import PersistenceConfig, SQLiteConfig + + +@pytest.mark.unit +class TestSQLiteConfig: + def test_defaults(self) -> None: + cfg = SQLiteConfig() + assert cfg.path == "ai-company.db" + assert cfg.wal_mode is True + assert cfg.journal_size_limit == 67_108_864 + + def test_custom_values(self) -> None: + cfg = SQLiteConfig( + path="/data/test.db", + wal_mode=False, + journal_size_limit=1024, + ) + assert cfg.path == "/data/test.db" + assert cfg.wal_mode is False + assert cfg.journal_size_limit == 1024 + + def test_memory_path(self) -> None: + cfg = SQLiteConfig(path=":memory:") + assert cfg.path == ":memory:" + + def test_frozen(self) -> None: + cfg = SQLiteConfig() + with pytest.raises(ValidationError): + cfg.path = "other.db" # type: ignore[misc] + + def test_blank_path_rejected(self) -> None: + with pytest.raises(ValidationError): + SQLiteConfig(path="") + + def test_whitespace_path_rejected(self) -> None: + with pytest.raises(ValidationError, match="whitespace-only"): + SQLiteConfig(path=" ") + + def test_negative_journal_size_rejected(self) -> None: + with pytest.raises(ValidationError): + SQLiteConfig(journal_size_limit=-1) + + +@pytest.mark.unit +class TestPersistenceConfig: + def test_defaults(self) -> None: + cfg = PersistenceConfig() + assert cfg.backend == "sqlite" + assert isinstance(cfg.sqlite, SQLiteConfig) + + def test_sqlite_backend_valid(self) -> None: + cfg = PersistenceConfig(backend="sqlite") + assert cfg.backend == "sqlite" + + def test_unknown_backend_rejected(self) -> None: + with pytest.raises(ValidationError, match="Unknown persistence backend"): + PersistenceConfig(backend="postgres") + + def test_blank_backend_rejected(self) -> None: + with pytest.raises(ValidationError): + PersistenceConfig(backend="") + + def test_frozen(self) -> None: + cfg = PersistenceConfig() + with pytest.raises(ValidationError): + cfg.backend = "other" # type: ignore[misc] + + def test_custom_sqlite_config(self) -> None: + cfg = PersistenceConfig( + sqlite=SQLiteConfig(path="data/test.db", wal_mode=False), + ) + assert cfg.sqlite.path == "data/test.db" + assert cfg.sqlite.wal_mode is False diff --git a/tests/unit/persistence/test_errors.py b/tests/unit/persistence/test_errors.py new file mode 100644 index 0000000000..d66a3d930a --- /dev/null +++ b/tests/unit/persistence/test_errors.py @@ -0,0 +1,57 @@ +"""Tests for persistence error hierarchy.""" + +import pytest + +from ai_company.persistence.errors import ( + DuplicateRecordError, + MigrationError, + PersistenceConnectionError, + PersistenceError, + QueryError, + RecordNotFoundError, +) + + +@pytest.mark.unit +class TestPersistenceErrorHierarchy: + def test_base_is_exception(self) -> None: + assert issubclass(PersistenceError, Exception) + + def test_connection_error_inherits(self) -> None: + assert issubclass(PersistenceConnectionError, PersistenceError) + + def test_migration_error_inherits(self) -> None: + assert issubclass(MigrationError, PersistenceError) + + def test_record_not_found_inherits(self) -> None: + assert issubclass(RecordNotFoundError, PersistenceError) + + def test_duplicate_record_inherits(self) -> None: + assert issubclass(DuplicateRecordError, PersistenceError) + + def test_query_error_inherits(self) -> None: + assert issubclass(QueryError, PersistenceError) + + @pytest.mark.parametrize( + "cls", + [ + PersistenceConnectionError, + MigrationError, + RecordNotFoundError, + DuplicateRecordError, + QueryError, + ], + ) + def test_catch_all_with_base(self, cls: type[PersistenceError]) -> None: + """All subclasses are caught by except PersistenceError.""" + msg = "test" + with pytest.raises(PersistenceError): + raise cls(msg) + + def test_error_message_preserved(self) -> None: + err = PersistenceConnectionError("db down") + assert str(err) == "db down" + + def test_does_not_shadow_builtin(self) -> None: + """Our error is NOT the builtin ConnectionError.""" + assert PersistenceConnectionError is not ConnectionError # type: ignore[comparison-overlap] diff --git a/tests/unit/persistence/test_factory.py b/tests/unit/persistence/test_factory.py new file mode 100644 index 0000000000..4d4a61db13 --- /dev/null +++ b/tests/unit/persistence/test_factory.py @@ -0,0 +1,77 @@ +"""Tests for persistence backend factory.""" + +import pytest + +from ai_company.persistence.config import PersistenceConfig, SQLiteConfig +from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.factory import create_backend +from ai_company.persistence.protocol import PersistenceBackend +from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend + + +@pytest.mark.unit +class TestCreateBackend: + def test_creates_sqlite_backend(self) -> None: + config = PersistenceConfig( + backend="sqlite", + sqlite=SQLiteConfig(path=":memory:"), + ) + backend = create_backend(config) + assert isinstance(backend, SQLitePersistenceBackend) + assert backend.backend_name == "sqlite" + assert backend.is_connected is False + + def test_returns_protocol_type(self) -> None: + config = PersistenceConfig( + sqlite=SQLiteConfig(path=":memory:"), + ) + backend = create_backend(config) + assert isinstance(backend, PersistenceBackend) + + def test_passes_sqlite_config(self) -> None: + config = PersistenceConfig( + sqlite=SQLiteConfig(path="data/company.db", wal_mode=False), + ) + backend = create_backend(config) + assert isinstance(backend, SQLitePersistenceBackend) + + def test_unknown_backend_raises(self) -> None: + """Bypass validation via model_copy to test the factory guard.""" + config = PersistenceConfig() + bad_config = config.model_copy(update={"backend": "postgres"}) + with pytest.raises( + PersistenceConnectionError, + match="Unknown persistence backend", + ): + create_backend(bad_config) + + async def test_multi_tenancy_separate_databases(self) -> None: + """Each company config creates an isolated backend instance.""" + config_a = PersistenceConfig( + sqlite=SQLiteConfig(path=":memory:"), + ) + config_b = PersistenceConfig( + sqlite=SQLiteConfig(path=":memory:"), + ) + + backend_a = create_backend(config_a) + backend_b = create_backend(config_b) + + # They are separate instances + assert backend_a is not backend_b + + # Each can connect and operate independently + await backend_a.connect() + await backend_a.migrate() + await backend_b.connect() + await backend_b.migrate() + + # Verify isolation — data in one doesn't affect the other + from tests.unit.persistence.conftest import make_task + + await backend_a.tasks.save(make_task(task_id="company-a-task")) + assert await backend_a.tasks.get("company-a-task") is not None + assert await backend_b.tasks.get("company-a-task") is None + + await backend_a.disconnect() + await backend_b.disconnect() diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py new file mode 100644 index 0000000000..ec40d19452 --- /dev/null +++ b/tests/unit/persistence/test_protocol.py @@ -0,0 +1,116 @@ +"""Tests for persistence protocol compliance.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.persistence.protocol import PersistenceBackend +from ai_company.persistence.repositories import ( + CostRecordRepository, + MessageRepository, + TaskRepository, +) + +if TYPE_CHECKING: + from ai_company.budget.cost_record import CostRecord + from ai_company.communication.message import Message + from ai_company.core.enums import TaskStatus + from ai_company.core.task import Task + + +class _FakeTaskRepository: + async def save(self, task: Task) -> None: + pass + + async def get(self, task_id: str) -> Task | None: + return None + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + return () + + async def delete(self, task_id: str) -> bool: + return False + + +class _FakeCostRecordRepository: + async def save(self, record: CostRecord) -> None: + pass + + async def query( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> tuple[CostRecord, ...]: + return () + + async def aggregate(self, *, agent_id: str | None = None) -> float: + return 0.0 + + +class _FakeMessageRepository: + async def save(self, message: Message) -> None: + pass + + async def get_history( + self, + channel: str, + *, + limit: int | None = None, + ) -> tuple[Message, ...]: + return () + + +class _FakeBackend: + async def connect(self) -> None: + pass + + async def disconnect(self) -> None: + pass + + async def health_check(self) -> bool: + return True + + async def migrate(self) -> None: + pass + + @property + def is_connected(self) -> bool: + return True + + @property + def backend_name(self) -> str: + return "fake" + + @property + def tasks(self) -> _FakeTaskRepository: + return _FakeTaskRepository() + + @property + def cost_records(self) -> _FakeCostRecordRepository: + return _FakeCostRecordRepository() + + @property + def messages(self) -> _FakeMessageRepository: + return _FakeMessageRepository() + + +@pytest.mark.unit +class TestProtocolCompliance: + def test_fake_backend_is_persistence_backend(self) -> None: + assert isinstance(_FakeBackend(), PersistenceBackend) + + def test_fake_task_repo_is_task_repository(self) -> None: + assert isinstance(_FakeTaskRepository(), TaskRepository) + + def test_fake_cost_repo_is_cost_record_repository(self) -> None: + assert isinstance(_FakeCostRecordRepository(), CostRecordRepository) + + def test_fake_message_repo_is_message_repository(self) -> None: + assert isinstance(_FakeMessageRepository(), MessageRepository) diff --git a/uv.lock b/uv.lock index 71ba180f77..ea7e2c73c6 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,7 @@ requires-python = ">=3.14" name = "ai-company" source = { editable = "." } dependencies = [ + { name = "aiosqlite" }, { name = "jinja2" }, { name = "jsonschema" }, { name = "litellm" }, @@ -44,6 +45,7 @@ test = [ [package.metadata] requires-dist = [ + { name = "aiosqlite", specifier = "==0.21.0" }, { name = "jinja2", specifier = "==3.1.6" }, { name = "jsonschema", specifier = "==4.26.0" }, { name = "litellm", specifier = "==1.82.0" }, @@ -152,6 +154,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.21.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/13/7d/8bca2bf9a247c2c5dfeec1d7a5f40db6518f88d314b8bca9da29670d2671/aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3", size = 13454, upload-time = "2025-02-03T07:30:16.235Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" From a010a76253859a1a277f2f15a084bbd3b3e23f4e Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Mon, 9 Mar 2026 07:25:44 +0100 Subject: [PATCH 3/5] fix: address 33 PR review findings from local agents and external reviewers - Fix PEP 649 runtime annotation errors in factory, protocol, repositories - Harden path traversal validation in SQLiteConfig (reject .. components) - Add deserialization error logging in SQLite repositories - Fix disconnect() to only log success on successful close - Add rollback error protection in migrations - Add set_user_version input validation - Add limit validation in MessageRepository.get_history - Fix docstrings for append-only repos (CostRecord, Message) - Add fixture cleanup (try/finally) to prevent connection leaks - Add auto-fill assignee for statuses requiring one in make_task - Update DESIGN_SPEC and README to reflect persistence completion - Add comprehensive test coverage: deserialization failures, protocol compliance, disconnect error handling, PRAGMA failure cleanup, migration version validation, invalid limit, path traversal rejection --- DESIGN_SPEC.md | 6 +- README.md | 3 +- .../observability/events/persistence.py | 5 + src/ai_company/persistence/config.py | 21 ++++- src/ai_company/persistence/errors.py | 3 +- src/ai_company/persistence/factory.py | 8 +- src/ai_company/persistence/protocol.py | 15 ++- src/ai_company/persistence/repositories.py | 15 ++- src/ai_company/persistence/sqlite/backend.py | 17 +++- .../persistence/sqlite/migrations.py | 25 ++++- .../persistence/sqlite/repositories.py | 49 +++++++++- .../persistence/test_sqlite_integration.py | 11 ++- tests/unit/observability/test_events.py | 85 ++++++++--------- tests/unit/persistence/conftest.py | 21 ++++- tests/unit/persistence/sqlite/conftest.py | 18 ++-- tests/unit/persistence/sqlite/test_backend.py | 37 ++++++++ .../persistence/sqlite/test_migrations.py | 12 +++ .../persistence/sqlite/test_repositories.py | 91 +++++++++++++++++++ tests/unit/persistence/test_config.py | 8 ++ tests/unit/persistence/test_errors.py | 47 +++++----- 20 files changed, 372 insertions(+), 125 deletions(-) diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 604290f442..fbd5505fae 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -81,7 +81,7 @@ The MVP validates the core hypothesis: **a single agent can complete a real task > **Implementation snapshot (2026-03-08):** > - **Done:** M0–M4 (tooling, config/core, providers, single-agent engine, multi-agent orchestration). Memory layer backend selected ([ADR-001](docs/decisions/ADR-001-memory-layer.md)). -> - **In progress:** M5 — memory layer implementation, persistence, budget enforcement. +> - **In progress:** M5 — memory layer implementation, budget enforcement. Persistence backend (§7.5) completed. > - **Not started (mostly placeholders):** M6 API/CLI surface, M7 security + approval system. ### 1.5 Configuration Philosophy @@ -1438,13 +1438,13 @@ persistence: | `Task` | `core/task.py` | `TaskRepository` | by status, by assignee, by project | | `CostRecord` | `budget/cost_record.py` | `CostRecordRepository` | by agent, by task, aggregations | | `Message` | `communication/message.py` | `MessageRepository` | by channel, by sender, time range | -| Audit entries | `security/` | `AuditRepository` | by agent, by action type, time range | +| Audit entries (planned — M7) | `security/` | `AuditRepository` (planned) | by agent, by action type, time range | #### Migration Strategy - Migrations run programmatically at startup via `PersistenceBackend.migrate()` - Initial migration creates all tables -- Versioned migration scripts tracked in `persistence/migrations/` +- Versioned migrations implemented per-backend (e.g. `persistence/sqlite/migrations.py` for SQLite) - SQLite uses `user_version` pragma for version tracking; PostgreSQL/MariaDB use a migrations table #### Key Principles diff --git a/README.md b/README.md index 7fe72ce0f2..3570a76166 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Multi-Agent Core (M4)** - Message bus, delegation with loop prevention, conflict resolution, meeting protocols - **Task Intelligence (M4)** - Task decomposition, routing, assignment strategies, workspace isolation via git worktrees - **Templates** - Built-in templates, inheritance/merge, rendering, personality presets +- **Persistence Layer (M5)** - Pluggable `PersistenceBackend` protocol with SQLite backend (aiosqlite), repository protocols, schema migrations ### Not implemented yet (planned milestones) @@ -43,7 +44,7 @@ AI Company lets you spin up a virtual organization staffed entirely by AI agents - **Mem0** for agent memory (initial backend; custom stack future — see [ADR-001](docs/decisions/ADR-001-memory-layer.md)) - **MCP** for tool integration (planned) - **Vue 3** for web dashboard (planned) -- **SQLite** → PostgreSQL for data persistence (planned) +- **SQLite** (aiosqlite) → PostgreSQL for operational data persistence ## System Requirements diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py index 95c385fe05..afd2876981 100644 --- a/src/ai_company/observability/events/persistence.py +++ b/src/ai_company/observability/events/persistence.py @@ -50,8 +50,13 @@ "persistence.cost_record.aggregate_failed" ) +PERSISTENCE_TASK_DESERIALIZE_FAILED: Final[str] = "persistence.task.deserialize_failed" + PERSISTENCE_MESSAGE_SAVED: Final[str] = "persistence.message.saved" PERSISTENCE_MESSAGE_SAVE_FAILED: Final[str] = "persistence.message.save_failed" PERSISTENCE_MESSAGE_DUPLICATE: Final[str] = "persistence.message.duplicate" PERSISTENCE_MESSAGE_HISTORY_FETCHED: Final[str] = "persistence.message.history_fetched" PERSISTENCE_MESSAGE_HISTORY_FAILED: Final[str] = "persistence.message.history_failed" +PERSISTENCE_MESSAGE_DESERIALIZE_FAILED: Final[str] = ( + "persistence.message.deserialize_failed" +) diff --git a/src/ai_company/persistence/config.py b/src/ai_company/persistence/config.py index bef48063fa..61fc6544ae 100644 --- a/src/ai_company/persistence/config.py +++ b/src/ai_company/persistence/config.py @@ -4,7 +4,7 @@ backend-specific settings. """ -from pathlib import PurePosixPath +from pathlib import PurePosixPath, PureWindowsPath from typing import ClassVar, Self from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -45,15 +45,26 @@ class SQLiteConfig(BaseModel): ) @model_validator(mode="after") - def _resolve_path(self) -> Self: - """Resolve relative paths to absolute to prevent traversal ambiguity. + def _reject_traversal(self) -> Self: + """Reject parent-directory traversal to prevent path escapes. The special ``:memory:`` identifier is passed through unchanged. + Paths containing ``..`` components are rejected to prevent + path-traversal attacks in multi-tenant configs. Absolute paths + are allowed for operational flexibility. """ if self.path == ":memory:": return self - resolved = str(PurePosixPath(self.path)) - object.__setattr__(self, "path", resolved) + parts = PureWindowsPath(self.path).parts + PurePosixPath(self.path).parts + if ".." in parts: + msg = "Database path must not contain parent-directory traversal (..)" + logger.warning( + CONFIG_VALIDATION_FAILED, + field="path", + value=self.path, + reason=msg, + ) + raise ValueError(msg) return self diff --git a/src/ai_company/persistence/errors.py b/src/ai_company/persistence/errors.py index 6ea2fdb800..27b675364a 100644 --- a/src/ai_company/persistence/errors.py +++ b/src/ai_company/persistence/errors.py @@ -20,8 +20,9 @@ class MigrationError(PersistenceError): class RecordNotFoundError(PersistenceError): """Raised when a requested record does not exist. + Currently unused — ``TaskRepository.get()`` returns ``None`` + on miss, and other repositories use collection-returning queries. Reserved for future strict-fetch methods (e.g. ``get_or_raise``). - Current repository protocols return ``None`` on miss instead. """ diff --git a/src/ai_company/persistence/factory.py b/src/ai_company/persistence/factory.py index da8aef3043..791751f892 100644 --- a/src/ai_company/persistence/factory.py +++ b/src/ai_company/persistence/factory.py @@ -6,20 +6,16 @@ company's ``RootConfig``. """ -from typing import TYPE_CHECKING - from ai_company.observability import get_logger from ai_company.observability.events.persistence import ( PERSISTENCE_BACKEND_CREATED, PERSISTENCE_BACKEND_UNKNOWN, ) +from ai_company.persistence.config import PersistenceConfig # noqa: TC001 from ai_company.persistence.errors import PersistenceConnectionError +from ai_company.persistence.protocol import PersistenceBackend # noqa: TC001 from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend -if TYPE_CHECKING: - from ai_company.persistence.config import PersistenceConfig - from ai_company.persistence.protocol import PersistenceBackend - logger = get_logger(__name__) diff --git a/src/ai_company/persistence/protocol.py b/src/ai_company/persistence/protocol.py index 869155b068..28ac619b84 100644 --- a/src/ai_company/persistence/protocol.py +++ b/src/ai_company/persistence/protocol.py @@ -4,14 +4,13 @@ management. Repository protocols provide entity-level access. """ -from typing import TYPE_CHECKING, Protocol, runtime_checkable - -if TYPE_CHECKING: - from ai_company.persistence.repositories import ( - CostRecordRepository, - MessageRepository, - TaskRepository, - ) +from typing import Protocol, runtime_checkable + +from ai_company.persistence.repositories import ( + CostRecordRepository, # noqa: TC001 + MessageRepository, # noqa: TC001 + TaskRepository, # noqa: TC001 +) @runtime_checkable diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py index 1f4e6d4713..e55ba48e18 100644 --- a/src/ai_company/persistence/repositories.py +++ b/src/ai_company/persistence/repositories.py @@ -4,13 +4,12 @@ only on abstract interfaces, never on a concrete backend. """ -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable -if TYPE_CHECKING: - from ai_company.budget.cost_record import CostRecord - from ai_company.communication.message import Message - from ai_company.core.enums import TaskStatus - from ai_company.core.task import Task +from ai_company.budget.cost_record import CostRecord # noqa: TC001 +from ai_company.communication.message import Message # noqa: TC001 +from ai_company.core.enums import TaskStatus # noqa: TC001 +from ai_company.core.task import Task # noqa: TC001 @runtime_checkable @@ -81,7 +80,7 @@ async def delete(self, task_id: str) -> bool: @runtime_checkable class CostRecordRepository(Protocol): - """CRUD + aggregation interface for CostRecord persistence.""" + """Append-only persistence + query/aggregation for CostRecord.""" async def save(self, record: CostRecord) -> None: """Persist a cost record (append-only). @@ -131,7 +130,7 @@ async def aggregate(self, *, agent_id: str | None = None) -> float: @runtime_checkable class MessageRepository(Protocol): - """CRUD + query interface for Message persistence.""" + """Write + history query interface for Message persistence.""" async def save(self, message: Message) -> None: """Persist a message. diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py index 7b93db0b0f..dfa4d93a36 100644 --- a/src/ai_company/persistence/sqlite/backend.py +++ b/src/ai_company/persistence/sqlite/backend.py @@ -1,6 +1,5 @@ """SQLite persistence backend implementation.""" -import contextlib import sqlite3 from typing import TYPE_CHECKING @@ -36,8 +35,8 @@ class SQLitePersistenceBackend: """SQLite implementation of the PersistenceBackend protocol. Uses a single ``aiosqlite.Connection`` with WAL mode enabled by - default for concurrent read performance (configurable via - ``SQLiteConfig.wal_mode``). + default for file-based databases (in-memory databases do not + support WAL). Configurable via ``SQLiteConfig.wal_mode``. Args: config: SQLite-specific configuration. @@ -93,8 +92,16 @@ async def connect(self) -> None: error=str(exc), ) if self._db is not None: - with contextlib.suppress(sqlite3.Error, OSError): + try: await self._db.close() + except (sqlite3.Error, OSError) as cleanup_exc: + logger.warning( + PERSISTENCE_BACKEND_DISCONNECT_ERROR, + path=self._config.path, + error=str(cleanup_exc), + error_type=type(cleanup_exc).__name__, + context="cleanup_after_connect_failure", + ) self._clear_state() msg = "Failed to connect to persistence backend" raise PersistenceConnectionError(msg) from exc @@ -109,6 +116,7 @@ async def disconnect(self) -> None: logger.info(PERSISTENCE_BACKEND_DISCONNECTING, path=self._config.path) try: await self._db.close() + logger.info(PERSISTENCE_BACKEND_DISCONNECTED, path=self._config.path) except (sqlite3.Error, OSError) as exc: logger.warning( PERSISTENCE_BACKEND_DISCONNECT_ERROR, @@ -118,7 +126,6 @@ async def disconnect(self) -> None: ) finally: self._clear_state() - logger.info(PERSISTENCE_BACKEND_DISCONNECTED, path=self._config.path) async def health_check(self) -> bool: """Check database connectivity.""" diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py index 302d5bcd35..bd65e8198b 100644 --- a/src/ai_company/persistence/sqlite/migrations.py +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -110,6 +110,11 @@ async def set_user_version(db: aiosqlite.Connection, version: int) -> None: """ if not isinstance(version, int) or version < 0: msg = f"Schema version must be a non-negative integer, got {version!r}" + logger.error( + PERSISTENCE_MIGRATION_FAILED, + error=msg, + version=version, + ) raise MigrationError(msg) # PRAGMA does not support parameterized queries; version is validated above. await db.execute(f"PRAGMA user_version = {version}") @@ -141,7 +146,12 @@ async def run_migrations(db: aiosqlite.Connection) -> None: Raises: MigrationError: If any migration step fails. """ - current = await get_user_version(db) + try: + current = await get_user_version(db) + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to read current schema version" + logger.exception(PERSISTENCE_MIGRATION_FAILED, error=str(exc)) + raise MigrationError(msg) from exc if current >= SCHEMA_VERSION: logger.debug( @@ -165,8 +175,17 @@ async def run_migrations(db: aiosqlite.Connection) -> None: await set_user_version(db, SCHEMA_VERSION) await db.commit() - except (sqlite3.Error, aiosqlite.Error) as exc: - await db.rollback() + except (sqlite3.Error, aiosqlite.Error, MigrationError) as exc: + try: + await db.rollback() + except (sqlite3.Error, aiosqlite.Error) as rollback_exc: + logger.error( # noqa: TRY400 + PERSISTENCE_MIGRATION_FAILED, + error=f"Rollback also failed: {rollback_exc}", + original_error=str(exc), + ) + if isinstance(exc, MigrationError): + raise msg = f"Migration to version {SCHEMA_VERSION} failed" logger.exception( PERSISTENCE_MIGRATION_FAILED, diff --git a/src/ai_company/persistence/sqlite/repositories.py b/src/ai_company/persistence/sqlite/repositories.py index 65dee51c67..34f59b2739 100644 --- a/src/ai_company/persistence/sqlite/repositories.py +++ b/src/ai_company/persistence/sqlite/repositories.py @@ -18,6 +18,7 @@ PERSISTENCE_COST_RECORD_QUERY_FAILED, PERSISTENCE_COST_RECORD_SAVE_FAILED, PERSISTENCE_COST_RECORD_SAVED, + PERSISTENCE_MESSAGE_DESERIALIZE_FAILED, PERSISTENCE_MESSAGE_DUPLICATE, PERSISTENCE_MESSAGE_HISTORY_FAILED, PERSISTENCE_MESSAGE_HISTORY_FETCHED, @@ -25,6 +26,7 @@ PERSISTENCE_MESSAGE_SAVED, PERSISTENCE_TASK_DELETE_FAILED, PERSISTENCE_TASK_DELETED, + PERSISTENCE_TASK_DESERIALIZE_FAILED, PERSISTENCE_TASK_FETCH_FAILED, PERSISTENCE_TASK_FETCHED, PERSISTENCE_TASK_LIST_FAILED, @@ -141,9 +143,19 @@ def _row_to_task(self, row: aiosqlite.Row) -> Task: for field in self._JSON_FIELDS: data[field] = json.loads(data[field]) return Task.model_validate(data) - except (json.JSONDecodeError, ValidationError) as exc: + except ( + json.JSONDecodeError, + ValidationError, + KeyError, + TypeError, + ) as exc: task_id = row["id"] if row else "unknown" msg = f"Failed to deserialize task {task_id!r}" + logger.exception( + PERSISTENCE_TASK_DESERIALIZE_FAILED, + task_id=task_id, + error=str(exc), + ) raise QueryError(msg) from exc _TASK_COLUMNS = """\ @@ -385,9 +397,23 @@ async def save(self, message: Message) -> None: ) await self._db.commit() except sqlite3.IntegrityError as exc: - err_msg = f"Message {msg_id} already exists" - logger.warning(PERSISTENCE_MESSAGE_DUPLICATE, message_id=msg_id) - raise DuplicateRecordError(err_msg) from exc + error_text = str(exc) + is_duplicate_id = ( + "UNIQUE constraint failed: messages.id" in error_text + or "PRIMARY KEY" in error_text + ) + if is_duplicate_id: + err_msg = f"Message {msg_id} already exists" + logger.warning(PERSISTENCE_MESSAGE_DUPLICATE, message_id=msg_id) + raise DuplicateRecordError(err_msg) from exc + # Other integrity errors (NOT NULL, different UNIQUE). + msg = f"Failed to save message {msg_id!r}" + logger.exception( + PERSISTENCE_MESSAGE_SAVE_FAILED, + message_id=msg_id, + error=error_text, + ) + raise QueryError(msg) from exc except (sqlite3.Error, aiosqlite.Error) as exc: msg = f"Failed to save message {msg_id!r}" logger.exception( @@ -407,9 +433,19 @@ def _row_to_message(self, row: aiosqlite.Row) -> Message: data["attachments"] = json.loads(data["attachments"]) data["metadata"] = json.loads(data["metadata"]) return Message.model_validate(data) - except (json.JSONDecodeError, ValidationError) as exc: + except ( + json.JSONDecodeError, + ValidationError, + KeyError, + TypeError, + ) as exc: msg_id = row["id"] if row else "unknown" msg = f"Failed to deserialize message {msg_id!r}" + logger.exception( + PERSISTENCE_MESSAGE_DESERIALIZE_FAILED, + message_id=msg_id, + error=str(exc), + ) raise QueryError(msg) from exc async def get_history( @@ -419,6 +455,9 @@ async def get_history( limit: int | None = None, ) -> tuple[Message, ...]: """Retrieve message history for a channel, newest first.""" + if limit is not None and limit < 1: + msg = f"limit must be a positive integer, got {limit}" + raise QueryError(msg) sql = """\ SELECT id, timestamp, sender, "to", type, priority, channel, content, attachments, metadata diff --git a/tests/integration/persistence/test_sqlite_integration.py b/tests/integration/persistence/test_sqlite_integration.py index 131952f977..8e63f6b111 100644 --- a/tests/integration/persistence/test_sqlite_integration.py +++ b/tests/integration/persistence/test_sqlite_integration.py @@ -2,6 +2,7 @@ from pathlib import Path +import aiosqlite import pytest from ai_company.persistence.config import SQLiteConfig @@ -21,11 +22,16 @@ async def test_wal_file_created(self, db_path: str) -> None: task = make_task() await backend.tasks.save(task) - # WAL file may or may not exist depending on checkpoint behavior, - # but the db file should exist assert Path(db_path).exists() # noqa: ASYNC240 await backend.disconnect() + # Verify WAL mode by querying the journal_mode pragma + async with aiosqlite.connect(db_path) as db: + cursor = await db.execute("PRAGMA journal_mode") + row = await cursor.fetchone() + assert row is not None + assert row[0] == "wal" + async def test_data_persists_across_reconnect(self, db_path: str) -> None: """Data written before disconnect is readable after reconnect.""" backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) @@ -101,7 +107,6 @@ async def test_concurrent_reads(self, db_path: str) -> None: async def read_all() -> int: b = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) await b.connect() - await b.migrate() tasks = await b.tasks.list_tasks() await b.disconnect() return len(tasks) diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 2313129af1..a5f6ca811a 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -367,45 +367,46 @@ def test_workspace_events_exist(self) -> None: ) assert WORKSPACE_GROUP_SETUP_FAILED == "workspace.group.setup.failed" - def test_persistence_events_exist(self) -> None: - from ai_company.observability.events.persistence import ( - PERSISTENCE_BACKEND_CONNECTED, - PERSISTENCE_BACKEND_CONNECTING, - PERSISTENCE_BACKEND_DISCONNECTED, - PERSISTENCE_BACKEND_DISCONNECTING, - PERSISTENCE_BACKEND_HEALTH_CHECK, - PERSISTENCE_COST_RECORD_AGGREGATED, - PERSISTENCE_COST_RECORD_QUERIED, - PERSISTENCE_COST_RECORD_SAVED, - PERSISTENCE_MESSAGE_HISTORY_FETCHED, - PERSISTENCE_MESSAGE_SAVED, - PERSISTENCE_MIGRATION_COMPLETED, - PERSISTENCE_MIGRATION_SKIPPED, - PERSISTENCE_MIGRATION_STARTED, - PERSISTENCE_TASK_DELETED, - PERSISTENCE_TASK_FETCHED, - PERSISTENCE_TASK_LISTED, - PERSISTENCE_TASK_SAVED, - ) - - assert PERSISTENCE_BACKEND_CONNECTING == "persistence.backend.connecting" - assert PERSISTENCE_BACKEND_CONNECTED == "persistence.backend.connected" - assert PERSISTENCE_BACKEND_DISCONNECTING == "persistence.backend.disconnecting" - assert PERSISTENCE_BACKEND_DISCONNECTED == "persistence.backend.disconnected" - assert PERSISTENCE_BACKEND_HEALTH_CHECK == "persistence.backend.health_check" - assert PERSISTENCE_MIGRATION_STARTED == "persistence.migration.started" - assert PERSISTENCE_MIGRATION_COMPLETED == "persistence.migration.completed" - assert PERSISTENCE_MIGRATION_SKIPPED == "persistence.migration.skipped" - assert PERSISTENCE_TASK_SAVED == "persistence.task.saved" - assert PERSISTENCE_TASK_FETCHED == "persistence.task.fetched" - assert PERSISTENCE_TASK_LISTED == "persistence.task.listed" - assert PERSISTENCE_TASK_DELETED == "persistence.task.deleted" - assert PERSISTENCE_COST_RECORD_SAVED == "persistence.cost_record.saved" - assert PERSISTENCE_COST_RECORD_QUERIED == "persistence.cost_record.queried" - assert PERSISTENCE_COST_RECORD_AGGREGATED == ( - "persistence.cost_record.aggregated" - ) - assert PERSISTENCE_MESSAGE_SAVED == "persistence.message.saved" - assert PERSISTENCE_MESSAGE_HISTORY_FETCHED == ( - "persistence.message.history_fetched" - ) + @pytest.mark.parametrize( + ("constant_name", "expected"), + [ + ("PERSISTENCE_BACKEND_CONNECTING", "persistence.backend.connecting"), + ("PERSISTENCE_BACKEND_CONNECTED", "persistence.backend.connected"), + ("PERSISTENCE_BACKEND_DISCONNECTING", "persistence.backend.disconnecting"), + ("PERSISTENCE_BACKEND_DISCONNECTED", "persistence.backend.disconnected"), + ("PERSISTENCE_BACKEND_HEALTH_CHECK", "persistence.backend.health_check"), + ("PERSISTENCE_MIGRATION_STARTED", "persistence.migration.started"), + ("PERSISTENCE_MIGRATION_COMPLETED", "persistence.migration.completed"), + ("PERSISTENCE_MIGRATION_SKIPPED", "persistence.migration.skipped"), + ("PERSISTENCE_TASK_SAVED", "persistence.task.saved"), + ("PERSISTENCE_TASK_FETCHED", "persistence.task.fetched"), + ("PERSISTENCE_TASK_LISTED", "persistence.task.listed"), + ("PERSISTENCE_TASK_DELETED", "persistence.task.deleted"), + ( + "PERSISTENCE_TASK_DESERIALIZE_FAILED", + "persistence.task.deserialize_failed", + ), + ("PERSISTENCE_COST_RECORD_SAVED", "persistence.cost_record.saved"), + ( + "PERSISTENCE_COST_RECORD_QUERIED", + "persistence.cost_record.queried", + ), + ( + "PERSISTENCE_COST_RECORD_AGGREGATED", + "persistence.cost_record.aggregated", + ), + ("PERSISTENCE_MESSAGE_SAVED", "persistence.message.saved"), + ( + "PERSISTENCE_MESSAGE_HISTORY_FETCHED", + "persistence.message.history_fetched", + ), + ( + "PERSISTENCE_MESSAGE_DESERIALIZE_FAILED", + "persistence.message.deserialize_failed", + ), + ], + ) + def test_persistence_events_exist(self, constant_name: str, expected: str) -> None: + from ai_company.observability.events import persistence as mod + + assert getattr(mod, constant_name) == expected diff --git a/tests/unit/persistence/conftest.py b/tests/unit/persistence/conftest.py index b17822f934..38e0dc8b96 100644 --- a/tests/unit/persistence/conftest.py +++ b/tests/unit/persistence/conftest.py @@ -12,6 +12,16 @@ from uuid import UUID +_REQUIRES_ASSIGNEE = frozenset( + { + TaskStatus.ASSIGNED, + TaskStatus.IN_PROGRESS, + TaskStatus.IN_REVIEW, + TaskStatus.COMPLETED, + } +) + + def make_task( # noqa: PLR0913 *, task_id: str = "task-001", @@ -24,7 +34,14 @@ def make_task( # noqa: PLR0913 assigned_to: str | None = None, status: TaskStatus = TaskStatus.CREATED, ) -> Task: - """Build a Task with sensible defaults for persistence tests.""" + """Build a Task with sensible defaults for persistence tests. + + Automatically fills ``assigned_to`` with ``"alice"`` when the + status requires an assignee and none was provided. + """ + effective_assigned_to = assigned_to + if effective_assigned_to is None and status in _REQUIRES_ASSIGNEE: + effective_assigned_to = "alice" return Task( id=task_id, title=title, @@ -33,7 +50,7 @@ def make_task( # noqa: PLR0913 priority=priority, project=project, created_by=created_by, - assigned_to=assigned_to, + assigned_to=effective_assigned_to, status=status, ) diff --git a/tests/unit/persistence/sqlite/conftest.py b/tests/unit/persistence/sqlite/conftest.py index 0849460507..8d08d27d83 100644 --- a/tests/unit/persistence/sqlite/conftest.py +++ b/tests/unit/persistence/sqlite/conftest.py @@ -15,16 +15,20 @@ async def memory_db() -> AsyncGenerator[aiosqlite.Connection]: """Raw in-memory SQLite connection (no migrations).""" db = await aiosqlite.connect(":memory:") - db.row_factory = aiosqlite.Row - yield db - await db.close() + try: + db.row_factory = aiosqlite.Row + yield db + finally: + await db.close() @pytest.fixture async def migrated_db() -> AsyncGenerator[aiosqlite.Connection]: """In-memory SQLite connection with v1 schema applied.""" db = await aiosqlite.connect(":memory:") - db.row_factory = aiosqlite.Row - await run_migrations(db) - yield db - await db.close() + try: + db.row_factory = aiosqlite.Row + await run_migrations(db) + yield db + finally: + await db.close() diff --git a/tests/unit/persistence/sqlite/test_backend.py b/tests/unit/persistence/sqlite/test_backend.py index dd52c0f79b..ddd2339d83 100644 --- a/tests/unit/persistence/sqlite/test_backend.py +++ b/tests/unit/persistence/sqlite/test_backend.py @@ -2,6 +2,7 @@ import sqlite3 +import aiosqlite import pytest from ai_company.persistence.config import SQLiteConfig @@ -122,6 +123,42 @@ async def test_health_check_returns_false_on_error(self) -> None: assert result is False await backend.disconnect() + async def test_disconnect_cleans_up_on_close_error(self) -> None: + """If db.close() raises, state is still cleared.""" + from unittest.mock import AsyncMock + + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + await backend.connect() + assert backend._db is not None + backend._db.close = AsyncMock( # type: ignore[method-assign] + side_effect=sqlite3.OperationalError("close failed") + ) + await backend.disconnect() + assert backend.is_connected is False + + async def test_connect_pragma_failure_cleans_up(self) -> None: + """If PRAGMA fails after connect, backend cleans up and raises.""" + from unittest.mock import AsyncMock, patch + + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:", wal_mode=True)) + + real_connect = aiosqlite.connect + + async def patched_connect(*args: object, **kwargs: object) -> object: + conn = await real_connect(":memory:") + conn.execute = AsyncMock( # type: ignore[method-assign] + side_effect=sqlite3.OperationalError("PRAGMA failed") + ) + return conn + + with ( + patch("aiosqlite.connect", side_effect=patched_connect), + pytest.raises(PersistenceConnectionError), + ): + await backend.connect() + + assert backend.is_connected is False + async def test_protocol_compliance(self) -> None: backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) assert isinstance(backend, PersistenceBackend) diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py index 2a1dd95547..35f61d0b0d 100644 --- a/tests/unit/persistence/sqlite/test_migrations.py +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -29,6 +29,18 @@ async def test_set_and_get_version(self, memory_db: aiosqlite.Connection) -> Non await set_user_version(memory_db, 42) assert await get_user_version(memory_db) == 42 + async def test_set_negative_version_raises( + self, memory_db: aiosqlite.Connection + ) -> None: + with pytest.raises(MigrationError, match="non-negative integer"): + await set_user_version(memory_db, -1) + + async def test_set_non_int_version_raises( + self, memory_db: aiosqlite.Connection + ) -> None: + with pytest.raises(MigrationError, match="non-negative integer"): + await set_user_version(memory_db, 2.5) # type: ignore[arg-type] + @pytest.mark.unit class TestRunMigrations: diff --git a/tests/unit/persistence/sqlite/test_repositories.py b/tests/unit/persistence/sqlite/test_repositories.py index a7dd8796b4..29e9543e4c 100644 --- a/tests/unit/persistence/sqlite/test_repositories.py +++ b/tests/unit/persistence/sqlite/test_repositories.py @@ -438,3 +438,94 @@ async def test_round_trip_uuid_id(self, migrated_db: aiosqlite.Connection) -> No history = await repo.get_history("general") assert history[0].id == original_id + + async def test_get_history_invalid_limit( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Negative or zero limit raises QueryError.""" + from ai_company.persistence.errors import QueryError + + repo = SQLiteMessageRepository(migrated_db) + with pytest.raises(QueryError, match="positive integer"): + await repo.get_history("general", limit=0) + with pytest.raises(QueryError, match="positive integer"): + await repo.get_history("general", limit=-1) + + +@pytest.mark.unit +class TestSQLiteRepoProtocolCompliance: + """Verify SQLite repositories satisfy their protocol interfaces.""" + + async def test_task_repo_implements_protocol( + self, migrated_db: aiosqlite.Connection + ) -> None: + from ai_company.persistence.repositories import TaskRepository + + repo = SQLiteTaskRepository(migrated_db) + assert isinstance(repo, TaskRepository) + + async def test_cost_record_repo_implements_protocol( + self, migrated_db: aiosqlite.Connection + ) -> None: + from ai_company.persistence.repositories import CostRecordRepository + + repo = SQLiteCostRecordRepository(migrated_db) + assert isinstance(repo, CostRecordRepository) + + async def test_message_repo_implements_protocol( + self, migrated_db: aiosqlite.Connection + ) -> None: + from ai_company.persistence.repositories import MessageRepository + + repo = SQLiteMessageRepository(migrated_db) + assert isinstance(repo, MessageRepository) + + +@pytest.mark.unit +class TestDeserializationFailures: + """Test deserialization error paths with corrupt data.""" + + async def test_row_to_task_corrupt_json( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Corrupt JSON in a tuple field raises QueryError.""" + from ai_company.persistence.errors import QueryError + + await migrated_db.execute( + """\ +INSERT INTO tasks ( + id, title, description, type, priority, project, + created_by, status, reviewers +) VALUES ( + 'corrupt-1', 'Test', 'Test', 'development', 'medium', + 'proj', 'alice', 'created', '{BAD JSON}' +)""" + ) + await migrated_db.commit() + + repo = SQLiteTaskRepository(migrated_db) + with pytest.raises(QueryError, match="deserialize task"): + await repo.get("corrupt-1") + + async def test_row_to_message_corrupt_json( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Corrupt JSON in attachments raises QueryError.""" + from ai_company.persistence.errors import QueryError + + await migrated_db.execute( + """\ +INSERT INTO messages ( + id, timestamp, sender, "to", type, priority, + channel, content, attachments, metadata +) VALUES ( + 'corrupt-msg', '2026-01-01T00:00:00+00:00', 'alice', + 'bob', 'task_update', 'normal', 'general', + 'hello', '{BAD}', '{}' +)""" + ) + await migrated_db.commit() + + repo = SQLiteMessageRepository(migrated_db) + with pytest.raises(QueryError, match="deserialize message"): + await repo.get_history("general") diff --git a/tests/unit/persistence/test_config.py b/tests/unit/persistence/test_config.py index 7fb552884a..15d2630097 100644 --- a/tests/unit/persistence/test_config.py +++ b/tests/unit/persistence/test_config.py @@ -45,6 +45,14 @@ def test_negative_journal_size_rejected(self) -> None: with pytest.raises(ValidationError): SQLiteConfig(journal_size_limit=-1) + def test_traversal_rejected(self) -> None: + with pytest.raises(ValidationError, match="traversal"): + SQLiteConfig(path="../escape/test.db") + + def test_embedded_traversal_rejected(self) -> None: + with pytest.raises(ValidationError, match="traversal"): + SQLiteConfig(path="data/../../../etc/test.db") + @pytest.mark.unit class TestPersistenceConfig: diff --git a/tests/unit/persistence/test_errors.py b/tests/unit/persistence/test_errors.py index d66a3d930a..841fc08cc9 100644 --- a/tests/unit/persistence/test_errors.py +++ b/tests/unit/persistence/test_errors.py @@ -11,38 +11,33 @@ RecordNotFoundError, ) +_SUBCLASSES = [ + PersistenceConnectionError, + MigrationError, + RecordNotFoundError, + DuplicateRecordError, + QueryError, +] + @pytest.mark.unit class TestPersistenceErrorHierarchy: def test_base_is_exception(self) -> None: assert issubclass(PersistenceError, Exception) - def test_connection_error_inherits(self) -> None: - assert issubclass(PersistenceConnectionError, PersistenceError) - - def test_migration_error_inherits(self) -> None: - assert issubclass(MigrationError, PersistenceError) - - def test_record_not_found_inherits(self) -> None: - assert issubclass(RecordNotFoundError, PersistenceError) - - def test_duplicate_record_inherits(self) -> None: - assert issubclass(DuplicateRecordError, PersistenceError) - - def test_query_error_inherits(self) -> None: - assert issubclass(QueryError, PersistenceError) - - @pytest.mark.parametrize( - "cls", - [ - PersistenceConnectionError, - MigrationError, - RecordNotFoundError, - DuplicateRecordError, - QueryError, - ], - ) - def test_catch_all_with_base(self, cls: type[PersistenceError]) -> None: + @pytest.mark.parametrize("cls", _SUBCLASSES) + def test_subclass_inherits_from_base( + self, + cls: type[PersistenceError], + ) -> None: + """All error subclasses inherit from PersistenceError.""" + assert issubclass(cls, PersistenceError) + + @pytest.mark.parametrize("cls", _SUBCLASSES) + def test_catch_all_with_base( + self, + cls: type[PersistenceError], + ) -> None: """All subclasses are caught by except PersistenceError.""" msg = "test" with pytest.raises(PersistenceError): From eb4374ff3c25290a6dae11ef69b482f384b84a3a Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Mon, 9 Mar 2026 07:27:48 +0100 Subject: [PATCH 4/5] ci: allow LicenseRef-scancode-free-unknown in dependency review aiosqlite is MIT-licensed (confirmed via PyPI classifiers) but scancode detects it as "LicenseRef-scancode-free-unknown AND MIT" because the package metadata lacks a License field. --- .github/workflows/dependency-review.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index 277fe3baf6..d6fbfc78b8 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -25,5 +25,6 @@ jobs: allow-licenses: >- MIT, Apache-2.0, BSD-2-Clause, BSD-3-Clause, ISC, MPL-2.0, PSF-2.0, Unlicense, 0BSD, - CC0-1.0, Python-2.0 + CC0-1.0, Python-2.0, + LicenseRef-scancode-free-unknown comment-summary-in-pr: always From 0fee5d1b37a2e1ccbbec2340589f7ba088ceb62a Mon Sep 17 00:00:00 2001 From: Aurelio <19254254+Aureliolo@users.noreply.github.com> Date: Mon, 9 Mar 2026 07:49:44 +0100 Subject: [PATCH 5/5] fix: address round-2 review findings from Copilot, CodeRabbit, and Greptile - Add asyncio.Lock for backend lifecycle serialization (CodeRabbit) - Use NotBlankStr for repository protocol identifier params (CodeRabbit) - Move TaskStatus to runtime import for PEP 649 compat (CodeRabbit) - Add IF NOT EXISTS guards to migration DDL for crash recovery (Greptile) - Fix migration docstring to be honest about DDL auto-commit (Greptile) - Add task_id filter to CostRecordRepository.aggregate() (Greptile) - Add PERSISTENCE_BACKEND_NOT_CONNECTED event, stop reusing CONNECTION_FAILED for precondition checks (Greptile) - Add pytest.mark.timeout(30) to integration tests (Copilot) - Fix WAL test docstring to match actual assertions (Copilot) - Narrow DESIGN_SPEC MessageRepository queries to match protocol - Document aiosqlite license exception in CI workflow --- .github/workflows/dependency-review.yml | 2 +- DESIGN_SPEC.md | 2 +- .../observability/events/persistence.py | 1 + src/ai_company/persistence/repositories.py | 25 ++-- src/ai_company/persistence/sqlite/backend.py | 140 +++++++++--------- .../persistence/sqlite/migrations.py | 30 ++-- .../persistence/sqlite/repositories.py | 28 ++-- .../persistence/test_sqlite_integration.py | 7 +- tests/unit/observability/test_events.py | 4 + .../persistence/sqlite/test_repositories.py | 23 +++ 10 files changed, 159 insertions(+), 103 deletions(-) diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index d6fbfc78b8..249c0a8634 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -26,5 +26,5 @@ jobs: MIT, Apache-2.0, BSD-2-Clause, BSD-3-Clause, ISC, MPL-2.0, PSF-2.0, Unlicense, 0BSD, CC0-1.0, Python-2.0, - LicenseRef-scancode-free-unknown + LicenseRef-scancode-free-unknown # aiosqlite 0.21.0 — MIT per classifiers, scancode misdetects comment-summary-in-pr: always diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index fbd5505fae..b72e443f4f 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -1437,7 +1437,7 @@ persistence: |--------|-------------|------------|-------------| | `Task` | `core/task.py` | `TaskRepository` | by status, by assignee, by project | | `CostRecord` | `budget/cost_record.py` | `CostRecordRepository` | by agent, by task, aggregations | -| `Message` | `communication/message.py` | `MessageRepository` | by channel, by sender, time range | +| `Message` | `communication/message.py` | `MessageRepository` | by channel | | Audit entries (planned — M7) | `security/` | `AuditRepository` (planned) | by agent, by action type, time range | #### Migration Strategy diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py index afd2876981..598e57ec3e 100644 --- a/src/ai_company/observability/events/persistence.py +++ b/src/ai_company/observability/events/persistence.py @@ -24,6 +24,7 @@ PERSISTENCE_BACKEND_CREATED: Final[str] = "persistence.backend.created" PERSISTENCE_BACKEND_UNKNOWN: Final[str] = "persistence.backend.unknown" PERSISTENCE_BACKEND_WAL_MODE_FAILED: Final[str] = "persistence.backend.wal_mode_failed" +PERSISTENCE_BACKEND_NOT_CONNECTED: Final[str] = "persistence.backend.not_connected" PERSISTENCE_MIGRATION_STARTED: Final[str] = "persistence.migration.started" PERSISTENCE_MIGRATION_COMPLETED: Final[str] = "persistence.migration.completed" diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py index e55ba48e18..18a3092f76 100644 --- a/src/ai_company/persistence/repositories.py +++ b/src/ai_company/persistence/repositories.py @@ -10,6 +10,7 @@ from ai_company.communication.message import Message # noqa: TC001 from ai_company.core.enums import TaskStatus # noqa: TC001 from ai_company.core.task import Task # noqa: TC001 +from ai_company.core.types import NotBlankStr # noqa: TC001 @runtime_checkable @@ -27,7 +28,7 @@ async def save(self, task: Task) -> None: """ ... - async def get(self, task_id: str) -> Task | None: + async def get(self, task_id: NotBlankStr) -> Task | None: """Retrieve a task by its ID. Args: @@ -45,8 +46,8 @@ async def list_tasks( self, *, status: TaskStatus | None = None, - assigned_to: str | None = None, - project: str | None = None, + assigned_to: NotBlankStr | None = None, + project: NotBlankStr | None = None, ) -> tuple[Task, ...]: """List tasks with optional filters. @@ -63,7 +64,7 @@ async def list_tasks( """ ... - async def delete(self, task_id: str) -> bool: + async def delete(self, task_id: NotBlankStr) -> bool: """Delete a task by ID. Args: @@ -96,8 +97,8 @@ async def save(self, record: CostRecord) -> None: async def query( self, *, - agent_id: str | None = None, - task_id: str | None = None, + agent_id: NotBlankStr | None = None, + task_id: NotBlankStr | None = None, ) -> tuple[CostRecord, ...]: """Query cost records with optional filters. @@ -113,11 +114,17 @@ async def query( """ ... - async def aggregate(self, *, agent_id: str | None = None) -> float: - """Sum total cost_usd, optionally filtered by agent. + async def aggregate( + self, + *, + agent_id: NotBlankStr | None = None, + task_id: NotBlankStr | None = None, + ) -> float: + """Sum total cost_usd, optionally filtered by agent and/or task. Args: agent_id: Filter by agent identifier. + task_id: Filter by task identifier. Returns: Total cost in USD. @@ -146,7 +153,7 @@ async def save(self, message: Message) -> None: async def get_history( self, - channel: str, + channel: NotBlankStr, *, limit: int | None = None, ) -> tuple[Message, ...]: diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py index dfa4d93a36..a5c7a28ccb 100644 --- a/src/ai_company/persistence/sqlite/backend.py +++ b/src/ai_company/persistence/sqlite/backend.py @@ -1,5 +1,6 @@ """SQLite persistence backend implementation.""" +import asyncio import sqlite3 from typing import TYPE_CHECKING @@ -15,6 +16,7 @@ PERSISTENCE_BACKEND_DISCONNECTED, PERSISTENCE_BACKEND_DISCONNECTING, PERSISTENCE_BACKEND_HEALTH_CHECK, + PERSISTENCE_BACKEND_NOT_CONNECTED, PERSISTENCE_BACKEND_WAL_MODE_FAILED, ) from ai_company.persistence.errors import PersistenceConnectionError @@ -44,6 +46,7 @@ class SQLitePersistenceBackend: def __init__(self, config: SQLiteConfig) -> None: self._config = config + self._lifecycle_lock = asyncio.Lock() self._db: aiosqlite.Connection | None = None self._tasks: SQLiteTaskRepository | None = None self._cost_records: SQLiteCostRecordRepository | None = None @@ -58,74 +61,79 @@ def _clear_state(self) -> None: async def connect(self) -> None: """Open the SQLite database and configure WAL mode.""" - if self._db is not None: - logger.debug(PERSISTENCE_BACKEND_ALREADY_CONNECTED) - return - - logger.info(PERSISTENCE_BACKEND_CONNECTING, path=self._config.path) - try: - self._db = await aiosqlite.connect(self._config.path) - self._db.row_factory = aiosqlite.Row - - if self._config.wal_mode: - cursor = await self._db.execute("PRAGMA journal_mode=WAL") - row = await cursor.fetchone() - actual_mode = row[0] if row else "unknown" - if actual_mode != "wal" and self._config.path != ":memory:": - logger.warning( - PERSISTENCE_BACKEND_WAL_MODE_FAILED, - requested="wal", - actual=actual_mode, - ) - # PRAGMA does not support parameterized queries; - # journal_size_limit is validated as int >= 0 by Pydantic. - limit = int(self._config.journal_size_limit) - await self._db.execute(f"PRAGMA journal_size_limit={limit}") - - self._tasks = SQLiteTaskRepository(self._db) - self._cost_records = SQLiteCostRecordRepository(self._db) - self._messages = SQLiteMessageRepository(self._db) - except (sqlite3.Error, OSError) as exc: - logger.exception( - PERSISTENCE_BACKEND_CONNECTION_FAILED, - path=self._config.path, - error=str(exc), - ) + async with self._lifecycle_lock: if self._db is not None: - try: - await self._db.close() - except (sqlite3.Error, OSError) as cleanup_exc: - logger.warning( - PERSISTENCE_BACKEND_DISCONNECT_ERROR, - path=self._config.path, - error=str(cleanup_exc), - error_type=type(cleanup_exc).__name__, - context="cleanup_after_connect_failure", - ) - self._clear_state() - msg = "Failed to connect to persistence backend" - raise PersistenceConnectionError(msg) from exc - - logger.info(PERSISTENCE_BACKEND_CONNECTED, path=self._config.path) + logger.debug(PERSISTENCE_BACKEND_ALREADY_CONNECTED) + return + + logger.info(PERSISTENCE_BACKEND_CONNECTING, path=self._config.path) + try: + self._db = await aiosqlite.connect(self._config.path) + self._db.row_factory = aiosqlite.Row + + if self._config.wal_mode: + cursor = await self._db.execute("PRAGMA journal_mode=WAL") + row = await cursor.fetchone() + actual_mode = row[0] if row else "unknown" + if actual_mode != "wal" and self._config.path != ":memory:": + logger.warning( + PERSISTENCE_BACKEND_WAL_MODE_FAILED, + requested="wal", + actual=actual_mode, + ) + # PRAGMA does not support parameterized queries; + # journal_size_limit is validated as int >= 0 by Pydantic. + limit = int(self._config.journal_size_limit) + await self._db.execute(f"PRAGMA journal_size_limit={limit}") + + self._tasks = SQLiteTaskRepository(self._db) + self._cost_records = SQLiteCostRecordRepository(self._db) + self._messages = SQLiteMessageRepository(self._db) + except (sqlite3.Error, OSError) as exc: + logger.exception( + PERSISTENCE_BACKEND_CONNECTION_FAILED, + path=self._config.path, + error=str(exc), + ) + if self._db is not None: + try: + await self._db.close() + except (sqlite3.Error, OSError) as cleanup_exc: + logger.warning( + PERSISTENCE_BACKEND_DISCONNECT_ERROR, + path=self._config.path, + error=str(cleanup_exc), + error_type=type(cleanup_exc).__name__, + context="cleanup_after_connect_failure", + ) + self._clear_state() + msg = "Failed to connect to persistence backend" + raise PersistenceConnectionError(msg) from exc + + logger.info(PERSISTENCE_BACKEND_CONNECTED, path=self._config.path) async def disconnect(self) -> None: """Close the database connection.""" - if self._db is None: - return - - logger.info(PERSISTENCE_BACKEND_DISCONNECTING, path=self._config.path) - try: - await self._db.close() - logger.info(PERSISTENCE_BACKEND_DISCONNECTED, path=self._config.path) - except (sqlite3.Error, OSError) as exc: - logger.warning( - PERSISTENCE_BACKEND_DISCONNECT_ERROR, - path=self._config.path, - error=str(exc), - error_type=type(exc).__name__, - ) - finally: - self._clear_state() + async with self._lifecycle_lock: + if self._db is None: + return + + logger.info(PERSISTENCE_BACKEND_DISCONNECTING, path=self._config.path) + try: + await self._db.close() + logger.info( + PERSISTENCE_BACKEND_DISCONNECTED, + path=self._config.path, + ) + except (sqlite3.Error, OSError) as exc: + logger.warning( + PERSISTENCE_BACKEND_DISCONNECT_ERROR, + path=self._config.path, + error=str(exc), + error_type=type(exc).__name__, + ) + finally: + self._clear_state() async def health_check(self) -> bool: """Check database connectivity.""" @@ -155,7 +163,7 @@ async def migrate(self) -> None: """ if self._db is None: msg = "Cannot migrate: not connected" - logger.warning(PERSISTENCE_BACKEND_CONNECTION_FAILED, error=msg) + logger.warning(PERSISTENCE_BACKEND_NOT_CONNECTED, error=msg) raise PersistenceConnectionError(msg) await run_migrations(self._db) @@ -181,7 +189,7 @@ def _require_connected[T](self, repo: T | None, name: str) -> T: """ if repo is None: msg = f"Not connected — call connect() before accessing {name}" - logger.warning(PERSISTENCE_BACKEND_CONNECTION_FAILED, error=msg) + logger.warning(PERSISTENCE_BACKEND_NOT_CONNECTED, error=msg) raise PersistenceConnectionError(msg) return repo diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py index bd65e8198b..6bf6372d32 100644 --- a/src/ai_company/persistence/sqlite/migrations.py +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -28,7 +28,7 @@ _V1_STATEMENTS: Sequence[str] = ( # ── Tasks ───────────────────────────────────────────── """\ -CREATE TABLE tasks ( +CREATE TABLE IF NOT EXISTS tasks ( id TEXT PRIMARY KEY, title TEXT NOT NULL, description TEXT NOT NULL, @@ -51,12 +51,12 @@ acceptance_criteria TEXT NOT NULL DEFAULT '[]', delegation_chain TEXT NOT NULL DEFAULT '[]' )""", - "CREATE INDEX idx_tasks_status ON tasks(status)", - "CREATE INDEX idx_tasks_assigned_to ON tasks(assigned_to)", - "CREATE INDEX idx_tasks_project ON tasks(project)", + "CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status)", + "CREATE INDEX IF NOT EXISTS idx_tasks_assigned_to ON tasks(assigned_to)", + "CREATE INDEX IF NOT EXISTS idx_tasks_project ON tasks(project)", # ── Cost records ────────────────────────────────────── """\ -CREATE TABLE cost_records ( +CREATE TABLE IF NOT EXISTS cost_records ( rowid INTEGER PRIMARY KEY AUTOINCREMENT, agent_id TEXT NOT NULL, task_id TEXT NOT NULL, @@ -68,11 +68,11 @@ timestamp TEXT NOT NULL, call_category TEXT )""", - "CREATE INDEX idx_cost_records_agent_id ON cost_records(agent_id)", - "CREATE INDEX idx_cost_records_task_id ON cost_records(task_id)", + "CREATE INDEX IF NOT EXISTS idx_cost_records_agent_id ON cost_records(agent_id)", + "CREATE INDEX IF NOT EXISTS idx_cost_records_task_id ON cost_records(task_id)", # ── Messages ────────────────────────────────────────── """\ -CREATE TABLE messages ( +CREATE TABLE IF NOT EXISTS messages ( id TEXT PRIMARY KEY, timestamp TEXT NOT NULL, sender TEXT NOT NULL, @@ -84,8 +84,8 @@ attachments TEXT NOT NULL DEFAULT '[]', metadata TEXT NOT NULL DEFAULT '{}' )""", - "CREATE INDEX idx_messages_channel ON messages(channel)", - "CREATE INDEX idx_messages_timestamp ON messages(timestamp)", + "CREATE INDEX IF NOT EXISTS idx_messages_channel ON messages(channel)", + "CREATE INDEX IF NOT EXISTS idx_messages_timestamp ON messages(timestamp)", ) _MigrateFn = Callable[[aiosqlite.Connection], Coroutine[Any, Any, None]] @@ -136,9 +136,13 @@ async def _apply_v1(db: aiosqlite.Connection) -> None: async def run_migrations(db: aiosqlite.Connection) -> None: """Run pending migrations up to ``SCHEMA_VERSION``. - Migrations are executed within an implicit transaction. On - failure, the transaction is explicitly rolled back and - ``MigrationError`` is raised. + .. note:: + + SQLite implicitly commits before each DDL statement, so + multi-statement migrations are **not** fully atomic. All DDL + uses ``IF NOT EXISTS`` guards so that a partial failure + (e.g. disk full after creating some tables) can be recovered + by re-running the migration. Args: db: An open aiosqlite connection. diff --git a/src/ai_company/persistence/sqlite/repositories.py b/src/ai_company/persistence/sqlite/repositories.py index 34f59b2739..d2dbfb342a 100644 --- a/src/ai_company/persistence/sqlite/repositories.py +++ b/src/ai_company/persistence/sqlite/repositories.py @@ -2,13 +2,13 @@ import json import sqlite3 -from typing import TYPE_CHECKING import aiosqlite from pydantic import BaseModel, ValidationError from ai_company.budget.cost_record import CostRecord from ai_company.communication.message import Message +from ai_company.core.enums import TaskStatus # noqa: TC001 from ai_company.core.task import Task from ai_company.observability import get_logger from ai_company.observability.events.persistence import ( @@ -36,9 +36,6 @@ ) from ai_company.persistence.errors import DuplicateRecordError, QueryError -if TYPE_CHECKING: - from ai_company.core.enums import TaskStatus - logger = get_logger(__name__) @@ -322,15 +319,26 @@ async def query( logger.debug(PERSISTENCE_COST_RECORD_QUERIED, count=len(records)) return records - async def aggregate(self, *, agent_id: str | None = None) -> float: - """Sum total cost_usd, optionally filtered by agent.""" + async def aggregate( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> float: + """Sum total cost_usd, optionally filtered by agent and/or task.""" try: sql = "SELECT COALESCE(SUM(cost_usd), 0.0) FROM cost_records" - params: tuple[str, ...] = () + conditions: list[str] = [] + params: list[str] = [] if agent_id is not None: - sql += " WHERE agent_id = ?" - params = (agent_id,) - cursor = await self._db.execute(sql, params) + conditions.append("agent_id = ?") + params.append(agent_id) + if task_id is not None: + conditions.append("task_id = ?") + params.append(task_id) + if conditions: + sql += " WHERE " + " AND ".join(conditions) + cursor = await self._db.execute(sql, tuple(params)) row = await cursor.fetchone() except (sqlite3.Error, aiosqlite.Error) as exc: msg = "Failed to aggregate cost records" diff --git a/tests/integration/persistence/test_sqlite_integration.py b/tests/integration/persistence/test_sqlite_integration.py index 8e63f6b111..ab859f53ee 100644 --- a/tests/integration/persistence/test_sqlite_integration.py +++ b/tests/integration/persistence/test_sqlite_integration.py @@ -9,11 +9,12 @@ from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend from tests.unit.persistence.conftest import make_message, make_task +pytestmark = [pytest.mark.integration, pytest.mark.timeout(30)] + -@pytest.mark.integration class TestSQLiteOnDisk: - async def test_wal_file_created(self, db_path: str) -> None: - """WAL mode creates -wal and -shm files on disk.""" + async def test_wal_mode_enabled(self, db_path: str) -> None: + """WAL journal mode is enabled for the on-disk SQLite database.""" backend = SQLitePersistenceBackend(SQLiteConfig(path=db_path)) await backend.connect() await backend.migrate() diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index a5f6ca811a..7afb95d653 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -375,6 +375,10 @@ def test_workspace_events_exist(self) -> None: ("PERSISTENCE_BACKEND_DISCONNECTING", "persistence.backend.disconnecting"), ("PERSISTENCE_BACKEND_DISCONNECTED", "persistence.backend.disconnected"), ("PERSISTENCE_BACKEND_HEALTH_CHECK", "persistence.backend.health_check"), + ( + "PERSISTENCE_BACKEND_NOT_CONNECTED", + "persistence.backend.not_connected", + ), ("PERSISTENCE_MIGRATION_STARTED", "persistence.migration.started"), ("PERSISTENCE_MIGRATION_COMPLETED", "persistence.migration.completed"), ("PERSISTENCE_MIGRATION_SKIPPED", "persistence.migration.skipped"), diff --git a/tests/unit/persistence/sqlite/test_repositories.py b/tests/unit/persistence/sqlite/test_repositories.py index 29e9543e4c..9485d24ebb 100644 --- a/tests/unit/persistence/sqlite/test_repositories.py +++ b/tests/unit/persistence/sqlite/test_repositories.py @@ -263,6 +263,29 @@ async def test_aggregate_by_agent(self, migrated_db: aiosqlite.Connection) -> No total = await repo.aggregate(agent_id="alice") assert abs(total - 0.10) < 1e-9 + async def test_aggregate_by_task(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save(self._make_record(task_id="t1", cost_usd=0.10)) + await repo.save(self._make_record(task_id="t2", cost_usd=0.20)) + + total = await repo.aggregate(task_id="t1") + assert abs(total - 0.10) < 1e-9 + + async def test_aggregate_by_agent_and_task( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCostRecordRepository(migrated_db) + await repo.save( + self._make_record(agent_id="alice", task_id="t1", cost_usd=0.10) + ) + await repo.save( + self._make_record(agent_id="alice", task_id="t2", cost_usd=0.20) + ) + await repo.save(self._make_record(agent_id="bob", task_id="t1", cost_usd=0.30)) + + total = await repo.aggregate(agent_id="alice", task_id="t1") + assert abs(total - 0.10) < 1e-9 + async def test_aggregate_empty(self, migrated_db: aiosqlite.Connection) -> None: repo = SQLiteCostRecordRepository(migrated_db) total = await repo.aggregate()