Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESIGN_SPEC.md
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ call_analytics:

### 11.1.1 Tool Execution Model

When the LLM requests multiple tool calls in a single turn, `ToolInvoker.invoke_all` currently executes them **sequentially**. Migration to `asyncio.TaskGroup` for parallel structured concurrency is planned (see §15.5). Recoverable errors are captured as `ToolResult(is_error=True)` without aborting remaining invocations; non-recoverable errors (`MemoryError`, `RecursionError`) propagate immediately and abort the sequence.
When the LLM requests multiple tool calls in a single turn, `ToolInvoker.invoke_all` executes them **concurrently** using `asyncio.TaskGroup`. An optional `max_concurrency` parameter (default unbounded) limits parallelism via `asyncio.Semaphore`. Recoverable errors are captured as `ToolResult(is_error=True)` without aborting sibling invocations; non-recoverable errors (`MemoryError`, `RecursionError`) are collected and re-raised after all tasks complete (bare exception for one, `ExceptionGroup` for multiple).

`BaseTool.parameters_schema` deep-copies the caller-supplied schema at construction and wraps it in `MappingProxyType` for read-only enforcement; the property returns a deep copy on access to prevent mutation of internal state. `ToolInvoker` deep-copies arguments at the tool execution boundary before passing them to `tool.execute()`. `MappingProxyType` wrapping is also used in `ToolRegistry` for its internal collections.

Expand Down Expand Up @@ -2156,7 +2156,7 @@ ai-company/
│ ├── tools/ # Tool/capability system
│ │ ├── base.py # BaseTool ABC, ToolExecutionResult
│ │ ├── registry.py # Immutable tool registry (MappingProxyType)
│ │ ├── invoker.py # Tool invocation (sequential execution)
│ │ ├── invoker.py # Tool invocation (concurrent via TaskGroup)
│ │ ├── errors.py # Tool error hierarchy
│ │ ├── examples/ # Example tool implementations
│ │ │ └── echo.py # Echo tool (for testing)
Expand Down Expand Up @@ -2245,7 +2245,7 @@ These conventions were established during the M0–M2+ review cycle. **Adopted**
| **String validation** | Adopted | `NotBlankStr` type from `core.types` for all identifiers | Eliminates per-model `@model_validator` boilerplate for whitespace checks. All identifier/name fields use `NotBlankStr`; optional identifiers use `NotBlankStr \| None`; tuple fields use `tuple[NotBlankStr, ...]` for per-element validation. |
| **Shared field groups** | Planned | Extract common field sets into base models (e.g. `_SpendingTotals`) | Prevents field duplication across spending summary models. Not yet implemented — each model independently defines fields. |
| **Event constants** | Adopted (per-domain) | Per-domain submodules under `events/` package (e.g. `events.provider`, `events.budget`). Import directly: `from ai_company.observability.events.<domain> import CONSTANT` | Split by domain for discoverability, co-location with domain logic, and reduced merge conflicts as constants grow. `__init__.py` serves as package marker with usage documentation; no re-exports. |
| **Parallel tool execution** | Planned | `asyncio.TaskGroup` in `ToolInvoker.invoke_all` | Structured concurrency with proper cancellation semantics. Currently sequential; migration planned for M3 when the agent engine needs concurrent tool calls. |
| **Parallel tool execution** | Adopted (M2.5) | `asyncio.TaskGroup` in `ToolInvoker.invoke_all` with optional `max_concurrency` semaphore | Structured concurrency with proper cancellation semantics. Fatal errors collected via guarded wrapper and re-raised after all tasks complete. |
| **Tool sandboxing** | Planned (M3) | Layered `SandboxBackend` protocol: `SubprocessSandbox` for low-risk tools (file, git), `DockerSandbox` for high-risk tools (code_runner, terminal, web, database). `K8sSandbox` planned for future container deployments. | Risk-proportionate isolation. Docker optional — only needed for code execution and network-sensitive tools. Pluggable protocol enables seamless migration to K8s per-agent pods in Phase 3-4. See §11.1.2. |
| **Crash recovery** | Planned (M3) | Pluggable `RecoveryStrategy` protocol. M3: `FailAndReassignStrategy` (catch at engine boundary, log snapshot, mark FAILED, reassign). M4/M5: `CheckpointStrategy` (persist `AgentContext` per turn, resume from last checkpoint). | Immutable `model_copy` pattern makes checkpoint serialization trivial to add later. Fail-and-reassign is sufficient for short MVP tasks. See §6.6. |
| **Agent behavior testing** | Planned (M3) | Scripted `FakeProvider` for unit tests (deterministic turn sequences); behavioral outcome assertions for integration tests (task completed, tools called, cost within budget). | Leverages existing `FakeProvider` and `CompletionResponseFactory` fixtures. Precise engine testing without brittle response-matching at integration level. |
Expand Down
2 changes: 2 additions & 0 deletions src/ai_company/observability/events/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
TOOL_INVOKE_VALIDATION_UNEXPECTED: Final[str] = "tool.invoke.validation_unexpected"
TOOL_BASE_INVALID_NAME: Final[str] = "tool.base.invalid_name"
TOOL_REGISTRY_CONTAINS_TYPE_ERROR: Final[str] = "tool.registry.contains_type_error"
TOOL_INVOKE_ALL_START: Final[str] = "tool.invoke_all.start"
TOOL_INVOKE_ALL_COMPLETE: Final[str] = "tool.invoke_all.complete"
106 changes: 94 additions & 12 deletions src/ai_company/tools/invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@

