diff --git a/CLAUDE.md b/CLAUDE.md index c93848ebd4..fbcc2f12f9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -120,7 +120,7 @@ src/synthorg/ meeting/ # Meeting protocol (round-robin, position papers, structured phases), scheduler (frequency, participant resolver), orchestrator config/ # YAML company config loading and validation core/ # Shared domain models, base classes, and resilience config (RetryConfig, RateLimiterConfig) - engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination (multi-agent pipeline: TopologyDispatcher protocol, 4 dispatchers — SAS/centralized/decentralized/context-dependent, wave execution, workspace lifecycle integration, CoordinationSectionConfig company config bridge, build_coordinator factory), coordination error classification, prompt policy validation, checkpoint recovery (checkpoint/, per-turn persistence, heartbeat detection, CheckpointRecoveryStrategy), approval gate (escalation detection, context parking/resume, EscalationInfo/ResumePayload models), stagnation detection (stagnation/, StagnationDetector protocol, ToolRepetitionDetector, dual-signal analysis, corrective prompt injection) + engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination (multi-agent pipeline: TopologyDispatcher protocol, 4 dispatchers — SAS/centralized/decentralized/context-dependent, wave execution, workspace lifecycle integration, CoordinationSectionConfig company config bridge, build_coordinator factory), coordination error classification, prompt policy validation, checkpoint recovery (checkpoint/, per-turn persistence, heartbeat detection, CheckpointRecoveryStrategy), approval gate (escalation detection, context parking/resume, EscalationInfo/ResumePayload models), stagnation detection (stagnation/, StagnationDetector protocol, ToolRepetitionDetector, dual-signal analysis, corrective prompt injection), agent runtime state (AgentRuntimeState, lightweight per-agent execution status for dashboard queries and recovery) hr/ # HR engine: hiring, firing, onboarding, offboarding, agent registry, performance tracking (task metrics, collaboration scoring, trend detection), promotion/demotion (criteria evaluation, approval strategies, model mapping) memory/ # Persistent agent memory (pluggable MemoryBackend protocol), backends/ (Mem0 adapter: backends/mem0/), retrieval pipeline (ranking, injection, context formatting, non-inferable filtering), shared org memory (org/), consolidation/archival (consolidation/) persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (see Memory & Persistence design page) @@ -189,7 +189,7 @@ site/ # Astro landing page (synthorg.io) - **Every module** with business logic MUST have: `from synthorg.observability import get_logger` then `logger = get_logger(__name__)` - **Never** use `import logging` / `logging.getLogger()` / `print()` in application code - **Variable name**: always `logger` (not `_logger`, not `log`) -- **Event names**: always use constants from the domain-specific module under `synthorg.observability.events` (e.g., `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `MEETING_SCHEDULER_STARTED` from `events.meeting`, `MEETING_SCHEDULER_ERROR` from `events.meeting`, `MEETING_SCHEDULER_STOPPED` from `events.meeting`, `MEETING_PERIODIC_TRIGGERED` from `events.meeting`, `MEETING_EVENT_TRIGGERED` from `events.meeting`, `MEETING_PARTICIPANTS_RESOLVED` from `events.meeting`, `MEETING_NO_PARTICIPANTS` from `events.meeting`, `MEETING_NOT_FOUND` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `API_REQUEST_COMPLETED` from `events.api`, `API_REQUEST_ERROR` from `events.api`, `API_ROUTE_NOT_FOUND` from `events.api`, `API_COORDINATION_STARTED` from `events.api`, `API_COORDINATION_COMPLETED` from `events.api`, `API_COORDINATION_FAILED` from `events.api`, `API_COORDINATION_AGENT_RESOLVE_FAILED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `MEMORY_BACKEND_CONNECTED` from `events.memory`, `MEMORY_ENTRY_STORED` from `events.memory`, `MEMORY_BACKEND_SYSTEM_ERROR` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`, `TASK_ENGINE_STARTED` from `events.task_engine`, `COORDINATION_STARTED` from `events.coordination`, `COORDINATION_FACTORY_BUILT` from `events.coordination`, `COMMUNICATION_DISPATCH_START` from `events.communication`, `COMPANY_STARTED` from `events.company`, `CONFIG_LOADED` from `events.config`, `CORRELATION_ID_CREATED` from `events.correlation`, `DECOMPOSITION_STARTED` from `events.decomposition`, `DELEGATION_STARTED` from `events.delegation`, `EXECUTION_LOOP_START` from `events.execution`, `CHECKPOINT_SAVED` from `events.checkpoint`, `PERSISTENCE_CHECKPOINT_SAVED` from `events.persistence`, `GIT_OPERATION_START` from `events.git`, `PARALLEL_GROUP_START` from `events.parallel`, `PERSONALITY_LOADED` from `events.personality`, `QUOTA_CHECKED` from `events.quota`, `ROLE_ASSIGNED` from `events.role`, `ROUTING_STARTED` from `events.routing`, `SANDBOX_EXECUTE_START` from `events.sandbox`, `TASK_CREATED` from `events.task`, `TASK_ASSIGNMENT_STARTED` from `events.task_assignment`, `TASK_ROUTING_STARTED` from `events.task_routing`, `TEMPLATE_LOADED` from `events.template`, `TOOL_INVOKE_START` from `events.tool`, `TOOL_OUTPUT_WITHHELD` from `events.tool`, `WORKSPACE_CREATED` from `events.workspace`, `APPROVAL_GATE_ESCALATION_DETECTED` from `events.approval_gate`, `APPROVAL_GATE_ESCALATION_FAILED` from `events.approval_gate`, `APPROVAL_GATE_INITIALIZED` from `events.approval_gate`, `APPROVAL_GATE_RISK_CLASSIFIED` from `events.approval_gate`, `APPROVAL_GATE_RISK_CLASSIFY_FAILED` from `events.approval_gate`, `APPROVAL_GATE_CONTEXT_PARKED` from `events.approval_gate`, `APPROVAL_GATE_CONTEXT_PARK_FAILED` from `events.approval_gate`, `APPROVAL_GATE_PARK_TASKLESS` from `events.approval_gate`, `APPROVAL_GATE_RESUME_STARTED` from `events.approval_gate`, `APPROVAL_GATE_CONTEXT_RESUMED` from `events.approval_gate`, `APPROVAL_GATE_RESUME_FAILED` from `events.approval_gate`, `APPROVAL_GATE_RESUME_DELETE_FAILED` from `events.approval_gate`, `APPROVAL_GATE_RESUME_TRIGGERED` from `events.approval_gate`, `APPROVAL_GATE_NO_PARKED_CONTEXT` from `events.approval_gate`, `APPROVAL_GATE_LOOP_WIRING_WARNING` from `events.approval_gate`, `STAGNATION_CHECK_PERFORMED` from `events.stagnation`, `STAGNATION_DETECTED` from `events.stagnation`, `STAGNATION_CORRECTION_INJECTED` from `events.stagnation`, `STAGNATION_TERMINATED` from `events.stagnation`). Import directly: `from synthorg.observability.events. import EVENT_CONSTANT` +- **Event names**: always use constants from the domain-specific module under `synthorg.observability.events` (e.g., `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `MEETING_SCHEDULER_STARTED` from `events.meeting`, `MEETING_SCHEDULER_ERROR` from `events.meeting`, `MEETING_SCHEDULER_STOPPED` from `events.meeting`, `MEETING_PERIODIC_TRIGGERED` from `events.meeting`, `MEETING_EVENT_TRIGGERED` from `events.meeting`, `MEETING_PARTICIPANTS_RESOLVED` from `events.meeting`, `MEETING_NO_PARTICIPANTS` from `events.meeting`, `MEETING_NOT_FOUND` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `API_REQUEST_COMPLETED` from `events.api`, `API_REQUEST_ERROR` from `events.api`, `API_ROUTE_NOT_FOUND` from `events.api`, `API_COORDINATION_STARTED` from `events.api`, `API_COORDINATION_COMPLETED` from `events.api`, `API_COORDINATION_FAILED` from `events.api`, `API_COORDINATION_AGENT_RESOLVE_FAILED` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `MEMORY_BACKEND_CONNECTED` from `events.memory`, `MEMORY_ENTRY_STORED` from `events.memory`, `MEMORY_BACKEND_SYSTEM_ERROR` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`, `TASK_ENGINE_STARTED` from `events.task_engine`, `COORDINATION_STARTED` from `events.coordination`, `COORDINATION_FACTORY_BUILT` from `events.coordination`, `COMMUNICATION_DISPATCH_START` from `events.communication`, `COMPANY_STARTED` from `events.company`, `CONFIG_LOADED` from `events.config`, `CORRELATION_ID_CREATED` from `events.correlation`, `DECOMPOSITION_STARTED` from `events.decomposition`, `DELEGATION_STARTED` from `events.delegation`, `EXECUTION_LOOP_START` from `events.execution`, `CHECKPOINT_SAVED` from `events.checkpoint`, `PERSISTENCE_CHECKPOINT_SAVED` from `events.persistence`, `GIT_OPERATION_START` from `events.git`, `PARALLEL_GROUP_START` from `events.parallel`, `PERSONALITY_LOADED` from `events.personality`, `QUOTA_CHECKED` from `events.quota`, `ROLE_ASSIGNED` from `events.role`, `ROUTING_STARTED` from `events.routing`, `SANDBOX_EXECUTE_START` from `events.sandbox`, `TASK_CREATED` from `events.task`, `TASK_ASSIGNMENT_STARTED` from `events.task_assignment`, `TASK_ROUTING_STARTED` from `events.task_routing`, `TEMPLATE_LOADED` from `events.template`, `TOOL_INVOKE_START` from `events.tool`, `TOOL_OUTPUT_WITHHELD` from `events.tool`, `WORKSPACE_CREATED` from `events.workspace`, `APPROVAL_GATE_ESCALATION_DETECTED` from `events.approval_gate`, `APPROVAL_GATE_ESCALATION_FAILED` from `events.approval_gate`, `APPROVAL_GATE_INITIALIZED` from `events.approval_gate`, `APPROVAL_GATE_RISK_CLASSIFIED` from `events.approval_gate`, `APPROVAL_GATE_RISK_CLASSIFY_FAILED` from `events.approval_gate`, `APPROVAL_GATE_CONTEXT_PARKED` from `events.approval_gate`, `APPROVAL_GATE_CONTEXT_PARK_FAILED` from `events.approval_gate`, `APPROVAL_GATE_PARK_TASKLESS` from `events.approval_gate`, `APPROVAL_GATE_RESUME_STARTED` from `events.approval_gate`, `APPROVAL_GATE_CONTEXT_RESUMED` from `events.approval_gate`, `APPROVAL_GATE_RESUME_FAILED` from `events.approval_gate`, `APPROVAL_GATE_RESUME_DELETE_FAILED` from `events.approval_gate`, `APPROVAL_GATE_RESUME_TRIGGERED` from `events.approval_gate`, `APPROVAL_GATE_NO_PARKED_CONTEXT` from `events.approval_gate`, `APPROVAL_GATE_LOOP_WIRING_WARNING` from `events.approval_gate`, `STAGNATION_CHECK_PERFORMED` from `events.stagnation`, `STAGNATION_DETECTED` from `events.stagnation`, `STAGNATION_CORRECTION_INJECTED` from `events.stagnation`, `STAGNATION_TERMINATED` from `events.stagnation`, `PERSISTENCE_AGENT_STATE_SAVED` from `events.persistence`, `PERSISTENCE_AGENT_STATE_FETCHED` from `events.persistence`, `PERSISTENCE_AGENT_STATE_ACTIVE_QUERIED` from `events.persistence`, `PERSISTENCE_AGENT_STATE_DELETED` from `events.persistence`). Import directly: `from synthorg.observability.events. import EVENT_CONSTANT` - **Structured kwargs**: always `logger.info(EVENT, key=value)` — never `logger.info("msg %s", val)` - **All error paths** must log at WARNING or ERROR with context before raising - **All state transitions** must log at INFO diff --git a/docs/design/agents.md b/docs/design/agents.md index 19a4bc0f85..0281df4058 100644 --- a/docs/design/agents.md +++ b/docs/design/agents.md @@ -144,6 +144,9 @@ with `model_copy`: accumulated cost (`TokenUsage`), turn count, and timestamps. - **AgentContext** wraps `AgentIdentity` + `TaskExecution` with a unique execution ID, conversation history, cost accumulation, turn limits, and timing. +- **AgentRuntimeState** provides a lightweight per-agent execution status snapshot + (idle / executing / paused) for dashboard queries and graceful-shutdown discovery. + Persisted via `AgentStateRepository`, independent of the checkpoint system. --- diff --git a/docs/design/engine.md b/docs/design/engine.md index 1fd0ad3f48..4d4095c271 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -264,6 +264,24 @@ reach the `MessageBusBridge` and WebSocket consumers. --- +## Agent Execution Status + +The `ExecutionStatus` enum (in `core/enums.py`) tracks the per-agent runtime +execution state: + +| Status | Meaning | +|--------|---------| +| `IDLE` | Agent is not currently executing — no active task or execution run. | +| `EXECUTING` | Agent is actively processing a task within an execution loop. | +| `PAUSED` | Agent is waiting for an external event (e.g. approval gate). | + +`ExecutionStatus` is consumed by `AgentRuntimeState` (in `engine/agent_state.py`), +which is persisted via `AgentStateRepository` for dashboard queries and +graceful-shutdown discovery. See the [Agents design page](agents.md#runtime-state) +for how `AgentRuntimeState` fits into the runtime state layer. + +--- + ## Agent Execution Loop The agent execution loop defines how an agent processes a task from start to diff --git a/docs/design/memory.md b/docs/design/memory.md index 39d116be3b..6c915062b5 100644 --- a/docs/design/memory.md +++ b/docs/design/memory.md @@ -390,7 +390,8 @@ class PersistenceBackend(Protocol): @property def messages(self) -> MessageRepository: ... # ... plus lifecycle_events, task_metrics, collaboration_metrics, - # parked_contexts, audit_entries + # parked_contexts, audit_entries, users, api_keys, checkpoints, + # heartbeats, agent_states ``` Each entity type has its own repository protocol: @@ -447,7 +448,7 @@ persistence: | `Message` | `communication/message.py` | `MessageRepository` | by channel | | `AuditEntry` | `security/models.py` | `AuditRepository` | by agent, by action type, by verdict, by risk level, time range | | `ParkedContext` | `security/timeout/parked_context.py` | `ParkedContextRepository` | by execution_id, by agent_id, by task_id | -| Agent runtime state (planned) | `engine/` | `AgentStateRepository` (planned) | by agent_id, active agents | +| `AgentRuntimeState` | `engine/agent_state.py` | `AgentStateRepository` | by agent_id, active agents | ### Migration Strategy diff --git a/src/synthorg/core/enums.py b/src/synthorg/core/enums.py index 7726f749cd..a84c47cde8 100644 --- a/src/synthorg/core/enums.py +++ b/src/synthorg/core/enums.py @@ -502,6 +502,19 @@ class DowngradeReason(StrEnum): SECURITY_INCIDENT = "security_incident" +class ExecutionStatus(StrEnum): + """Runtime execution status of an agent. + + Tracks whether an agent is currently executing, paused (e.g. waiting + for approval), or idle. Used by ``AgentRuntimeState`` for dashboard + queries and graceful-shutdown discovery. + """ + + IDLE = "idle" + EXECUTING = "executing" + PAUSED = "paused" + + class TimeoutActionType(StrEnum): """Action to take when an approval item times out (see Operations design page).""" diff --git a/src/synthorg/engine/__init__.py b/src/synthorg/engine/__init__.py index 522ad513ac..340b816d1d 100644 --- a/src/synthorg/engine/__init__.py +++ b/src/synthorg/engine/__init__.py @@ -6,6 +6,7 @@ """ from synthorg.engine.agent_engine import AgentEngine +from synthorg.engine.agent_state import AgentRuntimeState from synthorg.engine.approval_gate import ApprovalGate from synthorg.engine.approval_gate_models import EscalationInfo, ResumePayload from synthorg.engine.assignment import ( @@ -215,6 +216,7 @@ "AgentEngine", "AgentOutcome", "AgentRunResult", + "AgentRuntimeState", "AgentTaskScorer", "AgentWorkload", "ApprovalGate", diff --git a/src/synthorg/engine/agent_state.py b/src/synthorg/engine/agent_state.py new file mode 100644 index 0000000000..6168b5fe49 --- /dev/null +++ b/src/synthorg/engine/agent_state.py @@ -0,0 +1,150 @@ +"""Lightweight per-agent runtime state for dashboard queries and recovery. + +``AgentRuntimeState`` is a frozen Pydantic model that captures an agent's +current execution status (idle / executing / paused), the associated +execution and task identifiers, cost, and turn count. It is persisted +via :class:`~synthorg.persistence.repositories.AgentStateRepository` and +is independent of the heavier checkpoint system. +""" + +from datetime import UTC, datetime + +from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, model_validator + +from synthorg.core.enums import ExecutionStatus +from synthorg.core.types import NotBlankStr # noqa: TC001 +from synthorg.engine.context import AgentContext # noqa: TC001 + + +class AgentRuntimeState(BaseModel): + """Frozen snapshot of an agent's runtime execution state. + + Attributes: + agent_id: Primary key -- the agent identifier. + execution_id: Current execution run identifier (``None`` when idle). + task_id: Current task identifier (``None`` when idle or taskless). + status: Execution status (idle / executing / paused). + turn_count: Turns completed in the current execution. + accumulated_cost_usd: Cost accumulated in the current execution. + last_activity_at: Timestamp of the last state update. + started_at: When the current execution started (``None`` when idle). + """ + + model_config = ConfigDict(frozen=True) + + agent_id: NotBlankStr = Field(description="Agent identifier (primary key)") + execution_id: NotBlankStr | None = Field( + default=None, + description="Current execution run identifier", + ) + task_id: NotBlankStr | None = Field( + default=None, + description="Current task identifier", + ) + status: ExecutionStatus = Field(description="Execution status") + turn_count: int = Field(default=0, ge=0, description="Turns completed") + accumulated_cost_usd: float = Field( + default=0.0, + ge=0.0, + description="Cost in current execution (USD)", + ) + last_activity_at: AwareDatetime = Field( + description="Timestamp of last state update", + ) + started_at: AwareDatetime | None = Field( + default=None, + description="When the current execution started", + ) + + def _idle_violations(self) -> list[str]: + """Collect field violations for IDLE status.""" + violations: list[str] = [] + if self.execution_id is not None: + violations.append("execution_id must be None") + if self.task_id is not None: + violations.append("task_id must be None") + if self.started_at is not None: + violations.append("started_at must be None") + if self.turn_count != 0: + violations.append("turn_count must be 0") + if self.accumulated_cost_usd != 0.0: + violations.append("accumulated_cost_usd must be 0.0") + return violations + + @model_validator(mode="after") + def _validate_status_invariants(self) -> AgentRuntimeState: + """Enforce status-dependent field invariants. + + * **IDLE** requires ``execution_id``, ``task_id``, and + ``started_at`` to be ``None``, and ``turn_count`` and + ``accumulated_cost_usd`` to be zero. + * **EXECUTING** / **PAUSED** require ``execution_id`` and + ``started_at`` to be set. + """ + if self.status == ExecutionStatus.IDLE: + violations = self._idle_violations() + if violations: + msg = f"IDLE state invariant violated: {'; '.join(violations)}" + raise ValueError(msg) + else: + active_violations: list[str] = [] + if self.execution_id is None: + active_violations.append("execution_id is required") + if self.started_at is None: + active_violations.append("started_at is required") + if active_violations: + msg = ( + f"{self.status.value.upper()} state invariant violated: " + f"{'; '.join(active_violations)}" + ) + raise ValueError(msg) + return self + + @classmethod + def idle(cls, agent_id: NotBlankStr) -> AgentRuntimeState: + """Create an IDLE state for the given agent. + + Args: + agent_id: The agent identifier. + + Returns: + A new ``AgentRuntimeState`` in IDLE status. + """ + return cls( + agent_id=agent_id, + status=ExecutionStatus.IDLE, + last_activity_at=datetime.now(UTC), + ) + + @classmethod + def from_context( + cls, + context: AgentContext, + status: ExecutionStatus, + ) -> AgentRuntimeState: + """Create a runtime state from an ``AgentContext``. + + Args: + context: The agent execution context. + status: Must be ``EXECUTING`` or ``PAUSED`` (not ``IDLE``). + + Returns: + A new ``AgentRuntimeState`` derived from the context. + + Raises: + ValueError: If *status* is ``IDLE``. + """ + if status == ExecutionStatus.IDLE: + msg = "Cannot create from_context with IDLE status; use idle() instead" + raise ValueError(msg) + te = context.task_execution + return cls( + agent_id=str(context.identity.id), + execution_id=context.execution_id, + task_id=te.task.id if te is not None else None, + status=status, + turn_count=context.turn_count, + accumulated_cost_usd=context.accumulated_cost.cost_usd, + last_activity_at=datetime.now(UTC), + started_at=context.started_at, + ) diff --git a/src/synthorg/observability/events/persistence.py b/src/synthorg/observability/events/persistence.py index 1fe5bd47b6..978ad7942b 100644 --- a/src/synthorg/observability/events/persistence.py +++ b/src/synthorg/observability/events/persistence.py @@ -174,3 +174,25 @@ PERSISTENCE_HEARTBEAT_DESERIALIZE_FAILED: Final[str] = ( "persistence.heartbeat.deserialize_failed" ) + +# Agent state events +PERSISTENCE_AGENT_STATE_SAVED: Final[str] = "persistence.agent_state.saved" +PERSISTENCE_AGENT_STATE_SAVE_FAILED: Final[str] = "persistence.agent_state.save_failed" +PERSISTENCE_AGENT_STATE_FETCHED: Final[str] = "persistence.agent_state.fetched" +PERSISTENCE_AGENT_STATE_FETCH_FAILED: Final[str] = ( + "persistence.agent_state.fetch_failed" +) +PERSISTENCE_AGENT_STATE_NOT_FOUND: Final[str] = "persistence.agent_state.not_found" +PERSISTENCE_AGENT_STATE_ACTIVE_QUERIED: Final[str] = ( + "persistence.agent_state.active_queried" +) +PERSISTENCE_AGENT_STATE_ACTIVE_QUERY_FAILED: Final[str] = ( + "persistence.agent_state.active_query_failed" +) +PERSISTENCE_AGENT_STATE_DELETED: Final[str] = "persistence.agent_state.deleted" +PERSISTENCE_AGENT_STATE_DELETE_FAILED: Final[str] = ( + "persistence.agent_state.delete_failed" +) +PERSISTENCE_AGENT_STATE_DESERIALIZE_FAILED: Final[str] = ( + "persistence.agent_state.deserialize_failed" +) diff --git a/src/synthorg/persistence/__init__.py b/src/synthorg/persistence/__init__.py index deb802bf05..1f16b102b5 100644 --- a/src/synthorg/persistence/__init__.py +++ b/src/synthorg/persistence/__init__.py @@ -17,6 +17,7 @@ from synthorg.persistence.factory import create_backend from synthorg.persistence.protocol import PersistenceBackend from synthorg.persistence.repositories import ( + AgentStateRepository, AuditRepository, CostRecordRepository, MessageRepository, @@ -25,6 +26,7 @@ ) __all__ = [ + "AgentStateRepository", "AuditRepository", "CostRecordRepository", "DuplicateRecordError", diff --git a/src/synthorg/persistence/protocol.py b/src/synthorg/persistence/protocol.py index 548c2305ac..5c972aab81 100644 --- a/src/synthorg/persistence/protocol.py +++ b/src/synthorg/persistence/protocol.py @@ -13,6 +13,7 @@ TaskMetricRepository, # noqa: TC001 ) from synthorg.persistence.repositories import ( + AgentStateRepository, # noqa: TC001 ApiKeyRepository, # noqa: TC001 AuditRepository, # noqa: TC001 CheckpointRepository, # noqa: TC001 @@ -48,6 +49,7 @@ class PersistenceBackend(Protocol): api_keys: Repository for ApiKey persistence. checkpoints: Repository for Checkpoint persistence. heartbeats: Repository for Heartbeat persistence. + agent_states: Repository for AgentRuntimeState persistence. """ async def connect(self) -> None: @@ -152,6 +154,11 @@ def heartbeats(self) -> HeartbeatRepository: """Repository for Heartbeat persistence.""" ... + @property + def agent_states(self) -> AgentStateRepository: + """Repository for AgentRuntimeState persistence.""" + ... + async def get_setting(self, key: NotBlankStr) -> str | None: """Retrieve a setting value by key. diff --git a/src/synthorg/persistence/repositories.py b/src/synthorg/persistence/repositories.py index 489d28f8e6..4a2647b483 100644 --- a/src/synthorg/persistence/repositories.py +++ b/src/synthorg/persistence/repositories.py @@ -23,9 +23,11 @@ from synthorg.security.timeout.parked_context import ParkedContext # noqa: TC001 if TYPE_CHECKING: + from synthorg.engine.agent_state import AgentRuntimeState from synthorg.engine.checkpoint.models import Checkpoint, Heartbeat __all__ = [ + "AgentStateRepository", "ApiKeyRepository", "AuditRepository", "CheckpointRepository", @@ -593,3 +595,66 @@ async def delete(self, execution_id: NotBlankStr) -> bool: PersistenceError: If the operation fails. """ ... + + +@runtime_checkable +class AgentStateRepository(Protocol): + """CRUD + query interface for agent runtime state persistence. + + Provides a lightweight per-agent registry of execution state for + dashboard queries, graceful shutdown discovery, and cross-restart + recovery. + """ + + async def save(self, state: AgentRuntimeState) -> None: + """Upsert an agent runtime state by ``agent_id``. + + Args: + state: The agent runtime state to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, agent_id: NotBlankStr) -> AgentRuntimeState | None: + """Retrieve an agent runtime state by agent ID. + + Args: + agent_id: The agent identifier. + + Returns: + The agent state, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_active(self) -> tuple[AgentRuntimeState, ...]: + """Retrieve all non-idle agent states. + + Returns states where ``status != 'idle'``, ordered by + ``last_activity_at`` descending (most recent first). + + Returns: + Active agent states as a tuple. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, agent_id: NotBlankStr) -> bool: + """Delete an agent runtime state by agent ID. + + Args: + agent_id: The agent identifier. + + Returns: + ``True`` if deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/synthorg/persistence/sqlite/__init__.py b/src/synthorg/persistence/sqlite/__init__.py index e2b2f747a0..3decb95d32 100644 --- a/src/synthorg/persistence/sqlite/__init__.py +++ b/src/synthorg/persistence/sqlite/__init__.py @@ -1,5 +1,8 @@ """SQLite persistence backend (see Memory design page — initial backend).""" +from synthorg.persistence.sqlite.agent_state_repo import ( + SQLiteAgentStateRepository, +) from synthorg.persistence.sqlite.audit_repository import ( SQLiteAuditRepository, ) @@ -22,6 +25,7 @@ __all__ = [ "SCHEMA_VERSION", + "SQLiteAgentStateRepository", "SQLiteAuditRepository", "SQLiteCheckpointRepository", "SQLiteCostRecordRepository", diff --git a/src/synthorg/persistence/sqlite/agent_state_repo.py b/src/synthorg/persistence/sqlite/agent_state_repo.py new file mode 100644 index 0000000000..f1c31f8ec5 --- /dev/null +++ b/src/synthorg/persistence/sqlite/agent_state_repo.py @@ -0,0 +1,173 @@ +"""SQLite repository implementation for agent runtime state persistence.""" + +import sqlite3 + +import aiosqlite +from pydantic import ValidationError + +from synthorg.core.enums import ExecutionStatus +from synthorg.core.types import NotBlankStr # noqa: TC001 +from synthorg.engine.agent_state import AgentRuntimeState +from synthorg.observability import get_logger +from synthorg.observability.events.persistence import ( + PERSISTENCE_AGENT_STATE_ACTIVE_QUERIED, + PERSISTENCE_AGENT_STATE_ACTIVE_QUERY_FAILED, + PERSISTENCE_AGENT_STATE_DELETE_FAILED, + PERSISTENCE_AGENT_STATE_DELETED, + PERSISTENCE_AGENT_STATE_DESERIALIZE_FAILED, + PERSISTENCE_AGENT_STATE_FETCH_FAILED, + PERSISTENCE_AGENT_STATE_FETCHED, + PERSISTENCE_AGENT_STATE_NOT_FOUND, + PERSISTENCE_AGENT_STATE_SAVE_FAILED, + PERSISTENCE_AGENT_STATE_SAVED, +) +from synthorg.persistence.errors import QueryError + +logger = get_logger(__name__) + + +class SQLiteAgentStateRepository: + """SQLite implementation of the AgentStateRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, state: AgentRuntimeState) -> None: + """Persist an agent runtime state (upsert by agent_id).""" + try: + data = state.model_dump(mode="json") + await self._db.execute( + """\ +INSERT OR REPLACE INTO agent_states ( + agent_id, execution_id, task_id, status, turn_count, + accumulated_cost_usd, last_activity_at, started_at +) VALUES ( + :agent_id, :execution_id, :task_id, :status, :turn_count, + :accumulated_cost_usd, :last_activity_at, :started_at +)""", + data, + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save agent state for {state.agent_id!r}" + logger.exception( + PERSISTENCE_AGENT_STATE_SAVE_FAILED, + agent_id=state.agent_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.info( + PERSISTENCE_AGENT_STATE_SAVED, + agent_id=state.agent_id, + status=state.status.value, + ) + + async def get(self, agent_id: NotBlankStr) -> AgentRuntimeState | None: + """Retrieve an agent runtime state by agent ID.""" + try: + cursor = await self._db.execute( + "SELECT agent_id, execution_id, task_id, status, " + "turn_count, accumulated_cost_usd, last_activity_at, started_at " + "FROM agent_states WHERE agent_id = ?", + (agent_id,), + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to fetch agent state for {agent_id!r}" + logger.exception( + PERSISTENCE_AGENT_STATE_FETCH_FAILED, + agent_id=agent_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + if row is None: + logger.debug( + PERSISTENCE_AGENT_STATE_NOT_FOUND, + agent_id=agent_id, + ) + return None + + state = self._row_to_model(dict(row)) + logger.debug( + PERSISTENCE_AGENT_STATE_FETCHED, + agent_id=state.agent_id, + status=state.status.value, + ) + return state + + async def get_active(self) -> tuple[AgentRuntimeState, ...]: + """Retrieve all non-idle agent states, ordered by last_activity_at DESC.""" + try: + cursor = await self._db.execute( + "SELECT agent_id, execution_id, task_id, status, " + "turn_count, accumulated_cost_usd, last_activity_at, started_at " + "FROM agent_states WHERE status != ? " + "ORDER BY last_activity_at DESC", + (ExecutionStatus.IDLE.value,), + ) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to query active agent states" + logger.exception( + PERSISTENCE_AGENT_STATE_ACTIVE_QUERY_FAILED, + error=str(exc), + ) + raise QueryError(msg) from exc + + states = tuple(self._row_to_model(dict(row)) for row in rows) + logger.debug( + PERSISTENCE_AGENT_STATE_ACTIVE_QUERIED, + count=len(states), + ) + return states + + async def delete(self, agent_id: NotBlankStr) -> bool: + """Delete an agent runtime state by agent ID.""" + try: + cursor = await self._db.execute( + "DELETE FROM agent_states WHERE agent_id = ?", + (agent_id,), + ) + deleted = cursor.rowcount > 0 + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete agent state for {agent_id!r}" + logger.exception( + PERSISTENCE_AGENT_STATE_DELETE_FAILED, + agent_id=agent_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if deleted: + logger.info( + PERSISTENCE_AGENT_STATE_DELETED, + agent_id=agent_id, + ) + else: + logger.debug( + PERSISTENCE_AGENT_STATE_NOT_FOUND, + agent_id=agent_id, + ) + return deleted + + def _row_to_model(self, row: dict[str, object]) -> AgentRuntimeState: + """Convert a database row to an ``AgentRuntimeState`` model. + + Raises: + QueryError: If the row cannot be deserialized. + """ + try: + return AgentRuntimeState.model_validate(row) + except ValidationError as exc: + msg = f"Failed to deserialize agent state {row.get('agent_id')!r}" + logger.exception( + PERSISTENCE_AGENT_STATE_DESERIALIZE_FAILED, + agent_id=row.get("agent_id"), + error=str(exc), + ) + raise QueryError(msg) from exc diff --git a/src/synthorg/persistence/sqlite/backend.py b/src/synthorg/persistence/sqlite/backend.py index 523b941529..9ad590fffd 100644 --- a/src/synthorg/persistence/sqlite/backend.py +++ b/src/synthorg/persistence/sqlite/backend.py @@ -26,6 +26,9 @@ PersistenceConnectionError, QueryError, ) +from synthorg.persistence.sqlite.agent_state_repo import ( + SQLiteAgentStateRepository, +) from synthorg.persistence.sqlite.audit_repository import ( SQLiteAuditRepository, ) @@ -87,6 +90,7 @@ def __init__(self, config: SQLiteConfig) -> None: self._api_keys: SQLiteApiKeyRepository | None = None self._checkpoints: SQLiteCheckpointRepository | None = None self._heartbeats: SQLiteHeartbeatRepository | None = None + self._agent_states: SQLiteAgentStateRepository | None = None def _clear_state(self) -> None: """Reset connection and repository references to ``None``.""" @@ -103,6 +107,7 @@ def _clear_state(self) -> None: self._api_keys = None self._checkpoints = None self._heartbeats = None + self._agent_states = None async def connect(self) -> None: """Open the SQLite database and configure WAL mode.""" @@ -169,6 +174,7 @@ def _create_repositories(self) -> None: self._api_keys = SQLiteApiKeyRepository(self._db) self._checkpoints = SQLiteCheckpointRepository(self._db) self._heartbeats = SQLiteHeartbeatRepository(self._db) + self._agent_states = SQLiteAgentStateRepository(self._db) async def _cleanup_failed_connect(self, exc: sqlite3.Error | OSError) -> None: """Log failure, close partial connection, and raise. @@ -387,6 +393,15 @@ def heartbeats(self) -> SQLiteHeartbeatRepository: """ return self._require_connected(self._heartbeats, "heartbeats") + @property + def agent_states(self) -> SQLiteAgentStateRepository: + """Repository for AgentRuntimeState persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._agent_states, "agent_states") + async def get_setting(self, key: str) -> str | None: """Retrieve a setting value by key. diff --git a/src/synthorg/persistence/sqlite/migrations.py b/src/synthorg/persistence/sqlite/migrations.py index 041e385d75..d9b0fc544b 100644 --- a/src/synthorg/persistence/sqlite/migrations.py +++ b/src/synthorg/persistence/sqlite/migrations.py @@ -23,7 +23,7 @@ logger = get_logger(__name__) # Current schema version — bump when adding new migrations. -SCHEMA_VERSION = 7 +SCHEMA_VERSION = 8 _V1_STATEMENTS: Sequence[str] = ( # ── Tasks ───────────────────────────────────────────── @@ -404,6 +404,44 @@ async def _apply_v7(db: aiosqlite.Connection) -> None: ) +_V8_STATEMENTS: Sequence[str] = ( + # ── Agent states ────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS agent_states ( + agent_id TEXT PRIMARY KEY, + execution_id TEXT, + task_id TEXT, + status TEXT NOT NULL DEFAULT 'idle' + CHECK (status IN ('idle', 'executing', 'paused')), + turn_count INTEGER NOT NULL DEFAULT 0 CHECK (turn_count >= 0), + accumulated_cost_usd REAL NOT NULL DEFAULT 0.0 + CHECK (accumulated_cost_usd >= 0.0), + last_activity_at TEXT NOT NULL, + started_at TEXT, + CHECK ( + (status = 'idle' + AND execution_id IS NULL + AND task_id IS NULL + AND started_at IS NULL + AND turn_count = 0 + AND accumulated_cost_usd = 0.0) + OR + (status IN ('executing', 'paused') + AND execution_id IS NOT NULL + AND started_at IS NOT NULL) + ) +)""", + "CREATE INDEX IF NOT EXISTS idx_as_status_activity " + "ON agent_states(status, last_activity_at DESC)", +) + + +async def _apply_v8(db: aiosqlite.Connection) -> None: + """Apply schema v8: agent_states.""" + for stmt in _V8_STATEMENTS: + await db.execute(stmt) + + # Ordered list of (target_version, migration_function) pairs. Each migration # is applied when the current schema version is below its target version. _MIGRATIONS: list[tuple[int, _MigrateFn]] = [ @@ -414,6 +452,7 @@ async def _apply_v7(db: aiosqlite.Connection) -> None: (5, _apply_v5), (6, _apply_v6), (7, _apply_v7), + (8, _apply_v8), ] diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index b3f44fd9fb..8ddde86af1 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -1,6 +1,5 @@ """Shared fixtures for API unit tests.""" -import asyncio import uuid from collections.abc import Generator from datetime import UTC, datetime, timedelta @@ -12,13 +11,10 @@ from synthorg.api.app import create_app from synthorg.api.approval_store import ApprovalStore from synthorg.api.auth.config import AuthConfig -from synthorg.api.auth.models import ApiKey, User +from synthorg.api.auth.models import User from synthorg.api.auth.service import AuthService from synthorg.api.guards import HumanRole -from synthorg.budget.cost_record import CostRecord from synthorg.budget.tracker import CostTracker -from synthorg.communication.channel import Channel -from synthorg.communication.message import Message from synthorg.config.schema import RootConfig from synthorg.core.approval import ApprovalItem from synthorg.core.enums import ( @@ -27,17 +23,13 @@ TaskStatus, ) from synthorg.core.task import Task -from synthorg.engine.checkpoint.models import Checkpoint, Heartbeat from synthorg.engine.task_engine import TaskEngine -from synthorg.hr.enums import LifecycleEventType -from synthorg.hr.models import AgentLifecycleEvent -from synthorg.hr.performance.models import ( - CollaborationMetricRecord, - TaskMetricRecord, +from tests.unit.api.fakes import ( + FakeMessageBus, + FakePersistenceBackend, ) -from synthorg.persistence.errors import DuplicateRecordError, QueryError -from synthorg.security.models import AuditEntry, AuditVerdictStr -from synthorg.security.timeout.parked_context import ParkedContext + +__all__ = ["FakeMessageBus", "FakePersistenceBackend"] # ── Test auth constants ─────────────────────────────────────── @@ -45,515 +37,6 @@ _TEST_USER_ID = "test-user-001" _TEST_USERNAME = "testadmin" -# ── Fake Repositories ──────────────────────────────────────────── - - -class FakeTaskRepository: - """In-memory task repository for tests.""" - - def __init__(self) -> None: - self._tasks: dict[str, Task] = {} - - async def save(self, task: Task) -> None: - self._tasks[task.id] = task - - async def get(self, task_id: str) -> Task | None: - return self._tasks.get(task_id) - - async def list_tasks( - self, - *, - status: TaskStatus | None = None, - assigned_to: str | None = None, - project: str | None = None, - ) -> tuple[Task, ...]: - result = list(self._tasks.values()) - if status is not None: - result = [t for t in result if t.status == status] - if assigned_to is not None: - result = [t for t in result if t.assigned_to == assigned_to] - if project is not None: - result = [t for t in result if t.project == project] - return tuple(result) - - async def delete(self, task_id: str) -> bool: - return self._tasks.pop(task_id, None) is not None - - -class FakeCostRecordRepository: - """In-memory cost record repository for tests.""" - - def __init__(self) -> None: - self._records: list[CostRecord] = [] - - async def save(self, record: CostRecord) -> None: - self._records.append(record) - - async def query( - self, - *, - agent_id: str | None = None, - task_id: str | None = None, - ) -> tuple[CostRecord, ...]: - result = self._records - if agent_id is not None: - result = [r for r in result if r.agent_id == agent_id] - if task_id is not None: - result = [r for r in result if r.task_id == task_id] - return tuple(result) - - async def aggregate( - self, - *, - agent_id: str | None = None, - task_id: str | None = None, - ) -> float: - records = await self.query(agent_id=agent_id, task_id=task_id) - return sum(r.cost_usd for r in records) - - -class FakeMessageRepository: - """In-memory message repository for tests.""" - - def __init__(self) -> None: - self._messages: list[Message] = [] - - async def save(self, message: Message) -> None: - self._messages.append(message) - - async def get_history( - self, - channel: str, - *, - limit: int | None = None, - ) -> tuple[Message, ...]: - result = [m for m in self._messages if m.channel == channel] - if limit is not None and limit > 0: - result = result[-limit:] - return tuple(result) - - -# ── Fake Persistence Backend ──────────────────────────────────── - - -class FakeLifecycleEventRepository: - """In-memory lifecycle event repository for tests.""" - - def __init__(self) -> None: - self._events: list[AgentLifecycleEvent] = [] - - async def save(self, event: AgentLifecycleEvent) -> None: - self._events.append(event) - - async def list_events( - self, - *, - agent_id: str | None = None, - event_type: LifecycleEventType | None = None, - since: datetime | None = None, - ) -> tuple[AgentLifecycleEvent, ...]: - result = self._events - if agent_id is not None: - result = [e for e in result if e.agent_id == agent_id] - if event_type is not None: - result = [e for e in result if e.event_type == event_type] - if since is not None: - result = [e for e in result if e.timestamp >= since] - return tuple(result) - - -class FakeTaskMetricRepository: - """In-memory task metric repository for tests.""" - - def __init__(self) -> None: - self._records: list[TaskMetricRecord] = [] - - async def save(self, record: TaskMetricRecord) -> None: - self._records.append(record) - - async def query( - self, - *, - agent_id: str | None = None, - since: datetime | None = None, - until: datetime | None = None, - ) -> tuple[TaskMetricRecord, ...]: - result = self._records - if agent_id is not None: - result = [r for r in result if r.agent_id == agent_id] - if since is not None: - result = [r for r in result if r.completed_at >= since] - if until is not None: - result = [r for r in result if r.completed_at <= until] - return tuple(result) - - -class FakeCollaborationMetricRepository: - """In-memory collaboration metric repository for tests.""" - - def __init__(self) -> None: - self._records: list[CollaborationMetricRecord] = [] - - async def save(self, record: CollaborationMetricRecord) -> None: - self._records.append(record) - - async def query( - self, - *, - agent_id: str | None = None, - since: datetime | None = None, - ) -> tuple[CollaborationMetricRecord, ...]: - result = self._records - if agent_id is not None: - result = [r for r in result if r.agent_id == agent_id] - if since is not None: - result = [r for r in result if r.recorded_at >= since] - return tuple(result) - - -class FakeParkedContextRepository: - """In-memory parked context repository for tests.""" - - def __init__(self) -> None: - self._contexts: dict[str, ParkedContext] = {} - - async def save(self, context: ParkedContext) -> None: - self._contexts[context.id] = context - - async def get(self, parked_id: str) -> ParkedContext | None: - return self._contexts.get(parked_id) - - async def get_by_approval(self, approval_id: str) -> ParkedContext | None: - for ctx in self._contexts.values(): - if ctx.approval_id == approval_id: - return ctx - return None - - async def get_by_agent(self, agent_id: str) -> tuple[ParkedContext, ...]: - return tuple(ctx for ctx in self._contexts.values() if ctx.agent_id == agent_id) - - async def delete(self, parked_id: str) -> bool: - return self._contexts.pop(parked_id, None) is not None - - -class FakeAuditRepository: - """In-memory audit entry repository for tests.""" - - def __init__(self) -> None: - self._entries: dict[str, AuditEntry] = {} - - async def save(self, entry: AuditEntry) -> None: - if entry.id in self._entries: - msg = f"Duplicate audit entry {entry.id!r}" - raise DuplicateRecordError(msg) - self._entries[entry.id] = entry - - async def query( # noqa: PLR0913 - self, - *, - agent_id: str | None = None, - action_type: str | None = None, - verdict: AuditVerdictStr | None = None, - risk_level: ApprovalRiskLevel | None = None, - since: datetime | None = None, - until: datetime | None = None, - limit: int = 100, - ) -> tuple[AuditEntry, ...]: - if limit < 1: - msg = "limit must be >= 1" - raise QueryError(msg) - results = sorted( - self._entries.values(), - key=lambda e: e.timestamp, - reverse=True, - ) - if agent_id is not None: - results = [e for e in results if e.agent_id == agent_id] - if action_type is not None: - results = [e for e in results if e.action_type == action_type] - if verdict is not None: - results = [e for e in results if e.verdict == verdict] - if risk_level is not None: - results = [e for e in results if e.risk_level == risk_level] - if since is not None: - results = [e for e in results if e.timestamp >= since] - if until is not None: - results = [e for e in results if e.timestamp <= until] - return tuple(results[:limit]) - - -class FakeUserRepository: - """In-memory user repository for tests.""" - - def __init__(self) -> None: - self._users: dict[str, User] = {} - - async def save(self, user: User) -> None: - self._users[user.id] = user - - async def get(self, user_id: str) -> User | None: - return self._users.get(user_id) - - async def get_by_username(self, username: str) -> User | None: - for user in self._users.values(): - if user.username == username: - return user - return None - - async def list_users(self) -> tuple[User, ...]: - return tuple(self._users.values()) - - async def count(self) -> int: - return len(self._users) - - async def delete(self, user_id: str) -> bool: - return self._users.pop(user_id, None) is not None - - -class FakeApiKeyRepository: - """In-memory API key repository for tests.""" - - def __init__(self) -> None: - self._keys: dict[str, ApiKey] = {} - - async def save(self, key: ApiKey) -> None: - self._keys[key.id] = key - - async def get(self, key_id: str) -> ApiKey | None: - return self._keys.get(key_id) - - async def get_by_hash(self, key_hash: str) -> ApiKey | None: - for key in self._keys.values(): - if key.key_hash == key_hash: - return key - return None - - async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: - return tuple(k for k in self._keys.values() if k.user_id == user_id) - - async def delete(self, key_id: str) -> bool: - return self._keys.pop(key_id, None) is not None - - -class FakeCheckpointRepository: - """In-memory checkpoint repository for tests.""" - - def __init__(self) -> None: - self._checkpoints: dict[str, Checkpoint] = {} - - async def save(self, checkpoint: Checkpoint) -> None: - self._checkpoints[checkpoint.id] = checkpoint - - async def get_latest( - self, - *, - execution_id: str | None = None, - task_id: str | None = None, - ) -> Checkpoint | None: - if execution_id is None and task_id is None: - msg = "At least one of execution_id or task_id is required" - raise ValueError(msg) - candidates = list(self._checkpoints.values()) - if execution_id is not None: - candidates = [c for c in candidates if c.execution_id == execution_id] - if task_id is not None: - candidates = [c for c in candidates if c.task_id == task_id] - if not candidates: - return None - return max(candidates, key=lambda c: c.turn_number) - - async def delete_by_execution(self, execution_id: str) -> int: - to_delete = [ - k for k, v in self._checkpoints.items() if v.execution_id == execution_id - ] - for k in to_delete: - del self._checkpoints[k] - return len(to_delete) - - -class FakeHeartbeatRepository: - """In-memory heartbeat repository for tests.""" - - def __init__(self) -> None: - self._heartbeats: dict[str, Heartbeat] = {} - - async def save(self, heartbeat: Heartbeat) -> None: - self._heartbeats[heartbeat.execution_id] = heartbeat - - async def get(self, execution_id: str) -> Heartbeat | None: - return self._heartbeats.get(execution_id) - - async def get_stale(self, threshold: datetime) -> tuple[Heartbeat, ...]: - stale = [ - h for h in self._heartbeats.values() if h.last_heartbeat_at < threshold - ] - stale.sort(key=lambda h: h.last_heartbeat_at) - return tuple(stale) - - async def delete(self, execution_id: str) -> bool: - return self._heartbeats.pop(execution_id, None) is not None - - -class FakePersistenceBackend: - """In-memory persistence backend for tests.""" - - def __init__(self) -> None: - self._tasks = FakeTaskRepository() - self._cost_records = FakeCostRecordRepository() - self._messages = FakeMessageRepository() - self._lifecycle_events = FakeLifecycleEventRepository() - self._task_metrics = FakeTaskMetricRepository() - self._collaboration_metrics = FakeCollaborationMetricRepository() - self._parked_contexts = FakeParkedContextRepository() - self._audit_entries = FakeAuditRepository() - self._users = FakeUserRepository() - self._api_keys = FakeApiKeyRepository() - self._checkpoints = FakeCheckpointRepository() - self._heartbeats_repo = FakeHeartbeatRepository() - self._settings: dict[str, str] = {} - self._connected = False - - async def connect(self) -> None: - self._connected = True - - async def disconnect(self) -> None: - self._connected = False - - async def health_check(self) -> bool: - return self._connected - - async def migrate(self) -> None: - pass - - @property - def is_connected(self) -> bool: - return self._connected - - @property - def backend_name(self) -> str: - return "fake" - - @property - def tasks(self) -> FakeTaskRepository: - return self._tasks - - @property - def cost_records(self) -> FakeCostRecordRepository: - return self._cost_records - - @property - def messages(self) -> FakeMessageRepository: - return self._messages - - @property - def lifecycle_events(self) -> FakeLifecycleEventRepository: - return self._lifecycle_events - - @property - def task_metrics(self) -> FakeTaskMetricRepository: - return self._task_metrics - - @property - def collaboration_metrics(self) -> FakeCollaborationMetricRepository: - return self._collaboration_metrics - - @property - def parked_contexts(self) -> FakeParkedContextRepository: - return self._parked_contexts - - @property - def audit_entries(self) -> FakeAuditRepository: - return self._audit_entries - - @property - def users(self) -> FakeUserRepository: - return self._users - - @property - def api_keys(self) -> FakeApiKeyRepository: - return self._api_keys - - @property - def checkpoints(self) -> FakeCheckpointRepository: - return self._checkpoints - - @property - def heartbeats(self) -> FakeHeartbeatRepository: - return self._heartbeats_repo - - async def get_setting(self, key: str) -> str | None: - return self._settings.get(key) - - async def set_setting(self, key: str, value: str) -> None: - self._settings[key] = value - - -# ── Fake Message Bus ──────────────────────────────────────────── - - -class FakeMessageBus: - """In-memory message bus for tests.""" - - def __init__(self) -> None: - self._running = False - self._channels: list[Channel] = [] - - async def start(self) -> None: - self._running = True - - async def stop(self) -> None: - self._running = False - - @property - def is_running(self) -> bool: - return self._running - - async def publish(self, message: Message) -> None: - pass - - async def send_direct(self, message: Message, *, recipient: str) -> None: - pass - - async def subscribe(self, channel_name: str, subscriber_id: str) -> Any: - return None - - async def unsubscribe(self, channel_name: str, subscriber_id: str) -> None: - pass - - async def receive( - self, - channel_name: str, - subscriber_id: str, - *, - timeout: float | None = None, # noqa: ASYNC109 - ) -> Any: - # Simulate waiting for a message (yields to event loop) - if timeout is not None: - await asyncio.sleep(min(timeout, 0.01)) - return None - - async def create_channel(self, channel: Channel) -> Channel: - self._channels.append(channel) - return channel - - async def get_channel(self, channel_name: str) -> Channel: - for ch in self._channels: - if ch.name == channel_name: - return ch - msg = f"Channel {channel_name!r} not found" - raise ValueError(msg) - - async def list_channels(self) -> tuple[Channel, ...]: - return tuple(self._channels) - - async def get_channel_history( - self, - channel_name: str, - *, - limit: int | None = None, - ) -> tuple[Message, ...]: - return () - # ── Auth helpers ──────────────────────────────────────────────── diff --git a/tests/unit/api/fakes.py b/tests/unit/api/fakes.py new file mode 100644 index 0000000000..dd2e118dc1 --- /dev/null +++ b/tests/unit/api/fakes.py @@ -0,0 +1,567 @@ +"""In-memory fake implementations for API unit tests.""" + +import asyncio +from datetime import datetime +from typing import Any + +from synthorg.api.auth.models import ApiKey, User +from synthorg.budget.cost_record import CostRecord +from synthorg.communication.channel import Channel +from synthorg.communication.message import Message +from synthorg.core.enums import ( + ApprovalRiskLevel, + ExecutionStatus, + TaskStatus, +) +from synthorg.core.task import Task +from synthorg.engine.agent_state import AgentRuntimeState +from synthorg.engine.checkpoint.models import Checkpoint, Heartbeat +from synthorg.hr.enums import LifecycleEventType +from synthorg.hr.models import AgentLifecycleEvent +from synthorg.hr.performance.models import ( + CollaborationMetricRecord, + TaskMetricRecord, +) +from synthorg.persistence.errors import DuplicateRecordError, QueryError +from synthorg.security.models import AuditEntry, AuditVerdictStr +from synthorg.security.timeout.parked_context import ParkedContext + +# ── Fake Repositories ──────────────────────────────────────────── + + +class FakeTaskRepository: + """In-memory task repository for tests.""" + + def __init__(self) -> None: + self._tasks: dict[str, Task] = {} + + async def save(self, task: Task) -> None: + self._tasks[task.id] = task + + async def get(self, task_id: str) -> Task | None: + return self._tasks.get(task_id) + + async def list_tasks( + self, + *, + status: TaskStatus | None = None, + assigned_to: str | None = None, + project: str | None = None, + ) -> tuple[Task, ...]: + result = list(self._tasks.values()) + if status is not None: + result = [t for t in result if t.status == status] + if assigned_to is not None: + result = [t for t in result if t.assigned_to == assigned_to] + if project is not None: + result = [t for t in result if t.project == project] + return tuple(result) + + async def delete(self, task_id: str) -> bool: + return self._tasks.pop(task_id, None) is not None + + +class FakeCostRecordRepository: + """In-memory cost record repository for tests.""" + + def __init__(self) -> None: + self._records: list[CostRecord] = [] + + async def save(self, record: CostRecord) -> None: + self._records.append(record) + + async def query( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> tuple[CostRecord, ...]: + result = self._records + if agent_id is not None: + result = [r for r in result if r.agent_id == agent_id] + if task_id is not None: + result = [r for r in result if r.task_id == task_id] + return tuple(result) + + async def aggregate( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> float: + records = await self.query(agent_id=agent_id, task_id=task_id) + return sum(r.cost_usd for r in records) + + +class FakeMessageRepository: + """In-memory message repository for tests.""" + + def __init__(self) -> None: + self._messages: list[Message] = [] + + async def save(self, message: Message) -> None: + if any(m.id == message.id for m in self._messages): + msg = f"Message {message.id} already exists" + raise DuplicateRecordError(msg) + self._messages.append(message) + + async def get_history( + self, + channel: str, + *, + limit: int | None = None, + ) -> tuple[Message, ...]: + if limit is not None and limit < 1: + msg = f"limit must be a positive integer, got {limit}" + raise QueryError(msg) + result = sorted( + (m for m in self._messages if m.channel == channel), + key=lambda m: m.timestamp, + reverse=True, + ) + if limit is not None: + result = result[:limit] + return tuple(result) + + +class FakeLifecycleEventRepository: + """In-memory lifecycle event repository for tests.""" + + def __init__(self) -> None: + self._events: list[AgentLifecycleEvent] = [] + + async def save(self, event: AgentLifecycleEvent) -> None: + self._events.append(event) + + async def list_events( + self, + *, + agent_id: str | None = None, + event_type: LifecycleEventType | None = None, + since: datetime | None = None, + ) -> tuple[AgentLifecycleEvent, ...]: + result = self._events + if agent_id is not None: + result = [e for e in result if e.agent_id == agent_id] + if event_type is not None: + result = [e for e in result if e.event_type == event_type] + if since is not None: + result = [e for e in result if e.timestamp >= since] + return tuple(result) + + +class FakeTaskMetricRepository: + """In-memory task metric repository for tests.""" + + def __init__(self) -> None: + self._records: list[TaskMetricRecord] = [] + + async def save(self, record: TaskMetricRecord) -> None: + self._records.append(record) + + async def query( + self, + *, + agent_id: str | None = None, + since: datetime | None = None, + until: datetime | None = None, + ) -> tuple[TaskMetricRecord, ...]: + result = self._records + if agent_id is not None: + result = [r for r in result if r.agent_id == agent_id] + if since is not None: + result = [r for r in result if r.completed_at >= since] + if until is not None: + result = [r for r in result if r.completed_at <= until] + return tuple(result) + + +class FakeCollaborationMetricRepository: + """In-memory collaboration metric repository for tests.""" + + def __init__(self) -> None: + self._records: list[CollaborationMetricRecord] = [] + + async def save(self, record: CollaborationMetricRecord) -> None: + self._records.append(record) + + async def query( + self, + *, + agent_id: str | None = None, + since: datetime | None = None, + ) -> tuple[CollaborationMetricRecord, ...]: + result = self._records + if agent_id is not None: + result = [r for r in result if r.agent_id == agent_id] + if since is not None: + result = [r for r in result if r.recorded_at >= since] + return tuple(result) + + +class FakeParkedContextRepository: + """In-memory parked context repository for tests.""" + + def __init__(self) -> None: + self._contexts: dict[str, ParkedContext] = {} + + async def save(self, context: ParkedContext) -> None: + self._contexts[context.id] = context + + async def get(self, parked_id: str) -> ParkedContext | None: + return self._contexts.get(parked_id) + + async def get_by_approval(self, approval_id: str) -> ParkedContext | None: + for ctx in self._contexts.values(): + if ctx.approval_id == approval_id: + return ctx + return None + + async def get_by_agent(self, agent_id: str) -> tuple[ParkedContext, ...]: + return tuple(ctx for ctx in self._contexts.values() if ctx.agent_id == agent_id) + + async def delete(self, parked_id: str) -> bool: + return self._contexts.pop(parked_id, None) is not None + + +class FakeAuditRepository: + """In-memory audit entry repository for tests.""" + + def __init__(self) -> None: + self._entries: dict[str, AuditEntry] = {} + + async def save(self, entry: AuditEntry) -> None: + if entry.id in self._entries: + msg = f"Duplicate audit entry {entry.id!r}" + raise DuplicateRecordError(msg) + self._entries[entry.id] = entry + + async def query( # noqa: PLR0913 + self, + *, + agent_id: str | None = None, + action_type: str | None = None, + verdict: AuditVerdictStr | None = None, + risk_level: ApprovalRiskLevel | None = None, + since: datetime | None = None, + until: datetime | None = None, + limit: int = 100, + ) -> tuple[AuditEntry, ...]: + if limit < 1: + msg = "limit must be >= 1" + raise QueryError(msg) + if since is not None and until is not None and until < since: + msg = "until must not be earlier than since" + raise QueryError(msg) + results = sorted( + self._entries.values(), + key=lambda e: e.timestamp, + reverse=True, + ) + if agent_id is not None: + results = [e for e in results if e.agent_id == agent_id] + if action_type is not None: + results = [e for e in results if e.action_type == action_type] + if verdict is not None: + results = [e for e in results if e.verdict == verdict] + if risk_level is not None: + results = [e for e in results if e.risk_level == risk_level] + if since is not None: + results = [e for e in results if e.timestamp >= since] + if until is not None: + results = [e for e in results if e.timestamp <= until] + return tuple(results[:limit]) + + +class FakeUserRepository: + """In-memory user repository for tests.""" + + def __init__(self) -> None: + self._users: dict[str, User] = {} + + async def save(self, user: User) -> None: + self._users[user.id] = user + + async def get(self, user_id: str) -> User | None: + return self._users.get(user_id) + + async def get_by_username(self, username: str) -> User | None: + for user in self._users.values(): + if user.username == username: + return user + return None + + async def list_users(self) -> tuple[User, ...]: + return tuple(self._users.values()) + + async def count(self) -> int: + return len(self._users) + + async def delete(self, user_id: str) -> bool: + return self._users.pop(user_id, None) is not None + + +class FakeApiKeyRepository: + """In-memory API key repository for tests.""" + + def __init__(self) -> None: + self._keys: dict[str, ApiKey] = {} + + async def save(self, key: ApiKey) -> None: + self._keys[key.id] = key + + async def get(self, key_id: str) -> ApiKey | None: + return self._keys.get(key_id) + + async def get_by_hash(self, key_hash: str) -> ApiKey | None: + for key in self._keys.values(): + if key.key_hash == key_hash: + return key + return None + + async def list_by_user(self, user_id: str) -> tuple[ApiKey, ...]: + return tuple(k for k in self._keys.values() if k.user_id == user_id) + + async def delete(self, key_id: str) -> bool: + return self._keys.pop(key_id, None) is not None + + +class FakeCheckpointRepository: + """In-memory checkpoint repository for tests.""" + + def __init__(self) -> None: + self._checkpoints: dict[str, Checkpoint] = {} + + async def save(self, checkpoint: Checkpoint) -> None: + self._checkpoints[checkpoint.id] = checkpoint + + async def get_latest( + self, + *, + execution_id: str | None = None, + task_id: str | None = None, + ) -> Checkpoint | None: + if execution_id is None and task_id is None: + msg = "At least one of execution_id or task_id is required" + raise ValueError(msg) + candidates = list(self._checkpoints.values()) + if execution_id is not None: + candidates = [c for c in candidates if c.execution_id == execution_id] + if task_id is not None: + candidates = [c for c in candidates if c.task_id == task_id] + if not candidates: + return None + return max(candidates, key=lambda c: c.turn_number) + + async def delete_by_execution(self, execution_id: str) -> int: + to_delete = [ + k for k, v in self._checkpoints.items() if v.execution_id == execution_id + ] + for k in to_delete: + del self._checkpoints[k] + return len(to_delete) + + +class FakeHeartbeatRepository: + """In-memory heartbeat repository for tests.""" + + def __init__(self) -> None: + self._heartbeats: dict[str, Heartbeat] = {} + + async def save(self, heartbeat: Heartbeat) -> None: + self._heartbeats[heartbeat.execution_id] = heartbeat + + async def get(self, execution_id: str) -> Heartbeat | None: + return self._heartbeats.get(execution_id) + + async def get_stale(self, threshold: datetime) -> tuple[Heartbeat, ...]: + stale = [ + h for h in self._heartbeats.values() if h.last_heartbeat_at < threshold + ] + stale.sort(key=lambda h: h.last_heartbeat_at) + return tuple(stale) + + async def delete(self, execution_id: str) -> bool: + return self._heartbeats.pop(execution_id, None) is not None + + +class FakeAgentStateRepository: + """In-memory agent state repository for tests.""" + + def __init__(self) -> None: + self._states: dict[str, AgentRuntimeState] = {} + + async def save(self, state: AgentRuntimeState) -> None: + self._states[state.agent_id] = state + + async def get(self, agent_id: str) -> AgentRuntimeState | None: + return self._states.get(agent_id) + + async def get_active(self) -> tuple[AgentRuntimeState, ...]: + active = (s for s in self._states.values() if s.status != ExecutionStatus.IDLE) + return tuple(sorted(active, key=lambda s: s.last_activity_at, reverse=True)) + + async def delete(self, agent_id: str) -> bool: + return self._states.pop(agent_id, None) is not None + + +class FakePersistenceBackend: + """In-memory persistence backend for tests.""" + + def __init__(self) -> None: + self._tasks = FakeTaskRepository() + self._cost_records = FakeCostRecordRepository() + self._messages = FakeMessageRepository() + self._lifecycle_events = FakeLifecycleEventRepository() + self._task_metrics = FakeTaskMetricRepository() + self._collaboration_metrics = FakeCollaborationMetricRepository() + self._parked_contexts = FakeParkedContextRepository() + self._audit_entries = FakeAuditRepository() + self._users = FakeUserRepository() + self._api_keys = FakeApiKeyRepository() + self._checkpoints = FakeCheckpointRepository() + self._heartbeats = FakeHeartbeatRepository() + self._agent_states = FakeAgentStateRepository() + self._settings: dict[str, str] = {} + self._connected = False + + async def connect(self) -> None: + self._connected = True + + async def disconnect(self) -> None: + self._connected = False + + async def health_check(self) -> bool: + return self._connected + + async def migrate(self) -> None: + pass + + @property + def is_connected(self) -> bool: + return self._connected + + @property + def backend_name(self) -> str: + return "fake" + + @property + def tasks(self) -> FakeTaskRepository: + return self._tasks + + @property + def cost_records(self) -> FakeCostRecordRepository: + return self._cost_records + + @property + def messages(self) -> FakeMessageRepository: + return self._messages + + @property + def lifecycle_events(self) -> FakeLifecycleEventRepository: + return self._lifecycle_events + + @property + def task_metrics(self) -> FakeTaskMetricRepository: + return self._task_metrics + + @property + def collaboration_metrics(self) -> FakeCollaborationMetricRepository: + return self._collaboration_metrics + + @property + def parked_contexts(self) -> FakeParkedContextRepository: + return self._parked_contexts + + @property + def audit_entries(self) -> FakeAuditRepository: + return self._audit_entries + + @property + def users(self) -> FakeUserRepository: + return self._users + + @property + def api_keys(self) -> FakeApiKeyRepository: + return self._api_keys + + @property + def checkpoints(self) -> FakeCheckpointRepository: + return self._checkpoints + + @property + def heartbeats(self) -> FakeHeartbeatRepository: + return self._heartbeats + + @property + def agent_states(self) -> FakeAgentStateRepository: + return self._agent_states + + async def get_setting(self, key: str) -> str | None: + return self._settings.get(key) + + async def set_setting(self, key: str, value: str) -> None: + self._settings[key] = value + + +class FakeMessageBus: + """In-memory message bus for tests.""" + + def __init__(self) -> None: + self._running = False + self._channels: list[Channel] = [] + + async def start(self) -> None: + self._running = True + + async def stop(self) -> None: + self._running = False + + @property + def is_running(self) -> bool: + return self._running + + async def publish(self, message: Message) -> None: + pass + + async def send_direct(self, message: Message, *, recipient: str) -> None: + pass + + async def subscribe(self, channel_name: str, subscriber_id: str) -> Any: + return None + + async def unsubscribe(self, channel_name: str, subscriber_id: str) -> None: + pass + + async def receive( + self, + channel_name: str, + subscriber_id: str, + *, + timeout: float | None = None, # noqa: ASYNC109 + ) -> Any: + # Yield to event loop without real delay (deterministic in tests) + await asyncio.sleep(0) + return None + + async def create_channel(self, channel: Channel) -> Channel: + self._channels.append(channel) + return channel + + async def get_channel(self, channel_name: str) -> Channel: + for ch in self._channels: + if ch.name == channel_name: + return ch + msg = f"Channel {channel_name!r} not found" + raise ValueError(msg) + + async def list_channels(self) -> tuple[Channel, ...]: + return tuple(self._channels) + + async def get_channel_history( + self, + channel_name: str, + *, + limit: int | None = None, + ) -> tuple[Message, ...]: + return () diff --git a/tests/unit/engine/test_agent_state.py b/tests/unit/engine/test_agent_state.py new file mode 100644 index 0000000000..6ce7c881a9 --- /dev/null +++ b/tests/unit/engine/test_agent_state.py @@ -0,0 +1,285 @@ +"""Tests for AgentRuntimeState model.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING +from uuid import uuid4 + +import pytest +from pydantic import AwareDatetime + +from synthorg.core.enums import ExecutionStatus +from synthorg.engine.agent_state import AgentRuntimeState + +if TYPE_CHECKING: + from synthorg.engine.context import AgentContext + +pytestmark = pytest.mark.timeout(30) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_NOW = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + + +def _make_executing_state( # noqa: PLR0913 + *, + agent_id: str = "agent-001", + execution_id: str = "exec-001", + task_id: str | None = "task-001", + status: ExecutionStatus = ExecutionStatus.EXECUTING, + turn_count: int = 3, + accumulated_cost_usd: float = 0.05, + last_activity_at: AwareDatetime = _NOW, + started_at: AwareDatetime = _NOW, +) -> AgentRuntimeState: + return AgentRuntimeState( + agent_id=agent_id, + execution_id=execution_id, + task_id=task_id, + status=status, + turn_count=turn_count, + accumulated_cost_usd=accumulated_cost_usd, + last_activity_at=last_activity_at, + started_at=started_at, + ) + + +def _make_context( + *, + agent_id: str = "agent-ctx", + task_id: str | None = "task-ctx", + turn_count: int = 5, + cost_usd: float = 0.10, +) -> AgentContext: + """Build a minimal AgentContext for testing from_context.""" + from datetime import date + from uuid import UUID + + from synthorg.core.agent import AgentIdentity, ModelConfig + from synthorg.core.enums import TaskType + from synthorg.core.task import Task + from synthorg.engine.context import AgentContext + from synthorg.engine.task_execution import TaskExecution + from synthorg.providers.models import ZERO_TOKEN_USAGE, TokenUsage + + identity = AgentIdentity( + id=UUID(int=0) if agent_id == "agent-ctx" else uuid4(), + name=agent_id, + role="engineer", + department="engineering", + model=ModelConfig(provider="test-provider", model_id="test-small-001"), + hiring_date=date(2026, 1, 1), + ) + + task_execution = None + if task_id is not None: + task = Task( + id=task_id, + title="Test task", + description="A test task", + type=TaskType.DEVELOPMENT, + project="test-project", + created_by=str(identity.id), + ) + task_execution = TaskExecution.from_task(task) + + usage = ( + TokenUsage( + input_tokens=100, + output_tokens=50, + cost_usd=cost_usd, + ) + if cost_usd > 0 + else ZERO_TOKEN_USAGE + ) + + return AgentContext( + execution_id=str(uuid4()), + identity=identity, + task_execution=task_execution, + turn_count=turn_count, + accumulated_cost=usage, + started_at=_NOW, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestAgentRuntimeStateIdle: + """Tests for the idle() factory.""" + + def test_idle_creates_idle_state(self) -> None: + state = AgentRuntimeState.idle("agent-idle") + assert state.agent_id == "agent-idle" + assert state.status == ExecutionStatus.IDLE + assert state.execution_id is None + assert state.task_id is None + assert state.started_at is None + assert state.turn_count == 0 + assert state.accumulated_cost_usd == 0.0 + + def test_idle_sets_last_activity_at(self) -> None: + state = AgentRuntimeState.idle("agent-idle") + assert state.last_activity_at is not None + assert state.last_activity_at.tzinfo is not None + + def test_idle_with_blank_agent_id_raises(self) -> None: + with pytest.raises(ValueError, match="whitespace"): + AgentRuntimeState.idle(" ") + + +@pytest.mark.unit +class TestAgentRuntimeStateFromContext: + """Tests for the from_context() factory.""" + + def test_from_context_executing(self) -> None: + ctx = _make_context() + state = AgentRuntimeState.from_context(ctx, ExecutionStatus.EXECUTING) + assert state.agent_id == str(ctx.identity.id) + assert state.execution_id == ctx.execution_id + assert state.status == ExecutionStatus.EXECUTING + assert state.turn_count == ctx.turn_count + assert state.accumulated_cost_usd == ctx.accumulated_cost.cost_usd + assert state.started_at == ctx.started_at + + def test_from_context_paused(self) -> None: + ctx = _make_context() + state = AgentRuntimeState.from_context(ctx, ExecutionStatus.PAUSED) + assert state.status == ExecutionStatus.PAUSED + + def test_from_context_with_task(self) -> None: + ctx = _make_context(task_id="my-task") + state = AgentRuntimeState.from_context(ctx, ExecutionStatus.EXECUTING) + assert state.task_id == "my-task" + + def test_from_context_without_task(self) -> None: + ctx = _make_context(task_id=None) + state = AgentRuntimeState.from_context(ctx, ExecutionStatus.EXECUTING) + assert state.task_id is None + + def test_from_context_rejects_idle(self) -> None: + ctx = _make_context() + with pytest.raises(ValueError, match="IDLE"): + AgentRuntimeState.from_context(ctx, ExecutionStatus.IDLE) + + def test_from_context_with_zero_cost(self) -> None: + ctx = _make_context(cost_usd=0.0) + state = AgentRuntimeState.from_context(ctx, ExecutionStatus.EXECUTING) + assert state.accumulated_cost_usd == 0.0 + + +@pytest.mark.unit +class TestAgentRuntimeStateValidation: + """Tests for status invariant validation.""" + + @pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"execution_id": "e"}, "execution_id must be None"), + ({"task_id": "t"}, "task_id must be None"), + ({"started_at": _NOW}, "started_at must be None"), + ({"turn_count": 1}, "turn_count must be 0"), + ({"accumulated_cost_usd": 0.01}, r"accumulated_cost_usd must be 0\.0"), + ], + ) + def test_idle_single_violation_raises( + self, kwargs: dict[str, object], match: str + ) -> None: + fields = { + "agent_id": "a", + "status": ExecutionStatus.IDLE, + "last_activity_at": _NOW, + **kwargs, + } + with pytest.raises(ValueError, match=match): + AgentRuntimeState.model_validate(fields) + + def test_executing_without_execution_id_raises(self) -> None: + with pytest.raises(ValueError, match="execution_id is required"): + AgentRuntimeState( + agent_id="a", + status=ExecutionStatus.EXECUTING, + started_at=_NOW, + last_activity_at=_NOW, + ) + + def test_executing_without_started_at_raises(self) -> None: + with pytest.raises(ValueError, match="started_at is required"): + AgentRuntimeState( + agent_id="a", + execution_id="e", + status=ExecutionStatus.EXECUTING, + last_activity_at=_NOW, + ) + + def test_paused_without_execution_id_raises(self) -> None: + with pytest.raises(ValueError, match="execution_id is required"): + AgentRuntimeState( + agent_id="a", + status=ExecutionStatus.PAUSED, + started_at=_NOW, + last_activity_at=_NOW, + ) + + def test_paused_without_started_at_raises(self) -> None: + with pytest.raises(ValueError, match="started_at is required"): + AgentRuntimeState( + agent_id="a", + execution_id="e", + status=ExecutionStatus.PAUSED, + last_activity_at=_NOW, + ) + + def test_multiple_idle_violations_reported(self) -> None: + """Multiple violations appear in a single error message.""" + with pytest.raises(ValueError, match=r"execution_id.*task_id"): + AgentRuntimeState( + agent_id="a", + execution_id="e", + task_id="t", + status=ExecutionStatus.IDLE, + last_activity_at=_NOW, + ) + + def test_negative_turn_count_raises(self) -> None: + with pytest.raises(ValueError, match="greater than or equal to 0"): + _make_executing_state(turn_count=-1) + + def test_negative_cost_raises(self) -> None: + with pytest.raises(ValueError, match="greater than or equal to 0"): + _make_executing_state(accumulated_cost_usd=-0.01) + + def test_blank_agent_id_raises(self) -> None: + with pytest.raises(ValueError, match="whitespace"): + _make_executing_state(agent_id=" ") + + +@pytest.mark.unit +class TestAgentRuntimeStateImmutability: + """Tests for frozen model behavior and serialization.""" + + def test_frozen(self) -> None: + from pydantic import ValidationError + + state = _make_executing_state() + with pytest.raises(ValidationError): + state.turn_count = 99 # type: ignore[misc] + + def test_json_roundtrip(self) -> None: + state = _make_executing_state() + data = state.model_dump(mode="json") + restored = AgentRuntimeState.model_validate(data) + assert restored == state + + def test_json_roundtrip_idle(self) -> None: + state = AgentRuntimeState.idle("agent-rt") + data = state.model_dump(mode="json") + restored = AgentRuntimeState.model_validate(data) + assert restored == state + assert restored.status == ExecutionStatus.IDLE diff --git a/tests/unit/persistence/sqlite/test_agent_state_repo.py b/tests/unit/persistence/sqlite/test_agent_state_repo.py new file mode 100644 index 0000000000..08e38dceb8 --- /dev/null +++ b/tests/unit/persistence/sqlite/test_agent_state_repo.py @@ -0,0 +1,332 @@ +"""Tests for SQLiteAgentStateRepository.""" + +from datetime import UTC, datetime +from typing import TYPE_CHECKING + +import pytest + +from synthorg.core.enums import ExecutionStatus +from synthorg.engine.agent_state import AgentRuntimeState +from synthorg.persistence.sqlite.agent_state_repo import ( + SQLiteAgentStateRepository, +) + +if TYPE_CHECKING: + import aiosqlite + +pytestmark = pytest.mark.timeout(30) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_T0 = datetime(2026, 3, 15, 10, 0, 0, tzinfo=UTC) +_T1 = datetime(2026, 3, 15, 11, 0, 0, tzinfo=UTC) +_T2 = datetime(2026, 3, 15, 12, 0, 0, tzinfo=UTC) + + +def _make_state( # noqa: PLR0913 + *, + agent_id: str = "agent-001", + execution_id: str | None = "exec-001", + task_id: str | None = "task-001", + status: ExecutionStatus = ExecutionStatus.EXECUTING, + turn_count: int = 3, + accumulated_cost_usd: float = 0.05, + last_activity_at: datetime = _T0, + started_at: datetime | None = _T0, +) -> AgentRuntimeState: + if status == ExecutionStatus.IDLE: + return AgentRuntimeState( + agent_id=agent_id, + status=ExecutionStatus.IDLE, + last_activity_at=last_activity_at, + ) + return AgentRuntimeState( + agent_id=agent_id, + execution_id=execution_id, + task_id=task_id, + status=status, + turn_count=turn_count, + accumulated_cost_usd=accumulated_cost_usd, + last_activity_at=last_activity_at, + started_at=started_at, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestSQLiteAgentStateRepository: + async def test_save_and_get_roundtrip( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + state = _make_state() + await repo.save(state) + + result = await repo.get("agent-001") + assert result is not None + assert result.agent_id == state.agent_id + assert result.execution_id == state.execution_id + assert result.task_id == state.task_id + assert result.status == state.status + assert result.turn_count == state.turn_count + assert result.accumulated_cost_usd == state.accumulated_cost_usd + assert result.last_activity_at == state.last_activity_at + assert result.started_at == state.started_at + + async def test_save_idle_roundtrip(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + state = _make_state( + agent_id="agent-idle", + status=ExecutionStatus.IDLE, + ) + await repo.save(state) + + result = await repo.get("agent-idle") + assert result is not None + assert result.status == ExecutionStatus.IDLE + assert result.execution_id is None + assert result.task_id is None + assert result.started_at is None + assert result.turn_count == 0 + assert result.accumulated_cost_usd == 0.0 + + async def test_upsert_overwrites(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + v1 = _make_state(turn_count=1) + await repo.save(v1) + + v2 = _make_state(turn_count=5, accumulated_cost_usd=0.10) + await repo.save(v2) + + result = await repo.get("agent-001") + assert result is not None + assert result.turn_count == 5 + assert result.accumulated_cost_usd == pytest.approx(0.10) + + async def test_get_returns_none_when_not_found( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + result = await repo.get("nonexistent") + assert result is None + + async def test_get_active_filters_idle( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + executing = _make_state(agent_id="active-1", last_activity_at=_T1) + paused = _make_state( + agent_id="paused-1", + status=ExecutionStatus.PAUSED, + last_activity_at=_T2, + ) + idle = _make_state( + agent_id="idle-1", + status=ExecutionStatus.IDLE, + last_activity_at=_T0, + ) + await repo.save(executing) + await repo.save(paused) + await repo.save(idle) + + active = await repo.get_active() + agent_ids = [s.agent_id for s in active] + assert "active-1" in agent_ids + assert "paused-1" in agent_ids + assert "idle-1" not in agent_ids + + async def test_get_active_ordered_by_last_activity_desc( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + older = _make_state(agent_id="older", last_activity_at=_T0) + newer = _make_state(agent_id="newer", last_activity_at=_T2) + middle = _make_state(agent_id="middle", last_activity_at=_T1) + await repo.save(older) + await repo.save(newer) + await repo.save(middle) + + active = await repo.get_active() + assert [s.agent_id for s in active] == ["newer", "middle", "older"] + + async def test_get_active_returns_empty_when_all_idle( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + idle = _make_state(agent_id="idle-only", status=ExecutionStatus.IDLE) + await repo.save(idle) + + active = await repo.get_active() + assert active == () + + async def test_delete_existing(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + state = _make_state() + await repo.save(state) + + deleted = await repo.delete("agent-001") + assert deleted is True + assert await repo.get("agent-001") is None + + async def test_delete_nonexistent(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + deleted = await repo.delete("nonexistent") + assert deleted is False + + async def test_delete_does_not_affect_other_agents( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + keep = _make_state(agent_id="keep") + remove = _make_state(agent_id="remove") + await repo.save(keep) + await repo.save(remove) + + await repo.delete("remove") + assert await repo.get("keep") is not None + assert await repo.get("remove") is None + + async def test_get_active_returns_empty_on_empty_table( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteAgentStateRepository(migrated_db) + active = await repo.get_active() + assert active == () + + async def test_lifecycle_idle_to_executing_to_idle( + self, migrated_db: aiosqlite.Connection + ) -> None: + """Full lifecycle: idle → executing → idle roundtrip.""" + repo = SQLiteAgentStateRepository(migrated_db) + idle = _make_state(agent_id="lifecycle", status=ExecutionStatus.IDLE) + await repo.save(idle) + + result = await repo.get("lifecycle") + assert result is not None + assert result.status == ExecutionStatus.IDLE + + executing = _make_state( + agent_id="lifecycle", + status=ExecutionStatus.EXECUTING, + last_activity_at=_T1, + ) + await repo.save(executing) + + result = await repo.get("lifecycle") + assert result is not None + assert result.status == ExecutionStatus.EXECUTING + assert result.execution_id == "exec-001" + assert result.started_at == _T0 + + idle_again = _make_state( + agent_id="lifecycle", + status=ExecutionStatus.IDLE, + last_activity_at=_T2, + ) + await repo.save(idle_again) + + result = await repo.get("lifecycle") + assert result is not None + assert result.status == ExecutionStatus.IDLE + assert result.execution_id is None + assert result.task_id is None + assert result.started_at is None + assert result.turn_count == 0 + assert result.accumulated_cost_usd == 0.0 + + +@pytest.mark.unit +class TestSQLiteAgentStateRepositoryErrors: + """Error paths raise QueryError.""" + + @pytest.mark.parametrize( + ("method", "args", "match"), + [ + ("save", (_make_state(),), "Failed to save"), + ("get", ("agent-001",), "Failed to fetch"), + ("get_active", (), "Failed to query"), + ("delete", ("agent-001",), "Failed to delete"), + ], + ) + async def test_crud_raises_query_error_on_db_error( + self, + memory_db: aiosqlite.Connection, + method: str, + args: tuple[object, ...], + match: str, + ) -> None: + from synthorg.persistence.errors import QueryError + + repo = SQLiteAgentStateRepository(memory_db) + with pytest.raises(QueryError, match=match) as exc_info: + await getattr(repo, method)(*args) + assert exc_info.value.__cause__ is not None + + async def test_row_to_model_raises_query_error_on_invalid_row( + self, migrated_db: aiosqlite.Connection + ) -> None: + from synthorg.persistence.errors import QueryError + + repo = SQLiteAgentStateRepository(migrated_db) + # Insert a row with a malformed datetime to trigger deserialization + # failure (passes CHECK constraints but fails Pydantic AwareDatetime) + await migrated_db.execute( + "INSERT INTO agent_states " + "(agent_id, execution_id, task_id, status, turn_count, " + "accumulated_cost_usd, last_activity_at, started_at) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + "agent-bad", + "exec-bad", + None, + "executing", + 0, + 0.0, + "not-a-datetime", + "2026-01-01T00:00:00+00:00", + ), + ) + await migrated_db.commit() + + with pytest.raises(QueryError, match="Failed to deserialize"): + await repo.get("agent-bad") + + async def test_get_active_raises_query_error_on_corrupt_row( + self, migrated_db: aiosqlite.Connection + ) -> None: + """get_active() fails when any row has corrupt data.""" + from synthorg.persistence.errors import QueryError + + repo = SQLiteAgentStateRepository(migrated_db) + # Insert a valid executing row + valid = _make_state(agent_id="agent-ok") + await repo.save(valid) + # Insert a corrupt row with a malformed datetime (passes CHECK + # constraints but fails Pydantic AwareDatetime validation) + await migrated_db.execute( + "INSERT INTO agent_states " + "(agent_id, execution_id, task_id, status, turn_count, " + "accumulated_cost_usd, last_activity_at, started_at) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + "agent-corrupt", + "exec-corrupt", + None, + "executing", + 0, + 0.0, + "not-a-datetime", + "2026-01-01T00:00:00+00:00", + ), + ) + await migrated_db.commit() + + with pytest.raises(QueryError, match="Failed to deserialize"): + await repo.get_active() diff --git a/tests/unit/persistence/sqlite/test_backend.py b/tests/unit/persistence/sqlite/test_backend.py index 20ae40b62e..68adb2282a 100644 --- a/tests/unit/persistence/sqlite/test_backend.py +++ b/tests/unit/persistence/sqlite/test_backend.py @@ -89,6 +89,11 @@ async def test_audit_entries_before_connect_raises(self) -> None: with pytest.raises(PersistenceConnectionError, match="Not connected"): _ = backend.audit_entries + async def test_agent_states_before_connect_raises(self) -> None: + backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:")) + with pytest.raises(PersistenceConnectionError, match="Not connected"): + _ = backend.agent_states + async def test_wal_mode_enabled(self) -> None: backend = SQLitePersistenceBackend(SQLiteConfig(path=":memory:", wal_mode=True)) await backend.connect() diff --git a/tests/unit/persistence/sqlite/test_migrations.py b/tests/unit/persistence/sqlite/test_migrations.py index 96ad175cbf..0c05f8b9ab 100644 --- a/tests/unit/persistence/sqlite/test_migrations.py +++ b/tests/unit/persistence/sqlite/test_migrations.py @@ -173,6 +173,92 @@ async def test_v5_creates_user_indexes( indexes = {row[0] for row in await cursor.fetchall()} assert "idx_api_keys_user_id" in indexes + async def test_v8_creates_agent_states_table( + self, memory_db: aiosqlite.Connection + ) -> None: + """V8 migration creates the agent_states table.""" + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='agent_states'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_v8_creates_agent_states_columns( + self, memory_db: aiosqlite.Connection + ) -> None: + """V8 migration creates the agent_states table with correct columns.""" + await run_migrations(memory_db) + cursor = await memory_db.execute("PRAGMA table_info('agent_states')") + columns = {row[1] for row in await cursor.fetchall()} + expected = { + "agent_id", + "execution_id", + "task_id", + "status", + "turn_count", + "accumulated_cost_usd", + "last_activity_at", + "started_at", + } + assert columns == expected + + async def test_v8_agent_states_ddl_has_check_constraints( + self, memory_db: aiosqlite.Connection + ) -> None: + """V8 DDL includes CHECK constraints for status, counters, and invariant.""" + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='agent_states'" + ) + row = await cursor.fetchone() + assert row is not None + ddl = row[0] + assert "CHECK (status IN ('idle', 'executing', 'paused'))" in ddl + assert "CHECK (turn_count >= 0)" in ddl + assert "CHECK (accumulated_cost_usd >= 0.0)" in ddl + # Cross-field invariant CHECK + assert "status = 'idle'" in ddl + assert "execution_id IS NULL" in ddl + assert "started_at IS NOT NULL" in ddl + + async def test_v8_check_constraint_rejects_invalid_status( + self, memory_db: aiosqlite.Connection + ) -> None: + """CHECK constraint rejects rows with invalid status values.""" + await run_migrations(memory_db) + with pytest.raises(sqlite3.IntegrityError, match="CHECK"): + await memory_db.execute( + "INSERT INTO agent_states " + "(agent_id, status, last_activity_at) " + "VALUES (?, ?, ?)", + ("a", "invalid", "2026-01-01T00:00:00+00:00"), + ) + + async def test_v8_creates_agent_states_composite_index( + self, memory_db: aiosqlite.Connection + ) -> None: + """V8 migration creates the composite status+activity index.""" + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='index' " + "AND name = 'idx_as_status_activity'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_v8_composite_index_covers_status_and_activity( + self, memory_db: aiosqlite.Connection + ) -> None: + """Composite index covers (status, last_activity_at DESC).""" + await run_migrations(memory_db) + cursor = await memory_db.execute("PRAGMA index_xinfo('idx_as_status_activity')") + rows = await cursor.fetchall() + # index_xinfo columns: seqno, cid, name, desc, coll, key + indexed = [(row[2], row[3]) for row in rows if row[5]] + assert ("status", 0) in indexed # status ASC + assert ("last_activity_at", 1) in indexed # last_activity_at DESC + async def test_migration_failure_raises_migration_error( self, memory_db: aiosqlite.Connection ) -> None: diff --git a/tests/unit/persistence/sqlite/test_migrations_v6.py b/tests/unit/persistence/sqlite/test_migrations_v6.py index 6b24a251b8..22ab4960cf 100644 --- a/tests/unit/persistence/sqlite/test_migrations_v6.py +++ b/tests/unit/persistence/sqlite/test_migrations_v6.py @@ -17,8 +17,8 @@ @pytest.mark.unit class TestSchemaVersion: - def test_schema_version_is_seven(self) -> None: - assert SCHEMA_VERSION == 7 + def test_schema_version_is_eight(self) -> None: + assert SCHEMA_VERSION == 8 @pytest.mark.unit diff --git a/tests/unit/persistence/test_migrations_v2.py b/tests/unit/persistence/test_migrations_v2.py index bf294445a4..259f1c3ef6 100644 --- a/tests/unit/persistence/test_migrations_v2.py +++ b/tests/unit/persistence/test_migrations_v2.py @@ -29,8 +29,8 @@ async def memory_db() -> AsyncGenerator[aiosqlite.Connection]: class TestSchemaMigrations: - async def test_schema_version_is_seven(self) -> None: - assert SCHEMA_VERSION == 7 + async def test_schema_version_is_eight(self) -> None: + assert SCHEMA_VERSION == 8 async def test_fresh_db_creates_all_v2_tables( self, memory_db: aiosqlite.Connection @@ -122,7 +122,7 @@ async def test_v7_makes_task_id_nullable( # Run migrations (applies v7) await run_migrations(memory_db) - assert await get_user_version(memory_db) == 7 + assert await get_user_version(memory_db) == SCHEMA_VERSION # Verify task_id is now nullable cursor = await memory_db.execute("PRAGMA table_info('parked_contexts')") diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py index ca181419a7..40225aa430 100644 --- a/tests/unit/persistence/test_protocol.py +++ b/tests/unit/persistence/test_protocol.py @@ -12,6 +12,7 @@ ) from synthorg.persistence.protocol import PersistenceBackend from synthorg.persistence.repositories import ( + AgentStateRepository, ApiKeyRepository, AuditRepository, CheckpointRepository, @@ -31,6 +32,7 @@ from synthorg.communication.message import Message from synthorg.core.enums import ApprovalRiskLevel, TaskStatus from synthorg.core.task import Task + from synthorg.engine.agent_state import AgentRuntimeState from synthorg.engine.checkpoint.models import Checkpoint, Heartbeat from synthorg.hr.models import AgentLifecycleEvent from synthorg.hr.performance.models import ( @@ -244,6 +246,20 @@ async def delete(self, execution_id: str) -> bool: return False +class _FakeAgentStateRepository: + async def save(self, state: AgentRuntimeState) -> None: + pass + + async def get(self, agent_id: str) -> AgentRuntimeState | None: + return None + + async def get_active(self) -> tuple[AgentRuntimeState, ...]: + return () + + async def delete(self, agent_id: str) -> bool: + return False + + class _FakeBackend: async def connect(self) -> None: pass @@ -313,6 +329,10 @@ def checkpoints(self) -> _FakeCheckpointRepository: def heartbeats(self) -> _FakeHeartbeatRepository: return _FakeHeartbeatRepository() + @property + def agent_states(self) -> _FakeAgentStateRepository: + return _FakeAgentStateRepository() + async def get_setting(self, key: str) -> str | None: return None @@ -370,3 +390,6 @@ def test_fake_checkpoint_repo_is_checkpoint_repository(self) -> None: def test_fake_heartbeat_repo_is_heartbeat_repository(self) -> None: assert isinstance(_FakeHeartbeatRepository(), HeartbeatRepository) + + def test_fake_agent_state_repo_is_agent_state_repository(self) -> None: + assert isinstance(_FakeAgentStateRepository(), AgentStateRepository)