diff --git a/CLAUDE.md b/CLAUDE.md index 2471d1a0d4..b3c42647d3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -104,7 +104,7 @@ src/ai_company/ communication/ # Message bus, dispatcher, messenger, channels, delegation, loop prevention, conflict resolution, meeting protocol config/ # YAML company config loading and validation core/ # Shared domain models, base classes, and resilience config (RetryConfig, RateLimiterConfig) - engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination (multi-agent pipeline: TopologyDispatcher protocol, 4 dispatchers — SAS/centralized/decentralized/context-dependent, wave execution, workspace lifecycle integration), coordination error classification, and prompt policy validation + engine/ # Agent orchestration, execution loops, parallel execution, task decomposition, routing, task assignment, centralized single-writer task state engine (TaskEngine), task lifecycle, recovery, shutdown, workspace isolation, coordination (multi-agent pipeline: TopologyDispatcher protocol, 4 dispatchers — SAS/centralized/decentralized/context-dependent, wave execution, workspace lifecycle integration), coordination error classification, prompt policy validation, checkpoint recovery (checkpoint/, per-turn persistence, heartbeat detection, CheckpointRecoveryStrategy) hr/ # HR engine: hiring, firing, onboarding, offboarding, agent registry, performance tracking (task metrics, collaboration scoring, trend detection), promotion/demotion (criteria evaluation, approval strategies, model mapping) memory/ # Persistent agent memory (pluggable MemoryBackend protocol), backends/ (Mem0 adapter: backends/mem0/), retrieval pipeline (ranking, injection, context formatting, non-inferable filtering), shared org memory (org/), consolidation/archival (consolidation/) persistence/ # Operational data persistence — pluggable PersistenceBackend protocol, SQLite initial (see Memory & Persistence design page) @@ -151,7 +151,7 @@ web/ # Vue 3 + PrimeVue + Tailwind CSS dashboard - **Every module** with business logic MUST have: `from ai_company.observability import get_logger` then `logger = get_logger(__name__)` - **Never** use `import logging` / `logging.getLogger()` / `print()` in application code - **Variable name**: always `logger` (not `_logger`, not `log`) -- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `API_ROUTE_NOT_FOUND` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `MEMORY_BACKEND_CONNECTED` from `events.memory`, `MEMORY_ENTRY_STORED` from `events.memory`, `MEMORY_BACKEND_SYSTEM_ERROR` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`, `TASK_ENGINE_STARTED` from `events.task_engine`, `COORDINATION_STARTED` from `events.coordination`, `COMMUNICATION_DISPATCH_START` from `events.communication`, `COMPANY_STARTED` from `events.company`, `CONFIG_LOADED` from `events.config`, `CORRELATION_ID_CREATED` from `events.correlation`, `DECOMPOSITION_STARTED` from `events.decomposition`, `DELEGATION_STARTED` from `events.delegation`, `EXECUTION_LOOP_STARTED` from `events.execution`, `GIT_OPERATION_START` from `events.git`, `PARALLEL_EXECUTION_STARTED` from `events.parallel`, `PERSONALITY_LOADED` from `events.personality`, `QUOTA_CHECKED` from `events.quota`, `ROLE_ASSIGNED` from `events.role`, `ROUTING_STARTED` from `events.routing`, `SANDBOX_EXECUTE_START` from `events.sandbox`, `TASK_CREATED` from `events.task`, `TASK_ASSIGNMENT_STARTED` from `events.task_assignment`, `TASK_ROUTING_STARTED` from `events.task_routing`, `TEMPLATE_LOADED` from `events.template`, `TOOL_INVOKE_START` from `events.tool`, `WORKSPACE_CREATED` from `events.workspace`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` +- **Event names**: always use constants from the domain-specific module under `ai_company.observability.events` (e.g. `PROVIDER_CALL_START` from `events.provider`, `BUDGET_RECORD_ADDED` from `events.budget`, `CFO_ANOMALY_DETECTED` from `events.cfo`, `CONFLICT_DETECTED` from `events.conflict`, `MEETING_STARTED` from `events.meeting`, `CLASSIFICATION_START` from `events.classification`, `CONSOLIDATION_START` from `events.consolidation`, `ORG_MEMORY_QUERY_START` from `events.org_memory`, `API_REQUEST_STARTED` from `events.api`, `API_ROUTE_NOT_FOUND` from `events.api`, `CODE_RUNNER_EXECUTE_START` from `events.code_runner`, `DOCKER_EXECUTE_START` from `events.docker`, `MCP_INVOKE_START` from `events.mcp`, `SECURITY_EVALUATE_START` from `events.security`, `HR_HIRING_REQUEST_CREATED` from `events.hr`, `PERF_METRIC_RECORDED` from `events.performance`, `TRUST_EVALUATE_START` from `events.trust`, `PROMOTION_EVALUATE_START` from `events.promotion`, `PROMPT_BUILD_START` from `events.prompt`, `MEMORY_RETRIEVAL_START` from `events.memory`, `MEMORY_BACKEND_CONNECTED` from `events.memory`, `MEMORY_ENTRY_STORED` from `events.memory`, `MEMORY_BACKEND_SYSTEM_ERROR` from `events.memory`, `AUTONOMY_ACTION_AUTO_APPROVED` from `events.autonomy`, `TIMEOUT_POLICY_EVALUATED` from `events.timeout`, `PERSISTENCE_AUDIT_ENTRY_SAVED` from `events.persistence`, `TASK_ENGINE_STARTED` from `events.task_engine`, `COORDINATION_STARTED` from `events.coordination`, `COMMUNICATION_DISPATCH_START` from `events.communication`, `COMPANY_STARTED` from `events.company`, `CONFIG_LOADED` from `events.config`, `CORRELATION_ID_CREATED` from `events.correlation`, `DECOMPOSITION_STARTED` from `events.decomposition`, `DELEGATION_STARTED` from `events.delegation`, `EXECUTION_LOOP_START` from `events.execution`, `CHECKPOINT_SAVED` from `events.checkpoint`, `PERSISTENCE_CHECKPOINT_SAVED` from `events.persistence`, `GIT_OPERATION_START` from `events.git`, `PARALLEL_GROUP_START` from `events.parallel`, `PERSONALITY_LOADED` from `events.personality`, `QUOTA_CHECKED` from `events.quota`, `ROLE_ASSIGNED` from `events.role`, `ROUTING_STARTED` from `events.routing`, `SANDBOX_EXECUTE_START` from `events.sandbox`, `TASK_CREATED` from `events.task`, `TASK_ASSIGNMENT_STARTED` from `events.task_assignment`, `TASK_ROUTING_STARTED` from `events.task_routing`, `TEMPLATE_LOADED` from `events.template`, `TOOL_INVOKE_START` from `events.tool`, `WORKSPACE_CREATED` from `events.workspace`). Import directly: `from ai_company.observability.events. import EVENT_CONSTANT` - **Structured kwargs**: always `logger.info(EVENT, key=value)` — never `logger.info("msg %s", val)` - **All error paths** must log at WARNING or ERROR with context before raising - **All state transitions** must log at INFO diff --git a/README.md b/README.md index 5119978bf6..da6c4f7d47 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ The framework is provider-agnostic (any LLM via LiteLLM), configuration-driven ( **Agent Orchestration** -Define agents with roles, models, and tools. The engine handles task decomposition, routing, execution loops (ReAct, Plan-and-Execute), and multi-agent coordination. +Define agents with roles, models, and tools. The engine handles task decomposition, routing, execution loops (ReAct, Plan-and-Execute), crash recovery (checkpoint resume), and multi-agent coordination. @@ -111,6 +111,7 @@ graph TB Observability[Observability] -.-> Engine Persistence[Persistence] -.-> HR Persistence -.-> Security + Persistence -.-> Engine ``` ## Documentation diff --git a/docs/design/engine.md b/docs/design/engine.md index 1ec1922a18..936f70bf61 100644 --- a/docs/design/engine.md +++ b/docs/design/engine.md @@ -508,6 +508,9 @@ implemented behind a `RecoveryStrategy` protocol, making the system pluggable. | `strategy_type` | `NotBlankStr` | Strategy identifier | | `context_snapshot` | `AgentContextSnapshot` | Redacted snapshot (turn count, accumulated cost, message count, max turns -- no message contents) | | `error_message` | `NotBlankStr` | Error that triggered recovery | +| `checkpoint_context_json` | `str \| None` | Serialized `AgentContext` for resume (`None` for non-checkpoint strategies) | +| `resume_attempt` | `int` (ge=0) | Current resume attempt number (0 when not resuming) | +| `can_resume` | `bool` (computed) | `checkpoint_context_json is not None` | | `can_reassign` | `bool` (computed) | `retry_count < task.max_retries` | ### Recovery Strategies @@ -547,9 +550,6 @@ implemented behind a `RecoveryStrategy` protocol, making the system pluggable. === "Strategy 2: Checkpoint Recovery" - !!! warning "Planned" - Checkpoint recovery is planned for a future release. - The engine persists an `AgentContext` snapshot after each completed turn. On crash, the framework detects the failure (via heartbeat timeout or exception), loads the last checkpoint, and resumes execution from the exact @@ -562,21 +562,21 @@ implemented behind a `RecoveryStrategy` protocol, making the system pluggable. strategy: "checkpoint" checkpoint: persist_every_n_turns: 1 # checkpoint frequency - storage: "sqlite" # sqlite, filesystem + # Storage backend determined by the injected CheckpointRepository heartbeat_interval_seconds: 30 # detect unresponsive agents max_resume_attempts: 2 # retry limit before falling back to fail_reassign ``` - Preserves progress -- critical for long tasks (multi-step plans, epic-level work) - - Requires persistence layer and environment state reconciliation on resume + - Requires persistence layer and reconciliation message on resume - Natural fit with the existing immutable state model - When resuming from a checkpoint, the agent's tools and workspace may have - changed (other agents modified files, external state drifted). The - checkpoint strategy includes a reconciliation step: the resumed agent - receives a summary of changes since the checkpoint timestamp and can adapt - its plan accordingly. + When resuming from a checkpoint, the agent receives a system message + informing it of the resume point (turn number) and the error that triggered + recovery. This reconciliation message allows the agent to review its + progress and adapt. Richer reconciliation (e.g. workspace change + detection) is planned for a future iteration. --- diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index ab4e4b4f6a..1b34c373f8 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -28,6 +28,14 @@ TaskAssignmentStrategy, build_strategy_map, ) +from ai_company.engine.checkpoint import ( + Checkpoint, + CheckpointCallback, + CheckpointConfig, + CheckpointRecoveryStrategy, + Heartbeat, + make_checkpoint_callback, +) from ai_company.engine.classification import ( ClassificationResult, ErrorFinding, @@ -208,6 +216,10 @@ "BudgetChecker", "CancelTaskMutation", "CentralizedDispatcher", + "Checkpoint", + "CheckpointCallback", + "CheckpointConfig", + "CheckpointRecoveryStrategy", "ClassificationResult", "CleanupCallback", "ContextDependentDispatcher", @@ -243,6 +255,7 @@ "ExecutionResult", "ExecutionStateError", "FailAndReassignStrategy", + "Heartbeat", "HierarchicalAssignmentStrategy", "InMemoryResourceLock", "LlmDecompositionConfig", @@ -332,5 +345,6 @@ "build_strategy_map", "build_system_prompt", "classify_execution_errors", + "make_checkpoint_callback", "select_dispatcher", ] diff --git a/src/ai_company/engine/agent_engine.py b/src/ai_company/engine/agent_engine.py index 17d7579b0e..c46d1ff68f 100644 --- a/src/ai_company/engine/agent_engine.py +++ b/src/ai_company/engine/agent_engine.py @@ -15,6 +15,13 @@ validate_run_inputs, validate_task, ) +from ai_company.engine.checkpoint.models import CheckpointConfig +from ai_company.engine.checkpoint.resume import ( + cleanup_after_resume, + deserialize_and_reconcile, + make_loop_with_callback, +) +from ai_company.engine.checkpoint.strategy import CheckpointRecoveryStrategy from ai_company.engine.classification.pipeline import classify_execution_errors from ai_company.engine.context import DEFAULT_MAX_TURNS, AgentContext from ai_company.engine.cost_recording import record_execution_costs @@ -32,7 +39,11 @@ format_task_instruction, ) from ai_company.engine.react_loop import ReactLoop -from ai_company.engine.recovery import FailAndReassignStrategy, RecoveryStrategy +from ai_company.engine.recovery import ( + FailAndReassignStrategy, + RecoveryResult, + RecoveryStrategy, +) from ai_company.engine.run_result import AgentRunResult from ai_company.engine.task_sync import ( apply_post_execution_transitions, @@ -51,6 +62,9 @@ EXECUTION_ENGINE_TASK_TRANSITION, EXECUTION_ENGINE_TIMEOUT, EXECUTION_RECOVERY_FAILED, + EXECUTION_RESUME_COMPLETE, + EXECUTION_RESUME_FAILED, + EXECUTION_RESUME_START, ) from ai_company.observability.events.prompt import PROMPT_TOKEN_RATIO_HIGH from ai_company.observability.events.security import SECURITY_DISABLED @@ -89,6 +103,10 @@ ShutdownChecker, ) from ai_company.engine.task_engine import TaskEngine + from ai_company.persistence.repositories import ( + CheckpointRepository, + HeartbeatRepository, + ) from ai_company.providers.models import CompletionConfig from ai_company.providers.protocol import CompletionProvider from ai_company.security.config import SecurityConfig @@ -144,11 +162,23 @@ def __init__( # noqa: PLR0913 security_config: SecurityConfig | None = None, approval_store: ApprovalStore | None = None, task_engine: TaskEngine | None = None, + checkpoint_repo: CheckpointRepository | None = None, + heartbeat_repo: HeartbeatRepository | None = None, + checkpoint_config: CheckpointConfig | None = None, ) -> None: self._provider = provider self._loop: ExecutionLoop = execution_loop or ReactLoop() self._tool_registry = tool_registry self._budget_enforcer = budget_enforcer + if (checkpoint_repo is None) != (heartbeat_repo is None): + msg = ( + "checkpoint_repo and heartbeat_repo must both be " + "provided or both omitted" + ) + raise ValueError(msg) + self._checkpoint_repo = checkpoint_repo + self._heartbeat_repo = heartbeat_repo + self._checkpoint_config = checkpoint_config or CheckpointConfig() self._cost_tracker: CostTracker | None if budget_enforcer is not None: if ( @@ -252,6 +282,7 @@ async def run( # noqa: PLR0913 start=start, timeout_seconds=timeout_seconds, tool_invoker=tool_invoker, + effective_autonomy=effective_autonomy, ) except MemoryError, RecursionError: logger.exception( @@ -297,6 +328,7 @@ async def _execute( # noqa: PLR0913 start: float, timeout_seconds: float | None = None, tool_invoker: ToolInvoker | None = None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> AgentRunResult: """Run execution loop, record costs, apply transitions, and build result.""" budget_checker: BudgetChecker | None @@ -331,6 +363,8 @@ async def _execute( # noqa: PLR0913 identity, agent_id, task_id, + completion_config=completion_config, + effective_autonomy=effective_autonomy, ) return self._build_and_log_result( @@ -341,12 +375,15 @@ async def _execute( # noqa: PLR0913 task_id, ) - async def _post_execution_pipeline( + async def _post_execution_pipeline( # noqa: PLR0913 self, execution_result: ExecutionResult, identity: AgentIdentity, agent_id: str, task_id: str, + *, + completion_config: CompletionConfig | None = None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> ExecutionResult: """Post-execution: costs, transitions, recovery, classify. @@ -378,6 +415,8 @@ async def _post_execution_pipeline( execution_result, agent_id, task_id, + completion_config=completion_config, + effective_autonomy=effective_autonomy, ) # Sync post-recovery status to TaskEngine (typically FAILED, # depends on recovery strategy). @@ -452,6 +491,21 @@ def _build_and_log_result( ) return result + def _make_loop_with_callback( + self, + agent_id: str, + task_id: str, + ) -> ExecutionLoop: + """Return the execution loop with a checkpoint callback if configured.""" + return make_loop_with_callback( + self._loop, + self._checkpoint_repo, + self._heartbeat_repo, + self._checkpoint_config, + agent_id, + task_id, + ) + async def _run_loop_with_timeout( # noqa: PLR0913 self, *, @@ -470,7 +524,8 @@ async def _run_loop_with_timeout( # noqa: PLR0913 ``TimeoutError`` raised inside the loop propagates normally and is not conflated with the engine's wall-clock deadline. """ - coro = self._loop.execute( + loop = self._make_loop_with_callback(agent_id, task_id) + coro = loop.execute( context=ctx, provider=self._provider, tool_invoker=tool_invoker, @@ -564,12 +619,15 @@ async def _apply_recovery( execution_result: ExecutionResult, agent_id: str, task_id: str, + *, + completion_config: CompletionConfig | None = None, + effective_autonomy: EffectiveAutonomy | None = None, ) -> ExecutionResult: """Invoke the configured recovery strategy on error outcomes. - The default strategy transitions the task to FAILED; other - strategies may behave differently. If no strategy is set or - no task execution exists, returns the result unchanged. + The default strategy transitions the task to FAILED; checkpoint + recovery may resume from a persisted checkpoint. If no strategy + is set or no task execution exists, returns the result unchanged. Recovery failures are logged but never block the error result. """ if self._recovery_strategy is None: @@ -585,6 +643,17 @@ async def _apply_recovery( error_message=error_msg, context=ctx, ) + + # Checkpoint resume path + if recovery_result.can_resume: + return await self._resume_from_checkpoint( + recovery_result, + agent_id, + task_id, + completion_config=completion_config, + effective_autonomy=effective_autonomy, + ) + updated_ctx = ctx.model_copy( update={"task_execution": recovery_result.task_execution}, ) @@ -602,6 +671,138 @@ async def _apply_recovery( ) return execution_result + async def _resume_from_checkpoint( + self, + recovery_result: RecoveryResult, + agent_id: str, + task_id: str, + *, + completion_config: CompletionConfig | None = None, + effective_autonomy: EffectiveAutonomy | None = None, + ) -> ExecutionResult: + """Resume execution from a checkpoint. + + Delegates to ``deserialize_and_reconcile`` for context + reconstruction and ``cleanup_after_resume`` for post-resume + housekeeping. Budget checking is constructed ad-hoc (same + approach as ``_execute``); timeout is not applied because + the original wall-clock deadline is no longer available. + """ + if recovery_result.checkpoint_context_json is None: + logger.error( + EXECUTION_RESUME_FAILED, + agent_id=agent_id, + task_id=task_id, + error="checkpoint_context_json is None but can_resume was True", + ) + msg = "checkpoint_context_json is None but can_resume was True" + raise RuntimeError(msg) + + logger.info( + EXECUTION_RESUME_START, + agent_id=agent_id, + task_id=task_id, + resume_attempt=recovery_result.resume_attempt, + ) + + try: + checkpoint_ctx = deserialize_and_reconcile( + recovery_result.checkpoint_context_json, + recovery_result.error_message, + agent_id, + task_id, + ) + result = await self._execute_resumed_loop( + checkpoint_ctx, + agent_id, + task_id, + completion_config=completion_config, + effective_autonomy=effective_autonomy, + ) + except MemoryError, RecursionError: + raise + except Exception as exc: + logger.exception( + EXECUTION_RESUME_FAILED, + agent_id=agent_id, + task_id=task_id, + error=f"{type(exc).__name__}: {exc}", + ) + raise + else: + await self._finalize_resume( + result, + checkpoint_ctx.execution_id, + agent_id, + task_id, + ) + return result + + async def _execute_resumed_loop( + self, + checkpoint_ctx: AgentContext, + agent_id: str, + task_id: str, + *, + completion_config: CompletionConfig | None = None, + effective_autonomy: EffectiveAutonomy | None = None, + ) -> ExecutionResult: + """Run the execution loop on a reconstituted checkpoint context.""" + budget_checker: BudgetChecker | None + if checkpoint_ctx.task_execution is None: + budget_checker = None + elif self._budget_enforcer: + budget_checker = await self._budget_enforcer.make_budget_checker( + checkpoint_ctx.task_execution.task, + agent_id, + ) + else: + budget_checker = make_budget_checker( + checkpoint_ctx.task_execution.task, + ) + + loop = self._make_loop_with_callback(agent_id, task_id) + return await loop.execute( + context=checkpoint_ctx, + provider=self._provider, + tool_invoker=self._make_tool_invoker( + checkpoint_ctx.identity, + task_id=task_id, + effective_autonomy=effective_autonomy, + ), + budget_checker=budget_checker, + shutdown_checker=self._shutdown_checker, + completion_config=completion_config, + ) + + async def _finalize_resume( + self, + result: ExecutionResult, + execution_id: str, + agent_id: str, + task_id: str, + ) -> None: + """Log completion and clean up after a successful resume.""" + logger.info( + EXECUTION_RESUME_COMPLETE, + agent_id=agent_id, + task_id=task_id, + termination_reason=result.termination_reason.value, + ) + if result.termination_reason != TerminationReason.ERROR: + if isinstance( + self._recovery_strategy, + CheckpointRecoveryStrategy, + ): + await self._recovery_strategy.clear_resume_count( + execution_id, + ) + await cleanup_after_resume( + self._checkpoint_repo, + self._heartbeat_repo, + execution_id, + ) + def _make_security_interceptor( self, effective_autonomy: EffectiveAutonomy | None = None, diff --git a/src/ai_company/engine/checkpoint/__init__.py b/src/ai_company/engine/checkpoint/__init__.py new file mode 100644 index 0000000000..df21ced86b --- /dev/null +++ b/src/ai_company/engine/checkpoint/__init__.py @@ -0,0 +1,29 @@ +"""Checkpoint recovery for agent crash recovery. + +Persists ``AgentContext`` snapshots at configurable turn intervals +and resumes from the last checkpoint on crash, preserving progress. +""" + +from ai_company.engine.checkpoint.callback import CheckpointCallback +from ai_company.engine.checkpoint.callback_factory import make_checkpoint_callback +from ai_company.engine.checkpoint.models import ( + Checkpoint, + CheckpointConfig, + Heartbeat, +) +from ai_company.engine.checkpoint.resume import ( + cleanup_after_resume, + deserialize_and_reconcile, +) +from ai_company.engine.checkpoint.strategy import CheckpointRecoveryStrategy + +__all__ = [ + "Checkpoint", + "CheckpointCallback", + "CheckpointConfig", + "CheckpointRecoveryStrategy", + "Heartbeat", + "cleanup_after_resume", + "deserialize_and_reconcile", + "make_checkpoint_callback", +] diff --git a/src/ai_company/engine/checkpoint/callback.py b/src/ai_company/engine/checkpoint/callback.py new file mode 100644 index 0000000000..909bdc32c1 --- /dev/null +++ b/src/ai_company/engine/checkpoint/callback.py @@ -0,0 +1,14 @@ +"""Checkpoint callback type alias. + +The callback is invoked after each completed turn with the current +``AgentContext``. The implementation decides whether to persist +based on configuration (e.g. every N turns). +""" + +from collections.abc import Callable, Coroutine +from typing import Any + +from ai_company.engine.context import AgentContext + +CheckpointCallback = Callable[[AgentContext], Coroutine[Any, Any, None]] +"""Async callback invoked after each turn; may skip persistence based on config.""" diff --git a/src/ai_company/engine/checkpoint/callback_factory.py b/src/ai_company/engine/checkpoint/callback_factory.py new file mode 100644 index 0000000000..5f6a1c96b8 --- /dev/null +++ b/src/ai_company/engine/checkpoint/callback_factory.py @@ -0,0 +1,133 @@ +"""Factory for creating checkpoint callbacks. + +Produces a closure that persists checkpoints and heartbeats after +each completed turn. Errors are logged but never propagated +(best-effort) to avoid crashing the execution loop. +""" + +from datetime import UTC, datetime + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.engine.checkpoint.callback import CheckpointCallback # noqa: TC001 +from ai_company.engine.checkpoint.models import ( + Checkpoint, + CheckpointConfig, + Heartbeat, +) +from ai_company.engine.context import AgentContext # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.checkpoint import ( + CHECKPOINT_SAVE_FAILED, + CHECKPOINT_SAVED, + CHECKPOINT_SKIPPED, + HEARTBEAT_UPDATE_FAILED, + HEARTBEAT_UPDATED, +) +from ai_company.persistence.repositories import ( + CheckpointRepository, # noqa: TC001 + HeartbeatRepository, # noqa: TC001 +) + +logger = get_logger(__name__) + + +def make_checkpoint_callback( + *, + checkpoint_repo: CheckpointRepository, + heartbeat_repo: HeartbeatRepository, + config: CheckpointConfig, + agent_id: NotBlankStr, + task_id: NotBlankStr, +) -> CheckpointCallback: + """Create a checkpoint callback closure. + + The returned callback: + 1. Skips turn 0 (no work done yet) and non-boundary turns + where ``turn_count % persist_every_n_turns != 0``. + 2. Serializes the ``AgentContext`` to JSON and saves a checkpoint. + 3. Updates the heartbeat timestamp. + 4. Errors are logged but never propagated (except ``MemoryError`` + and ``RecursionError``). + + Args: + checkpoint_repo: Repository for persisting checkpoints. + heartbeat_repo: Repository for persisting heartbeats. + config: Checkpoint configuration. + agent_id: Agent identifier for the checkpoint. + task_id: Task identifier for the checkpoint. + + Returns: + An async callback suitable for injection into execution loops. + """ + + async def _checkpoint_callback(ctx: AgentContext) -> None: + turn = ctx.turn_count + if turn == 0 or turn % config.persist_every_n_turns != 0: + logger.debug( + CHECKPOINT_SKIPPED, + execution_id=ctx.execution_id, + turn_number=turn, + persist_every_n_turns=config.persist_every_n_turns, + ) + return + + checkpoint_saved = await _save_checkpoint(ctx, turn) + if checkpoint_saved: + await _save_heartbeat(ctx) + + async def _save_checkpoint(ctx: AgentContext, turn: int) -> bool: + """Persist checkpoint (best-effort). Return True on success.""" + try: + checkpoint = Checkpoint( + execution_id=ctx.execution_id, + agent_id=agent_id, + task_id=task_id, + turn_number=turn, + context_json=ctx.model_dump_json(), + ) + await checkpoint_repo.save(checkpoint) + logger.info( + CHECKPOINT_SAVED, + execution_id=ctx.execution_id, + turn_number=turn, + checkpoint_id=checkpoint.id, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.exception( + CHECKPOINT_SAVE_FAILED, + execution_id=ctx.execution_id, + turn_number=turn, + ) + return False + return True + + async def _save_heartbeat(ctx: AgentContext) -> None: + """Update heartbeat (best-effort). + + Only called after checkpoint save succeeds, preventing the + limbo state where a fresh heartbeat exists but there is no + checkpoint to resume from. + """ + try: + heartbeat = Heartbeat( + execution_id=ctx.execution_id, + agent_id=agent_id, + task_id=task_id, + last_heartbeat_at=datetime.now(UTC), + ) + await heartbeat_repo.save(heartbeat) + logger.debug( + HEARTBEAT_UPDATED, + execution_id=ctx.execution_id, + ) + except MemoryError, RecursionError: + raise + except Exception: + logger.exception( + HEARTBEAT_UPDATE_FAILED, + execution_id=ctx.execution_id, + ) + + return _checkpoint_callback diff --git a/src/ai_company/engine/checkpoint/models.py b/src/ai_company/engine/checkpoint/models.py new file mode 100644 index 0000000000..929cffc807 --- /dev/null +++ b/src/ai_company/engine/checkpoint/models.py @@ -0,0 +1,120 @@ +"""Checkpoint and heartbeat models for crash recovery. + +``Checkpoint`` persists a serialized ``AgentContext`` after each completed +turn so that execution can resume from the last checkpoint on crash. +``Heartbeat`` tracks liveness for stale-execution detection. +``CheckpointConfig`` controls checkpoint frequency and resume limits. +""" + +import json +from datetime import UTC, datetime +from typing import Self +from uuid import uuid4 + +from pydantic import ( + AwareDatetime, + BaseModel, + ConfigDict, + Field, + model_validator, +) + +from ai_company.core.types import NotBlankStr # noqa: TC001 + + +class Checkpoint(BaseModel): + """Serialized snapshot of an agent execution at a given turn. + + Attributes: + id: Unique checkpoint identifier. + execution_id: The execution run ID from ``AgentContext``. + agent_id: Agent whose execution was checkpointed. + task_id: Task the agent was working on. + turn_number: Turn index at which this checkpoint was taken. + context_json: JSON-serialized ``AgentContext``. + created_at: Timestamp when the checkpoint was created. + """ + + model_config = ConfigDict(frozen=True) + + id: NotBlankStr = Field( + default_factory=lambda: str(uuid4()), + description="Unique checkpoint identifier", + ) + execution_id: NotBlankStr = Field(description="Execution run identifier") + agent_id: NotBlankStr = Field(description="Agent identifier") + task_id: NotBlankStr = Field(description="Task identifier") + turn_number: int = Field(ge=0, description="Turn index of this checkpoint") + context_json: str = Field(description="JSON-serialized AgentContext") + created_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(UTC), + description="When the checkpoint was created", + ) + + @model_validator(mode="after") + def _validate_context_json(self) -> Self: + """Validate that context_json is a valid JSON object.""" + try: + parsed = json.loads(self.context_json) + except (json.JSONDecodeError, TypeError) as exc: + msg = f"context_json must be valid JSON: {exc}" + raise ValueError(msg) from exc + if not isinstance(parsed, dict): + msg = "context_json must be a JSON object, not a primitive or array" + raise ValueError(msg) # noqa: TRY004 + return self + + +class Heartbeat(BaseModel): + """Liveness signal for a running agent execution. + + Attributes: + execution_id: The execution run identifier (unique key). + agent_id: Agent whose execution is being tracked. + task_id: Task the agent was working on. + last_heartbeat_at: Timestamp of the last heartbeat update. + """ + + model_config = ConfigDict(frozen=True) + + execution_id: NotBlankStr = Field(description="Execution run identifier") + agent_id: NotBlankStr = Field(description="Agent identifier") + task_id: NotBlankStr = Field(description="Task identifier") + last_heartbeat_at: AwareDatetime = Field( + description="Timestamp of the last heartbeat", + ) + + +class CheckpointConfig(BaseModel): + """Configuration for checkpoint persistence and resume behavior. + + Attributes: + persist_every_n_turns: Save a checkpoint every N turns. + heartbeat_interval_seconds: Heartbeat update interval (reserved + for future background heartbeat loop; not used by the + per-turn callback). + max_resume_attempts: Maximum number of resume attempts before + falling back to fail-and-reassign. + """ + + model_config = ConfigDict(frozen=True) + + persist_every_n_turns: int = Field( + default=1, + gt=0, + description="Save a checkpoint every N turns", + ) + heartbeat_interval_seconds: float = Field( + default=30.0, + gt=0, + description=( + "Heartbeat update interval in seconds (reserved for " + "future background heartbeat loop; not used by the " + "per-turn callback)" + ), + ) + max_resume_attempts: int = Field( + default=2, + ge=0, + description="Max resume attempts before fallback", + ) diff --git a/src/ai_company/engine/checkpoint/resume.py b/src/ai_company/engine/checkpoint/resume.py new file mode 100644 index 0000000000..564608f5cd --- /dev/null +++ b/src/ai_company/engine/checkpoint/resume.py @@ -0,0 +1,172 @@ +"""Checkpoint resume helpers. + +Standalone functions for deserializing checkpoint context, +injecting reconciliation messages, building loop instances with +checkpoint callbacks, and cleaning up after a successful resume. +Used by ``AgentEngine`` to keep resume orchestration concise. +""" + +from typing import TYPE_CHECKING + +from ai_company.engine.checkpoint.callback_factory import make_checkpoint_callback +from ai_company.engine.checkpoint.models import CheckpointConfig # noqa: TC001 +from ai_company.engine.context import AgentContext +from ai_company.engine.plan_execute_loop import PlanExecuteLoop +from ai_company.engine.react_loop import ReactLoop +from ai_company.observability import get_logger +from ai_company.observability.events.checkpoint import ( + CHECKPOINT_DELETE_FAILED, + CHECKPOINT_DELETED, + CHECKPOINT_RECOVERY_RECONCILIATION, + CHECKPOINT_UNSUPPORTED_LOOP, + HEARTBEAT_DELETE_FAILED, + HEARTBEAT_DELETED, +) +from ai_company.providers.enums import MessageRole +from ai_company.providers.models import ChatMessage + +if TYPE_CHECKING: + from ai_company.engine.loop_protocol import ExecutionLoop + from ai_company.persistence.repositories import ( + CheckpointRepository, + HeartbeatRepository, + ) + +logger = get_logger(__name__) + + +def deserialize_and_reconcile( + checkpoint_json: str, + error_message: str, + agent_id: str, + task_id: str, +) -> AgentContext: + """Deserialize checkpoint context and inject reconciliation message. + + Args: + checkpoint_json: JSON-serialized ``AgentContext``. + error_message: The error that triggered recovery (included + in the reconciliation message so the agent is aware of the + specific failure that preceded the resume). + agent_id: Agent identifier (for logging). + task_id: Task identifier (for logging). + + Returns: + Reconstituted ``AgentContext`` with reconciliation message. + + Raises: + ValueError: If deserialization fails. + """ + try: + checkpoint_ctx = AgentContext.model_validate_json(checkpoint_json) + except ValueError: + logger.exception( + CHECKPOINT_RECOVERY_RECONCILIATION, + agent_id=agent_id, + task_id=task_id, + error="Failed to deserialize checkpoint context", + ) + raise + + reconciliation_msg = ChatMessage( + role=MessageRole.SYSTEM, + content=( + f"Execution resumed from checkpoint at turn " + f"{checkpoint_ctx.turn_count}. Previous error: " + f"{error_message}. " + "Review progress and continue." + ), + ) + logger.debug( + CHECKPOINT_RECOVERY_RECONCILIATION, + agent_id=agent_id, + task_id=task_id, + turn_count=checkpoint_ctx.turn_count, + ) + return checkpoint_ctx.with_message(reconciliation_msg) + + +def make_loop_with_callback( # noqa: PLR0913 + loop: ExecutionLoop, + checkpoint_repo: CheckpointRepository | None, + heartbeat_repo: HeartbeatRepository | None, + checkpoint_config: CheckpointConfig, + agent_id: str, + task_id: str, +) -> ExecutionLoop: + """Return the execution loop with a checkpoint callback if configured. + + If ``checkpoint_repo`` and ``heartbeat_repo`` are both set, + creates a checkpoint callback and returns a new loop instance + with it injected. Otherwise returns the original loop unchanged. + """ + if checkpoint_repo is None or heartbeat_repo is None: + return loop + + callback = make_checkpoint_callback( + checkpoint_repo=checkpoint_repo, + heartbeat_repo=heartbeat_repo, + config=checkpoint_config, + agent_id=agent_id, + task_id=task_id, + ) + + if isinstance(loop, ReactLoop): + return ReactLoop(checkpoint_callback=callback) + if isinstance(loop, PlanExecuteLoop): + return PlanExecuteLoop( + config=loop.config, + checkpoint_callback=callback, + ) + logger.warning( + CHECKPOINT_UNSUPPORTED_LOOP, + loop_type=type(loop).__name__, + error="Unsupported loop type for checkpoint callback injection", + ) + return loop + + +async def cleanup_after_resume( + checkpoint_repo: CheckpointRepository | None, + heartbeat_repo: HeartbeatRepository | None, + execution_id: str, +) -> None: + """Delete checkpoints and heartbeat after a successful resume. + + Best-effort: errors are logged but never propagated. + + Args: + checkpoint_repo: Checkpoint repository (may be ``None``). + heartbeat_repo: Heartbeat repository (may be ``None``). + execution_id: The execution whose data should be cleaned up. + """ + if checkpoint_repo is not None: + try: + count = await checkpoint_repo.delete_by_execution(execution_id) + logger.debug( + CHECKPOINT_DELETED, + execution_id=execution_id, + deleted_count=count, + ) + except Exception: + logger.warning( + CHECKPOINT_DELETE_FAILED, + execution_id=execution_id, + error="Failed to clean up checkpoints after resume", + exc_info=True, + ) + + if heartbeat_repo is not None: + try: + await heartbeat_repo.delete(execution_id) + logger.debug( + HEARTBEAT_DELETED, + execution_id=execution_id, + ) + except Exception: + logger.warning( + HEARTBEAT_DELETE_FAILED, + execution_id=execution_id, + error="Failed to clean up heartbeat after resume", + exc_info=True, + ) diff --git a/src/ai_company/engine/checkpoint/strategy.py b/src/ai_company/engine/checkpoint/strategy.py new file mode 100644 index 0000000000..9adfabce82 --- /dev/null +++ b/src/ai_company/engine/checkpoint/strategy.py @@ -0,0 +1,284 @@ +"""Checkpoint recovery strategy. + +Resumes execution from the last persisted checkpoint on crash. +After ``max_resume_attempts`` resume attempts, falls back to the +``FailAndReassignStrategy``. +""" + +import asyncio +from typing import Final + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.engine.checkpoint.models import ( + Checkpoint, # noqa: TC001 + CheckpointConfig, # noqa: TC001 +) +from ai_company.engine.checkpoint.resume import cleanup_after_resume +from ai_company.engine.context import AgentContext # noqa: TC001 +from ai_company.engine.recovery import ( + FailAndReassignStrategy, + RecoveryResult, + RecoveryStrategy, +) +from ai_company.engine.task_execution import TaskExecution # noqa: TC001 +from ai_company.observability import get_logger +from ai_company.observability.events.checkpoint import ( + CHECKPOINT_LOAD_FAILED, + CHECKPOINT_LOADED, + CHECKPOINT_RECOVERY_FALLBACK, + CHECKPOINT_RECOVERY_NO_CHECKPOINT, + CHECKPOINT_RECOVERY_RESUME, + CHECKPOINT_RECOVERY_START, +) +from ai_company.persistence.errors import PersistenceError +from ai_company.persistence.repositories import ( + CheckpointRepository, # noqa: TC001 + HeartbeatRepository, # noqa: TC001 +) + +logger = get_logger(__name__) + +_MAX_TRACKED_EXECUTIONS: Final[int] = 10_000 +"""Safety bound on the ``_resume_counts`` dict to prevent unbounded growth.""" + + +class CheckpointRecoveryStrategy: + """Resume from the last checkpoint on crash. + + Loads the latest checkpoint for the execution and returns a + ``RecoveryResult`` with the serialized checkpoint context + (making ``can_resume`` evaluate to ``True``). After + ``max_resume_attempts`` resume attempts, delegates to the + fallback strategy (default: fail-and-reassign). + + Args: + checkpoint_repo: Repository for loading checkpoints. + heartbeat_repo: Repository for heartbeat cleanup on fallback. + config: Checkpoint configuration (controls max_resume_attempts). + fallback: Fallback recovery strategy; defaults to + ``FailAndReassignStrategy``. + """ + + STRATEGY_TYPE: Final[str] = "checkpoint" + + def __init__( + self, + *, + checkpoint_repo: CheckpointRepository, + heartbeat_repo: HeartbeatRepository | None = None, + config: CheckpointConfig, + fallback: RecoveryStrategy | None = None, + ) -> None: + self._checkpoint_repo = checkpoint_repo + self._heartbeat_repo = heartbeat_repo + self._config = config + self._fallback: RecoveryStrategy = fallback or FailAndReassignStrategy() + # NOTE: _resume_counts is in-memory only. Across process + # restarts the counter resets, allowing up to max_resume_attempts + # fresh attempts per lifetime. Persisting the counter (e.g. in + # the checkpoints table) is a planned improvement. + self._resume_counts: dict[str, int] = {} + self._resume_lock = asyncio.Lock() + + async def recover( + self, + *, + task_execution: TaskExecution, + error_message: str, + context: AgentContext, + ) -> RecoveryResult: + """Apply checkpoint recovery. + + 1. Load the latest checkpoint for the execution. + 2. If no checkpoint exists, delegate to fallback. + 3. If resume attempts are exhausted, delegate to fallback. + 4. Otherwise, return a ``RecoveryResult`` with the checkpoint + context for resume. + + Args: + task_execution: Current execution state. + error_message: Description of the failure. + context: Full agent context at the time of failure. + + Returns: + ``RecoveryResult`` — either resumable or fallback. + """ + execution_id = context.execution_id + task_id = task_execution.task.id + + logger.info( + CHECKPOINT_RECOVERY_START, + execution_id=execution_id, + task_id=task_id, + strategy=self.STRATEGY_TYPE, + ) + + checkpoint = await self._load_latest_checkpoint( + execution_id, + task_id, + ) + if checkpoint is None: + return await self._delegate_to_fallback( + task_execution=task_execution, + error_message=error_message, + context=context, + ) + + should_fallback = await self._reserve_resume_attempt( + execution_id, + task_id, + ) + if should_fallback: + return await self._delegate_to_fallback( + task_execution=task_execution, + error_message=error_message, + context=context, + ) + + return self._build_resume_result( + task_execution=task_execution, + error_message=error_message, + context=context, + checkpoint=checkpoint, + ) + + def get_strategy_type(self) -> str: + """Return the strategy type identifier.""" + return self.STRATEGY_TYPE + + async def clear_resume_count( + self, + execution_id: NotBlankStr, + ) -> None: + """Clear the resume counter for a completed execution. + + Called after successful completion to reset the counter. + Safe to call with unknown execution IDs (no-op). + + Args: + execution_id: The execution identifier to clear. + """ + async with self._resume_lock: + self._resume_counts.pop(execution_id, None) + + # ── Private helpers ────────────────────────────────────────── + + async def _load_latest_checkpoint( + self, + execution_id: str, + task_id: str, + ) -> Checkpoint | None: + """Load the latest checkpoint, returning ``None`` on failure.""" + try: + checkpoint = await self._checkpoint_repo.get_latest( + execution_id=execution_id, + ) + except MemoryError, RecursionError: + raise + except PersistenceError: + logger.exception( + CHECKPOINT_LOAD_FAILED, + execution_id=execution_id, + task_id=task_id, + ) + return None + + if checkpoint is None: + logger.info( + CHECKPOINT_RECOVERY_NO_CHECKPOINT, + execution_id=execution_id, + task_id=task_id, + ) + return None + + logger.debug( + CHECKPOINT_LOADED, + execution_id=execution_id, + checkpoint_id=checkpoint.id, + turn_number=checkpoint.turn_number, + ) + return checkpoint + + async def _reserve_resume_attempt( + self, + execution_id: str, + task_id: str, + ) -> bool: + """Reserve a resume attempt, returning ``True`` when exhausted.""" + async with self._resume_lock: + resume_count = self._resume_counts.get(execution_id, 0) + if resume_count >= self._config.max_resume_attempts: + logger.info( + CHECKPOINT_RECOVERY_FALLBACK, + execution_id=execution_id, + task_id=task_id, + resume_count=resume_count, + max_resume_attempts=self._config.max_resume_attempts, + reason="max_resume_attempts_exhausted", + ) + self._resume_counts.pop(execution_id, None) + return True + + self._resume_counts[execution_id] = resume_count + 1 + + # Evict oldest entries when the dict grows too large + if len(self._resume_counts) > _MAX_TRACKED_EXECUTIONS: + oldest = next(iter(self._resume_counts)) + self._resume_counts.pop(oldest, None) + + return False + + def _build_resume_result( + self, + *, + task_execution: TaskExecution, + error_message: str, + context: AgentContext, + checkpoint: Checkpoint, + ) -> RecoveryResult: + """Build a resumable ``RecoveryResult``.""" + execution_id = context.execution_id + resume_attempt = self._resume_counts.get(execution_id, 1) + + snapshot = context.to_snapshot() + logger.info( + CHECKPOINT_RECOVERY_RESUME, + execution_id=execution_id, + task_id=task_execution.task.id, + checkpoint_id=checkpoint.id, + turn_number=checkpoint.turn_number, + resume_attempt=resume_attempt, + max_resume_attempts=self._config.max_resume_attempts, + ) + + return RecoveryResult( + task_execution=task_execution, + strategy_type=self.STRATEGY_TYPE, + context_snapshot=snapshot, + error_message=error_message, + checkpoint_context_json=checkpoint.context_json, + resume_attempt=resume_attempt, + ) + + async def _delegate_to_fallback( + self, + *, + task_execution: TaskExecution, + error_message: str, + context: AgentContext, + ) -> RecoveryResult: + """Delegate recovery to the fallback strategy. + + Also cleans up any orphaned checkpoint/heartbeat rows for + this execution, since the resume path will not be entered. + """ + await cleanup_after_resume( + self._checkpoint_repo, + self._heartbeat_repo, + context.execution_id, + ) + return await self._fallback.recover( + task_execution=task_execution, + error_message=error_message, + context=context, + ) diff --git a/src/ai_company/engine/plan_execute_loop.py b/src/ai_company/engine/plan_execute_loop.py index 85819813d9..6191ba3fa3 100644 --- a/src/ai_company/engine/plan_execute_loop.py +++ b/src/ai_company/engine/plan_execute_loop.py @@ -15,6 +15,7 @@ from ai_company.budget.call_category import LLMCallCategory from ai_company.observability import get_logger from ai_company.observability.events.execution import ( + EXECUTION_CHECKPOINT_CALLBACK_FAILED, EXECUTION_LOOP_START, EXECUTION_LOOP_TERMINATED, EXECUTION_LOOP_TURN_COMPLETE, @@ -66,6 +67,7 @@ ) if TYPE_CHECKING: + from ai_company.engine.checkpoint.callback import CheckpointCallback from ai_company.engine.context import AgentContext from ai_company.providers.models import ToolDefinition from ai_company.providers.protocol import CompletionProvider @@ -81,8 +83,18 @@ class PlanExecuteLoop: step with a mini-ReAct sub-loop. Supports re-planning on failure. """ - def __init__(self, config: PlanExecuteConfig | None = None) -> None: + def __init__( + self, + config: PlanExecuteConfig | None = None, + checkpoint_callback: CheckpointCallback | None = None, + ) -> None: self._config = config or PlanExecuteConfig() + self._checkpoint_callback = checkpoint_callback + + @property + def config(self) -> PlanExecuteConfig: + """Return the loop configuration.""" + return self._config def get_loop_type(self) -> str: """Return the loop type identifier.""" @@ -559,6 +571,8 @@ async def _call_planner( # noqa: PLR0913 tool_call_count=0, ) + await self._invoke_checkpoint_callback(ctx, turn_number) + plan = parse_plan( response, ctx.execution_id, @@ -694,6 +708,8 @@ async def _run_step_turn( # noqa: PLR0913 tool_call_count=len(response.tool_calls), ) + await self._invoke_checkpoint_callback(ctx, turn_number) + if not response.tool_calls: return self._handle_step_completion(ctx, response, turn_number) @@ -749,6 +765,32 @@ async def _handle_step_tool_calls( # noqa: PLR0913 turns, ) + # ── Checkpoint ────────────────────────────────────────────────── + + async def _invoke_checkpoint_callback( + self, + ctx: AgentContext, + turn_number: int, + ) -> None: + """Invoke the checkpoint callback if configured. + + Errors are logged but never propagated — checkpointing must + not interrupt execution. + """ + if self._checkpoint_callback is None: + return + try: + await self._checkpoint_callback(ctx) + except MemoryError, RecursionError: + raise + except Exception as exc: + logger.exception( + EXECUTION_CHECKPOINT_CALLBACK_FAILED, + execution_id=ctx.execution_id, + turn=turn_number, + error=f"{type(exc).__name__}: {exc}", + ) + # ── Utilities ─────────────────────────────────────────────────── @staticmethod diff --git a/src/ai_company/engine/react_loop.py b/src/ai_company/engine/react_loop.py index 59c4b2e274..9c923c4a5f 100644 --- a/src/ai_company/engine/react_loop.py +++ b/src/ai_company/engine/react_loop.py @@ -11,6 +11,7 @@ from ai_company.budget.call_category import LLMCallCategory from ai_company.observability import get_logger from ai_company.observability.events.execution import ( + EXECUTION_CHECKPOINT_CALLBACK_FAILED, EXECUTION_LOOP_ERROR, EXECUTION_LOOP_START, EXECUTION_LOOP_TERMINATED, @@ -43,6 +44,7 @@ ) if TYPE_CHECKING: + from ai_company.engine.checkpoint.callback import CheckpointCallback from ai_company.engine.context import AgentContext from ai_company.providers.models import ToolDefinition from ai_company.providers.protocol import CompletionProvider @@ -59,8 +61,18 @@ class ReactLoop: feeds results back, and repeats until the LLM signals completion, the turn limit is reached, the budget is exhausted, a shutdown is requested, or an error occurs. + + Args: + checkpoint_callback: Optional async callback invoked after each + completed turn; the callback itself decides whether to persist. """ + def __init__( + self, + checkpoint_callback: CheckpointCallback | None = None, + ) -> None: + self._checkpoint_callback = checkpoint_callback + def get_loop_type(self) -> str: """Return the loop type identifier.""" return "react" @@ -194,6 +206,25 @@ async def _process_turn_response( # noqa: PLR0913 tool_call_count=len(response.tool_calls), ) + # Checkpoint is saved after the LLM response is recorded but + # before tool execution. This is intentional: if a crash + # happens during tool execution, the agent resumes with the + # LLM response and can detect whether tools already ran. The + # alternative (after tools) would lose the entire LLM call on + # a mid-tool crash. Tools should be idempotent by design. + if self._checkpoint_callback is not None: + try: + await self._checkpoint_callback(ctx) + except MemoryError, RecursionError: + raise + except Exception as exc: + logger.exception( + EXECUTION_CHECKPOINT_CALLBACK_FAILED, + execution_id=ctx.execution_id, + turn=turn_number, + error=f"{type(exc).__name__}: {exc}", + ) + if not response.tool_calls: return self._handle_completion(ctx, response, turns) diff --git a/src/ai_company/engine/recovery.py b/src/ai_company/engine/recovery.py index 32db5b199c..03c061229c 100644 --- a/src/ai_company/engine/recovery.py +++ b/src/ai_company/engine/recovery.py @@ -9,9 +9,10 @@ See the Crash Recovery section of the Engine design page. """ -from typing import Final, Protocol, runtime_checkable +import json +from typing import Final, Protocol, Self, runtime_checkable -from pydantic import BaseModel, ConfigDict, Field, computed_field +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator from ai_company.core.enums import TaskStatus from ai_company.core.types import NotBlankStr # noqa: TC001 @@ -31,20 +32,23 @@ class RecoveryResult(BaseModel): """Frozen result of a recovery strategy invocation. Attributes: - task_execution: Updated execution after recovery (typically - ``FAILED`` for the default strategy). + task_execution: Execution state after recovery (``FAILED`` for + fail-and-reassign, original state for checkpoint resume). strategy_type: Identifier of the strategy used (e.g. ``"fail_reassign"``). can_reassign: Computed — ``True`` when retry_count < task.max_retries. The caller (task router) is responsible for incrementing ``retry_count`` when creating the next ``TaskExecution``. context_snapshot: Redacted snapshot (no message contents). error_message: The error that triggered recovery. + checkpoint_context_json: Serialized ``AgentContext`` for resume + (set by ``CheckpointRecoveryStrategy``, ``None`` otherwise). + resume_attempt: Current resume attempt number (0 when not resuming). """ model_config = ConfigDict(frozen=True) task_execution: TaskExecution = Field( - description="Updated execution with FAILED status", + description="Execution state after recovery", ) strategy_type: NotBlankStr = Field( description="Identifier of the recovery strategy used", @@ -55,6 +59,37 @@ class RecoveryResult(BaseModel): error_message: NotBlankStr = Field( description="The error that triggered recovery", ) + checkpoint_context_json: str | None = Field( + default=None, + description="Serialized AgentContext from checkpoint for resume", + ) + resume_attempt: int = Field( + default=0, + ge=0, + description="Current resume attempt number", + ) + + @model_validator(mode="after") + def _validate_checkpoint_consistency(self) -> Self: + """Validate checkpoint_context_json and resume_attempt are consistent.""" + has_json = self.checkpoint_context_json is not None + has_attempt = self.resume_attempt > 0 + if has_json != has_attempt: + msg = ( + "checkpoint_context_json and resume_attempt must be " + "consistent: both set or both at default" + ) + raise ValueError(msg) + if self.checkpoint_context_json is not None: + try: + parsed = json.loads(self.checkpoint_context_json) + except json.JSONDecodeError as exc: + msg = f"checkpoint_context_json must be valid JSON: {exc}" + raise ValueError(msg) from exc + if not isinstance(parsed, dict): + msg = "checkpoint_context_json must be a JSON object" + raise ValueError(msg) + return self @computed_field( # type: ignore[prop-decorator] description="Whether the task can be reassigned for retry", @@ -68,14 +103,22 @@ def can_reassign(self) -> bool: """ return self.task_execution.retry_count < self.task_execution.task.max_retries + @computed_field( # type: ignore[prop-decorator] + description="Whether execution can resume from a checkpoint", + ) + @property + def can_resume(self) -> bool: + """Whether execution can resume from a persisted checkpoint.""" + return self.checkpoint_context_json is not None + @runtime_checkable class RecoveryStrategy(Protocol): """Protocol for crash recovery strategies. - Implementations decide how to handle a failed task execution: - transition the task, capture diagnostics, and report whether - reassignment is possible. + Implementations decide how to handle a failed task execution. + Strategies may transition the task status, capture diagnostics, + and report recovery options (e.g. reassignment, checkpoint resume). """ async def recover( diff --git a/src/ai_company/observability/events/checkpoint.py b/src/ai_company/observability/events/checkpoint.py new file mode 100644 index 0000000000..4a1cb3c13b --- /dev/null +++ b/src/ai_company/observability/events/checkpoint.py @@ -0,0 +1,29 @@ +"""Checkpoint recovery event constants for structured logging.""" + +from typing import Final + +# Checkpoint lifecycle +CHECKPOINT_SAVED: Final[str] = "checkpoint.saved" +CHECKPOINT_SAVE_FAILED: Final[str] = "checkpoint.save_failed" +CHECKPOINT_LOADED: Final[str] = "checkpoint.loaded" +CHECKPOINT_LOAD_FAILED: Final[str] = "checkpoint.load_failed" +CHECKPOINT_DELETED: Final[str] = "checkpoint.deleted" +CHECKPOINT_DELETE_FAILED: Final[str] = "checkpoint.delete_failed" +CHECKPOINT_SKIPPED: Final[str] = "checkpoint.skipped" + +# Heartbeat lifecycle +HEARTBEAT_UPDATED: Final[str] = "heartbeat.updated" +HEARTBEAT_UPDATE_FAILED: Final[str] = "heartbeat.update_failed" +HEARTBEAT_STALE_DETECTED: Final[str] = "heartbeat.stale_detected" +HEARTBEAT_DELETED: Final[str] = "heartbeat.deleted" +HEARTBEAT_DELETE_FAILED: Final[str] = "heartbeat.delete_failed" + +# Loop integration +CHECKPOINT_UNSUPPORTED_LOOP: Final[str] = "checkpoint.unsupported_loop" + +# Recovery flow +CHECKPOINT_RECOVERY_START: Final[str] = "checkpoint.recovery.start" +CHECKPOINT_RECOVERY_RESUME: Final[str] = "checkpoint.recovery.resume" +CHECKPOINT_RECOVERY_FALLBACK: Final[str] = "checkpoint.recovery.fallback" +CHECKPOINT_RECOVERY_NO_CHECKPOINT: Final[str] = "checkpoint.recovery.no_checkpoint" +CHECKPOINT_RECOVERY_RECONCILIATION: Final[str] = "checkpoint.recovery.reconciliation" diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index e519866d5e..16d53ca364 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -64,3 +64,12 @@ EXECUTION_RECOVERY_COMPLETE: Final[str] = "execution.recovery.complete" EXECUTION_RECOVERY_FAILED: Final[str] = "execution.recovery.failed" EXECUTION_RECOVERY_SNAPSHOT: Final[str] = "execution.recovery.snapshot" + +# Checkpoint callback & resume events +EXECUTION_CHECKPOINT_CALLBACK: Final[str] = "execution.checkpoint.callback" +EXECUTION_CHECKPOINT_CALLBACK_FAILED: Final[str] = ( + "execution.checkpoint.callback_failed" +) +EXECUTION_RESUME_START: Final[str] = "execution.resume.start" +EXECUTION_RESUME_COMPLETE: Final[str] = "execution.resume.complete" +EXECUTION_RESUME_FAILED: Final[str] = "execution.resume.failed" diff --git a/src/ai_company/observability/events/persistence.py b/src/ai_company/observability/events/persistence.py index b2c048ac2a..1fe5bd47b6 100644 --- a/src/ai_company/observability/events/persistence.py +++ b/src/ai_company/observability/events/persistence.py @@ -148,3 +148,29 @@ PERSISTENCE_SETTING_FETCH_FAILED: Final[str] = "persistence.setting.fetch_failed" PERSISTENCE_SETTING_SAVED: Final[str] = "persistence.setting.saved" PERSISTENCE_SETTING_SAVE_FAILED: Final[str] = "persistence.setting.save_failed" + +# Checkpoint events +PERSISTENCE_CHECKPOINT_SAVED: Final[str] = "persistence.checkpoint.saved" +PERSISTENCE_CHECKPOINT_SAVE_FAILED: Final[str] = "persistence.checkpoint.save_failed" +PERSISTENCE_CHECKPOINT_QUERIED: Final[str] = "persistence.checkpoint.queried" +PERSISTENCE_CHECKPOINT_QUERY_FAILED: Final[str] = "persistence.checkpoint.query_failed" +PERSISTENCE_CHECKPOINT_NOT_FOUND: Final[str] = "persistence.checkpoint.not_found" +PERSISTENCE_CHECKPOINT_DELETED: Final[str] = "persistence.checkpoint.deleted" +PERSISTENCE_CHECKPOINT_DELETE_FAILED: Final[str] = ( + "persistence.checkpoint.delete_failed" +) +PERSISTENCE_CHECKPOINT_DESERIALIZE_FAILED: Final[str] = ( + "persistence.checkpoint.deserialize_failed" +) + +# Heartbeat events +PERSISTENCE_HEARTBEAT_SAVED: Final[str] = "persistence.heartbeat.saved" +PERSISTENCE_HEARTBEAT_SAVE_FAILED: Final[str] = "persistence.heartbeat.save_failed" +PERSISTENCE_HEARTBEAT_QUERIED: Final[str] = "persistence.heartbeat.queried" +PERSISTENCE_HEARTBEAT_QUERY_FAILED: Final[str] = "persistence.heartbeat.query_failed" +PERSISTENCE_HEARTBEAT_NOT_FOUND: Final[str] = "persistence.heartbeat.not_found" +PERSISTENCE_HEARTBEAT_DELETED: Final[str] = "persistence.heartbeat.deleted" +PERSISTENCE_HEARTBEAT_DELETE_FAILED: Final[str] = "persistence.heartbeat.delete_failed" +PERSISTENCE_HEARTBEAT_DESERIALIZE_FAILED: Final[str] = ( + "persistence.heartbeat.deserialize_failed" +) diff --git a/src/ai_company/persistence/protocol.py b/src/ai_company/persistence/protocol.py index 3cfcd03074..2c2c0007b4 100644 --- a/src/ai_company/persistence/protocol.py +++ b/src/ai_company/persistence/protocol.py @@ -15,7 +15,9 @@ from ai_company.persistence.repositories import ( ApiKeyRepository, # noqa: TC001 AuditRepository, # noqa: TC001 + CheckpointRepository, # noqa: TC001 CostRecordRepository, # noqa: TC001 + HeartbeatRepository, # noqa: TC001 MessageRepository, # noqa: TC001 ParkedContextRepository, # noqa: TC001 TaskRepository, # noqa: TC001 @@ -44,6 +46,8 @@ class PersistenceBackend(Protocol): audit_entries: Repository for AuditEntry persistence. users: Repository for User persistence. api_keys: Repository for ApiKey persistence. + checkpoints: Repository for Checkpoint persistence. + heartbeats: Repository for Heartbeat persistence. """ async def connect(self) -> None: @@ -138,6 +142,16 @@ def api_keys(self) -> ApiKeyRepository: """Repository for ApiKey persistence.""" ... + @property + def checkpoints(self) -> CheckpointRepository: + """Repository for Checkpoint persistence.""" + ... + + @property + def heartbeats(self) -> HeartbeatRepository: + """Repository for Heartbeat persistence.""" + ... + async def get_setting(self, key: NotBlankStr) -> str | None: """Retrieve a setting value by key. diff --git a/src/ai_company/persistence/repositories.py b/src/ai_company/persistence/repositories.py index ab9d2a6c10..1047612e09 100644 --- a/src/ai_company/persistence/repositories.py +++ b/src/ai_company/persistence/repositories.py @@ -4,7 +4,7 @@ only on abstract interfaces, never on a concrete backend. """ -from typing import Protocol, runtime_checkable +from typing import TYPE_CHECKING, Protocol, runtime_checkable from pydantic import AwareDatetime # noqa: TC002 @@ -22,11 +22,16 @@ from ai_company.security.models import AuditEntry, AuditVerdictStr # noqa: TC001 from ai_company.security.timeout.parked_context import ParkedContext # noqa: TC001 +if TYPE_CHECKING: + from ai_company.engine.checkpoint.models import Checkpoint, Heartbeat + __all__ = [ "ApiKeyRepository", "AuditRepository", + "CheckpointRepository", "CollaborationMetricRepository", "CostRecordRepository", + "HeartbeatRepository", "LifecycleEventRepository", "MessageRepository", "ParkedContextRepository", @@ -473,3 +478,118 @@ async def delete(self, key_id: NotBlankStr) -> bool: PersistenceError: If the operation fails. """ ... + + +@runtime_checkable +class CheckpointRepository(Protocol): + """CRUD interface for checkpoint persistence.""" + + async def save(self, checkpoint: Checkpoint) -> None: + """Persist a checkpoint (insert or replace by ID). + + Args: + checkpoint: The checkpoint to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_latest( + self, + *, + execution_id: NotBlankStr | None = None, + task_id: NotBlankStr | None = None, + ) -> Checkpoint | None: + """Retrieve the latest checkpoint by turn_number. + + At least one filter (``execution_id`` or ``task_id``) is required. + + Args: + execution_id: Filter by execution identifier. + task_id: Filter by task identifier. + + Returns: + The checkpoint with the highest turn_number, or ``None``. + + Raises: + PersistenceError: If the operation fails. + ValueError: If neither filter is provided. + """ + ... + + async def delete_by_execution(self, execution_id: NotBlankStr) -> int: + """Delete all checkpoints for an execution. + + Args: + execution_id: The execution identifier. + + Returns: + Number of checkpoints deleted. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + +@runtime_checkable +class HeartbeatRepository(Protocol): + """CRUD interface for heartbeat persistence.""" + + async def save(self, heartbeat: Heartbeat) -> None: + """Persist a heartbeat (upsert by execution_id). + + Args: + heartbeat: The heartbeat to persist. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get(self, execution_id: NotBlankStr) -> Heartbeat | None: + """Retrieve a heartbeat by execution ID. + + Args: + execution_id: The execution identifier. + + Returns: + The heartbeat, or ``None`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def get_stale( + self, + threshold: AwareDatetime, + ) -> tuple[Heartbeat, ...]: + """Retrieve heartbeats older than the threshold. + + Args: + threshold: Heartbeats with ``last_heartbeat_at`` before + this timestamp are considered stale. + + Returns: + Stale heartbeats as a tuple. + + Raises: + PersistenceError: If the operation fails. + """ + ... + + async def delete(self, execution_id: NotBlankStr) -> bool: + """Delete a heartbeat by execution ID. + + Args: + execution_id: The execution identifier. + + Returns: + ``True`` if deleted, ``False`` if not found. + + Raises: + PersistenceError: If the operation fails. + """ + ... diff --git a/src/ai_company/persistence/sqlite/__init__.py b/src/ai_company/persistence/sqlite/__init__.py index 3b5d1f5a9a..6078c86751 100644 --- a/src/ai_company/persistence/sqlite/__init__.py +++ b/src/ai_company/persistence/sqlite/__init__.py @@ -4,6 +4,12 @@ SQLiteAuditRepository, ) from ai_company.persistence.sqlite.backend import SQLitePersistenceBackend +from ai_company.persistence.sqlite.checkpoint_repo import ( + SQLiteCheckpointRepository, +) +from ai_company.persistence.sqlite.heartbeat_repo import ( + SQLiteHeartbeatRepository, +) from ai_company.persistence.sqlite.migrations import ( SCHEMA_VERSION, run_migrations, @@ -17,7 +23,9 @@ __all__ = [ "SCHEMA_VERSION", "SQLiteAuditRepository", + "SQLiteCheckpointRepository", "SQLiteCostRecordRepository", + "SQLiteHeartbeatRepository", "SQLiteMessageRepository", "SQLitePersistenceBackend", "SQLiteTaskRepository", diff --git a/src/ai_company/persistence/sqlite/backend.py b/src/ai_company/persistence/sqlite/backend.py index 9033eee3d8..501f972c72 100644 --- a/src/ai_company/persistence/sqlite/backend.py +++ b/src/ai_company/persistence/sqlite/backend.py @@ -29,6 +29,12 @@ from ai_company.persistence.sqlite.audit_repository import ( SQLiteAuditRepository, ) +from ai_company.persistence.sqlite.checkpoint_repo import ( + SQLiteCheckpointRepository, +) +from ai_company.persistence.sqlite.heartbeat_repo import ( + SQLiteHeartbeatRepository, +) from ai_company.persistence.sqlite.hr_repositories import ( SQLiteCollaborationMetricRepository, SQLiteLifecycleEventRepository, @@ -79,6 +85,8 @@ def __init__(self, config: SQLiteConfig) -> None: self._audit_entries: SQLiteAuditRepository | None = None self._users: SQLiteUserRepository | None = None self._api_keys: SQLiteApiKeyRepository | None = None + self._checkpoints: SQLiteCheckpointRepository | None = None + self._heartbeats: SQLiteHeartbeatRepository | None = None def _clear_state(self) -> None: """Reset connection and repository references to ``None``.""" @@ -93,6 +101,8 @@ def _clear_state(self) -> None: self._audit_entries = None self._users = None self._api_keys = None + self._checkpoints = None + self._heartbeats = None async def connect(self) -> None: """Open the SQLite database and configure WAL mode.""" @@ -157,6 +167,8 @@ def _create_repositories(self) -> None: self._audit_entries = SQLiteAuditRepository(self._db) self._users = SQLiteUserRepository(self._db) self._api_keys = SQLiteApiKeyRepository(self._db) + self._checkpoints = SQLiteCheckpointRepository(self._db) + self._heartbeats = SQLiteHeartbeatRepository(self._db) async def _cleanup_failed_connect(self, exc: sqlite3.Error | OSError) -> None: """Log failure, close partial connection, and raise. @@ -357,6 +369,24 @@ def api_keys(self) -> SQLiteApiKeyRepository: """ return self._require_connected(self._api_keys, "api_keys") + @property + def checkpoints(self) -> SQLiteCheckpointRepository: + """Repository for Checkpoint persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._checkpoints, "checkpoints") + + @property + def heartbeats(self) -> SQLiteHeartbeatRepository: + """Repository for Heartbeat persistence. + + Raises: + PersistenceConnectionError: If not connected. + """ + return self._require_connected(self._heartbeats, "heartbeats") + async def get_setting(self, key: str) -> str | None: """Retrieve a setting value by key. diff --git a/src/ai_company/persistence/sqlite/checkpoint_repo.py b/src/ai_company/persistence/sqlite/checkpoint_repo.py new file mode 100644 index 0000000000..9c5da49e1b --- /dev/null +++ b/src/ai_company/persistence/sqlite/checkpoint_repo.py @@ -0,0 +1,174 @@ +"""SQLite repository implementation for checkpoint persistence.""" +# ruff: noqa: S608 — dynamic WHERE built from hardcoded column names only + +import sqlite3 + +import aiosqlite +from pydantic import ValidationError + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.engine.checkpoint.models import Checkpoint +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_CHECKPOINT_DELETE_FAILED, + PERSISTENCE_CHECKPOINT_DELETED, + PERSISTENCE_CHECKPOINT_DESERIALIZE_FAILED, + PERSISTENCE_CHECKPOINT_NOT_FOUND, + PERSISTENCE_CHECKPOINT_QUERIED, + PERSISTENCE_CHECKPOINT_QUERY_FAILED, + PERSISTENCE_CHECKPOINT_SAVE_FAILED, + PERSISTENCE_CHECKPOINT_SAVED, +) +from ai_company.persistence.errors import QueryError + +logger = get_logger(__name__) + + +class SQLiteCheckpointRepository: + """SQLite implementation of the CheckpointRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, checkpoint: Checkpoint) -> None: + """Persist a checkpoint (upsert).""" + try: + data = checkpoint.model_dump(mode="json") + await self._db.execute( + """\ +INSERT OR REPLACE INTO checkpoints ( + id, execution_id, agent_id, task_id, turn_number, + context_json, created_at +) VALUES ( + :id, :execution_id, :agent_id, :task_id, :turn_number, + :context_json, :created_at +)""", + data, + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save checkpoint {checkpoint.id!r}" + logger.exception( + PERSISTENCE_CHECKPOINT_SAVE_FAILED, + checkpoint_id=checkpoint.id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug( + PERSISTENCE_CHECKPOINT_SAVED, + checkpoint_id=checkpoint.id, + execution_id=checkpoint.execution_id, + turn_number=checkpoint.turn_number, + ) + + async def get_latest( + self, + *, + execution_id: NotBlankStr | None = None, + task_id: NotBlankStr | None = None, + ) -> Checkpoint | None: + """Retrieve the latest checkpoint by turn_number. + + At least one filter is required. + + Raises: + ValueError: If neither filter is provided. + """ + if execution_id is None and task_id is None: + msg = "At least one of execution_id or task_id is required" + raise ValueError(msg) + + conditions: list[str] = [] + params: list[str] = [] + + if execution_id is not None: + conditions.append("execution_id = ?") + params.append(execution_id) + if task_id is not None: + conditions.append("task_id = ?") + params.append(task_id) + + where = " AND ".join(conditions) + # where is built from hardcoded column names; only values + # use parameterized placeholders — no injection risk. + query = ( + "SELECT id, execution_id, agent_id, task_id, " + "turn_number, context_json, created_at " + f"FROM checkpoints WHERE {where} " + "ORDER BY turn_number DESC LIMIT 1" + ) + + try: + cursor = await self._db.execute(query, params) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to query latest checkpoint" + logger.exception( + PERSISTENCE_CHECKPOINT_QUERY_FAILED, + execution_id=execution_id, + task_id=task_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + if row is None: + logger.debug( + PERSISTENCE_CHECKPOINT_NOT_FOUND, + execution_id=execution_id, + task_id=task_id, + ) + return None + + checkpoint = self._row_to_model(dict(row)) + logger.debug( + PERSISTENCE_CHECKPOINT_QUERIED, + checkpoint_id=checkpoint.id, + turn_number=checkpoint.turn_number, + ) + return checkpoint + + async def delete_by_execution(self, execution_id: NotBlankStr) -> int: + """Delete all checkpoints for an execution.""" + try: + cursor = await self._db.execute( + "DELETE FROM checkpoints WHERE execution_id = ?", + (execution_id,), + ) + count = cursor.rowcount + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete checkpoints for execution {execution_id!r}" + logger.exception( + PERSISTENCE_CHECKPOINT_DELETE_FAILED, + execution_id=execution_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if count > 0: + logger.debug( + PERSISTENCE_CHECKPOINT_DELETED, + execution_id=execution_id, + count=count, + ) + return count + + def _row_to_model(self, row: dict[str, object]) -> Checkpoint: + """Convert a database row to a ``Checkpoint`` model. + + Raises: + QueryError: If the row cannot be deserialized. + """ + try: + return Checkpoint.model_validate(row) + except ValidationError as exc: + msg = f"Failed to deserialize checkpoint {row.get('id')!r}" + logger.exception( + PERSISTENCE_CHECKPOINT_DESERIALIZE_FAILED, + checkpoint_id=row.get("id"), + error=str(exc), + ) + raise QueryError(msg) from exc diff --git a/src/ai_company/persistence/sqlite/heartbeat_repo.py b/src/ai_company/persistence/sqlite/heartbeat_repo.py new file mode 100644 index 0000000000..83d94e542c --- /dev/null +++ b/src/ai_company/persistence/sqlite/heartbeat_repo.py @@ -0,0 +1,168 @@ +"""SQLite repository implementation for heartbeat persistence.""" + +import sqlite3 +from datetime import UTC + +import aiosqlite +from pydantic import AwareDatetime, ValidationError + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.engine.checkpoint.models import Heartbeat +from ai_company.observability import get_logger +from ai_company.observability.events.persistence import ( + PERSISTENCE_HEARTBEAT_DELETE_FAILED, + PERSISTENCE_HEARTBEAT_DELETED, + PERSISTENCE_HEARTBEAT_DESERIALIZE_FAILED, + PERSISTENCE_HEARTBEAT_NOT_FOUND, + PERSISTENCE_HEARTBEAT_QUERIED, + PERSISTENCE_HEARTBEAT_QUERY_FAILED, + PERSISTENCE_HEARTBEAT_SAVE_FAILED, + PERSISTENCE_HEARTBEAT_SAVED, +) +from ai_company.persistence.errors import QueryError + +logger = get_logger(__name__) + + +class SQLiteHeartbeatRepository: + """SQLite implementation of the HeartbeatRepository protocol. + + Args: + db: An open aiosqlite connection. + """ + + def __init__(self, db: aiosqlite.Connection) -> None: + self._db = db + + async def save(self, heartbeat: Heartbeat) -> None: + """Persist a heartbeat (upsert by execution_id).""" + try: + data = heartbeat.model_dump(mode="json") + # Normalize to UTC so lexicographic comparisons in + # get_stale() work correctly regardless of input timezone. + data["last_heartbeat_at"] = heartbeat.last_heartbeat_at.astimezone( + UTC + ).isoformat() + await self._db.execute( + """\ +INSERT OR REPLACE INTO heartbeats ( + execution_id, agent_id, task_id, last_heartbeat_at +) VALUES ( + :execution_id, :agent_id, :task_id, :last_heartbeat_at +)""", + data, + ) + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to save heartbeat for execution {heartbeat.execution_id!r}" + logger.exception( + PERSISTENCE_HEARTBEAT_SAVE_FAILED, + execution_id=heartbeat.execution_id, + error=str(exc), + ) + raise QueryError(msg) from exc + logger.debug( + PERSISTENCE_HEARTBEAT_SAVED, + execution_id=heartbeat.execution_id, + ) + + async def get(self, execution_id: NotBlankStr) -> Heartbeat | None: + """Retrieve a heartbeat by execution ID.""" + try: + cursor = await self._db.execute( + "SELECT execution_id, agent_id, task_id, last_heartbeat_at " + "FROM heartbeats WHERE execution_id = ?", + (execution_id,), + ) + row = await cursor.fetchone() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to query heartbeat {execution_id!r}" + logger.exception( + PERSISTENCE_HEARTBEAT_QUERY_FAILED, + execution_id=execution_id, + error=str(exc), + ) + raise QueryError(msg) from exc + + if row is None: + logger.debug( + PERSISTENCE_HEARTBEAT_NOT_FOUND, + execution_id=execution_id, + ) + return None + + return self._row_to_model(dict(row)) + + async def get_stale(self, threshold: AwareDatetime) -> tuple[Heartbeat, ...]: + """Retrieve heartbeats older than the threshold. + + Args: + threshold: Heartbeats with ``last_heartbeat_at`` before + this timestamp are considered stale. + """ + threshold_iso = threshold.astimezone(UTC).isoformat() + try: + cursor = await self._db.execute( + "SELECT execution_id, agent_id, task_id, last_heartbeat_at " + "FROM heartbeats WHERE last_heartbeat_at < ? " + "ORDER BY last_heartbeat_at", + (threshold_iso,), + ) + rows = await cursor.fetchall() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = "Failed to query stale heartbeats" + logger.exception( + PERSISTENCE_HEARTBEAT_QUERY_FAILED, + threshold=threshold, + error=str(exc), + ) + raise QueryError(msg) from exc + + results = tuple(self._row_to_model(dict(row)) for row in rows) + logger.debug( + PERSISTENCE_HEARTBEAT_QUERIED, + threshold=threshold, + count=len(results), + ) + return results + + async def delete(self, execution_id: NotBlankStr) -> bool: + """Delete a heartbeat by execution ID.""" + try: + cursor = await self._db.execute( + "DELETE FROM heartbeats WHERE execution_id = ?", + (execution_id,), + ) + deleted = cursor.rowcount > 0 + await self._db.commit() + except (sqlite3.Error, aiosqlite.Error) as exc: + msg = f"Failed to delete heartbeat {execution_id!r}" + logger.exception( + PERSISTENCE_HEARTBEAT_DELETE_FAILED, + execution_id=execution_id, + error=str(exc), + ) + raise QueryError(msg) from exc + if deleted: + logger.debug( + PERSISTENCE_HEARTBEAT_DELETED, + execution_id=execution_id, + ) + return deleted + + def _row_to_model(self, row: dict[str, object]) -> Heartbeat: + """Convert a database row to a ``Heartbeat`` model. + + Raises: + QueryError: If the row cannot be deserialized. + """ + try: + return Heartbeat.model_validate(row) + except ValidationError as exc: + msg = f"Failed to deserialize heartbeat {row.get('execution_id')!r}" + logger.exception( + PERSISTENCE_HEARTBEAT_DESERIALIZE_FAILED, + execution_id=row.get("execution_id"), + error=str(exc), + ) + raise QueryError(msg) from exc diff --git a/src/ai_company/persistence/sqlite/migrations.py b/src/ai_company/persistence/sqlite/migrations.py index 66cb1bfbd6..06835e12aa 100644 --- a/src/ai_company/persistence/sqlite/migrations.py +++ b/src/ai_company/persistence/sqlite/migrations.py @@ -23,7 +23,7 @@ logger = get_logger(__name__) # Current schema version — bump when adding new migrations. -SCHEMA_VERSION = 5 +SCHEMA_VERSION = 6 _V1_STATEMENTS: Sequence[str] = ( # ── Tasks ───────────────────────────────────────────── @@ -221,6 +221,40 @@ "CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id)", ) +_V6_STATEMENTS: Sequence[str] = ( + # ── Checkpoints ──────────────────────────────────────── + """\ +CREATE TABLE IF NOT EXISTS checkpoints ( + id TEXT PRIMARY KEY, + execution_id TEXT NOT NULL, + agent_id TEXT NOT NULL, + task_id TEXT NOT NULL, + turn_number INTEGER NOT NULL CHECK (turn_number >= 0), + context_json TEXT NOT NULL, + created_at TEXT NOT NULL +)""", + "CREATE INDEX IF NOT EXISTS idx_cp_execution_id ON checkpoints(execution_id)", + "CREATE INDEX IF NOT EXISTS idx_cp_task_id ON checkpoints(task_id)", + # Ascending index — SQLite can reverse-scan efficiently for + # ORDER BY turn_number DESC LIMIT 1. DESC modifier silently + # ignored on SQLite < 3.47 so we use ascending for portability. + "CREATE INDEX IF NOT EXISTS idx_cp_exec_turn" + " ON checkpoints(execution_id, turn_number)", + "CREATE INDEX IF NOT EXISTS idx_cp_task_turn ON checkpoints(task_id, turn_number)", + # ── Heartbeats ───────────────────────────────────────── + # No FK to tasks — checkpoints/heartbeats are ephemeral recovery + # data that may outlive their tasks. Cleanup is the engine's + # responsibility (delete_by_execution after completion). + """\ +CREATE TABLE IF NOT EXISTS heartbeats ( + execution_id TEXT PRIMARY KEY, + agent_id TEXT NOT NULL, + task_id TEXT NOT NULL, + last_heartbeat_at TEXT NOT NULL +)""", + "CREATE INDEX IF NOT EXISTS idx_hb_last_heartbeat ON heartbeats(last_heartbeat_at)", +) + _MigrateFn = Callable[[aiosqlite.Connection], Coroutine[Any, Any, None]] @@ -283,6 +317,12 @@ async def _apply_v5(db: aiosqlite.Connection) -> None: await db.execute(stmt) +async def _apply_v6(db: aiosqlite.Connection) -> None: + """Apply schema v6: checkpoints, heartbeats.""" + for stmt in _V6_STATEMENTS: + await db.execute(stmt) + + # Ordered list of (target_version, migration_function) pairs. Each migration # is applied when the current schema version is below its target version. _MIGRATIONS: list[tuple[int, _MigrateFn]] = [ @@ -291,6 +331,7 @@ async def _apply_v5(db: aiosqlite.Connection) -> None: (3, _apply_v3), (4, _apply_v4), (5, _apply_v5), + (6, _apply_v6), ] diff --git a/tests/unit/api/conftest.py b/tests/unit/api/conftest.py index feed2d053a..30190d7ffb 100644 --- a/tests/unit/api/conftest.py +++ b/tests/unit/api/conftest.py @@ -27,6 +27,7 @@ TaskStatus, ) from ai_company.core.task import Task +from ai_company.engine.checkpoint.models import Checkpoint, Heartbeat from ai_company.engine.task_engine import TaskEngine from ai_company.hr.enums import LifecycleEventType from ai_company.hr.models import AgentLifecycleEvent @@ -334,6 +335,65 @@ async def delete(self, key_id: str) -> bool: return self._keys.pop(key_id, None) is not None +class FakeCheckpointRepository: + """In-memory checkpoint repository for tests.""" + + def __init__(self) -> None: + self._checkpoints: dict[str, Checkpoint] = {} + + async def save(self, checkpoint: Checkpoint) -> None: + self._checkpoints[checkpoint.id] = checkpoint + + async def get_latest( + self, + *, + execution_id: str | None = None, + task_id: str | None = None, + ) -> Checkpoint | None: + if execution_id is None and task_id is None: + msg = "At least one of execution_id or task_id is required" + raise ValueError(msg) + candidates = list(self._checkpoints.values()) + if execution_id is not None: + candidates = [c for c in candidates if c.execution_id == execution_id] + if task_id is not None: + candidates = [c for c in candidates if c.task_id == task_id] + if not candidates: + return None + return max(candidates, key=lambda c: c.turn_number) + + async def delete_by_execution(self, execution_id: str) -> int: + to_delete = [ + k for k, v in self._checkpoints.items() if v.execution_id == execution_id + ] + for k in to_delete: + del self._checkpoints[k] + return len(to_delete) + + +class FakeHeartbeatRepository: + """In-memory heartbeat repository for tests.""" + + def __init__(self) -> None: + self._heartbeats: dict[str, Heartbeat] = {} + + async def save(self, heartbeat: Heartbeat) -> None: + self._heartbeats[heartbeat.execution_id] = heartbeat + + async def get(self, execution_id: str) -> Heartbeat | None: + return self._heartbeats.get(execution_id) + + async def get_stale(self, threshold: datetime) -> tuple[Heartbeat, ...]: + stale = [ + h for h in self._heartbeats.values() if h.last_heartbeat_at < threshold + ] + stale.sort(key=lambda h: h.last_heartbeat_at) + return tuple(stale) + + async def delete(self, execution_id: str) -> bool: + return self._heartbeats.pop(execution_id, None) is not None + + class FakePersistenceBackend: """In-memory persistence backend for tests.""" @@ -348,6 +408,8 @@ def __init__(self) -> None: self._audit_entries = FakeAuditRepository() self._users = FakeUserRepository() self._api_keys = FakeApiKeyRepository() + self._checkpoints = FakeCheckpointRepository() + self._heartbeats_repo = FakeHeartbeatRepository() self._settings: dict[str, str] = {} self._connected = False @@ -411,6 +473,14 @@ def users(self) -> FakeUserRepository: def api_keys(self) -> FakeApiKeyRepository: return self._api_keys + @property + def checkpoints(self) -> FakeCheckpointRepository: + return self._checkpoints + + @property + def heartbeats(self) -> FakeHeartbeatRepository: + return self._heartbeats_repo + async def get_setting(self, key: str) -> str | None: return self._settings.get(key) diff --git a/tests/unit/engine/checkpoint/__init__.py b/tests/unit/engine/checkpoint/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/engine/checkpoint/test_callback_factory.py b/tests/unit/engine/checkpoint/test_callback_factory.py new file mode 100644 index 0000000000..2f147ae0b6 --- /dev/null +++ b/tests/unit/engine/checkpoint/test_callback_factory.py @@ -0,0 +1,384 @@ +"""Tests for make_checkpoint_callback factory.""" + +from datetime import date +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest + +from ai_company.core.agent import AgentIdentity, ModelConfig, SkillSet +from ai_company.core.enums import SeniorityLevel, TaskStatus, TaskType +from ai_company.core.task import Task +from ai_company.engine.checkpoint.callback_factory import make_checkpoint_callback +from ai_company.engine.checkpoint.models import CheckpointConfig +from ai_company.engine.context import AgentContext + +pytestmark = pytest.mark.timeout(30) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_agent() -> AgentIdentity: + """Build a test agent identity.""" + return AgentIdentity( + id=uuid4(), + name="Test Agent", + role="Developer", + department="Engineering", + level=SeniorityLevel.MID, + model=ModelConfig(provider="test-provider", model_id="test-small-001"), + hiring_date=date(2026, 1, 1), + skills=SkillSet(primary=("python",)), + ) + + +def _make_task(agent: AgentIdentity) -> Task: + """Build a test task.""" + return Task( + id="task-cb-001", + title="Callback test", + description="Test task for callback factory.", + type=TaskType.DEVELOPMENT, + project="proj-001", + created_by="manager", + assigned_to=str(agent.id), + status=TaskStatus.ASSIGNED, + ) + + +def _make_ctx_at_turn( + agent: AgentIdentity, + task: Task, + turn: int, +) -> AgentContext: + """Build an AgentContext at a given turn count.""" + ctx = AgentContext.from_identity(agent, task=task) + # Use model_copy to set the desired turn count + return ctx.model_copy(update={"turn_count": turn}) + + +def _make_repos() -> tuple[AsyncMock, AsyncMock]: + """Build mock checkpoint and heartbeat repositories.""" + checkpoint_repo = AsyncMock() + checkpoint_repo.save = AsyncMock() + heartbeat_repo = AsyncMock() + heartbeat_repo.save = AsyncMock() + return checkpoint_repo, heartbeat_repo + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckpointCallbackBoundaryTurns: + """Checkpoint is saved on boundary turns based on persist_every_n_turns.""" + + async def test_saves_on_every_turn_default(self) -> None: + """persist_every_n_turns=1 saves every turn with correct fields.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx1 = _make_ctx_at_turn(agent, task, 1) + await callback(ctx1) + assert cp_repo.save.await_count == 1 + + # Verify checkpoint model content + saved_cp = cp_repo.save.call_args_list[0][0][0] + assert saved_cp.agent_id == str(agent.id) + assert saved_cp.task_id == task.id + assert saved_cp.turn_number == 1 + assert saved_cp.execution_id == ctx1.execution_id + + # Verify heartbeat model content + saved_hb = hb_repo.save.call_args_list[0][0][0] + assert saved_hb.agent_id == str(agent.id) + assert saved_hb.task_id == task.id + assert saved_hb.execution_id == ctx1.execution_id + + ctx2 = _make_ctx_at_turn(agent, task, 2) + await callback(ctx2) + assert cp_repo.save.await_count == 2 + + async def test_skips_non_boundary_turns(self) -> None: + """persist_every_n_turns=2 skips odd turns.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + config = CheckpointConfig(persist_every_n_turns=2) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + # Turn 1: 1 % 2 != 0 → skip + ctx1 = _make_ctx_at_turn(agent, task, 1) + await callback(ctx1) + assert cp_repo.save.await_count == 0 + + # Turn 2: 2 % 2 == 0 → save + ctx2 = _make_ctx_at_turn(agent, task, 2) + await callback(ctx2) + assert cp_repo.save.await_count == 1 + + # Turn 3: 3 % 2 != 0 → skip + ctx3 = _make_ctx_at_turn(agent, task, 3) + await callback(ctx3) + assert cp_repo.save.await_count == 1 + + @pytest.mark.parametrize( + ("persist_every_n", "turn", "should_save"), + [ + (1, 0, False), # Turn 0 always skipped + (1, 1, True), + (1, 5, True), + (2, 0, False), # Turn 0 always skipped + (2, 1, False), + (2, 2, True), + (2, 4, True), + (3, 1, False), + (3, 2, False), + (3, 3, True), + (3, 6, True), + (5, 3, False), + (5, 5, True), + ], + ) + async def test_persist_boundary_parametrized( + self, + persist_every_n: int, + turn: int, + should_save: bool, + ) -> None: + """Verify boundary turn detection with various configurations.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + config = CheckpointConfig(persist_every_n_turns=persist_every_n) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, turn) + await callback(ctx) + + expected_count = 1 if should_save else 0 + assert cp_repo.save.await_count == expected_count + + +@pytest.mark.unit +class TestCheckpointCallbackHeartbeat: + """Heartbeat is updated alongside checkpoint.""" + + async def test_heartbeat_updated_on_save(self) -> None: + """Heartbeat repo is called when checkpoint is saved.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, 1) + await callback(ctx) + + hb_repo.save.assert_awaited_once() + + async def test_heartbeat_not_called_on_skip(self) -> None: + """Heartbeat not updated when turn is skipped.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + config = CheckpointConfig(persist_every_n_turns=2) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + # Turn 1 is skipped with persist_every_n_turns=2 + ctx = _make_ctx_at_turn(agent, task, 1) + await callback(ctx) + + hb_repo.save.assert_not_awaited() + + +@pytest.mark.unit +class TestCheckpointCallbackErrorHandling: + """Errors are swallowed except MemoryError and RecursionError.""" + + async def test_checkpoint_repo_error_swallowed(self) -> None: + """Checkpoint repo error is logged but not propagated.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + cp_repo.save = AsyncMock(side_effect=RuntimeError("DB write failed")) + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, 1) + # Should not raise + await callback(ctx) + + async def test_heartbeat_repo_error_swallowed(self) -> None: + """Heartbeat repo error is logged but not propagated.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + hb_repo.save = AsyncMock(side_effect=RuntimeError("Heartbeat write failed")) + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, 1) + # Should not raise + await callback(ctx) + + async def test_memory_error_not_swallowed_from_checkpoint(self) -> None: + """MemoryError from checkpoint save propagates.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + cp_repo.save = AsyncMock(side_effect=MemoryError) + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, 1) + with pytest.raises(MemoryError): + await callback(ctx) + + async def test_recursion_error_not_swallowed_from_checkpoint(self) -> None: + """RecursionError from checkpoint save propagates.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + cp_repo.save = AsyncMock(side_effect=RecursionError) + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, 1) + with pytest.raises(RecursionError): + await callback(ctx) + + async def test_memory_error_not_swallowed_from_heartbeat(self) -> None: + """MemoryError from heartbeat save propagates.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + hb_repo.save = AsyncMock(side_effect=MemoryError) + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, 1) + with pytest.raises(MemoryError): + await callback(ctx) + + async def test_recursion_error_not_swallowed_from_heartbeat(self) -> None: + """RecursionError from heartbeat save propagates.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + hb_repo.save = AsyncMock(side_effect=RecursionError) + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, 1) + with pytest.raises(RecursionError): + await callback(ctx) + + async def test_checkpoint_error_skips_heartbeat(self) -> None: + """When checkpoint save fails, heartbeat is skipped to avoid limbo state.""" + agent = _make_agent() + task = _make_task(agent) + cp_repo, hb_repo = _make_repos() + cp_repo.save = AsyncMock(side_effect=RuntimeError("checkpoint failed")) + config = CheckpointConfig(persist_every_n_turns=1) + + callback = make_checkpoint_callback( + checkpoint_repo=cp_repo, + heartbeat_repo=hb_repo, + config=config, + agent_id=str(agent.id), + task_id=task.id, + ) + + ctx = _make_ctx_at_turn(agent, task, 1) + await callback(ctx) + + # Checkpoint save was attempted + cp_repo.save.assert_awaited_once() + # Heartbeat should NOT be called when checkpoint failed + hb_repo.save.assert_not_awaited() diff --git a/tests/unit/engine/checkpoint/test_models.py b/tests/unit/engine/checkpoint/test_models.py new file mode 100644 index 0000000000..efd26856d0 --- /dev/null +++ b/tests/unit/engine/checkpoint/test_models.py @@ -0,0 +1,272 @@ +"""Tests for checkpoint and heartbeat Pydantic models.""" + +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError + +from ai_company.engine.checkpoint.models import ( + Checkpoint, + CheckpointConfig, + Heartbeat, +) + +pytestmark = pytest.mark.timeout(30) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_checkpoint(**overrides: object) -> Checkpoint: + """Build a Checkpoint with sensible defaults.""" + defaults: dict[str, object] = { + "execution_id": "exec-001", + "agent_id": "agent-001", + "task_id": "task-001", + "turn_number": 1, + "context_json": '{"state": "running"}', + } + defaults.update(overrides) + return Checkpoint(**defaults) # type: ignore[arg-type] + + +def _make_heartbeat(**overrides: object) -> Heartbeat: + """Build a Heartbeat with sensible defaults.""" + defaults: dict[str, object] = { + "execution_id": "exec-001", + "agent_id": "agent-001", + "task_id": "task-001", + "last_heartbeat_at": datetime.now(UTC), + } + defaults.update(overrides) + return Heartbeat(**defaults) # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# Checkpoint tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckpointCreation: + """Checkpoint model creation and defaults.""" + + def test_auto_generates_id(self) -> None: + cp = _make_checkpoint() + assert cp.id + assert len(cp.id) > 0 + + def test_auto_generates_unique_ids(self) -> None: + cp1 = _make_checkpoint() + cp2 = _make_checkpoint() + assert cp1.id != cp2.id + + def test_auto_generates_created_at(self) -> None: + before = datetime.now(UTC) + cp = _make_checkpoint() + after = datetime.now(UTC) + assert before <= cp.created_at <= after + + def test_explicit_id_preserved(self) -> None: + cp = _make_checkpoint(id="custom-id") + assert cp.id == "custom-id" + + def test_all_fields_set(self) -> None: + cp = _make_checkpoint( + execution_id="exec-x", + agent_id="agent-x", + task_id="task-x", + turn_number=5, + context_json='{"key": "val"}', + ) + assert cp.execution_id == "exec-x" + assert cp.agent_id == "agent-x" + assert cp.task_id == "task-x" + assert cp.turn_number == 5 + assert cp.context_json == '{"key": "val"}' + + +@pytest.mark.unit +class TestCheckpointFrozen: + """Checkpoint model is immutable.""" + + def test_cannot_mutate_field(self) -> None: + cp = _make_checkpoint() + with pytest.raises(ValidationError, match="frozen"): + cp.turn_number = 99 # type: ignore[misc] + + def test_cannot_mutate_id(self) -> None: + cp = _make_checkpoint() + with pytest.raises(ValidationError, match="frozen"): + cp.id = "changed" # type: ignore[misc] + + +@pytest.mark.unit +class TestCheckpointContextJsonValidation: + """context_json must be valid JSON.""" + + def test_valid_json_passes(self) -> None: + cp = _make_checkpoint(context_json='{"a": 1, "b": [2, 3]}') + assert cp.context_json == '{"a": 1, "b": [2, 3]}' + + def test_empty_object_passes(self) -> None: + cp = _make_checkpoint(context_json="{}") + assert cp.context_json == "{}" + + def test_json_array_rejected(self) -> None: + with pytest.raises(ValidationError, match="JSON object"): + _make_checkpoint(context_json="[1, 2, 3]") + + def test_json_primitive_rejected(self) -> None: + with pytest.raises(ValidationError, match="JSON object"): + _make_checkpoint(context_json='"hello"') + + def test_invalid_json_raises(self) -> None: + with pytest.raises(ValidationError, match="context_json must be valid JSON"): + _make_checkpoint(context_json="{bad json") + + def test_empty_string_raises(self) -> None: + with pytest.raises(ValidationError, match="context_json must be valid JSON"): + _make_checkpoint(context_json="") + + def test_non_json_string_raises(self) -> None: + with pytest.raises(ValidationError, match="context_json must be valid JSON"): + _make_checkpoint(context_json="not json at all") + + +@pytest.mark.unit +class TestCheckpointTurnNumberConstraint: + """turn_number must be >= 0.""" + + def test_zero_is_valid(self) -> None: + cp = _make_checkpoint(turn_number=0) + assert cp.turn_number == 0 + + def test_positive_is_valid(self) -> None: + cp = _make_checkpoint(turn_number=42) + assert cp.turn_number == 42 + + def test_negative_raises(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 0"): + _make_checkpoint(turn_number=-1) + + +@pytest.mark.unit +class TestCheckpointBlankFieldRejection: + """NotBlankStr fields reject blank strings.""" + + @pytest.mark.parametrize( + "field", + ["execution_id", "agent_id", "task_id"], + ) + def test_blank_string_rejected(self, field: str) -> None: + with pytest.raises(ValidationError): + _make_checkpoint(**{field: ""}) + + @pytest.mark.parametrize( + "field", + ["execution_id", "agent_id", "task_id"], + ) + def test_whitespace_only_rejected(self, field: str) -> None: + with pytest.raises(ValidationError): + _make_checkpoint(**{field: " "}) + + +# --------------------------------------------------------------------------- +# Heartbeat tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestHeartbeatCreation: + """Heartbeat model creation.""" + + def test_all_fields_set(self) -> None: + now = datetime.now(UTC) + hb = _make_heartbeat( + execution_id="exec-hb", + agent_id="agent-hb", + task_id="task-hb", + last_heartbeat_at=now, + ) + assert hb.execution_id == "exec-hb" + assert hb.agent_id == "agent-hb" + assert hb.task_id == "task-hb" + assert hb.last_heartbeat_at == now + + +@pytest.mark.unit +class TestHeartbeatFrozen: + """Heartbeat model is immutable.""" + + def test_cannot_mutate_field(self) -> None: + hb = _make_heartbeat() + with pytest.raises(ValidationError, match="frozen"): + hb.execution_id = "changed" # type: ignore[misc] + + +# --------------------------------------------------------------------------- +# CheckpointConfig tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckpointConfigDefaults: + """CheckpointConfig default values.""" + + def test_defaults(self) -> None: + cfg = CheckpointConfig() + assert cfg.persist_every_n_turns == 1 + assert cfg.heartbeat_interval_seconds == 30.0 + assert cfg.max_resume_attempts == 2 + + +@pytest.mark.unit +class TestCheckpointConfigCustom: + """CheckpointConfig with custom values.""" + + def test_custom_values(self) -> None: + cfg = CheckpointConfig( + persist_every_n_turns=5, + heartbeat_interval_seconds=60.0, + max_resume_attempts=10, + ) + assert cfg.persist_every_n_turns == 5 + assert cfg.heartbeat_interval_seconds == 60.0 + assert cfg.max_resume_attempts == 10 + + def test_frozen(self) -> None: + cfg = CheckpointConfig() + with pytest.raises(ValidationError, match="frozen"): + cfg.persist_every_n_turns = 99 # type: ignore[misc] + + +@pytest.mark.unit +class TestCheckpointConfigValidation: + """CheckpointConfig field constraints.""" + + def test_persist_every_n_turns_must_be_positive(self) -> None: + with pytest.raises(ValidationError, match="greater than 0"): + CheckpointConfig(persist_every_n_turns=0) + + def test_persist_every_n_turns_negative_rejected(self) -> None: + with pytest.raises(ValidationError, match="greater than 0"): + CheckpointConfig(persist_every_n_turns=-1) + + def test_heartbeat_interval_must_be_positive(self) -> None: + with pytest.raises(ValidationError, match="greater than 0"): + CheckpointConfig(heartbeat_interval_seconds=0.0) + + def test_heartbeat_interval_negative_rejected(self) -> None: + with pytest.raises(ValidationError, match="greater than 0"): + CheckpointConfig(heartbeat_interval_seconds=-1.0) + + def test_max_resume_attempts_zero_valid(self) -> None: + cfg = CheckpointConfig(max_resume_attempts=0) + assert cfg.max_resume_attempts == 0 + + def test_max_resume_attempts_negative_rejected(self) -> None: + with pytest.raises(ValidationError, match="greater than or equal to 0"): + CheckpointConfig(max_resume_attempts=-1) diff --git a/tests/unit/engine/checkpoint/test_strategy.py b/tests/unit/engine/checkpoint/test_strategy.py new file mode 100644 index 0000000000..2cc490d37a --- /dev/null +++ b/tests/unit/engine/checkpoint/test_strategy.py @@ -0,0 +1,439 @@ +"""Tests for CheckpointRecoveryStrategy.""" + +from typing import TYPE_CHECKING +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from ai_company.core.enums import TaskStatus, TaskType +from ai_company.core.task import Task +from ai_company.engine.checkpoint.models import Checkpoint, CheckpointConfig +from ai_company.engine.checkpoint.strategy import CheckpointRecoveryStrategy +from ai_company.engine.context import AgentContext +from ai_company.engine.recovery import ( + FailAndReassignStrategy, + RecoveryResult, + RecoveryStrategy, +) +from ai_company.persistence.errors import QueryError + +if TYPE_CHECKING: + from ai_company.core.agent import AgentIdentity + from ai_company.engine.task_execution import TaskExecution + +pytestmark = pytest.mark.timeout(30) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_checkpoint( + *, + execution_id: str = "exec-001", + turn_number: int = 3, + context_json: str = '{"state": "resumed"}', +) -> Checkpoint: + """Build a Checkpoint with sensible defaults.""" + return Checkpoint( + execution_id=execution_id, + agent_id="agent-001", + task_id="task-001", + turn_number=turn_number, + context_json=context_json, + ) + + +def _make_mock_repo( + checkpoint: Checkpoint | None = None, + *, + error: Exception | None = None, +) -> AsyncMock: + """Build a mock CheckpointRepository.""" + repo = AsyncMock() + if error is not None: + repo.get_latest = AsyncMock(side_effect=error) + else: + repo.get_latest = AsyncMock(return_value=checkpoint) + return repo + + +def _make_strategy( + repo: AsyncMock, + *, + config: CheckpointConfig | None = None, + fallback: RecoveryStrategy | None = None, +) -> CheckpointRecoveryStrategy: + """Build a CheckpointRecoveryStrategy.""" + return CheckpointRecoveryStrategy( + checkpoint_repo=repo, + config=config or CheckpointConfig(), + fallback=fallback, + ) + + +def _make_in_progress_ctx( + agent: AgentIdentity, + task: Task, +) -> tuple[AgentContext, TaskExecution]: + """Build a context with IN_PROGRESS task execution.""" + ctx = AgentContext.from_identity(agent, task=task) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="starting") + assert ctx.task_execution is not None + return ctx, ctx.task_execution + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCheckpointRecoveryProtocol: + """CheckpointRecoveryStrategy satisfies RecoveryStrategy protocol.""" + + def test_is_runtime_checkable(self) -> None: + repo = _make_mock_repo() + strategy = _make_strategy(repo) + assert isinstance(strategy, RecoveryStrategy) + + def test_get_strategy_type(self) -> None: + repo = _make_mock_repo() + strategy = _make_strategy(repo) + assert strategy.get_strategy_type() == "checkpoint" + + +@pytest.mark.unit +class TestCheckpointRecoveryResume: + """Resume from a valid checkpoint.""" + + async def test_resume_with_valid_checkpoint( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Returns RecoveryResult with can_resume=True when checkpoint exists.""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + checkpoint = _make_checkpoint(execution_id=ctx.execution_id) + repo = _make_mock_repo(checkpoint) + strategy = _make_strategy(repo) + + result = await strategy.recover( + task_execution=task_exec, + error_message="LLM crashed", + context=ctx, + ) + + assert isinstance(result, RecoveryResult) + assert result.can_resume is True + assert result.checkpoint_context_json == checkpoint.context_json + assert result.strategy_type == "checkpoint" + assert result.error_message == "LLM crashed" + assert result.resume_attempt == 1 + + async def test_task_not_transitioned_to_failed( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Task execution is NOT transitioned to FAILED (unlike FailAndReassign).""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + checkpoint = _make_checkpoint(execution_id=ctx.execution_id) + repo = _make_mock_repo(checkpoint) + strategy = _make_strategy(repo) + + result = await strategy.recover( + task_execution=task_exec, + error_message="crash", + context=ctx, + ) + + # The checkpoint strategy preserves the original task_execution + # (still IN_PROGRESS), not transitioning to FAILED. + assert result.task_execution.status is TaskStatus.IN_PROGRESS + + +@pytest.mark.unit +class TestCheckpointRecoveryFallback: + """Fallback to FailAndReassignStrategy when no checkpoint or max attempts.""" + + async def test_no_checkpoint_delegates_to_fallback( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Falls back when no checkpoint found.""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + repo = _make_mock_repo(checkpoint=None) + strategy = _make_strategy(repo) + + result = await strategy.recover( + task_execution=task_exec, + error_message="crash", + context=ctx, + ) + + # Fallback is FailAndReassignStrategy which transitions to FAILED + assert result.task_execution.status is TaskStatus.FAILED + assert result.strategy_type == "fail_reassign" + assert result.can_resume is False + + async def test_max_resume_attempts_exhausted( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Falls back after max_resume_attempts reached.""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + checkpoint = _make_checkpoint(execution_id=ctx.execution_id) + repo = _make_mock_repo(checkpoint) + config = CheckpointConfig(max_resume_attempts=2) + strategy = _make_strategy(repo, config=config) + + # First two recoveries succeed (resume_attempt 1 and 2) + result1 = await strategy.recover( + task_execution=task_exec, + error_message="crash 1", + context=ctx, + ) + assert result1.can_resume is True + assert result1.resume_attempt == 1 + + result2 = await strategy.recover( + task_execution=task_exec, + error_message="crash 2", + context=ctx, + ) + assert result2.can_resume is True + assert result2.resume_attempt == 2 + + # Third recovery should fall back + result3 = await strategy.recover( + task_execution=task_exec, + error_message="crash 3", + context=ctx, + ) + assert result3.can_resume is False + assert result3.strategy_type == "fail_reassign" + + async def test_zero_max_resume_attempts_always_fallback( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """max_resume_attempts=0 always falls back.""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + checkpoint = _make_checkpoint(execution_id=ctx.execution_id) + repo = _make_mock_repo(checkpoint) + config = CheckpointConfig(max_resume_attempts=0) + strategy = _make_strategy(repo, config=config) + + result = await strategy.recover( + task_execution=task_exec, + error_message="crash", + context=ctx, + ) + + assert result.can_resume is False + assert result.strategy_type == "fail_reassign" + + async def test_repo_error_delegates_to_fallback( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Falls back when checkpoint repo raises an exception.""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + repo = _make_mock_repo(error=QueryError("DB connection lost")) + strategy = _make_strategy(repo) + + result = await strategy.recover( + task_execution=task_exec, + error_message="crash", + context=ctx, + ) + + assert result.can_resume is False + assert result.strategy_type == "fail_reassign" + assert result.task_execution.status is TaskStatus.FAILED + + async def test_custom_fallback_used( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Custom fallback strategy is used when provided.""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + repo = _make_mock_repo(checkpoint=None) + + mock_fallback = MagicMock(spec=FailAndReassignStrategy) + snapshot = ctx.to_snapshot() + fallback_result = RecoveryResult( + task_execution=task_exec.with_transition( + TaskStatus.FAILED, reason="custom" + ), + strategy_type="custom_fallback", + context_snapshot=snapshot, + error_message="crash", + ) + mock_fallback.recover = AsyncMock(return_value=fallback_result) + + strategy = _make_strategy(repo, fallback=mock_fallback) + + result = await strategy.recover( + task_execution=task_exec, + error_message="crash", + context=ctx, + ) + + assert result.strategy_type == "custom_fallback" + mock_fallback.recover.assert_awaited_once() + + +@pytest.mark.unit +class TestCheckpointRecoveryCounter: + """Resume counter tracking and reset.""" + + async def test_counter_increments_per_execution( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Resume attempts are tracked per execution_id.""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + checkpoint = _make_checkpoint(execution_id=ctx.execution_id) + repo = _make_mock_repo(checkpoint) + config = CheckpointConfig(max_resume_attempts=5) + strategy = _make_strategy(repo, config=config) + + result1 = await strategy.recover( + task_execution=task_exec, + error_message="crash 1", + context=ctx, + ) + assert result1.resume_attempt == 1 + + result2 = await strategy.recover( + task_execution=task_exec, + error_message="crash 2", + context=ctx, + ) + assert result2.resume_attempt == 2 + + async def test_clear_resume_count_resets( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """clear_resume_count resets the counter for an execution.""" + ctx, task_exec = _make_in_progress_ctx( + sample_agent_with_personality, + sample_task_with_criteria, + ) + checkpoint = _make_checkpoint(execution_id=ctx.execution_id) + repo = _make_mock_repo(checkpoint) + config = CheckpointConfig(max_resume_attempts=5) + strategy = _make_strategy(repo, config=config) + + # Use up one attempt + await strategy.recover( + task_execution=task_exec, + error_message="crash", + context=ctx, + ) + + # Clear and retry + await strategy.clear_resume_count(ctx.execution_id) + + result = await strategy.recover( + task_execution=task_exec, + error_message="crash again", + context=ctx, + ) + assert result.resume_attempt == 1 # Reset to 1, not 2 + + async def test_clear_resume_count_noop_for_unknown(self) -> None: + """Clearing a nonexistent execution_id is a safe no-op.""" + repo = _make_mock_repo() + strategy = _make_strategy(repo) + await strategy.clear_resume_count("nonexistent-exec") # Should not raise + + async def test_independent_counters_per_execution( + self, + sample_agent_with_personality: AgentIdentity, + ) -> None: + """Different execution IDs have independent counters.""" + task_a = Task( + id="task-a", + title="Task A", + description="First task", + type=TaskType.DEVELOPMENT, + project="proj-001", + created_by="manager", + assigned_to=str(sample_agent_with_personality.id), + status=TaskStatus.ASSIGNED, + ) + task_b = Task( + id="task-b", + title="Task B", + description="Second task", + type=TaskType.DEVELOPMENT, + project="proj-001", + created_by="manager", + assigned_to=str(sample_agent_with_personality.id), + status=TaskStatus.ASSIGNED, + ) + + ctx_a, exec_a = _make_in_progress_ctx(sample_agent_with_personality, task_a) + ctx_b, exec_b = _make_in_progress_ctx(sample_agent_with_personality, task_b) + + cp_a = _make_checkpoint(execution_id=ctx_a.execution_id) + cp_b = _make_checkpoint(execution_id=ctx_b.execution_id) + + repo = AsyncMock() + repo.get_latest = AsyncMock( + side_effect=lambda execution_id=None, task_id=None: ( + cp_a if execution_id == ctx_a.execution_id else cp_b + ) + ) + config = CheckpointConfig(max_resume_attempts=5) + strategy = _make_strategy(repo, config=config) + + result_a = await strategy.recover( + task_execution=exec_a, + error_message="crash a", + context=ctx_a, + ) + assert result_a.resume_attempt == 1 + + result_b = await strategy.recover( + task_execution=exec_b, + error_message="crash b", + context=ctx_b, + ) + assert result_b.resume_attempt == 1 # Independent counter diff --git a/tests/unit/engine/test_recovery_checkpoint_fields.py b/tests/unit/engine/test_recovery_checkpoint_fields.py new file mode 100644 index 0000000000..33bc9cdf32 --- /dev/null +++ b/tests/unit/engine/test_recovery_checkpoint_fields.py @@ -0,0 +1,184 @@ +"""Tests for RecoveryResult checkpoint-related fields (can_resume, resume_attempt).""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.core.enums import TaskStatus +from ai_company.engine.context import AgentContext +from ai_company.engine.recovery import ( + FailAndReassignStrategy, + RecoveryResult, +) + +if TYPE_CHECKING: + from ai_company.core.agent import AgentIdentity + from ai_company.core.task import Task + +pytestmark = pytest.mark.timeout(30) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestCanResumeField: + """RecoveryResult.can_resume computed field.""" + + def test_can_resume_true_when_checkpoint_context_set( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """can_resume is True when checkpoint_context_json is provided.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="starting") + assert ctx.task_execution is not None + + result = RecoveryResult( + task_execution=ctx.task_execution, + strategy_type="checkpoint", + context_snapshot=ctx.to_snapshot(), + error_message="crash", + checkpoint_context_json='{"state": "partial"}', + resume_attempt=1, + ) + + assert result.can_resume is True + + def test_can_resume_false_when_checkpoint_context_none( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """can_resume is False when checkpoint_context_json is None (default).""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="starting") + assert ctx.task_execution is not None + + result = RecoveryResult( + task_execution=ctx.task_execution.with_transition( + TaskStatus.FAILED, reason="crash" + ), + strategy_type="fail_reassign", + context_snapshot=ctx.to_snapshot(), + error_message="crash", + ) + + assert result.can_resume is False + assert result.checkpoint_context_json is None + + +@pytest.mark.unit +class TestCheckpointConsistencyValidator: + """RecoveryResult rejects inconsistent checkpoint_context_json / resume_attempt.""" + + def test_json_set_but_attempt_zero_raises( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Setting checkpoint_context_json without resume_attempt > 0 raises.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="starting") + assert ctx.task_execution is not None + + with pytest.raises(ValueError, match="must be consistent"): + RecoveryResult( + task_execution=ctx.task_execution, + strategy_type="checkpoint", + context_snapshot=ctx.to_snapshot(), + error_message="crash", + checkpoint_context_json='{"state": "partial"}', + ) + + def test_attempt_set_but_json_none_raises( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """Setting resume_attempt > 0 without checkpoint_context_json raises.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="starting") + assert ctx.task_execution is not None + + with pytest.raises(ValueError, match="must be consistent"): + RecoveryResult( + task_execution=ctx.task_execution, + strategy_type="checkpoint", + context_snapshot=ctx.to_snapshot(), + error_message="crash", + resume_attempt=1, + ) + + +@pytest.mark.unit +class TestResumeAttemptDefault: + """RecoveryResult.resume_attempt defaults to 0.""" + + def test_defaults_to_zero( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="starting") + assert ctx.task_execution is not None + + result = RecoveryResult( + task_execution=ctx.task_execution.with_transition( + TaskStatus.FAILED, reason="crash" + ), + strategy_type="fail_reassign", + context_snapshot=ctx.to_snapshot(), + error_message="crash", + ) + + assert result.resume_attempt == 0 + + +@pytest.mark.unit +class TestBackwardCompatibility: + """Existing FailAndReassignStrategy produces can_resume=False.""" + + async def test_fail_and_reassign_has_no_resume( + self, + sample_agent_with_personality: AgentIdentity, + sample_task_with_criteria: Task, + ) -> None: + """FailAndReassignStrategy result has can_resume=False and resume_attempt=0.""" + ctx = AgentContext.from_identity( + sample_agent_with_personality, + task=sample_task_with_criteria, + ) + ctx = ctx.with_task_transition(TaskStatus.IN_PROGRESS, reason="starting") + assert ctx.task_execution is not None + + strategy = FailAndReassignStrategy() + result = await strategy.recover( + task_execution=ctx.task_execution, + error_message="LLM crashed", + context=ctx, + ) + + assert result.can_resume is False + assert result.checkpoint_context_json is None + assert result.resume_attempt == 0 + assert result.strategy_type == "fail_reassign" diff --git a/tests/unit/observability/test_events.py b/tests/unit/observability/test_events.py index 3a3c52b425..1fdce5e805 100644 --- a/tests/unit/observability/test_events.py +++ b/tests/unit/observability/test_events.py @@ -8,6 +8,11 @@ from ai_company.observability import events from ai_company.observability.events.budget import BUDGET_RECORD_ADDED +from ai_company.observability.events.checkpoint import ( + CHECKPOINT_RECOVERY_START, + CHECKPOINT_SAVED, + HEARTBEAT_UPDATED, +) from ai_company.observability.events.classification import ( CLASSIFICATION_COMPLETE, CLASSIFICATION_ERROR, @@ -219,6 +224,7 @@ def test_all_domain_modules_discovered(self) -> None: "workspace", "trust", "promotion", + "checkpoint", } discovered = {info.name for info in pkgutil.iter_modules(events.__path__)} assert discovered == expected @@ -245,6 +251,11 @@ def test_role_events_exist(self) -> None: def test_budget_events_exist(self) -> None: assert BUDGET_RECORD_ADDED == "budget.record.added" + def test_checkpoint_events_exist(self) -> None: + assert CHECKPOINT_SAVED == "checkpoint.saved" + assert HEARTBEAT_UPDATED == "heartbeat.updated" + assert CHECKPOINT_RECOVERY_START == "checkpoint.recovery.start" + def test_execution_events_exist(self) -> None: assert EXECUTION_TASK_CREATED == "execution.task.created" diff --git a/tests/unit/persistence/sqlite/test_checkpoint_repo.py b/tests/unit/persistence/sqlite/test_checkpoint_repo.py new file mode 100644 index 0000000000..4ba9fcbe26 --- /dev/null +++ b/tests/unit/persistence/sqlite/test_checkpoint_repo.py @@ -0,0 +1,238 @@ +"""Tests for SQLiteCheckpointRepository.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.engine.checkpoint.models import Checkpoint +from ai_company.persistence.sqlite.checkpoint_repo import ( + SQLiteCheckpointRepository, +) + +if TYPE_CHECKING: + import aiosqlite + +pytestmark = pytest.mark.timeout(30) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_checkpoint( # noqa: PLR0913 + *, + checkpoint_id: str = "cp-001", + execution_id: str = "exec-001", + agent_id: str = "agent-001", + task_id: str = "task-001", + turn_number: int = 1, + context_json: str = '{"state": "running"}', +) -> Checkpoint: + return Checkpoint( + id=checkpoint_id, + execution_id=execution_id, + agent_id=agent_id, + task_id=task_id, + turn_number=turn_number, + context_json=context_json, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestSQLiteCheckpointRepository: + async def test_save_and_get_latest_roundtrip( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + cp = _make_checkpoint(checkpoint_id="cp-rt-001") + await repo.save(cp) + + result = await repo.get_latest(execution_id="exec-001") + assert result is not None + assert result.id == cp.id + assert result.execution_id == cp.execution_id + assert result.agent_id == cp.agent_id + assert result.task_id == cp.task_id + assert result.turn_number == cp.turn_number + assert result.context_json == cp.context_json + assert result.created_at == cp.created_at + + async def test_get_latest_returns_highest_turn_number( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + cp_low = _make_checkpoint( + checkpoint_id="cp-low", + turn_number=1, + ) + cp_high = _make_checkpoint( + checkpoint_id="cp-high", + turn_number=5, + ) + cp_mid = _make_checkpoint( + checkpoint_id="cp-mid", + turn_number=3, + ) + # Insert in non-order to confirm DB ordering + await repo.save(cp_mid) + await repo.save(cp_low) + await repo.save(cp_high) + + result = await repo.get_latest(execution_id="exec-001") + assert result is not None + assert result.id == "cp-high" + assert result.turn_number == 5 + + async def test_get_latest_filter_by_task_id( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + cp_a = _make_checkpoint( + checkpoint_id="cp-a", + task_id="task-alpha", + turn_number=3, + ) + cp_b = _make_checkpoint( + checkpoint_id="cp-b", + task_id="task-beta", + turn_number=5, + ) + await repo.save(cp_a) + await repo.save(cp_b) + + result = await repo.get_latest(task_id="task-alpha") + assert result is not None + assert result.id == "cp-a" + assert result.task_id == "task-alpha" + + async def test_get_latest_filter_by_execution_id( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + cp_a = _make_checkpoint( + checkpoint_id="cp-exec-a", + execution_id="exec-alpha", + turn_number=2, + ) + cp_b = _make_checkpoint( + checkpoint_id="cp-exec-b", + execution_id="exec-beta", + turn_number=4, + ) + await repo.save(cp_a) + await repo.save(cp_b) + + result = await repo.get_latest(execution_id="exec-alpha") + assert result is not None + assert result.id == "cp-exec-a" + assert result.execution_id == "exec-alpha" + + async def test_get_latest_both_filters( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + cp_match = _make_checkpoint( + checkpoint_id="cp-match", + execution_id="exec-m", + task_id="task-m", + turn_number=3, + ) + cp_exec_only = _make_checkpoint( + checkpoint_id="cp-exec-only", + execution_id="exec-m", + task_id="task-other", + turn_number=5, + ) + await repo.save(cp_match) + await repo.save(cp_exec_only) + + result = await repo.get_latest(execution_id="exec-m", task_id="task-m") + assert result is not None + assert result.id == "cp-match" + + async def test_get_latest_returns_none_when_no_match( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + result = await repo.get_latest(execution_id="nonexistent") + assert result is None + + async def test_get_latest_raises_when_no_filter( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + with pytest.raises(ValueError, match="At least one"): + await repo.get_latest() + + async def test_upsert_same_id(self, migrated_db: aiosqlite.Connection) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + cp_v1 = _make_checkpoint( + checkpoint_id="cp-upsert", + context_json='{"version": 1}', + turn_number=1, + ) + await repo.save(cp_v1) + + cp_v2 = _make_checkpoint( + checkpoint_id="cp-upsert", + context_json='{"version": 2}', + turn_number=2, + ) + await repo.save(cp_v2) + + result = await repo.get_latest(execution_id="exec-001") + assert result is not None + assert result.id == "cp-upsert" + assert result.context_json == '{"version": 2}' + assert result.turn_number == 2 + + async def test_delete_by_execution_returns_count( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + for i in range(3): + cp = _make_checkpoint( + checkpoint_id=f"cp-del-{i}", + execution_id="exec-to-delete", + turn_number=i, + ) + await repo.save(cp) + + count = await repo.delete_by_execution("exec-to-delete") + assert count == 3 + + result = await repo.get_latest(execution_id="exec-to-delete") + assert result is None + + async def test_delete_by_execution_returns_zero_when_none_exist( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + count = await repo.delete_by_execution("nonexistent") + assert count == 0 + + async def test_delete_by_execution_does_not_affect_other_executions( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteCheckpointRepository(migrated_db) + cp_keep = _make_checkpoint( + checkpoint_id="cp-keep", + execution_id="exec-keep", + ) + cp_delete = _make_checkpoint( + checkpoint_id="cp-delete", + execution_id="exec-delete", + ) + await repo.save(cp_keep) + await repo.save(cp_delete) + + await repo.delete_by_execution("exec-delete") + + assert await repo.get_latest(execution_id="exec-keep") is not None + assert await repo.get_latest(execution_id="exec-delete") is None diff --git a/tests/unit/persistence/sqlite/test_heartbeat_repo.py b/tests/unit/persistence/sqlite/test_heartbeat_repo.py new file mode 100644 index 0000000000..a601dd6d2e --- /dev/null +++ b/tests/unit/persistence/sqlite/test_heartbeat_repo.py @@ -0,0 +1,194 @@ +"""Tests for SQLiteHeartbeatRepository.""" + +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING + +import pytest + +from ai_company.engine.checkpoint.models import Heartbeat +from ai_company.persistence.sqlite.heartbeat_repo import ( + SQLiteHeartbeatRepository, +) + +if TYPE_CHECKING: + import aiosqlite + +pytestmark = pytest.mark.timeout(30) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_heartbeat( + *, + execution_id: str = "exec-001", + agent_id: str = "agent-001", + task_id: str = "task-001", + last_heartbeat_at: datetime | None = None, +) -> Heartbeat: + return Heartbeat( + execution_id=execution_id, + agent_id=agent_id, + task_id=task_id, + last_heartbeat_at=last_heartbeat_at or datetime.now(UTC), + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestSQLiteHeartbeatRepository: + async def test_save_and_get_roundtrip( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + hb = _make_heartbeat(execution_id="exec-hb-001") + await repo.save(hb) + + result = await repo.get("exec-hb-001") + assert result is not None + assert result.execution_id == hb.execution_id + assert result.agent_id == hb.agent_id + assert result.task_id == hb.task_id + assert result.last_heartbeat_at == hb.last_heartbeat_at + + async def test_get_returns_none_for_missing( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + result = await repo.get("nonexistent") + assert result is None + + async def test_upsert_updates_existing( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + now = datetime.now(UTC) + later = now + timedelta(minutes=5) + + hb_original = _make_heartbeat( + execution_id="exec-upsert", + last_heartbeat_at=now, + ) + await repo.save(hb_original) + + hb_updated = _make_heartbeat( + execution_id="exec-upsert", + last_heartbeat_at=later, + ) + await repo.save(hb_updated) + + result = await repo.get("exec-upsert") + assert result is not None + assert result.last_heartbeat_at == later + + async def test_get_stale_returns_old_heartbeats( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + now = datetime.now(UTC) + old = now - timedelta(hours=1) + very_old = now - timedelta(hours=2) + + hb_fresh = _make_heartbeat( + execution_id="exec-fresh", + last_heartbeat_at=now, + ) + hb_stale = _make_heartbeat( + execution_id="exec-stale", + last_heartbeat_at=old, + ) + hb_very_stale = _make_heartbeat( + execution_id="exec-very-stale", + last_heartbeat_at=very_old, + ) + await repo.save(hb_fresh) + await repo.save(hb_stale) + await repo.save(hb_very_stale) + + threshold = now - timedelta(minutes=30) + stale = await repo.get_stale(threshold) + + stale_ids = {h.execution_id for h in stale} + assert "exec-stale" in stale_ids + assert "exec-very-stale" in stale_ids + assert "exec-fresh" not in stale_ids + + async def test_get_stale_returns_empty_when_none_stale( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + now = datetime.now(UTC) + hb = _make_heartbeat( + execution_id="exec-fresh", + last_heartbeat_at=now, + ) + await repo.save(hb) + + very_old_threshold = now - timedelta(hours=1) + stale = await repo.get_stale(very_old_threshold) + + # The heartbeat is newer than the threshold, so not stale + assert len(stale) == 0 + + async def test_get_stale_ordered_by_timestamp( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + now = datetime.now(UTC) + t1 = now - timedelta(hours=3) + t2 = now - timedelta(hours=2) + + hb1 = _make_heartbeat( + execution_id="exec-oldest", + last_heartbeat_at=t1, + ) + hb2 = _make_heartbeat( + execution_id="exec-older", + last_heartbeat_at=t2, + ) + # Save in reverse order to verify DB ordering + await repo.save(hb2) + await repo.save(hb1) + + threshold = now - timedelta(hours=1) + stale = await repo.get_stale(threshold) + + assert len(stale) == 2 + assert stale[0].execution_id == "exec-oldest" + assert stale[1].execution_id == "exec-older" + + async def test_delete_returns_true_when_found( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + hb = _make_heartbeat(execution_id="exec-del") + await repo.save(hb) + + assert await repo.delete("exec-del") is True + assert await repo.get("exec-del") is None + + async def test_delete_returns_false_when_not_found( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + assert await repo.delete("nonexistent") is False + + async def test_delete_does_not_affect_other_heartbeats( + self, migrated_db: aiosqlite.Connection + ) -> None: + repo = SQLiteHeartbeatRepository(migrated_db) + hb_keep = _make_heartbeat(execution_id="exec-keep") + hb_delete = _make_heartbeat(execution_id="exec-delete") + await repo.save(hb_keep) + await repo.save(hb_delete) + + await repo.delete("exec-delete") + + assert await repo.get("exec-keep") is not None + assert await repo.get("exec-delete") is None diff --git a/tests/unit/persistence/sqlite/test_migrations_v6.py b/tests/unit/persistence/sqlite/test_migrations_v6.py new file mode 100644 index 0000000000..29084cc2e1 --- /dev/null +++ b/tests/unit/persistence/sqlite/test_migrations_v6.py @@ -0,0 +1,135 @@ +"""Tests for V6 migration (checkpoints and heartbeats tables).""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.persistence.sqlite.migrations import ( + SCHEMA_VERSION, + run_migrations, +) + +if TYPE_CHECKING: + import aiosqlite + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestSchemaVersion: + def test_schema_version_is_six(self) -> None: + assert SCHEMA_VERSION == 6 + + +@pytest.mark.unit +class TestV6MigrationCheckpointsTable: + """V6 migration creates the checkpoints table.""" + + async def test_creates_checkpoints_table( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='checkpoints'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_checkpoints_table_columns( + self, memory_db: aiosqlite.Connection + ) -> None: + """Verify the checkpoints table has the expected columns.""" + await run_migrations(memory_db) + cursor = await memory_db.execute("PRAGMA table_info(checkpoints)") + columns = {row[1] for row in await cursor.fetchall()} + expected = { + "id", + "execution_id", + "agent_id", + "task_id", + "turn_number", + "context_json", + "created_at", + } + assert expected == columns + + +@pytest.mark.unit +class TestV6MigrationHeartbeatsTable: + """V6 migration creates the heartbeats table.""" + + async def test_creates_heartbeats_table( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='heartbeats'" + ) + row = await cursor.fetchone() + assert row is not None + + async def test_heartbeats_table_columns( + self, memory_db: aiosqlite.Connection + ) -> None: + """Verify the heartbeats table has the expected columns.""" + await run_migrations(memory_db) + cursor = await memory_db.execute("PRAGMA table_info(heartbeats)") + columns = {row[1] for row in await cursor.fetchall()} + expected = { + "execution_id", + "agent_id", + "task_id", + "last_heartbeat_at", + } + assert expected == columns + + +@pytest.mark.unit +class TestV6MigrationIndexes: + """V6 migration creates the expected indexes.""" + + async def test_creates_checkpoint_indexes( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='index' " + "AND name LIKE 'idx_cp_%' ORDER BY name" + ) + indexes = {row[0] for row in await cursor.fetchall()} + expected = { + "idx_cp_execution_id", + "idx_cp_task_id", + "idx_cp_exec_turn", + } + assert expected.issubset(indexes) + + async def test_creates_heartbeat_index( + self, memory_db: aiosqlite.Connection + ) -> None: + await run_migrations(memory_db) + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='index' " + "AND name LIKE 'idx_hb_%' ORDER BY name" + ) + indexes = {row[0] for row in await cursor.fetchall()} + assert "idx_hb_last_heartbeat" in indexes + + +@pytest.mark.unit +class TestV6MigrationIdempotent: + """Running migrations twice is safe.""" + + async def test_idempotent(self, memory_db: aiosqlite.Connection) -> None: + await run_migrations(memory_db) + # Second run should not fail + await run_migrations(memory_db) + + # Tables should still be there + cursor = await memory_db.execute( + "SELECT name FROM sqlite_master WHERE type='table' " + "AND name IN ('checkpoints', 'heartbeats') ORDER BY name" + ) + tables = [row[0] for row in await cursor.fetchall()] + assert "checkpoints" in tables + assert "heartbeats" in tables diff --git a/tests/unit/persistence/test_migrations_v2.py b/tests/unit/persistence/test_migrations_v2.py index 9e80beaf52..ce0355edf7 100644 --- a/tests/unit/persistence/test_migrations_v2.py +++ b/tests/unit/persistence/test_migrations_v2.py @@ -28,9 +28,6 @@ async def memory_db() -> AsyncGenerator[aiosqlite.Connection]: @pytest.mark.unit class TestSchemaMigrations: - async def test_schema_version_is_five(self) -> None: - assert SCHEMA_VERSION == 5 - async def test_fresh_db_creates_all_v2_tables( self, memory_db: aiosqlite.Connection ) -> None: @@ -58,7 +55,7 @@ async def test_v1_to_v2_migration(self, memory_db: aiosqlite.Connection) -> None assert await get_user_version(memory_db) == 1 await run_migrations(memory_db) - assert await get_user_version(memory_db) == 5 + assert await get_user_version(memory_db) == SCHEMA_VERSION cursor = await memory_db.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" diff --git a/tests/unit/persistence/test_protocol.py b/tests/unit/persistence/test_protocol.py index 7b6db34524..0e87d47109 100644 --- a/tests/unit/persistence/test_protocol.py +++ b/tests/unit/persistence/test_protocol.py @@ -14,7 +14,9 @@ from ai_company.persistence.repositories import ( ApiKeyRepository, AuditRepository, + CheckpointRepository, CostRecordRepository, + HeartbeatRepository, MessageRepository, ParkedContextRepository, TaskRepository, @@ -29,6 +31,7 @@ from ai_company.communication.message import Message from ai_company.core.enums import ApprovalRiskLevel, TaskStatus from ai_company.core.task import Task + from ai_company.engine.checkpoint.models import Checkpoint, Heartbeat from ai_company.hr.models import AgentLifecycleEvent from ai_company.hr.performance.models import ( CollaborationMetricRecord, @@ -70,7 +73,12 @@ async def query( ) -> tuple[CostRecord, ...]: return () - async def aggregate(self, *, agent_id: str | None = None) -> float: + async def aggregate( + self, + *, + agent_id: str | None = None, + task_id: str | None = None, + ) -> float: return 0.0 @@ -203,6 +211,39 @@ async def delete(self, key_id: str) -> bool: return False +class _FakeCheckpointRepository: + async def save(self, checkpoint: Checkpoint) -> None: + pass + + async def get_latest( + self, + *, + execution_id: str | None = None, + task_id: str | None = None, + ) -> Checkpoint | None: + return None + + async def delete_by_execution(self, execution_id: str) -> int: + return 0 + + +class _FakeHeartbeatRepository: + async def save(self, heartbeat: Heartbeat) -> None: + pass + + async def get(self, execution_id: str) -> Heartbeat | None: + return None + + async def get_stale( + self, + threshold: AwareDatetime, + ) -> tuple[Heartbeat, ...]: + return () + + async def delete(self, execution_id: str) -> bool: + return False + + class _FakeBackend: async def connect(self) -> None: pass @@ -264,6 +305,14 @@ def users(self) -> _FakeUserRepository: def api_keys(self) -> _FakeApiKeyRepository: return _FakeApiKeyRepository() + @property + def checkpoints(self) -> _FakeCheckpointRepository: + return _FakeCheckpointRepository() + + @property + def heartbeats(self) -> _FakeHeartbeatRepository: + return _FakeHeartbeatRepository() + async def get_setting(self, key: str) -> str | None: return None @@ -315,3 +364,9 @@ def test_fake_user_repo_is_user_repository(self) -> None: def test_fake_api_key_repo_is_api_key_repository(self) -> None: assert isinstance(_FakeApiKeyRepository(), ApiKeyRepository) + + def test_fake_checkpoint_repo_is_checkpoint_repository(self) -> None: + assert isinstance(_FakeCheckpointRepository(), CheckpointRepository) + + def test_fake_heartbeat_repo_is_heartbeat_repository(self) -> None: + assert isinstance(_FakeHeartbeatRepository(), HeartbeatRepository)