Bridges LLM ``ToolCall`` objects with concrete ``BaseTool.execute``
methods. Recoverable errors are returned as ``ToolResult(is_error=True)``;
non-recoverable errors (``MemoryError``, ``RecursionError``) and
``BaseException`` subclasses (``KeyboardInterrupt``, ``SystemExit``,
``asyncio.CancelledError``) propagate after logging.
non-recoverable errors (``MemoryError``, ``RecursionError``) are logged and
re-raised. ``BaseException`` subclasses (``KeyboardInterrupt``,
``SystemExit``, ``asyncio.CancelledError``) propagate uncaught.
"""

import asyncio
import copy
from typing import TYPE_CHECKING
from contextlib import nullcontext
from typing import TYPE_CHECKING, Never

import jsonschema
from referencing import Registry as JsonSchemaRegistry
from referencing import Resource
from referencing.exceptions import NoSuchResource

from ai_company.observability import get_logger
from ai_company.observability.events.tool import (
TOOL_INVOKE_ALL_COMPLETE,
TOOL_INVOKE_ALL_START,
TOOL_INVOKE_DEEPCOPY_ERROR,
TOOL_INVOKE_EXECUTION_ERROR,
TOOL_INVOKE_NON_RECOVERABLE,
Expand All @@ -41,7 +44,7 @@
logger = get_logger(__name__)


def _no_remote_retrieve(uri: str) -> Resource:
def _no_remote_retrieve(uri: str) -> Never:
"""Block remote ``$ref`` resolution to prevent SSRF."""
raise NoSuchResource(uri)

Expand All @@ -64,9 +67,13 @@ class ToolInvoker:
invoker = ToolInvoker(registry)
result = await invoker.invoke(tool_call)

Invoke multiple tool calls sequentially::
Invoke multiple tool calls concurrently::

results = await invoker.invoke_all(tool_calls)

Limit concurrency::

results = await invoker.invoke_all(tool_calls, max_concurrency=3)
"""

