diff --git a/CLAUDE.md b/CLAUDE.md index d98c44cd8a..38e95ab841 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -22,6 +22,7 @@ - Every implementation plan must be **presented to the user** for accept/deny before coding starts - At **every phase** of planning and implementation, be critical — actively look for ways to improve the design in the spirit of what we're building (robustness, correctness, simplicity, future-proofing where it's free) - Surface improvements as suggestions, not silent changes — user decides +- **Prioritize issues by dependency order**, not priority labels — unblocked dependencies come first ## Quick Commands @@ -112,6 +113,8 @@ src/ai_company/ - **Enforced by**: commitizen (commit-msg hook) - **Branches**: `/` from main - **Pre-commit hooks**: trailing-whitespace, end-of-file-fixer, check-yaml, check-toml, check-json, check-merge-conflict, check-added-large-files, no-commit-to-branch (main), ruff check+format, gitleaks +- **GitHub issue queries**: use `gh issue list` via Bash (not MCP tools) — MCP `list_issues` returns `null` for milestone data +- **PR issue references**: preserve existing `Closes #NNN` references — never remove unless explicitly asked ## Post-Implementation (MANDATORY) diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 3d3f5d690b..eb875b351d 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -799,6 +799,20 @@ The agent execution loop defines how an agent processes a task from start to fin > **MVP: ReAct only (Loop 1).** Plan-and-Execute and Hybrid are M4+. Auto-selection is M4+. +#### ExecutionLoop Protocol + +All loop implementations satisfy the `ExecutionLoop` runtime-checkable protocol (defined in `engine/loop_protocol.py`): + +- **`get_loop_type() -> str`** — returns a unique identifier (e.g. `"react"`) +- **`execute(...) -> ExecutionResult`** — runs the loop to completion, accepting `AgentContext`, `CompletionProvider`, optional `ToolInvoker`, optional `BudgetChecker`, and optional `CompletionConfig` + +Supporting models: + +- **`TerminationReason`** — enum: `COMPLETED`, `MAX_TURNS`, `BUDGET_EXHAUSTED`, `ERROR` +- **`TurnRecord`** — frozen per-turn stats (tokens, cost, tool calls, finish reason) +- **`ExecutionResult`** — frozen outcome with final context, termination reason, turn records, and optional error message (required when reason is `ERROR`) +- **`BudgetChecker`** — callback type `Callable[[AgentContext], bool]` invoked before each LLM call + #### Loop 1: ReAct (Default for Simple Tasks) A single interleaved loop: the agent reasons about the current state, selects an action (tool call or response), observes the result, and repeats until done or `max_turns` is reached. @@ -814,7 +828,7 @@ A single interleaved loop: the agent reasons about the current state, selects an │ └─────────────────────────┘ │ │ │ │ Terminate when: task complete, max │ -│ turns, budget exhausted, or blocked │ +│ turns, budget exhausted, or error │ └──────────────────────────────────────────┘ ``` @@ -2106,7 +2120,9 @@ ai-company/ │ │ ├── prompt_template.py # System prompt Jinja2 templates │ │ ├── task_execution.py # TaskExecution + StatusTransition │ │ ├── context.py # AgentContext + AgentContextSnapshot -│ │ ├── agent_engine.py # Agent execution loop (M3) +│ │ ├── loop_protocol.py # ExecutionLoop protocol + result models +│ │ ├── react_loop.py # ReAct loop implementation +│ │ ├── agent_engine.py # Agent execution engine (M3) │ │ ├── task_engine.py # Task routing & scheduling (M3-M4) │ │ ├── workflow_engine.py # Workflow orchestration (M4) │ │ ├── meeting_engine.py # Meeting coordination (M4) diff --git a/src/ai_company/engine/__init__.py b/src/ai_company/engine/__init__.py index eae31a64e8..25df804937 100644 --- a/src/ai_company/engine/__init__.py +++ b/src/ai_company/engine/__init__.py @@ -1,7 +1,7 @@ """Agent execution engine. Re-exports the public API for system prompt construction, -runtime execution state, and engine errors. +runtime execution state, execution loops, and engine errors. """ from ai_company.engine.context import ( @@ -10,17 +10,27 @@ AgentContextSnapshot, ) from ai_company.engine.errors import ( + BudgetExhaustedError, EngineError, ExecutionStateError, + LoopExecutionError, MaxTurnsExceededError, PromptBuildError, ) +from ai_company.engine.loop_protocol import ( + BudgetChecker, + ExecutionLoop, + ExecutionResult, + TerminationReason, + TurnRecord, +) from ai_company.engine.prompt import ( DefaultTokenEstimator, PromptTokenEstimator, SystemPrompt, build_system_prompt, ) +from ai_company.engine.react_loop import ReactLoop from ai_company.engine.task_execution import StatusTransition, TaskExecution from ai_company.providers.models import ZERO_TOKEN_USAGE, add_token_usage @@ -29,15 +39,23 @@ "ZERO_TOKEN_USAGE", "AgentContext", "AgentContextSnapshot", + "BudgetChecker", + "BudgetExhaustedError", "DefaultTokenEstimator", "EngineError", + "ExecutionLoop", + "ExecutionResult", "ExecutionStateError", + "LoopExecutionError", "MaxTurnsExceededError", "PromptBuildError", "PromptTokenEstimator", + "ReactLoop", "StatusTransition", "SystemPrompt", "TaskExecution", + "TerminationReason", + "TurnRecord", "add_token_usage", "build_system_prompt", ] diff --git a/src/ai_company/engine/errors.py b/src/ai_company/engine/errors.py index aa69661ea8..00058ab891 100644 --- a/src/ai_company/engine/errors.py +++ b/src/ai_company/engine/errors.py @@ -19,3 +19,21 @@ class MaxTurnsExceededError(EngineError): Enforced by ``AgentContext.with_turn_completed`` when the hard turn limit has been reached. """ + + +class BudgetExhaustedError(EngineError): + """Budget exhaustion signal for the engine layer. + + The execution loop returns ``TerminationReason.BUDGET_EXHAUSTED`` + internally. This exception is available for the engine layer above + the loop to convert that result into a raised error when appropriate. + """ + + +class LoopExecutionError(EngineError): + """Non-recoverable execution loop error for the engine layer. + + The execution loop returns ``TerminationReason.ERROR`` internally. + This exception is available for the engine layer above the loop to + convert that result into a raised error when appropriate. + """ diff --git a/src/ai_company/engine/loop_protocol.py b/src/ai_company/engine/loop_protocol.py new file mode 100644 index 0000000000..4a682aa587 --- /dev/null +++ b/src/ai_company/engine/loop_protocol.py @@ -0,0 +1,161 @@ +"""Execution loop protocol and supporting models. + +Defines the ``ExecutionLoop`` protocol that the agent engine calls to +run a task, along with ``ExecutionResult``, ``TurnRecord``, +``TerminationReason``, and the ``BudgetChecker`` type alias. +""" + +from collections.abc import Callable +from enum import StrEnum +from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable + +from pydantic import BaseModel, ConfigDict, Field, computed_field, model_validator + +from ai_company.core.types import NotBlankStr # noqa: TC001 +from ai_company.engine.context import AgentContext +from ai_company.providers.enums import FinishReason # noqa: TC001 + +if TYPE_CHECKING: + from ai_company.providers.models import CompletionConfig + from ai_company.providers.protocol import CompletionProvider + from ai_company.tools.invoker import ToolInvoker + + +class TerminationReason(StrEnum): + """Why the execution loop terminated.""" + + COMPLETED = "completed" + MAX_TURNS = "max_turns" + BUDGET_EXHAUSTED = "budget_exhausted" + ERROR = "error" + + +class TurnRecord(BaseModel): + """Per-turn metadata recorded during execution. + + Attributes: + turn_number: 1-indexed turn number. + input_tokens: Input tokens consumed this turn. + output_tokens: Output tokens generated this turn. + total_tokens: Sum of input and output tokens (computed). + cost_usd: Cost in USD for this turn. + tool_calls_made: Names of tools invoked this turn. + finish_reason: LLM finish reason for this turn. + """ + + model_config = ConfigDict(frozen=True) + + turn_number: int = Field(gt=0, description="1-indexed turn number") + input_tokens: int = Field(ge=0, description="Input tokens this turn") + output_tokens: int = Field(ge=0, description="Output tokens this turn") + cost_usd: float = Field(ge=0.0, description="Cost in USD this turn") + tool_calls_made: tuple[NotBlankStr, ...] = Field( + default=(), + description="Tool names invoked this turn", + ) + finish_reason: FinishReason = Field( + description="LLM finish reason this turn", + ) + + @computed_field(description="Total token count") # type: ignore[prop-decorator] + @property + def total_tokens(self) -> int: + """Sum of input and output tokens.""" + return self.input_tokens + self.output_tokens + + +class ExecutionResult(BaseModel): + """Result returned by an execution loop. + + Attributes: + context: Final agent context after execution. + termination_reason: Why the loop stopped. + turns: Per-turn metadata records. + total_tool_calls: Total tool calls across all turns (computed). + error_message: Error description when termination_reason is ERROR. + metadata: Forward-compatible dict for future loop types. + Note: ``frozen=True`` prevents field reassignment but not + in-place mutation of the dict contents; deep-copy at + system boundaries per project conventions. + """ + + model_config = ConfigDict(frozen=True) + + context: AgentContext = Field(description="Final agent context") + termination_reason: TerminationReason = Field( + description="Why the loop stopped", + ) + turns: tuple[TurnRecord, ...] = Field( + default=(), + description="Per-turn metadata", + ) + error_message: str | None = Field( + default=None, + description="Error description (when reason is ERROR)", + ) + metadata: dict[str, Any] = Field( + default_factory=dict, + description="Forward-compatible metadata for future loop types", + ) + + @computed_field( # type: ignore[prop-decorator] + description="Total tool calls across all turns", + ) + @property + def total_tool_calls(self) -> int: + """Sum of tool calls from all turn records.""" + return sum(len(t.tool_calls_made) for t in self.turns) + + @model_validator(mode="after") + def _validate_error_message(self) -> Self: + if self.termination_reason == TerminationReason.ERROR: + if self.error_message is None: + msg = "error_message is required when termination_reason is ERROR" + raise ValueError(msg) + elif self.error_message is not None: + msg = "error_message must be None when termination_reason is not ERROR" + raise ValueError(msg) + return self + + +BudgetChecker = Callable[[AgentContext], bool] +"""Callback that returns ``True`` when the budget is exhausted.""" + + +@runtime_checkable +class ExecutionLoop(Protocol): + """Protocol for agent execution loops. + + The agent engine calls ``execute`` to run a task through the loop. + Implementations decide the control flow (ReAct, Plan-and-Execute, etc.) + but all return an ``ExecutionResult`` with a ``TerminationReason``. + """ + + async def execute( + self, + *, + context: AgentContext, + provider: CompletionProvider, + tool_invoker: ToolInvoker | None = None, + budget_checker: BudgetChecker | None = None, + completion_config: CompletionConfig | None = None, + ) -> ExecutionResult: + """Run the execution loop. + + Args: + context: Initial agent context with conversation and identity. + provider: LLM completion provider. + tool_invoker: Optional tool invoker for tool execution. + budget_checker: Optional callback; returns ``True`` when + budget is exhausted. + completion_config: Optional per-execution override for + temperature/max_tokens (defaults to identity's model config). + + Returns: + Execution result with final context and termination reason. + """ + ... + + def get_loop_type(self) -> str: + """Return the loop type identifier (e.g. ``"react"``).""" + ... diff --git a/src/ai_company/engine/react_loop.py b/src/ai_company/engine/react_loop.py new file mode 100644 index 0000000000..3ff6bb5368 --- /dev/null +++ b/src/ai_company/engine/react_loop.py @@ -0,0 +1,456 @@ +"""ReAct execution loop — think, act, observe. + +Implements the ``ExecutionLoop`` protocol using the ReAct pattern: +check budget -> call LLM -> record turn -> check for LLM errors -> +update context -> handle completion or execute tools -> repeat. +""" + +from typing import TYPE_CHECKING + +from ai_company.observability import get_logger +from ai_company.observability.events.execution import ( + EXECUTION_LOOP_BUDGET_EXHAUSTED, + EXECUTION_LOOP_ERROR, + EXECUTION_LOOP_START, + EXECUTION_LOOP_TERMINATED, + EXECUTION_LOOP_TOOL_CALLS, + EXECUTION_LOOP_TURN_COMPLETE, + EXECUTION_LOOP_TURN_START, +) +from ai_company.providers.enums import FinishReason, MessageRole +from ai_company.providers.models import ( + ChatMessage, + CompletionConfig, + CompletionResponse, + ToolDefinition, + add_token_usage, +) + +from .loop_protocol import ( + BudgetChecker, + ExecutionResult, + TerminationReason, + TurnRecord, +) + +if TYPE_CHECKING: + from ai_company.engine.context import AgentContext + from ai_company.providers.protocol import CompletionProvider + from ai_company.tools.invoker import ToolInvoker + +logger = get_logger(__name__) + + +class ReactLoop: + """ReAct execution loop: reason, act, observe. + + The loop checks the budget, calls the LLM, checks for termination + conditions, executes any requested tools, feeds results back, and + repeats until the LLM signals completion, the turn limit is reached, + the budget is exhausted, or an error occurs. + """ + + def get_loop_type(self) -> str: + """Return the loop type identifier.""" + return "react" + + async def execute( + self, + *, + context: AgentContext, + provider: CompletionProvider, + tool_invoker: ToolInvoker | None = None, + budget_checker: BudgetChecker | None = None, + completion_config: CompletionConfig | None = None, + ) -> ExecutionResult: + """Run the ReAct loop until termination. + + Normal failure modes (budget exhaustion, LLM errors, provider + failures, missing tool invoker) are returned as + ``ExecutionResult`` with the appropriate ``TerminationReason`` + rather than raised as exceptions. Non-recoverable errors + (``MemoryError``, ``RecursionError``) are re-raised rather + than captured in the result. + + Args: + context: Initial agent context with conversation. + provider: LLM completion provider. + tool_invoker: Optional tool invoker for tool execution. + budget_checker: Optional budget exhaustion callback. + completion_config: Optional per-execution config override. + Implementations may fall back to the identity's model + config when not provided. + + Returns: + Execution result with final context and termination info. + """ + logger.info( + EXECUTION_LOOP_START, + execution_id=context.execution_id, + loop_type=self.get_loop_type(), + max_turns=context.max_turns, + ) + model_id = context.identity.model.model_id + config = completion_config or CompletionConfig( + temperature=context.identity.model.temperature, + max_tokens=context.identity.model.max_tokens, + ) + tool_defs = _get_tool_definitions(tool_invoker) + turns: list[TurnRecord] = [] + ctx = context + + while ctx.has_turns_remaining: + budget_result = self._check_budget( + ctx, + budget_checker, + turns, + ) + if budget_result is not None: + return budget_result + + turn_number = ctx.turn_count + 1 + response = await self._call_provider( + ctx, + provider, + model_id, + tool_defs, + config, + turn_number, + turns, + ) + if isinstance(response, ExecutionResult): + return response + + turns.append(_make_turn_record(turn_number, response)) + + error = self._check_response_errors( + ctx, + response, + turn_number, + turns, + ) + if error is not None: + return error + + ctx = ctx.with_turn_completed( + response.usage, + _response_to_message(response), + ) + logger.info( + EXECUTION_LOOP_TURN_COMPLETE, + execution_id=ctx.execution_id, + turn=turn_number, + finish_reason=response.finish_reason.value, + tool_call_count=len(response.tool_calls), + ) + + if not response.tool_calls: + return self._handle_completion( + ctx, + response, + turns, + ) + + ctx_or_err = await self._execute_tool_calls( + ctx, + tool_invoker, + response, + turn_number, + turns, + ) + if isinstance(ctx_or_err, ExecutionResult): + return ctx_or_err + ctx = ctx_or_err + + logger.info( + EXECUTION_LOOP_TERMINATED, + execution_id=ctx.execution_id, + reason=TerminationReason.MAX_TURNS.value, + turns=len(turns), + ) + return _build_result(ctx, TerminationReason.MAX_TURNS, turns) + + def _check_budget( + self, + ctx: AgentContext, + budget_checker: BudgetChecker | None, + turns: list[TurnRecord], + ) -> ExecutionResult | None: + """Return a termination result if budget is exhausted or checker raises.""" + if budget_checker is None: + return None + try: + exhausted = budget_checker(ctx) + except MemoryError, RecursionError: + raise + except Exception as exc: + error_msg = f"Budget checker failed: {type(exc).__name__}: {exc}" + logger.exception( + EXECUTION_LOOP_ERROR, + execution_id=ctx.execution_id, + turn=ctx.turn_count, + error=error_msg, + ) + return _build_result( + ctx, + TerminationReason.ERROR, + turns, + error_message=error_msg, + ) + if exhausted: + logger.warning( + EXECUTION_LOOP_BUDGET_EXHAUSTED, + execution_id=ctx.execution_id, + turn=ctx.turn_count, + ) + return _build_result( + ctx, + TerminationReason.BUDGET_EXHAUSTED, + turns, + ) + return None + + async def _call_provider( # noqa: PLR0913 + self, + ctx: AgentContext, + provider: CompletionProvider, + model_id: str, + tool_defs: list[ToolDefinition] | None, + config: CompletionConfig, + turn_number: int, + turns: list[TurnRecord], + ) -> CompletionResponse | ExecutionResult: + """Call provider.complete(), returning an error result on failure.""" + logger.debug( + EXECUTION_LOOP_TURN_START, + execution_id=ctx.execution_id, + turn=turn_number, + ) + try: + return await provider.complete( + messages=list(ctx.conversation), + model=model_id, + tools=tool_defs, + config=config, + ) + except MemoryError, RecursionError: + raise + except Exception as exc: + error_msg = ( + f"Provider error on turn {turn_number}: {type(exc).__name__}: {exc}" + ) + logger.exception( + EXECUTION_LOOP_ERROR, + execution_id=ctx.execution_id, + turn=turn_number, + error=error_msg, + ) + return _build_result( + ctx, + TerminationReason.ERROR, + turns, + error_message=error_msg, + ) + + def _check_response_errors( + self, + ctx: AgentContext, + response: CompletionResponse, + turn_number: int, + turns: list[TurnRecord], + ) -> ExecutionResult | None: + """Return an error result for CONTENT_FILTER or ERROR responses. + + The context's accumulated cost is updated to include the failing + turn's token usage so callers see accurate totals. + """ + if response.finish_reason not in ( + FinishReason.CONTENT_FILTER, + FinishReason.ERROR, + ): + return None + error_msg = f"LLM returned {response.finish_reason.value} on turn {turn_number}" + logger.error( + EXECUTION_LOOP_ERROR, + execution_id=ctx.execution_id, + turn=turn_number, + error=error_msg, + ) + updated_ctx = ctx.model_copy( + update={ + "turn_count": ctx.turn_count + 1, + "accumulated_cost": add_token_usage( + ctx.accumulated_cost, response.usage + ), + }, + ) + return _build_result( + updated_ctx, + TerminationReason.ERROR, + turns, + error_message=error_msg, + ) + + def _handle_completion( + self, + ctx: AgentContext, + response: CompletionResponse, + turns: list[TurnRecord], + ) -> ExecutionResult: + """Handle no-tool-call responses: normal completion or TOOL_USE error.""" + if response.finish_reason == FinishReason.TOOL_USE: + error_msg = ( + "Provider returned TOOL_USE with no tool calls " + f"on turn {ctx.turn_count}" + ) + logger.error( + EXECUTION_LOOP_ERROR, + execution_id=ctx.execution_id, + turn=ctx.turn_count, + error=error_msg, + ) + return _build_result( + ctx, + TerminationReason.ERROR, + turns, + error_message=error_msg, + ) + if response.finish_reason == FinishReason.MAX_TOKENS: + logger.warning( + EXECUTION_LOOP_TERMINATED, + execution_id=ctx.execution_id, + reason=TerminationReason.COMPLETED.value, + turns=len(turns), + truncated=True, + ) + else: + logger.info( + EXECUTION_LOOP_TERMINATED, + execution_id=ctx.execution_id, + reason=TerminationReason.COMPLETED.value, + turns=len(turns), + ) + return _build_result( + ctx, + TerminationReason.COMPLETED, + turns, + ) + + async def _execute_tool_calls( + self, + ctx: AgentContext, + tool_invoker: ToolInvoker | None, + response: CompletionResponse, + turn_number: int, + turns: list[TurnRecord], + ) -> AgentContext | ExecutionResult: + """Execute tool calls and append results to context, or error if no invoker.""" + if tool_invoker is None: + error_msg = ( + f"LLM requested {len(response.tool_calls)} tool " + f"call(s) but no tool invoker is available" + ) + logger.error( + EXECUTION_LOOP_ERROR, + execution_id=ctx.execution_id, + turn=turn_number, + error=error_msg, + ) + return _build_result( + ctx, + TerminationReason.ERROR, + turns, + error_message=error_msg, + ) + + tool_names = [tc.name for tc in response.tool_calls] + logger.info( + EXECUTION_LOOP_TOOL_CALLS, + execution_id=ctx.execution_id, + turn=turn_number, + tools=tool_names, + ) + + try: + results = await tool_invoker.invoke_all( + response.tool_calls, + ) + except MemoryError, RecursionError: + raise + except Exception as exc: + error_msg = ( + f"Tool execution failed on turn {turn_number}: " + f"{type(exc).__name__}: {exc}" + ) + logger.exception( + EXECUTION_LOOP_ERROR, + execution_id=ctx.execution_id, + turn=turn_number, + error=error_msg, + tools=tool_names, + ) + return _build_result( + ctx, + TerminationReason.ERROR, + turns, + error_message=error_msg, + ) + + for result in results: + tool_msg = ChatMessage( + role=MessageRole.TOOL, + tool_result=result, + ) + ctx = ctx.with_message(tool_msg) + + return ctx + + +def _get_tool_definitions( + tool_invoker: ToolInvoker | None, +) -> list[ToolDefinition] | None: + """Extract tool definitions from the invoker, or return None.""" + if tool_invoker is None: + return None + defs = tool_invoker.registry.to_definitions() + return list(defs) if defs else None + + +def _response_to_message(response: CompletionResponse) -> ChatMessage: + """Convert a ``CompletionResponse`` to an assistant ``ChatMessage``.""" + return ChatMessage( + role=MessageRole.ASSISTANT, + content=response.content, + tool_calls=response.tool_calls, + ) + + +def _make_turn_record( + turn_number: int, + response: CompletionResponse, +) -> TurnRecord: + """Create a ``TurnRecord`` from a provider response.""" + return TurnRecord( + turn_number=turn_number, + input_tokens=response.usage.input_tokens, + output_tokens=response.usage.output_tokens, + cost_usd=response.usage.cost_usd, + tool_calls_made=tuple(tc.name for tc in response.tool_calls), + finish_reason=response.finish_reason, + ) + + +def _build_result( + ctx: AgentContext, + reason: TerminationReason, + turns: list[TurnRecord], + *, + error_message: str | None = None, +) -> ExecutionResult: + """Build an ``ExecutionResult`` from loop state.""" + return ExecutionResult( + context=ctx, + termination_reason=reason, + turns=tuple(turns), + error_message=error_message, + ) diff --git a/src/ai_company/observability/events/execution.py b/src/ai_company/observability/events/execution.py index cb9eed96f1..74e27e763a 100644 --- a/src/ai_company/observability/events/execution.py +++ b/src/ai_company/observability/events/execution.py @@ -13,3 +13,11 @@ EXECUTION_TASK_TRANSITION_FAILED: Final[str] = "execution.task.transition_failed" EXECUTION_CONTEXT_TRANSITION_FAILED: Final[str] = "execution.context.transition_failed" EXECUTION_COST_ON_TERMINAL: Final[str] = "execution.cost.on_terminal" + +EXECUTION_LOOP_START: Final[str] = "execution.loop.start" +EXECUTION_LOOP_TURN_START: Final[str] = "execution.loop.turn_start" +EXECUTION_LOOP_TURN_COMPLETE: Final[str] = "execution.loop.turn_complete" +EXECUTION_LOOP_TOOL_CALLS: Final[str] = "execution.loop.tool_calls" +EXECUTION_LOOP_TERMINATED: Final[str] = "execution.loop.terminated" +EXECUTION_LOOP_BUDGET_EXHAUSTED: Final[str] = "execution.loop.budget_exhausted" +EXECUTION_LOOP_ERROR: Final[str] = "execution.loop.error" diff --git a/src/ai_company/tools/invoker.py b/src/ai_company/tools/invoker.py index c9a32fd9a5..68c8d4a4f7 100644 --- a/src/ai_company/tools/invoker.py +++ b/src/ai_company/tools/invoker.py @@ -84,6 +84,11 @@ def __init__(self, registry: ToolRegistry) -> None: """ self._registry = registry + @property + def registry(self) -> ToolRegistry: + """Read-only access to the underlying tool registry.""" + return self._registry + async def invoke(self, tool_call: ToolCall) -> ToolResult: """Execute a single tool call. diff --git a/tests/unit/engine/conftest.py b/tests/unit/engine/conftest.py index 4dc53742e3..33bdf4b3d2 100644 --- a/tests/unit/engine/conftest.py +++ b/tests/unit/engine/conftest.py @@ -1,6 +1,7 @@ """Unit test configuration and fixtures for engine modules.""" from datetime import date +from typing import TYPE_CHECKING from uuid import uuid4 import pytest @@ -26,7 +27,18 @@ from ai_company.core.task import AcceptanceCriterion, Task from ai_company.engine.context import AgentContext from ai_company.engine.task_execution import TaskExecution -from ai_company.providers.models import TokenUsage, ToolDefinition +from ai_company.providers.capabilities import ModelCapabilities +from ai_company.providers.models import ( + ChatMessage, + CompletionConfig, + CompletionResponse, + StreamChunk, + TokenUsage, + ToolDefinition, +) + +if TYPE_CHECKING: + from collections.abc import AsyncIterator @pytest.fixture @@ -159,3 +171,73 @@ def sample_company() -> Company: ), config=CompanyConfig(budget_monthly=500.0), ) + + +class MockCompletionProvider: + """Test double for ``CompletionProvider``. + + Pops the next response from a pre-configured list on each + ``complete()`` call. Raises ``IndexError`` if called more times + than there are responses. + """ + + def __init__(self, responses: list[CompletionResponse]) -> None: + self._responses = list(responses) + self._call_count = 0 + self._recorded_configs: list[CompletionConfig | None] = [] + + @property + def call_count(self) -> int: + """Number of ``complete()`` calls made.""" + return self._call_count + + @property + def recorded_configs(self) -> list[CompletionConfig | None]: + """Configs passed to each ``complete()`` call.""" + return list(self._recorded_configs) + + async def complete( + self, + messages: list[ChatMessage], + model: str, + *, + tools: list[ToolDefinition] | None = None, + config: CompletionConfig | None = None, + ) -> CompletionResponse: + """Return the next pre-configured response.""" + if not self._responses: + msg = "MockCompletionProvider: no more responses" + raise IndexError(msg) + self._call_count += 1 + self._recorded_configs.append(config) + return self._responses.pop(0) + + async def stream( + self, + messages: list[ChatMessage], + model: str, + *, + tools: list[ToolDefinition] | None = None, + config: CompletionConfig | None = None, + ) -> AsyncIterator[StreamChunk]: + msg = "MockCompletionProvider.stream() is not implemented" + raise NotImplementedError(msg) + + async def get_model_capabilities(self, model: str) -> ModelCapabilities: + """Return minimal capabilities.""" + return ModelCapabilities( + model_id=model, + provider="test-provider", + supports_tools=True, + supports_streaming=False, + max_context_tokens=8192, + max_output_tokens=4096, + cost_per_1k_input=0.01, + cost_per_1k_output=0.03, + ) + + +@pytest.fixture +def mock_provider_factory() -> type[MockCompletionProvider]: + """Expose MockCompletionProvider class for test construction.""" + return MockCompletionProvider diff --git a/tests/unit/engine/test_errors.py b/tests/unit/engine/test_errors.py index 13381c7c72..13885c010d 100644 --- a/tests/unit/engine/test_errors.py +++ b/tests/unit/engine/test_errors.py @@ -3,8 +3,10 @@ import pytest from ai_company.engine.errors import ( + BudgetExhaustedError, EngineError, ExecutionStateError, + LoopExecutionError, MaxTurnsExceededError, PromptBuildError, ) @@ -29,3 +31,15 @@ def test_prompt_build_error_is_engine_error(self) -> None: assert issubclass(PromptBuildError, EngineError) err = PromptBuildError("test") assert isinstance(err, EngineError) + + def test_budget_exhausted_error_is_engine_error(self) -> None: + assert issubclass(BudgetExhaustedError, EngineError) + err = BudgetExhaustedError("out of budget") + assert isinstance(err, EngineError) + assert str(err) == "out of budget" + + def test_loop_execution_error_is_engine_error(self) -> None: + assert issubclass(LoopExecutionError, EngineError) + err = LoopExecutionError("loop failed") + assert isinstance(err, EngineError) + assert str(err) == "loop failed" diff --git a/tests/unit/engine/test_loop_protocol.py b/tests/unit/engine/test_loop_protocol.py new file mode 100644 index 0000000000..3b41fdf0b2 --- /dev/null +++ b/tests/unit/engine/test_loop_protocol.py @@ -0,0 +1,246 @@ +"""Tests for execution loop protocol and supporting models.""" + +import pytest +from pydantic import ValidationError + +from ai_company.engine.context import AgentContext # noqa: TC001 +from ai_company.engine.loop_protocol import ( + ExecutionLoop, + ExecutionResult, + TerminationReason, + TurnRecord, +) +from ai_company.engine.react_loop import ReactLoop +from ai_company.providers.enums import FinishReason + + +@pytest.mark.unit +class TestTerminationReason: + """TerminationReason enum values.""" + + def test_values(self) -> None: + assert TerminationReason.COMPLETED.value == "completed" + assert TerminationReason.MAX_TURNS.value == "max_turns" + assert TerminationReason.BUDGET_EXHAUSTED.value == "budget_exhausted" + assert TerminationReason.ERROR.value == "error" + + def test_member_count(self) -> None: + assert len(TerminationReason) == 4 + + +@pytest.mark.unit +class TestTurnRecord: + """TurnRecord frozen model.""" + + def test_creation(self) -> None: + record = TurnRecord( + turn_number=1, + input_tokens=100, + output_tokens=50, + cost_usd=0.01, + tool_calls_made=("search",), + finish_reason=FinishReason.TOOL_USE, + ) + assert record.turn_number == 1 + assert record.input_tokens == 100 + assert record.output_tokens == 50 + assert record.cost_usd == 0.01 + assert record.tool_calls_made == ("search",) + assert record.finish_reason == FinishReason.TOOL_USE + + def test_frozen(self) -> None: + record = TurnRecord( + turn_number=1, + input_tokens=100, + output_tokens=50, + cost_usd=0.01, + finish_reason=FinishReason.STOP, + ) + with pytest.raises(ValidationError): + record.turn_number = 2 # type: ignore[misc] + + def test_defaults(self) -> None: + record = TurnRecord( + turn_number=1, + input_tokens=0, + output_tokens=0, + cost_usd=0.0, + finish_reason=FinishReason.STOP, + ) + assert record.tool_calls_made == () + + def test_total_tokens_computed(self) -> None: + record = TurnRecord( + turn_number=1, + input_tokens=100, + output_tokens=50, + cost_usd=0.01, + finish_reason=FinishReason.STOP, + ) + assert record.total_tokens == 150 + + def test_total_tokens_zero(self) -> None: + record = TurnRecord( + turn_number=1, + input_tokens=0, + output_tokens=0, + cost_usd=0.0, + finish_reason=FinishReason.STOP, + ) + assert record.total_tokens == 0 + + def test_turn_number_zero_rejected(self) -> None: + with pytest.raises(ValidationError): + TurnRecord( + turn_number=0, + input_tokens=10, + output_tokens=5, + cost_usd=0.01, + finish_reason=FinishReason.STOP, + ) + + def test_negative_input_tokens_rejected(self) -> None: + with pytest.raises(ValidationError): + TurnRecord( + turn_number=1, + input_tokens=-1, + output_tokens=5, + cost_usd=0.01, + finish_reason=FinishReason.STOP, + ) + + def test_negative_cost_rejected(self) -> None: + with pytest.raises(ValidationError): + TurnRecord( + turn_number=1, + input_tokens=10, + output_tokens=5, + cost_usd=-0.01, + finish_reason=FinishReason.STOP, + ) + + +@pytest.mark.unit +class TestExecutionResult: + """ExecutionResult frozen model.""" + + def test_creation( + self, + sample_agent_context: AgentContext, + ) -> None: + result = ExecutionResult( + context=sample_agent_context, + termination_reason=TerminationReason.COMPLETED, + turns=(), + ) + assert result.termination_reason == TerminationReason.COMPLETED + assert result.total_tool_calls == 0 + assert result.error_message is None + assert result.metadata == {} + + def test_with_error( + self, + sample_agent_context: AgentContext, + ) -> None: + result = ExecutionResult( + context=sample_agent_context, + termination_reason=TerminationReason.ERROR, + turns=(), + error_message="something went wrong", + ) + assert result.error_message == "something went wrong" + + def test_with_metadata( + self, + sample_agent_context: AgentContext, + ) -> None: + result = ExecutionResult( + context=sample_agent_context, + termination_reason=TerminationReason.COMPLETED, + turns=(), + metadata={"plan": "step1"}, + ) + assert result.metadata == {"plan": "step1"} + + def test_frozen( + self, + sample_agent_context: AgentContext, + ) -> None: + result = ExecutionResult( + context=sample_agent_context, + termination_reason=TerminationReason.COMPLETED, + turns=(), + ) + with pytest.raises(ValidationError): + result.termination_reason = TerminationReason.ERROR # type: ignore[misc] + + def test_total_tool_calls_computed( + self, + sample_agent_context: AgentContext, + ) -> None: + turns = ( + TurnRecord( + turn_number=1, + input_tokens=10, + output_tokens=5, + cost_usd=0.001, + tool_calls_made=("search", "read"), + finish_reason=FinishReason.TOOL_USE, + ), + TurnRecord( + turn_number=2, + input_tokens=10, + output_tokens=5, + cost_usd=0.001, + tool_calls_made=("write",), + finish_reason=FinishReason.STOP, + ), + ) + result = ExecutionResult( + context=sample_agent_context, + termination_reason=TerminationReason.COMPLETED, + turns=turns, + ) + assert result.total_tool_calls == 3 + + def test_error_message_required_when_error( + self, + sample_agent_context: AgentContext, + ) -> None: + with pytest.raises( + ValidationError, + match="error_message is required", + ): + ExecutionResult( + context=sample_agent_context, + termination_reason=TerminationReason.ERROR, + turns=(), + ) + + def test_error_message_forbidden_when_not_error( + self, + sample_agent_context: AgentContext, + ) -> None: + with pytest.raises( + ValidationError, + match="error_message must be None", + ): + ExecutionResult( + context=sample_agent_context, + termination_reason=TerminationReason.COMPLETED, + turns=(), + error_message="unexpected", + ) + + +@pytest.mark.unit +class TestProtocolConformance: + """ReactLoop satisfies ExecutionLoop protocol.""" + + def test_react_loop_is_execution_loop(self) -> None: + loop = ReactLoop() + assert isinstance(loop, ExecutionLoop) + + def test_react_loop_type(self) -> None: + loop = ReactLoop() + assert loop.get_loop_type() == "react" diff --git a/tests/unit/engine/test_react_loop.py b/tests/unit/engine/test_react_loop.py new file mode 100644 index 0000000000..5c14f2b165 --- /dev/null +++ b/tests/unit/engine/test_react_loop.py @@ -0,0 +1,876 @@ +"""Tests for the ReAct execution loop.""" + +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from ai_company.core.agent import AgentIdentity # noqa: TC001 +from ai_company.engine.context import AgentContext +from ai_company.engine.loop_protocol import TerminationReason +from ai_company.engine.react_loop import ReactLoop +from ai_company.providers.enums import FinishReason, MessageRole +from ai_company.providers.models import ( + ChatMessage, + CompletionConfig, + CompletionResponse, + TokenUsage, + ToolCall, +) +from ai_company.tools.base import BaseTool, ToolExecutionResult +from ai_company.tools.invoker import ToolInvoker +from ai_company.tools.registry import ToolRegistry + +if TYPE_CHECKING: + from .conftest import MockCompletionProvider + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _usage(input_tokens: int = 10, output_tokens: int = 5) -> TokenUsage: + return TokenUsage( + input_tokens=input_tokens, + output_tokens=output_tokens, + cost_usd=0.001, + ) + + +def _stop_response(content: str = "Done.") -> CompletionResponse: + return CompletionResponse( + content=content, + finish_reason=FinishReason.STOP, + usage=_usage(), + model="test-model-001", + ) + + +def _tool_use_response( + tool_name: str = "echo", + tool_call_id: str = "tc-1", + arguments: dict[str, Any] | None = None, +) -> CompletionResponse: + return CompletionResponse( + content=None, + tool_calls=( + ToolCall( + id=tool_call_id, + name=tool_name, + arguments=arguments or {}, + ), + ), + finish_reason=FinishReason.TOOL_USE, + usage=_usage(), + model="test-model-001", + ) + + +def _content_filter_response() -> CompletionResponse: + return CompletionResponse( + content=None, + finish_reason=FinishReason.CONTENT_FILTER, + usage=_usage(), + model="test-model-001", + ) + + +def _error_response() -> CompletionResponse: + return CompletionResponse( + content=None, + finish_reason=FinishReason.ERROR, + usage=_usage(), + model="test-model-001", + ) + + +class _StubTool(BaseTool): + """Minimal tool for testing.""" + + def __init__(self, name: str = "echo") -> None: + super().__init__( + name=name, + description="Test echo tool", + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult( + content=f"echoed: {arguments}", + is_error=False, + ) + + +def _make_invoker(*tool_names: str) -> ToolInvoker: + tools = [_StubTool(name=n) for n in tool_names] + return ToolInvoker(ToolRegistry(tools)) + + +def _ctx_with_user_msg(ctx: AgentContext) -> AgentContext: + """Add a user message so the conversation is non-empty.""" + msg = ChatMessage(role=MessageRole.USER, content="Do something") + return ctx.with_message(msg) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.unit +class TestReactLoopBasicCompletion: + """LLM returns STOP on turn 1, no tools.""" + + async def test_single_turn_completion( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_stop_response("All done.")]) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + assert len(result.turns) == 1 + assert result.total_tool_calls == 0 + assert result.error_message is None + assert result.turns[0].turn_number == 1 + assert result.turns[0].finish_reason == FinishReason.STOP + assert result.turns[0].tool_calls_made == () + + async def test_context_has_assistant_message( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_stop_response("Hello!")]) + loop = ReactLoop() + + result = await loop.execute(context=ctx, provider=provider) + + # Conversation should have: user msg + assistant msg + assert len(result.context.conversation) == 2 + last_msg = result.context.conversation[-1] + assert last_msg.role == MessageRole.ASSISTANT + assert last_msg.content == "Hello!" + + +@pytest.mark.unit +class TestReactLoopToolCalls: + """LLM requests tools, then completes.""" + + async def test_single_tool_call_then_complete( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + _stop_response("Done after tool."), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + assert len(result.turns) == 2 + assert result.total_tool_calls == 1 + assert result.turns[0].tool_calls_made == ("echo",) + assert result.turns[0].finish_reason == FinishReason.TOOL_USE + assert result.turns[1].finish_reason == FinishReason.STOP + + async def test_multi_turn_tool_calls( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + _tool_use_response("echo", "tc-2"), + _tool_use_response("echo", "tc-3"), + _stop_response("Finally done."), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + assert len(result.turns) == 4 + assert result.total_tool_calls == 3 + + async def test_tool_results_in_conversation( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + _stop_response("Done."), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + # Conversation: user, assistant(tool_use), tool(result), assistant(stop) + msgs = result.context.conversation + assert len(msgs) == 4 + assert msgs[0].role == MessageRole.USER + assert msgs[1].role == MessageRole.ASSISTANT + assert msgs[2].role == MessageRole.TOOL + assert msgs[2].tool_result is not None + assert msgs[2].tool_result.tool_call_id == "tc-1" + assert msgs[3].role == MessageRole.ASSISTANT + + +@pytest.mark.unit +class TestReactLoopMaxTurns: + """Loop exhausts turn limit.""" + + async def test_max_turns_termination( + self, + sample_agent_with_personality: AgentIdentity, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = AgentContext.from_identity( + sample_agent_with_personality, + max_turns=2, + ) + ctx = _ctx_with_user_msg(ctx) + # Both turns request tools, never stops + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + _tool_use_response("echo", "tc-2"), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + assert result.termination_reason == TerminationReason.MAX_TURNS + assert len(result.turns) == 2 + assert result.context.turn_count == 2 + + +@pytest.mark.unit +class TestReactLoopBudgetExhausted: + """Budget checker triggers termination.""" + + async def test_budget_exhausted_before_first_turn( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([]) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + budget_checker=lambda _: True, # always exhausted + ) + + assert result.termination_reason == TerminationReason.BUDGET_EXHAUSTED + assert len(result.turns) == 0 + assert result.total_tool_calls == 0 + assert provider.call_count == 0 + + async def test_budget_exhausted_after_first_turn( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + call_count = 0 + + def budget_check(_ctx: AgentContext) -> bool: + nonlocal call_count + call_count += 1 + # Exhausted on second check (after first turn) + return call_count > 1 + + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + budget_checker=budget_check, + ) + + assert result.termination_reason == TerminationReason.BUDGET_EXHAUSTED + assert len(result.turns) == 1 + + +@pytest.mark.unit +class TestReactLoopNoToolInvoker: + """LLM requests tools but no invoker available.""" + + async def test_error_when_tools_requested_without_invoker( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + ] + ) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=None, + ) + + assert result.termination_reason == TerminationReason.ERROR + assert result.error_message is not None + assert "no tool invoker" in result.error_message + + +@pytest.mark.unit +class TestReactLoopErrorResponses: + """LLM returns error or content_filter finish reason.""" + + async def test_content_filter_terminates_with_error( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_content_filter_response()]) + loop = ReactLoop() + + result = await loop.execute(context=ctx, provider=provider) + + assert result.termination_reason == TerminationReason.ERROR + assert result.error_message is not None + assert "content_filter" in result.error_message + + async def test_error_finish_reason_terminates( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_error_response()]) + loop = ReactLoop() + + result = await loop.execute(context=ctx, provider=provider) + + assert result.termination_reason == TerminationReason.ERROR + assert result.error_message is not None + assert "error" in result.error_message + + +@pytest.mark.unit +class TestReactLoopTurnRecords: + """Verify per-turn metadata accuracy.""" + + async def test_turn_record_accuracy( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + _stop_response("Done."), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + assert len(result.turns) == 2 + + t1 = result.turns[0] + assert t1.turn_number == 1 + assert t1.input_tokens == 10 + assert t1.output_tokens == 5 + assert t1.cost_usd == 0.001 + assert t1.tool_calls_made == ("echo",) + assert t1.finish_reason == FinishReason.TOOL_USE + + t2 = result.turns[1] + assert t2.turn_number == 2 + assert t2.tool_calls_made == () + assert t2.finish_reason == FinishReason.STOP + + async def test_total_tool_calls_accumulated( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + _tool_use_response("echo", "tc-2"), + _stop_response("Done."), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + assert result.total_tool_calls == 2 + + +@pytest.mark.unit +class TestReactLoopContextImmutability: + """Original context unchanged after execution.""" + + async def test_original_context_unchanged( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + original_turn_count = ctx.turn_count + original_conv_len = len(ctx.conversation) + original_cost = ctx.accumulated_cost + + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + _stop_response("Done."), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + # Original unchanged + assert ctx.turn_count == original_turn_count + assert len(ctx.conversation) == original_conv_len + assert ctx.accumulated_cost == original_cost + + # Result has evolved state + assert result.context.turn_count > original_turn_count + assert len(result.context.conversation) > original_conv_len + + +@pytest.mark.unit +class TestReactLoopConversationState: + """Final context has all messages.""" + + async def test_full_conversation_preserved( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory( + [ + _tool_use_response("echo", "tc-1"), + _stop_response("Final answer."), + ] + ) + invoker = _make_invoker("echo") + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + roles = [m.role for m in result.context.conversation] + assert roles == [ + MessageRole.USER, + MessageRole.ASSISTANT, # tool_use turn + MessageRole.TOOL, # tool result + MessageRole.ASSISTANT, # final response + ] + + +@pytest.mark.unit +class TestReactLoopCompletionConfig: + """Per-execution completion config override.""" + + async def test_custom_completion_config( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_stop_response("Ok.")]) + loop = ReactLoop() + custom_config = CompletionConfig(temperature=0.1, max_tokens=100) + + result = await loop.execute( + context=ctx, + provider=provider, + completion_config=custom_config, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + assert len(provider.recorded_configs) == 1 + assert provider.recorded_configs[0] is custom_config + + +@pytest.mark.unit +class TestReactLoopProviderException: + """Provider raising exception during complete().""" + + async def test_provider_exception_returns_error_result( + self, + sample_agent_context: AgentContext, + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + + class _FailingProvider: + async def complete(self, *_args: Any, **_kwargs: Any) -> None: + msg = "connection refused" + raise ConnectionError(msg) + + loop = ReactLoop() + result = await loop.execute( + context=ctx, + provider=_FailingProvider(), # type: ignore[arg-type] + ) + + assert result.termination_reason == TerminationReason.ERROR + assert result.error_message is not None + assert "ConnectionError" in result.error_message + + async def test_provider_memory_error_propagates( + self, + sample_agent_context: AgentContext, + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + + class _OOMProvider: + async def complete(self, *_args: Any, **_kwargs: Any) -> None: + raise MemoryError + + loop = ReactLoop() + with pytest.raises(MemoryError): + await loop.execute( + context=ctx, + provider=_OOMProvider(), # type: ignore[arg-type] + ) + + +@pytest.mark.unit +class TestReactLoopToolExecutionException: + """Tool execution errors are captured by ToolInvoker and do not crash the loop.""" + + async def test_tool_exception_returns_error_result( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_tool_use_response("explode", "tc-1")]) + + class _ExplodingTool(BaseTool): + def __init__(self) -> None: + super().__init__( + name="explode", + description="boom", + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + msg = "kaboom" + raise RuntimeError(msg) + + registry = ToolRegistry([_ExplodingTool()]) + invoker = ToolInvoker(registry) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + # The tool error is caught by ToolInvoker.invoke and returned + # as ToolResult(is_error=True), so the loop continues normally. + # It terminates with ERROR because the mock has no more + # responses, causing an IndexError in the next provider call. + assert result.termination_reason == TerminationReason.ERROR + + +@pytest.mark.unit +class TestReactLoopMaxTokensFinishReason: + """MAX_TOKENS finish reason with no tool calls.""" + + async def test_max_tokens_returns_completed( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + response = CompletionResponse( + content="partial output", + finish_reason=FinishReason.MAX_TOKENS, + usage=_usage(), + model="test-model-001", + ) + provider = mock_provider_factory([response]) + loop = ReactLoop() + + result = await loop.execute(context=ctx, provider=provider) + + assert result.termination_reason == TerminationReason.COMPLETED + assert len(result.turns) == 1 + assert result.turns[0].finish_reason == FinishReason.MAX_TOKENS + + +@pytest.mark.unit +class TestReactLoopToolUseEmptyToolCalls: + """TOOL_USE finish reason with no actual tool calls.""" + + async def test_tool_use_empty_calls_returns_error( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + response = CompletionResponse( + content="I want to use tools", + tool_calls=(), + finish_reason=FinishReason.TOOL_USE, + usage=_usage(), + model="test-model-001", + ) + provider = mock_provider_factory([response]) + loop = ReactLoop() + + result = await loop.execute(context=ctx, provider=provider) + + assert result.termination_reason == TerminationReason.ERROR + assert result.error_message is not None + assert "TOOL_USE" in result.error_message + + +@pytest.mark.unit +class TestReactLoopBudgetCheckerException: + """Budget checker callback raising an exception.""" + + async def test_budget_checker_exception_returns_error( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([]) + loop = ReactLoop() + + def bad_checker(_ctx: AgentContext) -> bool: + msg = "db connection lost" + raise ConnectionError(msg) + + result = await loop.execute( + context=ctx, + provider=provider, + budget_checker=bad_checker, + ) + + assert result.termination_reason == TerminationReason.ERROR + assert result.error_message is not None + assert "Budget checker failed" in result.error_message + + +@pytest.mark.unit +class TestReactLoopRecursionErrorPropagation: + """RecursionError propagates from provider and tool execution.""" + + async def test_provider_recursion_error_propagates( + self, + sample_agent_context: AgentContext, + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + + class _RecursionProvider: + async def complete(self, *_args: Any, **_kwargs: Any) -> None: + raise RecursionError + + loop = ReactLoop() + with pytest.raises(RecursionError): + await loop.execute( + context=ctx, + provider=_RecursionProvider(), # type: ignore[arg-type] + ) + + async def test_tool_invoke_all_recursion_error_propagates( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_tool_use_response("echo", "tc-1")]) + mock_invoker = MagicMock() + mock_invoker.registry.to_definitions.return_value = () + mock_invoker.invoke_all = AsyncMock(side_effect=RecursionError) + loop = ReactLoop() + + with pytest.raises(RecursionError): + await loop.execute( + context=ctx, + provider=provider, + tool_invoker=mock_invoker, + ) + + async def test_tool_invoke_all_memory_error_propagates( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_tool_use_response("echo", "tc-1")]) + mock_invoker = MagicMock() + mock_invoker.registry.to_definitions.return_value = () + mock_invoker.invoke_all = AsyncMock(side_effect=MemoryError) + loop = ReactLoop() + + with pytest.raises(MemoryError): + await loop.execute( + context=ctx, + provider=provider, + tool_invoker=mock_invoker, + ) + + +@pytest.mark.unit +class TestReactLoopInvokeAllException: + """invoke_all raising an exception is caught and returned as error.""" + + async def test_invoke_all_exception_returns_error_result( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_tool_use_response("echo", "tc-1")]) + mock_invoker = MagicMock() + mock_invoker.registry.to_definitions.return_value = () + mock_invoker.invoke_all = AsyncMock( + side_effect=RuntimeError("TaskGroup crashed"), + ) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=mock_invoker, + ) + + assert result.termination_reason == TerminationReason.ERROR + assert result.error_message is not None + assert "Tool execution failed" in result.error_message + assert "RuntimeError" in result.error_message + + +@pytest.mark.unit +class TestReactLoopEmptyToolRegistry: + """Empty ToolRegistry causes tool_defs to be None.""" + + async def test_empty_registry_passes_no_tools( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_stop_response("Done.")]) + registry = ToolRegistry([]) + invoker = ToolInvoker(registry) + loop = ReactLoop() + + result = await loop.execute( + context=ctx, + provider=provider, + tool_invoker=invoker, + ) + + assert result.termination_reason == TerminationReason.COMPLETED + + +@pytest.mark.unit +class TestReactLoopCostAccounting: + """Error responses include the failing turn's cost in context.""" + + async def test_content_filter_response_cost_in_context( + self, + sample_agent_context: AgentContext, + mock_provider_factory: type[MockCompletionProvider], + ) -> None: + ctx = _ctx_with_user_msg(sample_agent_context) + provider = mock_provider_factory([_content_filter_response()]) + loop = ReactLoop() + + result = await loop.execute(context=ctx, provider=provider) + + assert result.termination_reason == TerminationReason.ERROR + # The failing turn's cost should be in the context + assert result.context.accumulated_cost.cost_usd > ctx.accumulated_cost.cost_usd + assert result.context.turn_count == 1