-
Notifications
You must be signed in to change notification settings - Fork 1
feat: parallel tool execution in ToolInvoker.invoke_all #137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,21 +2,24 @@ | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Bridges LLM ``ToolCall`` objects with concrete ``BaseTool.execute`` | ||||||||||||||||||||||||||||||||||||
| methods. Recoverable errors are returned as ``ToolResult(is_error=True)``; | ||||||||||||||||||||||||||||||||||||
| non-recoverable errors (``MemoryError``, ``RecursionError``) and | ||||||||||||||||||||||||||||||||||||
| ``BaseException`` subclasses (``KeyboardInterrupt``, ``SystemExit``, | ||||||||||||||||||||||||||||||||||||
| ``asyncio.CancelledError``) propagate after logging. | ||||||||||||||||||||||||||||||||||||
| non-recoverable errors (``MemoryError``, ``RecursionError``) are logged and | ||||||||||||||||||||||||||||||||||||
| re-raised. ``BaseException`` subclasses (``KeyboardInterrupt``, | ||||||||||||||||||||||||||||||||||||
| ``SystemExit``, ``asyncio.CancelledError``) propagate uncaught. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||||||||||||||
| import copy | ||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING | ||||||||||||||||||||||||||||||||||||
| from contextlib import nullcontext | ||||||||||||||||||||||||||||||||||||
| from typing import TYPE_CHECKING, Never | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| import jsonschema | ||||||||||||||||||||||||||||||||||||
| from referencing import Registry as JsonSchemaRegistry | ||||||||||||||||||||||||||||||||||||
| from referencing import Resource | ||||||||||||||||||||||||||||||||||||
| from referencing.exceptions import NoSuchResource | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| from ai_company.observability import get_logger | ||||||||||||||||||||||||||||||||||||
| from ai_company.observability.events.tool import ( | ||||||||||||||||||||||||||||||||||||
| TOOL_INVOKE_ALL_COMPLETE, | ||||||||||||||||||||||||||||||||||||
| TOOL_INVOKE_ALL_START, | ||||||||||||||||||||||||||||||||||||
| TOOL_INVOKE_DEEPCOPY_ERROR, | ||||||||||||||||||||||||||||||||||||
| TOOL_INVOKE_EXECUTION_ERROR, | ||||||||||||||||||||||||||||||||||||
| TOOL_INVOKE_NON_RECOVERABLE, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -41,7 +44,7 @@ | |||||||||||||||||||||||||||||||||||
| logger = get_logger(__name__) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def _no_remote_retrieve(uri: str) -> Resource: | ||||||||||||||||||||||||||||||||||||
| def _no_remote_retrieve(uri: str) -> Never: | ||||||||||||||||||||||||||||||||||||
| """Block remote ``$ref`` resolution to prevent SSRF.""" | ||||||||||||||||||||||||||||||||||||
| raise NoSuchResource(uri) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -64,9 +67,13 @@ class ToolInvoker: | |||||||||||||||||||||||||||||||||||
| invoker = ToolInvoker(registry) | ||||||||||||||||||||||||||||||||||||
| result = await invoker.invoke(tool_call) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Invoke multiple tool calls sequentially:: | ||||||||||||||||||||||||||||||||||||
| Invoke multiple tool calls concurrently:: | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| results = await invoker.invoke_all(tool_calls) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Limit concurrency:: | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| results = await invoker.invoke_all(tool_calls, max_concurrency=3) | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def __init__(self, registry: ToolRegistry) -> None: | ||||||||||||||||||||||||||||||||||||
|
|
@@ -172,7 +179,7 @@ def _schema_error_result( | |||||||||||||||||||||||||||||||||||
| error_msg: str, | ||||||||||||||||||||||||||||||||||||
| ) -> ToolResult: | ||||||||||||||||||||||||||||||||||||
| """Build an error result for an invalid tool schema.""" | ||||||||||||||||||||||||||||||||||||
| logger.exception( | ||||||||||||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||||||||||
| TOOL_INVOKE_SCHEMA_ERROR, | ||||||||||||||||||||||||||||||||||||
| tool_call_id=tool_call.id, | ||||||||||||||||||||||||||||||||||||
| tool_name=tool_call.name, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -318,19 +325,94 @@ def _build_result( | |||||||||||||||||||||||||||||||||||
| is_error=result.is_error, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| async def _run_guarded( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
| index: int, | ||||||||||||||||||||||||||||||||||||
| tool_call: ToolCall, | ||||||||||||||||||||||||||||||||||||
| results: dict[int, ToolResult], | ||||||||||||||||||||||||||||||||||||
| fatal_errors: list[Exception], | ||||||||||||||||||||||||||||||||||||
| semaphore: asyncio.Semaphore | None, | ||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||
| """Execute a single tool call, storing fatal errors instead of raising. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| This wrapper ensures that ``MemoryError`` / ``RecursionError`` do not | ||||||||||||||||||||||||||||||||||||
| cancel sibling tasks inside a ``TaskGroup``. ``BaseException`` | ||||||||||||||||||||||||||||||||||||
| subclasses (``KeyboardInterrupt``, ``CancelledError``) are not | ||||||||||||||||||||||||||||||||||||
| intercepted and will cancel the group normally. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||
| ctx = semaphore if semaphore is not None else nullcontext() | ||||||||||||||||||||||||||||||||||||
| async with ctx: | ||||||||||||||||||||||||||||||||||||
| results[index] = await self.invoke(tool_call) | ||||||||||||||||||||||||||||||||||||
| except (MemoryError, RecursionError) as exc: | ||||||||||||||||||||||||||||||||||||
| fatal_errors.append(exc) | ||||||||||||||||||||||||||||||||||||
|
greptile-apps[bot] marked this conversation as resolved.
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| async def invoke_all( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
| tool_calls: Iterable[ToolCall], | ||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||
| max_concurrency: int | None = None, | ||||||||||||||||||||||||||||||||||||
| ) -> tuple[ToolResult, ...]: | ||||||||||||||||||||||||||||||||||||
| """Execute multiple tool calls sequentially. | ||||||||||||||||||||||||||||||||||||
| """Execute multiple tool calls concurrently. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Calls continue through recoverable failures; non-recoverable | ||||||||||||||||||||||||||||||||||||
| errors propagate immediately. | ||||||||||||||||||||||||||||||||||||
| errors (``MemoryError``, ``RecursionError``) are collected and | ||||||||||||||||||||||||||||||||||||
| re-raised after all tasks complete. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||
| tool_calls: Tool calls to execute in order. | ||||||||||||||||||||||||||||||||||||
| tool_calls: Tool calls to execute. | ||||||||||||||||||||||||||||||||||||
| max_concurrency: Maximum number of concurrent invocations. | ||||||||||||||||||||||||||||||||||||
| ``None`` (default) means unbounded. Must be ``>= 1`` | ||||||||||||||||||||||||||||||||||||
| if provided. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||
| Tuple of results in the same order as the input. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||||||
| ValueError: If ``max_concurrency`` is less than 1. | ||||||||||||||||||||||||||||||||||||
| MemoryError: Re-raised if it was the sole fatal error. | ||||||||||||||||||||||||||||||||||||
| RecursionError: Re-raised if it was the sole fatal error. | ||||||||||||||||||||||||||||||||||||
| ExceptionGroup: If multiple fatal errors occurred. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| return tuple([await self.invoke(call) for call in tool_calls]) | ||||||||||||||||||||||||||||||||||||
| if max_concurrency is not None and max_concurrency < 1: | ||||||||||||||||||||||||||||||||||||
| msg = f"max_concurrency must be >= 1, got {max_concurrency}" | ||||||||||||||||||||||||||||||||||||
| raise ValueError(msg) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| calls = list(tool_calls) | ||||||||||||||||||||||||||||||||||||
| if not calls: | ||||||||||||||||||||||||||||||||||||
| return () | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||
| TOOL_INVOKE_ALL_START, | ||||||||||||||||||||||||||||||||||||
| count=len(calls), | ||||||||||||||||||||||||||||||||||||
| max_concurrency=max_concurrency, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # SAFETY: Both ``results`` and ``fatal_errors`` are mutated by | ||||||||||||||||||||||||||||||||||||
| # concurrent tasks. This is safe because asyncio runs tasks on | ||||||||||||||||||||||||||||||||||||
| # a single thread — dict assignment and list.append() never race. | ||||||||||||||||||||||||||||||||||||
| results: dict[int, ToolResult] = {} | ||||||||||||||||||||||||||||||||||||
| fatal_errors: list[Exception] = [] | ||||||||||||||||||||||||||||||||||||
| semaphore = ( | ||||||||||||||||||||||||||||||||||||
| asyncio.Semaphore(max_concurrency) if max_concurrency is not None else None | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| async with asyncio.TaskGroup() as tg: | ||||||||||||||||||||||||||||||||||||
| for idx, call in enumerate(calls): | ||||||||||||||||||||||||||||||||||||
| tg.create_task( | ||||||||||||||||||||||||||||||||||||
| self._run_guarded(idx, call, results, fatal_errors, semaphore), | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||
| TOOL_INVOKE_ALL_COMPLETE, | ||||||||||||||||||||||||||||||||||||
| count=len(calls), | ||||||||||||||||||||||||||||||||||||
| fatal_count=len(fatal_errors), | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+406
to
+410
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Consider using a conditional log level when
Suggested change
This ensures that monitoring systems observing Prompt To Fix With AIThis is a comment left during a code review.
Path: src/ai_company/tools/invoker.py
Line: 406-410
Comment:
**`info`-level completion log emitted even when fatal errors will be raised**
`TOOL_INVOKE_ALL_COMPLETE` is logged at `info` regardless of `fatal_count`. Any log aggregator or dashboard that counts `TOOL_INVOKE_ALL_COMPLETE` at `info` as a success metric will miscount — the event is emitted, and then an exception is immediately raised, making the operation an effective failure from the caller's perspective.
Consider using a conditional log level when `fatal_errors` is non-empty:
```suggestion
if fatal_errors:
logger.warning(
TOOL_INVOKE_ALL_COMPLETE,
count=len(calls),
fatal_count=len(fatal_errors),
)
else:
logger.info(
TOOL_INVOKE_ALL_COMPLETE,
count=len(calls),
fatal_count=0,
)
```
This ensures that monitoring systems observing `TOOL_INVOKE_ALL_COMPLETE` at `info` level accurately reflects clean completions only.
How can I resolve this? If you propose a fix, please make it concise. |
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| if fatal_errors: | ||||||||||||||||||||||||||||||||||||
| if len(fatal_errors) == 1: | ||||||||||||||||||||||||||||||||||||
| raise fatal_errors[0] | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+406
to
+414
|
||||||||||||||||||||||||||||||||||||
| msg = "multiple non-recoverable tool errors" | ||||||||||||||||||||||||||||||||||||
| raise ExceptionGroup(msg, fatal_errors) | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+412
to
+416
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Asymmetric exception type makes caller handling error-prone When exactly one fatal error is collected, try:
results = await invoker.invoke_all(calls)
except (MemoryError, RecursionError):
# single-error path
...
except ExceptionGroup as eg:
# multi-error path
...If a caller only writes the first A consistent approach — always wrapping in Prompt To Fix With AIThis is a comment left during a code review.
Path: src/ai_company/tools/invoker.py
Line: 412-416
Comment:
**Asymmetric exception type makes caller handling error-prone**
When exactly one fatal error is collected, `invoke_all` re-raises a bare `MemoryError` or `RecursionError`. When two or more occur, it raises an `ExceptionGroup`. This means every caller that needs to handle fatal errors from `invoke_all` must cover two structurally different exception types:
```python
try:
results = await invoker.invoke_all(calls)
except (MemoryError, RecursionError):
# single-error path
...
except ExceptionGroup as eg:
# multi-error path
...
```
If a caller only writes the first `except` clause (which is the obvious, natural thing to do), it silently misses the `ExceptionGroup` case and the errors go unhandled.
A consistent approach — always wrapping in `ExceptionGroup`, even for a single error — would let callers use a single `except ExceptionGroup` branch and introspect the contained exceptions uniformly. The PR description's own "Robust error handling" bullet even describes the asymmetry as a feature, but the caller-side complexity it creates is real. At minimum the docstring should warn that callers must handle **both** exception shapes.
How can I resolve this? If you propose a fix, please make it concise. |
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| return tuple(results[i] for i in range(len(calls))) | ||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| """Unit test fixtures for the tool system.""" | ||
|
|
||
| import asyncio | ||
| from typing import Any | ||
|
|
||
| import pytest | ||
|
|
@@ -297,7 +298,9 @@ def sample_tool_call() -> ToolCall: | |
|
|
||
| @pytest.fixture | ||
| def extended_invoker() -> ToolInvoker: | ||
| """Invoker with additional edge-case tools for advanced tests.""" | ||
| """Invoker with echo, recursion, invalid-schema, empty-error, | ||
| remote-ref, and mutating tools for edge-case tests. | ||
| """ | ||
| tools = [ | ||
| _EchoTestTool(), | ||
| _RecursionTool(), | ||
|
|
@@ -307,3 +310,93 @@ def extended_invoker() -> ToolInvoker: | |
| _MutatingTool(), | ||
| ] | ||
| return ToolInvoker(ToolRegistry(tools)) | ||
|
|
||
|
|
||
| # ── Concurrency test tools ─────────────────────────────────────── | ||
|
|
||
|
|
||
| class _DelayTool(BaseTool): | ||
| """Sleeps for ``delay`` seconds, then returns ``value``.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__( | ||
| name="delay", | ||
| description="Sleeps then returns value", | ||
| parameters_schema={ | ||
| "type": "object", | ||
| "properties": { | ||
| "delay": {"type": "number"}, | ||
| "value": {"type": "string"}, | ||
| }, | ||
| "required": ["delay", "value"], | ||
| "additionalProperties": False, | ||
| }, | ||
| ) | ||
|
|
||
| async def execute( | ||
| self, | ||
| *, | ||
| arguments: dict[str, Any], | ||
| ) -> ToolExecutionResult: | ||
| await asyncio.sleep(arguments["delay"]) | ||
| return ToolExecutionResult(content=arguments["value"]) | ||
|
|
||
|
|
||
| class _ConcurrencyTrackingTool(BaseTool): | ||
| """Tracks peak concurrent executions via a lock-guarded counter.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__( | ||
| name="tracking", | ||
| description="Tracks concurrency", | ||
| parameters_schema={ | ||
| "type": "object", | ||
| "properties": { | ||
| "duration": {"type": "number"}, | ||
| }, | ||
| "required": ["duration"], | ||
| "additionalProperties": False, | ||
| }, | ||
| ) | ||
| self._lock = asyncio.Lock() | ||
| self._current = 0 | ||
| self._peak = 0 | ||
|
|
||
| @property | ||
| def peak(self) -> int: | ||
| """Return the peak concurrent execution count.""" | ||
| return self._peak | ||
|
|
||
| async def execute( | ||
| self, | ||
| *, | ||
| arguments: dict[str, Any], | ||
| ) -> ToolExecutionResult: | ||
| async with self._lock: | ||
| self._current += 1 | ||
| self._peak = max(self._peak, self._current) | ||
| await asyncio.sleep(arguments["duration"]) | ||
| async with self._lock: | ||
| self._current -= 1 | ||
| return ToolExecutionResult(content=str(self._peak)) | ||
|
Comment on lines
+374
to
+381
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Inside await asyncio.sleep(arguments["duration"]) # other tasks still running
async with self._lock:
self._current -= 1
return ToolExecutionResult(content=str(self._peak)) # stale peakThe tests don't assert on Prompt To Fix With AIThis is a comment left during a code review.
Path: tests/unit/tools/conftest.py
Line: 374-381
Comment:
**`_ConcurrencyTrackingTool.peak` value returned as content may be stale**
Inside `execute`, the content returned is `str(self._peak)` captured at the moment the individual task finishes, not after all concurrent tasks have exited. Because sibling tasks are still running when a given task releases the lock, the `_peak` value written to `content` may be lower than the true final peak:
```python
await asyncio.sleep(arguments["duration"]) # other tasks still running
async with self._lock:
self._current -= 1
return ToolExecutionResult(content=str(self._peak)) # stale peak
```
The tests don't assert on `result.content` for the tracking tool (they correctly read `concurrency_tracking_tool.peak` after all invocations), so this doesn't affect test correctness today. However, if `content` is ever used as evidence of peak concurrency elsewhere, it will silently give wrong numbers. Consider capturing peak from the shared state after all tasks complete, or removing it from the result content altogether since the fixture exposes `.peak` directly.
How can I resolve this? If you propose a fix, please make it concise. |
||
|
|
||
|
|
||
| @pytest.fixture | ||
| def concurrency_tracking_tool() -> _ConcurrencyTrackingTool: | ||
| """Standalone tracking tool for direct peak inspection.""" | ||
| return _ConcurrencyTrackingTool() | ||
|
Comment on lines
+384
to
+387
|
||
|
|
||
|
|
||
| @pytest.fixture | ||
| def concurrency_invoker( | ||
| concurrency_tracking_tool: _ConcurrencyTrackingTool, | ||
| ) -> ToolInvoker: | ||
| """Invoker with echo, failing, delay, tracking, and recursion tools.""" | ||
| tools: list[BaseTool] = [ | ||
| _EchoTestTool(), | ||
| _FailingTool(), | ||
| _DelayTool(), | ||
| concurrency_tracking_tool, | ||
| _RecursionTool(), | ||
| ] | ||
| return ToolInvoker(ToolRegistry(tools)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In
_run_guarded,MemoryErrorandRecursionErrorare caught and silently collected, butself.invoke()already catches these same exceptions, logs them withlogger.exception, and re-raises. This means the_run_guardedcatch will see them after they've already been logged with a full traceback vialogger.exceptionin_execute_tool/_validate_params. Thelogger.warninghere then logs a second time at a lower severity. This is intentional dual-logging (once at the origin, once at the orchestration level), but the orchestration log atwarninglevel is misleading — a non-recoverable error is more severe than a warning. Consider usinglogger.errorto be consistent with the severity of the event.