diff --git a/DESIGN_SPEC.md b/DESIGN_SPEC.md index 5af56e73fa..1f30bbc544 100644 --- a/DESIGN_SPEC.md +++ b/DESIGN_SPEC.md @@ -1473,7 +1473,7 @@ call_analytics: ### 11.1.1 Tool Execution Model -When the LLM requests multiple tool calls in a single turn, `ToolInvoker.invoke_all` currently executes them **sequentially**. Migration to `asyncio.TaskGroup` for parallel structured concurrency is planned (see §15.5). Recoverable errors are captured as `ToolResult(is_error=True)` without aborting remaining invocations; non-recoverable errors (`MemoryError`, `RecursionError`) propagate immediately and abort the sequence. +When the LLM requests multiple tool calls in a single turn, `ToolInvoker.invoke_all` executes them **concurrently** using `asyncio.TaskGroup`. An optional `max_concurrency` parameter (default unbounded) limits parallelism via `asyncio.Semaphore`. Recoverable errors are captured as `ToolResult(is_error=True)` without aborting sibling invocations; non-recoverable errors (`MemoryError`, `RecursionError`) are collected and re-raised after all tasks complete (bare exception for one, `ExceptionGroup` for multiple). `BaseTool.parameters_schema` deep-copies the caller-supplied schema at construction and wraps it in `MappingProxyType` for read-only enforcement; the property returns a deep copy on access to prevent mutation of internal state. `ToolInvoker` deep-copies arguments at the tool execution boundary before passing them to `tool.execute()`. `MappingProxyType` wrapping is also used in `ToolRegistry` for its internal collections. @@ -2156,7 +2156,7 @@ ai-company/ │ ├── tools/ # Tool/capability system │ │ ├── base.py # BaseTool ABC, ToolExecutionResult │ │ ├── registry.py # Immutable tool registry (MappingProxyType) -│ │ ├── invoker.py # Tool invocation (sequential execution) +│ │ ├── invoker.py # Tool invocation (concurrent via TaskGroup) │ │ ├── errors.py # Tool error hierarchy │ │ ├── examples/ # Example tool implementations │ │ │ └── echo.py # Echo tool (for testing) @@ -2245,7 +2245,7 @@ These conventions were established during the M0–M2+ review cycle. **Adopted** | **String validation** | Adopted | `NotBlankStr` type from `core.types` for all identifiers | Eliminates per-model `@model_validator` boilerplate for whitespace checks. All identifier/name fields use `NotBlankStr`; optional identifiers use `NotBlankStr \| None`; tuple fields use `tuple[NotBlankStr, ...]` for per-element validation. | | **Shared field groups** | Planned | Extract common field sets into base models (e.g. `_SpendingTotals`) | Prevents field duplication across spending summary models. Not yet implemented — each model independently defines fields. | | **Event constants** | Adopted (per-domain) | Per-domain submodules under `events/` package (e.g. `events.provider`, `events.budget`). Import directly: `from ai_company.observability.events. import CONSTANT` | Split by domain for discoverability, co-location with domain logic, and reduced merge conflicts as constants grow. `__init__.py` serves as package marker with usage documentation; no re-exports. | -| **Parallel tool execution** | Planned | `asyncio.TaskGroup` in `ToolInvoker.invoke_all` | Structured concurrency with proper cancellation semantics. Currently sequential; migration planned for M3 when the agent engine needs concurrent tool calls. | +| **Parallel tool execution** | Adopted (M2.5) | `asyncio.TaskGroup` in `ToolInvoker.invoke_all` with optional `max_concurrency` semaphore | Structured concurrency with proper cancellation semantics. Fatal errors collected via guarded wrapper and re-raised after all tasks complete. | | **Tool sandboxing** | Planned (M3) | Layered `SandboxBackend` protocol: `SubprocessSandbox` for low-risk tools (file, git), `DockerSandbox` for high-risk tools (code_runner, terminal, web, database). `K8sSandbox` planned for future container deployments. | Risk-proportionate isolation. Docker optional — only needed for code execution and network-sensitive tools. Pluggable protocol enables seamless migration to K8s per-agent pods in Phase 3-4. See §11.1.2. | | **Crash recovery** | Planned (M3) | Pluggable `RecoveryStrategy` protocol. M3: `FailAndReassignStrategy` (catch at engine boundary, log snapshot, mark FAILED, reassign). M4/M5: `CheckpointStrategy` (persist `AgentContext` per turn, resume from last checkpoint). | Immutable `model_copy` pattern makes checkpoint serialization trivial to add later. Fail-and-reassign is sufficient for short MVP tasks. See §6.6. | | **Agent behavior testing** | Planned (M3) | Scripted `FakeProvider` for unit tests (deterministic turn sequences); behavioral outcome assertions for integration tests (task completed, tools called, cost within budget). | Leverages existing `FakeProvider` and `CompletionResponseFactory` fixtures. Precise engine testing without brittle response-matching at integration level. | diff --git a/src/ai_company/observability/events/tool.py b/src/ai_company/observability/events/tool.py index 22562d08f8..9a56bfc470 100644 --- a/src/ai_company/observability/events/tool.py +++ b/src/ai_company/observability/events/tool.py @@ -17,3 +17,5 @@ TOOL_INVOKE_VALIDATION_UNEXPECTED: Final[str] = "tool.invoke.validation_unexpected" TOOL_BASE_INVALID_NAME: Final[str] = "tool.base.invalid_name" TOOL_REGISTRY_CONTAINS_TYPE_ERROR: Final[str] = "tool.registry.contains_type_error" +TOOL_INVOKE_ALL_START: Final[str] = "tool.invoke_all.start" +TOOL_INVOKE_ALL_COMPLETE: Final[str] = "tool.invoke_all.complete" diff --git a/src/ai_company/tools/invoker.py b/src/ai_company/tools/invoker.py index ea16cb578d..c9a32fd9a5 100644 --- a/src/ai_company/tools/invoker.py +++ b/src/ai_company/tools/invoker.py @@ -2,21 +2,24 @@ Bridges LLM ``ToolCall`` objects with concrete ``BaseTool.execute`` methods. Recoverable errors are returned as ``ToolResult(is_error=True)``; -non-recoverable errors (``MemoryError``, ``RecursionError``) and -``BaseException`` subclasses (``KeyboardInterrupt``, ``SystemExit``, -``asyncio.CancelledError``) propagate after logging. +non-recoverable errors (``MemoryError``, ``RecursionError``) are logged and +re-raised. ``BaseException`` subclasses (``KeyboardInterrupt``, +``SystemExit``, ``asyncio.CancelledError``) propagate uncaught. """ +import asyncio import copy -from typing import TYPE_CHECKING +from contextlib import nullcontext +from typing import TYPE_CHECKING, Never import jsonschema from referencing import Registry as JsonSchemaRegistry -from referencing import Resource from referencing.exceptions import NoSuchResource from ai_company.observability import get_logger from ai_company.observability.events.tool import ( + TOOL_INVOKE_ALL_COMPLETE, + TOOL_INVOKE_ALL_START, TOOL_INVOKE_DEEPCOPY_ERROR, TOOL_INVOKE_EXECUTION_ERROR, TOOL_INVOKE_NON_RECOVERABLE, @@ -41,7 +44,7 @@ logger = get_logger(__name__) -def _no_remote_retrieve(uri: str) -> Resource: +def _no_remote_retrieve(uri: str) -> Never: """Block remote ``$ref`` resolution to prevent SSRF.""" raise NoSuchResource(uri) @@ -64,9 +67,13 @@ class ToolInvoker: invoker = ToolInvoker(registry) result = await invoker.invoke(tool_call) - Invoke multiple tool calls sequentially:: + Invoke multiple tool calls concurrently:: results = await invoker.invoke_all(tool_calls) + + Limit concurrency:: + + results = await invoker.invoke_all(tool_calls, max_concurrency=3) """ def __init__(self, registry: ToolRegistry) -> None: @@ -172,7 +179,7 @@ def _schema_error_result( error_msg: str, ) -> ToolResult: """Build an error result for an invalid tool schema.""" - logger.exception( + logger.error( TOOL_INVOKE_SCHEMA_ERROR, tool_call_id=tool_call.id, tool_name=tool_call.name, @@ -318,19 +325,94 @@ def _build_result( is_error=result.is_error, ) + async def _run_guarded( + self, + index: int, + tool_call: ToolCall, + results: dict[int, ToolResult], + fatal_errors: list[Exception], + semaphore: asyncio.Semaphore | None, + ) -> None: + """Execute a single tool call, storing fatal errors instead of raising. + + This wrapper ensures that ``MemoryError`` / ``RecursionError`` do not + cancel sibling tasks inside a ``TaskGroup``. ``BaseException`` + subclasses (``KeyboardInterrupt``, ``CancelledError``) are not + intercepted and will cancel the group normally. + """ + try: + ctx = semaphore if semaphore is not None else nullcontext() + async with ctx: + results[index] = await self.invoke(tool_call) + except (MemoryError, RecursionError) as exc: + fatal_errors.append(exc) + async def invoke_all( self, tool_calls: Iterable[ToolCall], + *, + max_concurrency: int | None = None, ) -> tuple[ToolResult, ...]: - """Execute multiple tool calls sequentially. + """Execute multiple tool calls concurrently. Calls continue through recoverable failures; non-recoverable - errors propagate immediately. + errors (``MemoryError``, ``RecursionError``) are collected and + re-raised after all tasks complete. Args: - tool_calls: Tool calls to execute in order. + tool_calls: Tool calls to execute. + max_concurrency: Maximum number of concurrent invocations. + ``None`` (default) means unbounded. Must be ``>= 1`` + if provided. Returns: Tuple of results in the same order as the input. + + Raises: + ValueError: If ``max_concurrency`` is less than 1. + MemoryError: Re-raised if it was the sole fatal error. + RecursionError: Re-raised if it was the sole fatal error. + ExceptionGroup: If multiple fatal errors occurred. """ - return tuple([await self.invoke(call) for call in tool_calls]) + if max_concurrency is not None and max_concurrency < 1: + msg = f"max_concurrency must be >= 1, got {max_concurrency}" + raise ValueError(msg) + + calls = list(tool_calls) + if not calls: + return () + + logger.info( + TOOL_INVOKE_ALL_START, + count=len(calls), + max_concurrency=max_concurrency, + ) + + # SAFETY: Both ``results`` and ``fatal_errors`` are mutated by + # concurrent tasks. This is safe because asyncio runs tasks on + # a single thread — dict assignment and list.append() never race. + results: dict[int, ToolResult] = {} + fatal_errors: list[Exception] = [] + semaphore = ( + asyncio.Semaphore(max_concurrency) if max_concurrency is not None else None + ) + + async with asyncio.TaskGroup() as tg: + for idx, call in enumerate(calls): + tg.create_task( + self._run_guarded(idx, call, results, fatal_errors, semaphore), + ) + + logger.info( + TOOL_INVOKE_ALL_COMPLETE, + count=len(calls), + fatal_count=len(fatal_errors), + ) + + if fatal_errors: + if len(fatal_errors) == 1: + raise fatal_errors[0] + msg = "multiple non-recoverable tool errors" + raise ExceptionGroup(msg, fatal_errors) + + return tuple(results[i] for i in range(len(calls))) diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py index a8b32014f1..f8e7df8a8a 100644 --- a/tests/unit/tools/conftest.py +++ b/tests/unit/tools/conftest.py @@ -1,5 +1,6 @@ """Unit test fixtures for the tool system.""" +import asyncio from typing import Any import pytest @@ -297,7 +298,9 @@ def sample_tool_call() -> ToolCall: @pytest.fixture def extended_invoker() -> ToolInvoker: - """Invoker with additional edge-case tools for advanced tests.""" + """Invoker with echo, recursion, invalid-schema, empty-error, + remote-ref, and mutating tools for edge-case tests. + """ tools = [ _EchoTestTool(), _RecursionTool(), @@ -307,3 +310,93 @@ def extended_invoker() -> ToolInvoker: _MutatingTool(), ] return ToolInvoker(ToolRegistry(tools)) + + +# ── Concurrency test tools ─────────────────────────────────────── + + +class _DelayTool(BaseTool): + """Sleeps for ``delay`` seconds, then returns ``value``.""" + + def __init__(self) -> None: + super().__init__( + name="delay", + description="Sleeps then returns value", + parameters_schema={ + "type": "object", + "properties": { + "delay": {"type": "number"}, + "value": {"type": "string"}, + }, + "required": ["delay", "value"], + "additionalProperties": False, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + await asyncio.sleep(arguments["delay"]) + return ToolExecutionResult(content=arguments["value"]) + + +class _ConcurrencyTrackingTool(BaseTool): + """Tracks peak concurrent executions via a lock-guarded counter.""" + + def __init__(self) -> None: + super().__init__( + name="tracking", + description="Tracks concurrency", + parameters_schema={ + "type": "object", + "properties": { + "duration": {"type": "number"}, + }, + "required": ["duration"], + "additionalProperties": False, + }, + ) + self._lock = asyncio.Lock() + self._current = 0 + self._peak = 0 + + @property + def peak(self) -> int: + """Return the peak concurrent execution count.""" + return self._peak + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + async with self._lock: + self._current += 1 + self._peak = max(self._peak, self._current) + await asyncio.sleep(arguments["duration"]) + async with self._lock: + self._current -= 1 + return ToolExecutionResult(content=str(self._peak)) + + +@pytest.fixture +def concurrency_tracking_tool() -> _ConcurrencyTrackingTool: + """Standalone tracking tool for direct peak inspection.""" + return _ConcurrencyTrackingTool() + + +@pytest.fixture +def concurrency_invoker( + concurrency_tracking_tool: _ConcurrencyTrackingTool, +) -> ToolInvoker: + """Invoker with echo, failing, delay, tracking, and recursion tools.""" + tools: list[BaseTool] = [ + _EchoTestTool(), + _FailingTool(), + _DelayTool(), + concurrency_tracking_tool, + _RecursionTool(), + ] + return ToolInvoker(ToolRegistry(tools)) diff --git a/tests/unit/tools/test_invoker.py b/tests/unit/tools/test_invoker.py index b1a0bb1446..336e12be4c 100644 --- a/tests/unit/tools/test_invoker.py +++ b/tests/unit/tools/test_invoker.py @@ -1,5 +1,6 @@ """Tests for ToolInvoker.""" +import time from typing import TYPE_CHECKING import pytest @@ -7,8 +8,12 @@ from ai_company.providers.models import ToolCall, ToolResult if TYPE_CHECKING: + from collections.abc import Iterator + from ai_company.tools.invoker import ToolInvoker + from .conftest import _ConcurrencyTrackingTool + pytestmark = pytest.mark.timeout(30) @@ -112,7 +117,7 @@ async def test_extra_params_returns_error( result = await sample_invoker.invoke(call) assert result.is_error is True - async def test_empty_schema_skips_validation( + async def test_no_schema_skips_validation( self, sample_invoker: ToolInvoker, ) -> None: @@ -369,9 +374,9 @@ async def test_deepcopy_failure_returns_error_result( def _fail_on_execute(obj: object, memo: object = None) -> object: nonlocal call_count call_count += 1 - # First deepcopy call is in _validate_params via - # parameters_schema; let it pass. Fail on the second - # call in _execute_tool. + # First deepcopy call is BaseTool.parameters_schema + # (called from _validate_params); let it pass. Fail on + # the second call (argument copying in _execute_tool). if call_count > 1: msg = "cannot copy" raise TypeError(msg) @@ -431,3 +436,232 @@ async def test_empty_error_message_fallback( result = await extended_invoker.invoke(call) assert result.is_error is True assert "ValueError (no message)" in result.content + + +@pytest.mark.unit +class TestInvokeAllConcurrency: + """Tests for concurrent execution in invoke_all.""" + + async def test_concurrent_faster_than_sequential( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """Three 0.1s delay tools complete in less than 0.3s total.""" + calls = [ + ToolCall( + id=f"d{i}", + name="delay", + arguments={"delay": 0.1, "value": f"v{i}"}, + ) + for i in range(3) + ] + start = time.monotonic() + results = await concurrency_invoker.invoke_all(calls) + elapsed = time.monotonic() - start + assert len(results) == 3 + assert all(r.content == f"v{i}" for i, r in enumerate(results)) + # Sequential would take >= 0.3s; proves parallel execution + assert elapsed < 0.5 + + async def test_concurrent_results_in_input_order( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """Results match input order regardless of completion order.""" + calls = [ + ToolCall( + id="slow", + name="delay", + arguments={"delay": 0.1, "value": "first"}, + ), + ToolCall( + id="fast", + name="delay", + arguments={"delay": 0.01, "value": "second"}, + ), + ] + results = await concurrency_invoker.invoke_all(calls) + assert results[0].tool_call_id == "slow" + assert results[0].content == "first" + assert results[1].tool_call_id == "fast" + assert results[1].content == "second" + + async def test_recoverable_error_does_not_cancel_siblings( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """A failing tool doesn't prevent siblings from completing.""" + calls = [ + ToolCall(id="c1", name="echo_test", arguments={"message": "a"}), + ToolCall(id="c2", name="failing", arguments={"input": "x"}), + ToolCall(id="c3", name="echo_test", arguments={"message": "b"}), + ] + results = await concurrency_invoker.invoke_all(calls) + assert len(results) == 3 + assert results[0].is_error is False + assert results[1].is_error is True + assert results[2].is_error is False + + async def test_single_non_recoverable_raises_bare( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """Single fatal error re-raises as bare exception.""" + calls = [ + ToolCall( + id="r1", + name="recursion", + arguments={"input": "boom"}, + ), + ] + with pytest.raises(RecursionError, match="maximum recursion depth"): + await concurrency_invoker.invoke_all(calls) + + async def test_mixed_fatal_and_success_raises_fatal( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """Fatal error is raised even when siblings succeed.""" + calls = [ + ToolCall(id="ok1", name="echo_test", arguments={"message": "a"}), + ToolCall( + id="fatal", + name="recursion", + arguments={"input": "boom"}, + ), + ToolCall(id="ok2", name="echo_test", arguments={"message": "b"}), + ] + with pytest.raises(RecursionError, match="maximum recursion depth"): + await concurrency_invoker.invoke_all(calls) + + async def test_multiple_non_recoverable_raises_exception_group( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """Multiple fatal errors raise ExceptionGroup.""" + calls = [ + ToolCall( + id="r1", + name="recursion", + arguments={"input": "boom1"}, + ), + ToolCall( + id="r2", + name="recursion", + arguments={"input": "boom2"}, + ), + ] + with pytest.raises(ExceptionGroup) as exc_info: + await concurrency_invoker.invoke_all(calls) + assert len(exc_info.value.exceptions) == 2 + assert all(isinstance(e, RecursionError) for e in exc_info.value.exceptions) + + +@pytest.mark.unit +class TestInvokeAllBounded: + """Tests for max_concurrency parameter.""" + + async def test_max_concurrency_one_sequential( + self, + concurrency_invoker: ToolInvoker, + concurrency_tracking_tool: _ConcurrencyTrackingTool, + ) -> None: + """max_concurrency=1 enforces sequential execution (peak=1).""" + calls = [ + ToolCall( + id=f"t{i}", + name="tracking", + arguments={"duration": 0.02}, + ) + for i in range(3) + ] + results = await concurrency_invoker.invoke_all(calls, max_concurrency=1) + assert len(results) == 3 + assert concurrency_tracking_tool.peak == 1 + + async def test_max_concurrency_bounds_parallelism( + self, + concurrency_invoker: ToolInvoker, + concurrency_tracking_tool: _ConcurrencyTrackingTool, + ) -> None: + """With max_concurrency=2, peak never exceeds 2.""" + calls = [ + ToolCall( + id=f"t{i}", + name="tracking", + arguments={"duration": 0.05}, + ) + for i in range(5) + ] + await concurrency_invoker.invoke_all(calls, max_concurrency=2) + assert concurrency_tracking_tool.peak <= 2 + + async def test_max_concurrency_none_unbounded( + self, + concurrency_invoker: ToolInvoker, + concurrency_tracking_tool: _ConcurrencyTrackingTool, + ) -> None: + """Without max_concurrency, parallelism exceeds 1.""" + calls = [ + ToolCall( + id=f"t{i}", + name="tracking", + arguments={"duration": 0.05}, + ) + for i in range(5) + ] + await concurrency_invoker.invoke_all(calls) + assert concurrency_tracking_tool.peak >= 3 + + async def test_max_concurrency_validation( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """max_concurrency=0 and negative values raise ValueError.""" + calls = [ + ToolCall(id="c1", name="echo_test", arguments={"message": "a"}), + ] + with pytest.raises(ValueError, match="max_concurrency"): + await concurrency_invoker.invoke_all(calls, max_concurrency=0) + with pytest.raises(ValueError, match="max_concurrency"): + await concurrency_invoker.invoke_all(calls, max_concurrency=-1) + + +@pytest.mark.unit +class TestInvokeAllEdgeCases: + """Edge case tests for invoke_all.""" + + async def test_single_call( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """Single-element input works correctly.""" + calls = [ + ToolCall(id="c1", name="echo_test", arguments={"message": "solo"}), + ] + results = await concurrency_invoker.invoke_all(calls) + assert len(results) == 1 + assert results[0].content == "solo" + + async def test_generator_input( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """Non-list iterable (generator) works correctly.""" + + def _gen() -> Iterator[ToolCall]: + yield ToolCall(id="g1", name="echo_test", arguments={"message": "gen1"}) + yield ToolCall(id="g2", name="echo_test", arguments={"message": "gen2"}) + + results = await concurrency_invoker.invoke_all(_gen()) + assert len(results) == 2 + assert results[0].content == "gen1" + assert results[1].content == "gen2" + + async def test_empty_with_max_concurrency( + self, + concurrency_invoker: ToolInvoker, + ) -> None: + """Empty input with max_concurrency returns empty tuple.""" + results = await concurrency_invoker.invoke_all([], max_concurrency=3) + assert results == ()