def __init__(self, registry: ToolRegistry) -> None:
Expand Down Expand Up @@ -172,7 +179,7 @@ def _schema_error_result(
error_msg: str,
) -> ToolResult:
"""Build an error result for an invalid tool schema."""
logger.exception(
logger.error(
TOOL_INVOKE_SCHEMA_ERROR,
tool_call_id=tool_call.id,
tool_name=tool_call.name,
Expand Down Expand Up @@ -318,19 +325,94 @@ def _build_result(
is_error=result.is_error,
)

async def _run_guarded(
self,
index: int,
tool_call: ToolCall,
results: dict[int, ToolResult],
fatal_errors: list[Exception],
semaphore: asyncio.Semaphore | None,
) -> None:
"""Execute a single tool call, storing fatal errors instead of raising.

This wrapper ensures that ``MemoryError`` / ``RecursionError`` do not
cancel sibling tasks inside a ``TaskGroup``. ``BaseException``
subclasses (``KeyboardInterrupt``, ``CancelledError``) are not
intercepted and will cancel the group normally.
"""
try:
ctx = semaphore if semaphore is not None else nullcontext()
async with ctx:
results[index] = await self.invoke(tool_call)
except (MemoryError, RecursionError) as exc:
fatal_errors.append(exc)
Comment on lines +347 to +348

Copilot AI Mar 6, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In _run_guarded, MemoryError and RecursionError are caught and silently collected, but self.invoke() already catches these same exceptions, logs them with logger.exception, and re-raises. This means the _run_guarded catch will see them after they've already been logged with a full traceback via logger.exception in _execute_tool / _validate_params. The logger.warning here then logs a second time at a lower severity. This is intentional dual-logging (once at the origin, once at the orchestration level), but the orchestration log at warning level is misleading — a non-recoverable error is more severe than a warning. Consider using logger.error to be consistent with the severity of the event.

Copilot uses AI. Check for mistakes.
Comment thread
greptile-apps[bot] marked this conversation as resolved.

async def invoke_all(
self,
tool_calls: Iterable[ToolCall],
*,
max_concurrency: int | None = None,
) -> tuple[ToolResult, ...]:
"""Execute multiple tool calls sequentially.
"""Execute multiple tool calls concurrently.

Calls continue through recoverable failures; non-recoverable
errors propagate immediately.
errors (``MemoryError``, ``RecursionError``) are collected and
re-raised after all tasks complete.

Args:
tool_calls: Tool calls to execute in order.
tool_calls: Tool calls to execute.
max_concurrency: Maximum number of concurrent invocations.
``None`` (default) means unbounded. Must be ``>= 1``
if provided.

Returns:
Tuple of results in the same order as the input.

Raises:
ValueError: If ``max_concurrency`` is less than 1.
MemoryError: Re-raised if it was the sole fatal error.
RecursionError: Re-raised if it was the sole fatal error.
ExceptionGroup: If multiple fatal errors occurred.
"""
return tuple([await self.invoke(call) for call in tool_calls])
if max_concurrency is not None and max_concurrency < 1:
msg = f"max_concurrency must be >= 1, got {max_concurrency}"
raise ValueError(msg)

calls = list(tool_calls)
if not calls:
return ()

logger.info(
TOOL_INVOKE_ALL_START,
count=len(calls),
max_concurrency=max_concurrency,
)

# SAFETY: Both ``results`` and ``fatal_errors`` are mutated by
# concurrent tasks. This is safe because asyncio runs tasks on
# a single thread — dict assignment and list.append() never race.
results: dict[int, ToolResult] = {}
fatal_errors: list[Exception] = []
semaphore = (
asyncio.Semaphore(max_concurrency) if max_concurrency is not None else None
)

async with asyncio.TaskGroup() as tg:
for idx, call in enumerate(calls):
tg.create_task(
self._run_guarded(idx, call, results, fatal_errors, semaphore),
)

logger.info(
TOOL_INVOKE_ALL_COMPLETE,
count=len(calls),
fatal_count=len(fatal_errors),
)
Comment on lines +406 to +410

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

info-level completion log emitted even when fatal errors will be raised

TOOL_INVOKE_ALL_COMPLETE is logged at info regardless of fatal_count. Any log aggregator or dashboard that counts TOOL_INVOKE_ALL_COMPLETE at info as a success metric will miscount — the event is emitted, and then an exception is immediately raised, making the operation an effective failure from the caller's perspective.

Consider using a conditional log level when fatal_errors is non-empty:

Suggested change
logger.info(
TOOL_INVOKE_ALL_COMPLETE,
count=len(calls),
fatal_count=len(fatal_errors),
)
if fatal_errors:
logger.warning(
TOOL_INVOKE_ALL_COMPLETE,
count=len(calls),
fatal_count=len(fatal_errors),
)
else:
logger.info(
TOOL_INVOKE_ALL_COMPLETE,
count=len(calls),
fatal_count=0,
)

This ensures that monitoring systems observing TOOL_INVOKE_ALL_COMPLETE at info level accurately reflects clean completions only.

Prompt To Fix With AI
This is a comment left during a code review.
Path: src/ai_company/tools/invoker.py
Line: 406-410

Comment:
**`info`-level completion log emitted even when fatal errors will be raised**

`TOOL_INVOKE_ALL_COMPLETE` is logged at `info` regardless of `fatal_count`. Any log aggregator or dashboard that counts `TOOL_INVOKE_ALL_COMPLETE` at `info` as a success metric will miscount — the event is emitted, and then an exception is immediately raised, making the operation an effective failure from the caller's perspective.

Consider using a conditional log level when `fatal_errors` is non-empty:

```suggestion
        if fatal_errors:
            logger.warning(
                TOOL_INVOKE_ALL_COMPLETE,
                count=len(calls),
                fatal_count=len(fatal_errors),
            )
        else:
            logger.info(
                TOOL_INVOKE_ALL_COMPLETE,
                count=len(calls),
                fatal_count=0,
            )
```

This ensures that monitoring systems observing `TOOL_INVOKE_ALL_COMPLETE` at `info` level accurately reflects clean completions only.

How can I resolve this? If you propose a fix, please make it concise.


if fatal_errors:
if len(fatal_errors) == 1:
raise fatal_errors[0]
Comment on lines +406 to +414

Copilot AI Mar 6, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When fatal_errors is non-empty, the TOOL_INVOKE_ALL_COMPLETE log at info level fires right before raising. This is fine for observability, but note that when fatal errors occur, results may be incomplete (missing entries for fatal slots). If a caller ever catches the raised exception and somehow tries to inspect partial results, they won't have access. This is acceptable given the current design, but consider documenting that no partial results are available on fatal error.

Copilot uses AI. Check for mistakes.
msg = "multiple non-recoverable tool errors"
raise ExceptionGroup(msg, fatal_errors)
Comment on lines +412 to +416

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Asymmetric exception type makes caller handling error-prone

When exactly one fatal error is collected, invoke_all re-raises a bare MemoryError or RecursionError. When two or more occur, it raises an ExceptionGroup. This means every caller that needs to handle fatal errors from invoke_all must cover two structurally different exception types:

try:
    results = await invoker.invoke_all(calls)
except (MemoryError, RecursionError):
    # single-error path
    ...
except ExceptionGroup as eg:
    # multi-error path  
    ...

If a caller only writes the first except clause (which is the obvious, natural thing to do), it silently misses the ExceptionGroup case and the errors go unhandled.

A consistent approach — always wrapping in ExceptionGroup, even for a single error — would let callers use a single except ExceptionGroup branch and introspect the contained exceptions uniformly. The PR description's own "Robust error handling" bullet even describes the asymmetry as a feature, but the caller-side complexity it creates is real. At minimum the docstring should warn that callers must handle both exception shapes.

Prompt To Fix With AI
This is a comment left during a code review.
Path: src/ai_company/tools/invoker.py
Line: 412-416

Comment:
**Asymmetric exception type makes caller handling error-prone**

When exactly one fatal error is collected, `invoke_all` re-raises a bare `MemoryError` or `RecursionError`. When two or more occur, it raises an `ExceptionGroup`. This means every caller that needs to handle fatal errors from `invoke_all` must cover two structurally different exception types:

```python
try:
    results = await invoker.invoke_all(calls)
except (MemoryError, RecursionError):
    # single-error path
    ...
except ExceptionGroup as eg:
    # multi-error path  
    ...
```

If a caller only writes the first `except` clause (which is the obvious, natural thing to do), it silently misses the `ExceptionGroup` case and the errors go unhandled.

A consistent approach — always wrapping in `ExceptionGroup`, even for a single error — would let callers use a single `except ExceptionGroup` branch and introspect the contained exceptions uniformly. The PR description's own "Robust error handling" bullet even describes the asymmetry as a feature, but the caller-side complexity it creates is real. At minimum the docstring should warn that callers must handle **both** exception shapes.

How can I resolve this? If you propose a fix, please make it concise.


return tuple(results[i] for i in range(len(calls)))
95 changes: 94 additions & 1 deletion tests/unit/tools/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Unit test fixtures for the tool system."""

import asyncio
from typing import Any

import pytest
Expand Down Expand Up @@ -297,7 +298,9 @@ def sample_tool_call() -> ToolCall:

@pytest.fixture
def extended_invoker() -> ToolInvoker:
"""Invoker with additional edge-case tools for advanced tests."""
"""Invoker with echo, recursion, invalid-schema, empty-error,
remote-ref, and mutating tools for edge-case tests.
"""
tools = [
_EchoTestTool(),
_RecursionTool(),
Expand All @@ -307,3 +310,93 @@ def extended_invoker() -> ToolInvoker:
_MutatingTool(),
]
return ToolInvoker(ToolRegistry(tools))


# ── Concurrency test tools ───────────────────────────────────────


class _DelayTool(BaseTool):
"""Sleeps for ``delay`` seconds, then returns ``value``."""

def __init__(self) -> None:
super().__init__(
name="delay",
description="Sleeps then returns value",
parameters_schema={
"type": "object",
"properties": {
"delay": {"type": "number"},
"value": {"type": "string"},
},
"required": ["delay", "value"],
"additionalProperties": False,
},
)

async def execute(
self,
*,
arguments: dict[str, Any],
) -> ToolExecutionResult:
await asyncio.sleep(arguments["delay"])
return ToolExecutionResult(content=arguments["value"])


class _ConcurrencyTrackingTool(BaseTool):
"""Tracks peak concurrent executions via a lock-guarded counter."""

def __init__(self) -> None:
super().__init__(
name="tracking",
description="Tracks concurrency",
parameters_schema={
"type": "object",
"properties": {
"duration": {"type": "number"},
},
"required": ["duration"],
"additionalProperties": False,
},
)
self._lock = asyncio.Lock()
self._current = 0
self._peak = 0

@property
def peak(self) -> int:
"""Return the peak concurrent execution count."""
return self._peak

async def execute(
self,
*,
arguments: dict[str, Any],
) -> ToolExecutionResult:
async with self._lock:
self._current += 1
self._peak = max(self._peak, self._current)
await asyncio.sleep(arguments["duration"])
async with self._lock:
self._current -= 1
return ToolExecutionResult(content=str(self._peak))
Comment on lines +374 to +381

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ConcurrencyTrackingTool.peak value returned as content may be stale

Inside execute, the content returned is str(self._peak) captured at the moment the individual task finishes, not after all concurrent tasks have exited. Because sibling tasks are still running when a given task releases the lock, the _peak value written to content may be lower than the true final peak:

await asyncio.sleep(arguments["duration"])      # other tasks still running
async with self._lock:
    self._current -= 1
return ToolExecutionResult(content=str(self._peak))   # stale peak

The tests don't assert on result.content for the tracking tool (they correctly read concurrency_tracking_tool.peak after all invocations), so this doesn't affect test correctness today. However, if content is ever used as evidence of peak concurrency elsewhere, it will silently give wrong numbers. Consider capturing peak from the shared state after all tasks complete, or removing it from the result content altogether since the fixture exposes .peak directly.

Prompt To Fix With AI
This is a comment left during a code review.
Path: tests/unit/tools/conftest.py
Line: 374-381

Comment:
**`_ConcurrencyTrackingTool.peak` value returned as content may be stale**

Inside `execute`, the content returned is `str(self._peak)` captured at the moment the individual task finishes, not after all concurrent tasks have exited. Because sibling tasks are still running when a given task releases the lock, the `_peak` value written to `content` may be lower than the true final peak:

```python
await asyncio.sleep(arguments["duration"])      # other tasks still running
async with self._lock:
    self._current -= 1
return ToolExecutionResult(content=str(self._peak))   # stale peak
```

The tests don't assert on `result.content` for the tracking tool (they correctly read `concurrency_tracking_tool.peak` after all invocations), so this doesn't affect test correctness today. However, if `content` is ever used as evidence of peak concurrency elsewhere, it will silently give wrong numbers. Consider capturing peak from the shared state after all tasks complete, or removing it from the result content altogether since the fixture exposes `.peak` directly.

How can I resolve this? If you propose a fix, please make it concise.



@pytest.fixture
def concurrency_tracking_tool() -> _ConcurrencyTrackingTool:
"""Standalone tracking tool for direct peak inspection."""
return _ConcurrencyTrackingTool()
Comment on lines +384 to +387

Copilot AI Mar 6, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concurrency_tracking_tool fixture creates a new instance, and the concurrency_invoker fixture receives it as a parameter and includes it in the tool list. However, the _ConcurrencyTrackingTool creates its asyncio.Lock() at __init__ time. If tests run under different event loops (e.g., with pytest-asyncio creating a new loop per test), the lock will be bound to the wrong loop in Python 3.9 (though this is fine in 3.10+ where locks are not bound to a loop). More importantly, the reset() method is never called between tests — since the fixture has no explicit scope and defaults to function scope, a new instance is created per test, so this is actually fine. The reset() method appears to be dead code.

Copilot uses AI. Check for mistakes.


@pytest.fixture
def concurrency_invoker(
concurrency_tracking_tool: _ConcurrencyTrackingTool,
) -> ToolInvoker:
"""Invoker with echo, failing, delay, tracking, and recursion tools."""
tools: list[BaseTool] = [
_EchoTestTool(),
_FailingTool(),
_DelayTool(),
concurrency_tracking_tool,
_RecursionTool(),
]
return ToolInvoker(ToolRegistry(tools))
Loading