diff --git a/pyproject.toml b/pyproject.toml index dadfc3e79b..5151851cd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ classifiers = [ ] dependencies = [ "jinja2==3.1.6", + "jsonschema==4.26.0", "litellm==1.82.0", "pydantic==2.12.5", "pyyaml==6.0.3", @@ -150,6 +151,10 @@ plugins = ["pydantic.mypy"] module = "litellm.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "jsonschema.*" +ignore_missing_imports = true + [tool.pydantic-mypy] init_forbid_extra = true init_typed = true diff --git a/src/ai_company/observability/events.py b/src/ai_company/observability/events.py index a7cef1cdb6..15eeb7fbd7 100644 --- a/src/ai_company/observability/events.py +++ b/src/ai_company/observability/events.py @@ -108,6 +108,23 @@ BUDGET_TIME_RANGE_INVALID: Final[str] = "budget.time_range.invalid" BUDGET_DEPARTMENT_RESOLVE_FAILED: Final[str] = "budget.department.resolve_failed" +# ── Tool lifecycle ──────────────────────────────────────────────── + +TOOL_REGISTRY_BUILT: Final[str] = "tool.registry.built" +TOOL_REGISTRY_DUPLICATE: Final[str] = "tool.registry.duplicate" +TOOL_NOT_FOUND: Final[str] = "tool.not_found" +TOOL_INVOKE_START: Final[str] = "tool.invoke.start" +TOOL_INVOKE_SUCCESS: Final[str] = "tool.invoke.success" +TOOL_INVOKE_TOOL_ERROR: Final[str] = "tool.invoke.tool_error" +TOOL_INVOKE_NOT_FOUND: Final[str] = "tool.invoke.not_found" +TOOL_INVOKE_PARAMETER_ERROR: Final[str] = "tool.invoke.parameter_error" +TOOL_INVOKE_SCHEMA_ERROR: Final[str] = "tool.invoke.schema_error" +TOOL_INVOKE_EXECUTION_ERROR: Final[str] = "tool.invoke.execution_error" +TOOL_INVOKE_NON_RECOVERABLE: Final[str] = "tool.invoke.non_recoverable" +TOOL_INVOKE_VALIDATION_UNEXPECTED: Final[str] = "tool.invoke.validation_unexpected" +TOOL_BASE_INVALID_NAME: Final[str] = "tool.base.invalid_name" +TOOL_REGISTRY_CONTAINS_TYPE_ERROR: Final[str] = "tool.registry.contains_type_error" + # ── Role catalog ────────────────────────────────────────────────── ROLE_LOOKUP_MISS: Final[str] = "role.lookup.miss" diff --git a/src/ai_company/tools/__init__.py b/src/ai_company/tools/__init__.py index e69de29bb2..823af23df1 100644 --- a/src/ai_company/tools/__init__.py +++ b/src/ai_company/tools/__init__.py @@ -0,0 +1,19 @@ +"""Tool system — base abstraction, registry, invoker, and errors.""" + +from .base import BaseTool, ToolExecutionResult +from .errors import ToolError, ToolExecutionError, ToolNotFoundError, ToolParameterError +from .examples.echo import EchoTool +from .invoker import ToolInvoker +from .registry import ToolRegistry + +__all__ = [ + "BaseTool", + "EchoTool", + "ToolError", + "ToolExecutionError", + "ToolExecutionResult", + "ToolInvoker", + "ToolNotFoundError", + "ToolParameterError", + "ToolRegistry", +] diff --git a/src/ai_company/tools/base.py b/src/ai_company/tools/base.py new file mode 100644 index 0000000000..79d93ab75b --- /dev/null +++ b/src/ai_company/tools/base.py @@ -0,0 +1,139 @@ +"""Base tool abstraction and execution result model. + +Defines the ``BaseTool`` ABC that all concrete tools extend, and the +``ToolExecutionResult`` value object returned by tool execution. +""" + +import copy +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + +from ai_company.observability import get_logger +from ai_company.observability.events import TOOL_BASE_INVALID_NAME +from ai_company.providers.models import ToolDefinition + +logger = get_logger(__name__) + + +class ToolExecutionResult(BaseModel): + """Result of executing a tool's business logic. + + This is the internal result type returned by ``BaseTool.execute``. + The invoker converts it into a ``ToolResult`` for the LLM, carrying + only ``content`` and ``is_error`` — ``metadata`` is not forwarded + to the LLM and is available only for programmatic consumers. + + Note: + The ``metadata`` dict is shallowly immutable under the frozen + model — reassignment is prevented but contents can still be + mutated. Callers should treat it as read-only. + + Attributes: + content: Tool output as a string. + is_error: Whether the execution failed. + metadata: Optional structured data for programmatic consumers. + """ + + model_config = ConfigDict(frozen=True) + + content: str = Field(description="Tool output") + is_error: bool = Field(default=False, description="Whether tool errored") + metadata: dict[str, Any] = Field( + default_factory=dict, + description="Optional structured metadata", + ) + + +class BaseTool(ABC): + """Abstract base class for all tools in the system. + + Subclasses must implement ``execute`` to define tool behavior. + The ``to_definition`` method converts the tool into a + ``ToolDefinition`` suitable for sending to an LLM provider. + + Attributes: + name: Non-blank tool name. + description: Human-readable description of the tool. + parameters_schema: JSON Schema dict describing expected arguments, + or ``None`` if the tool accepts any arguments. + """ + + def __init__( + self, + *, + name: str, + description: str = "", + parameters_schema: dict[str, Any] | None = None, + ) -> None: + """Initialize a tool with name, description, and schema. + + Args: + name: Non-blank tool name. + description: Human-readable description. + parameters_schema: JSON Schema for tool parameters. + + Raises: + ValueError: If name is empty or whitespace-only. + """ + if not name or not name.strip(): + logger.warning(TOOL_BASE_INVALID_NAME, name=repr(name)) + msg = "Tool name must not be empty or whitespace-only" + raise ValueError(msg) + self._name = name + self._description = description + self._parameters_schema: dict[str, Any] | None = ( + copy.deepcopy(parameters_schema) if parameters_schema is not None else None + ) + + @property + def name(self) -> str: + """Tool name.""" + return self._name + + @property + def description(self) -> str: + """Tool description.""" + return self._description + + @property + def parameters_schema(self) -> dict[str, Any] | None: + """JSON Schema for tool parameters, or None if unspecified. + + Returns a deep copy to prevent mutation of the internal schema. + """ + return copy.deepcopy(self._parameters_schema) + + def to_definition(self) -> ToolDefinition: + """Convert this tool to a ``ToolDefinition`` for LLM providers. + + Returns: + A ``ToolDefinition`` with name, description, and schema. + """ + return ToolDefinition( + name=self._name, + description=self._description, + parameters_schema=self.parameters_schema or {}, + ) + + @abstractmethod + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Execute the tool with the given arguments. + + Arguments are pre-validated against the tool's JSON Schema (if + one is defined) by the ``ToolInvoker`` before reaching this + method. Implementations with a schema can assume compliance + when invoked through the invoker; tools without a schema + receive unvalidated arguments. + + Args: + arguments: Parsed arguments matching the parameters schema. + + Returns: + A ``ToolExecutionResult`` with the tool output. + """ diff --git a/src/ai_company/tools/errors.py b/src/ai_company/tools/errors.py new file mode 100644 index 0000000000..5e9f4c248b --- /dev/null +++ b/src/ai_company/tools/errors.py @@ -0,0 +1,56 @@ +"""Tool error hierarchy. + +All tool errors carry an immutable context mapping for structured +metadata. Unlike provider errors, tool errors have no ``is_retryable`` +flag — retry decisions are made at higher layers. +""" + +from types import MappingProxyType +from typing import Any + + +class ToolError(Exception): + """Base exception for all tool-layer errors. + + Attributes: + message: Human-readable error description. + context: Immutable metadata about the error (tool name, etc.). + """ + + def __init__( + self, + message: str, + *, + context: dict[str, Any] | None = None, + ) -> None: + """Initialize a tool error. + + Args: + message: Human-readable error description. + context: Arbitrary metadata about the error. Stored as an + immutable mapping; defaults to empty if not provided. + """ + self.message = message + self.context: MappingProxyType[str, Any] = MappingProxyType( + dict(context) if context else {}, + ) + super().__init__(message) + + def __str__(self) -> str: + """Format error with optional context metadata.""" + if self.context: + ctx = ", ".join(f"{k}={v!r}" for k, v in self.context.items()) + return f"{self.message} ({ctx})" + return self.message + + +class ToolNotFoundError(ToolError): + """Requested tool is not registered in the registry.""" + + +class ToolParameterError(ToolError): + """Tool parameters failed schema validation.""" + + +class ToolExecutionError(ToolError): + """Tool execution raised an unexpected error.""" diff --git a/src/ai_company/tools/examples/__init__.py b/src/ai_company/tools/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/ai_company/tools/examples/echo.py b/src/ai_company/tools/examples/echo.py new file mode 100644 index 0000000000..6a81255a37 --- /dev/null +++ b/src/ai_company/tools/examples/echo.py @@ -0,0 +1,49 @@ +"""Echo tool — returns the input message unchanged. + +A minimal reference implementation of ``BaseTool`` useful for testing +and as a starting point for new tool implementations. +""" + +from typing import Any + +from ai_company.tools.base import BaseTool, ToolExecutionResult + + +class EchoTool(BaseTool): + """Echoes the input message back as the tool result. + + Examples: + Basic usage:: + + tool = EchoTool() + result = await tool.execute(arguments={"message": "hello"}) + assert result.content == "hello" + """ + + def __init__(self) -> None: + """Initialize the echo tool with a fixed schema.""" + super().__init__( + name="echo", + description="Echoes the input message back", + parameters_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + "additionalProperties": False, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + """Return the ``message`` argument as content. + + Args: + arguments: Must contain a ``message`` key with a string value. + + Returns: + A ``ToolExecutionResult`` with the message as content. + """ + return ToolExecutionResult(content=arguments["message"]) diff --git a/src/ai_company/tools/invoker.py b/src/ai_company/tools/invoker.py new file mode 100644 index 0000000000..cfa1324047 --- /dev/null +++ b/src/ai_company/tools/invoker.py @@ -0,0 +1,306 @@ +"""Tool invoker — validates and executes tool calls. + +Bridges LLM ``ToolCall`` objects with concrete ``BaseTool.execute`` +methods. Never propagates exceptions — always returns a ``ToolResult``. + +Note: + ``BaseException`` subclasses (``KeyboardInterrupt``, ``SystemExit``, + ``asyncio.CancelledError``) are NOT caught and will propagate + normally. Non-recoverable errors (``MemoryError``, + ``RecursionError``) are re-raised after logging. +""" + +from typing import TYPE_CHECKING + +import jsonschema +from referencing import Registry as JsonSchemaRegistry +from referencing import Resource +from referencing.exceptions import NoSuchResource + +from ai_company.observability import get_logger +from ai_company.observability.events import ( + TOOL_INVOKE_EXECUTION_ERROR, + TOOL_INVOKE_NON_RECOVERABLE, + TOOL_INVOKE_NOT_FOUND, + TOOL_INVOKE_PARAMETER_ERROR, + TOOL_INVOKE_SCHEMA_ERROR, + TOOL_INVOKE_START, + TOOL_INVOKE_SUCCESS, + TOOL_INVOKE_TOOL_ERROR, + TOOL_INVOKE_VALIDATION_UNEXPECTED, +) +from ai_company.providers.models import ToolCall, ToolResult + +from .errors import ToolExecutionError, ToolNotFoundError, ToolParameterError + +if TYPE_CHECKING: + from collections.abc import Iterable + + from .base import BaseTool, ToolExecutionResult + from .registry import ToolRegistry + +logger = get_logger(__name__) + + +def _no_remote_retrieve(uri: str) -> Resource: + """Block remote ``$ref`` resolution to prevent SSRF.""" + raise NoSuchResource(uri) + + +_SAFE_REGISTRY: JsonSchemaRegistry = JsonSchemaRegistry( # type: ignore[call-arg] + retrieve=_no_remote_retrieve, +) + + +class ToolInvoker: + """Validates parameters and executes tool calls against a registry. + + Recoverable errors are returned as ``ToolResult(is_error=True)``. + Non-recoverable errors (``MemoryError``, ``RecursionError``) are + re-raised after logging. + + Examples: + Invoke a single tool call:: + + invoker = ToolInvoker(registry) + result = await invoker.invoke(tool_call) + + Invoke multiple tool calls sequentially:: + + results = await invoker.invoke_all(tool_calls) + """ + + def __init__(self, registry: ToolRegistry) -> None: + """Initialize with a tool registry. + + Args: + registry: Registry to look up tools from. + """ + self._registry = registry + + async def invoke(self, tool_call: ToolCall) -> ToolResult: + """Execute a single tool call. + + Steps: + 1. Look up the tool in the registry. + 2. Validate arguments against the tool's JSON Schema (if any). + 3. Call ``tool.execute(arguments=...)``. + 4. Return a ``ToolResult`` with the output. + + Recoverable errors produce ``ToolResult(is_error=True)``. + Non-recoverable errors are re-raised. + + Args: + tool_call: The tool call from the LLM. + + Returns: + A ``ToolResult`` with the tool's output or error message. + """ + logger.info( + TOOL_INVOKE_START, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + ) + + tool_or_error = self._lookup_tool(tool_call) + if isinstance(tool_or_error, ToolResult): + return tool_or_error + + param_error = self._validate_params(tool_or_error, tool_call) + if param_error is not None: + return param_error + + exec_result = await self._execute_tool(tool_or_error, tool_call) + if isinstance(exec_result, ToolResult): + return exec_result + + return self._build_result(tool_call, exec_result) + + def _lookup_tool(self, tool_call: ToolCall) -> BaseTool | ToolResult: + """Look up a tool in the registry, returning an error on miss.""" + try: + return self._registry.get(tool_call.name) + except ToolNotFoundError as exc: + logger.warning( + TOOL_INVOKE_NOT_FOUND, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + ) + return ToolResult( + tool_call_id=tool_call.id, + content=str(exc), + is_error=True, + ) + + def _validate_params( + self, + tool: BaseTool, + tool_call: ToolCall, + ) -> ToolResult | None: + """Validate tool call arguments against JSON Schema. + + Returns ``None`` on success or a ``ToolResult`` on failure. + """ + schema = tool.parameters_schema + if schema is None: + return None + try: + jsonschema.validate( + instance=dict(tool_call.arguments), + schema=schema, + registry=_SAFE_REGISTRY, + ) + except jsonschema.SchemaError as exc: + return self._schema_error_result(tool_call, exc.message) + except jsonschema.ValidationError as exc: + return self._param_error_result(tool_call, exc.message) + except (MemoryError, RecursionError) as exc: + logger.exception( + TOOL_INVOKE_NON_RECOVERABLE, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + error=f"{type(exc).__name__}: {exc}", + ) + raise + except Exception as exc: + error_msg = str(exc) or f"{type(exc).__name__} (no message)" + return self._unexpected_validation_result(tool_call, error_msg) + return None + + def _schema_error_result( + self, + tool_call: ToolCall, + error_msg: str, + ) -> ToolResult: + """Build an error result for an invalid tool schema.""" + logger.exception( + TOOL_INVOKE_SCHEMA_ERROR, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + error=error_msg, + ) + return ToolResult( + tool_call_id=tool_call.id, + content=( + f"Tool {tool_call.name!r} has an invalid parameter schema: {error_msg}" + ), + is_error=True, + ) + + def _param_error_result( + self, + tool_call: ToolCall, + error_msg: str, + ) -> ToolResult: + """Build an error result for failed parameter validation.""" + logger.warning( + TOOL_INVOKE_PARAMETER_ERROR, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + error=error_msg, + ) + param_err = ToolParameterError( + error_msg, + context={"tool": tool_call.name}, + ) + return ToolResult( + tool_call_id=tool_call.id, + content=str(param_err), + is_error=True, + ) + + def _unexpected_validation_result( + self, + tool_call: ToolCall, + error_msg: str, + ) -> ToolResult: + """Build an error result for unexpected validation failures.""" + logger.exception( + TOOL_INVOKE_VALIDATION_UNEXPECTED, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + error=error_msg, + ) + return ToolResult( + tool_call_id=tool_call.id, + content=( + f"Tool {tool_call.name!r} parameter validation failed: {error_msg}" + ), + is_error=True, + ) + + async def _execute_tool( + self, + tool: BaseTool, + tool_call: ToolCall, + ) -> ToolExecutionResult | ToolResult: + """Execute the tool, catching errors as ``ToolResult``.""" + try: + return await tool.execute(arguments=dict(tool_call.arguments)) + except (MemoryError, RecursionError) as exc: + logger.exception( + TOOL_INVOKE_NON_RECOVERABLE, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + error=f"{type(exc).__name__}: {exc}", + ) + raise + except Exception as exc: + error_msg = str(exc) or f"{type(exc).__name__} (no message)" + logger.exception( + TOOL_INVOKE_EXECUTION_ERROR, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + error=error_msg, + ) + exec_err = ToolExecutionError( + error_msg, + context={"tool": tool_call.name}, + ) + return ToolResult( + tool_call_id=tool_call.id, + content=str(exec_err), + is_error=True, + ) + + def _build_result( + self, + tool_call: ToolCall, + result: ToolExecutionResult, + ) -> ToolResult: + """Map a successful execution result to a ``ToolResult``.""" + if result.is_error: + logger.warning( + TOOL_INVOKE_TOOL_ERROR, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + content=result.content, + ) + else: + logger.info( + TOOL_INVOKE_SUCCESS, + tool_call_id=tool_call.id, + tool_name=tool_call.name, + ) + return ToolResult( + tool_call_id=tool_call.id, + content=result.content, + is_error=result.is_error, + ) + + async def invoke_all( + self, + tool_calls: Iterable[ToolCall], + ) -> tuple[ToolResult, ...]: + """Execute multiple tool calls sequentially. + + Calls continue through recoverable failures; non-recoverable + errors propagate immediately. + + Args: + tool_calls: Tool calls to execute in order. + + Returns: + Tuple of results in the same order as the input. + """ + return tuple([await self.invoke(call) for call in tool_calls]) diff --git a/src/ai_company/tools/registry.py b/src/ai_company/tools/registry.py new file mode 100644 index 0000000000..37de115f68 --- /dev/null +++ b/src/ai_company/tools/registry.py @@ -0,0 +1,122 @@ +"""Tool registry — maps tool names to ``BaseTool`` instances. + +Immutable after construction. Provides lookup, membership testing, +and conversion to a tuple of ``ToolDefinition`` objects for LLM providers. +""" + +from types import MappingProxyType +from typing import TYPE_CHECKING + +from ai_company.observability import get_logger +from ai_company.observability.events import ( + TOOL_NOT_FOUND, + TOOL_REGISTRY_BUILT, + TOOL_REGISTRY_CONTAINS_TYPE_ERROR, + TOOL_REGISTRY_DUPLICATE, +) + +from .errors import ToolNotFoundError + +if TYPE_CHECKING: + from collections.abc import Iterable + + from ai_company.providers.models import ToolDefinition + + from .base import BaseTool + +logger = get_logger(__name__) + + +class ToolRegistry: + """Immutable registry of named tools. + + Examples: + Build from a list of tools:: + + registry = ToolRegistry([echo_tool, search_tool]) + tool = registry.get("echo") + + Check membership:: + + if "echo" in registry: + ... + """ + + def __init__(self, tools: Iterable[BaseTool]) -> None: + """Initialize with an iterable of tools. + + Args: + tools: Tools to register. Duplicate names raise ``ValueError``. + + Raises: + ValueError: If two tools share the same name. + """ + mapping: dict[str, BaseTool] = {} + for tool in tools: + if tool.name in mapping: + logger.warning( + TOOL_REGISTRY_DUPLICATE, + tool_name=tool.name, + ) + msg = f"Duplicate tool name: {tool.name!r}" + raise ValueError(msg) + mapping[tool.name] = tool + self._tools: MappingProxyType[str, BaseTool] = MappingProxyType(mapping) + logger.info( + TOOL_REGISTRY_BUILT, + tool_count=len(self._tools), + tools=sorted(self._tools), + ) + + def get(self, name: str) -> BaseTool: + """Look up a tool by name. + + Args: + name: Tool name. + + Returns: + The registered tool instance. + + Raises: + ToolNotFoundError: If no tool is registered with that name. + """ + tool = self._tools.get(name) + if tool is None: + available = sorted(self._tools) or ["(none)"] + logger.warning( + TOOL_NOT_FOUND, + tool_name=name, + available=available, + ) + msg = ( + f"Tool {name!r} is not registered. " + f"Available tools: {', '.join(available)}" + ) + raise ToolNotFoundError(msg, context={"tool": name}) + return tool + + def list_tools(self) -> tuple[str, ...]: + """Return sorted tuple of registered tool names.""" + return tuple(sorted(self._tools)) + + def to_definitions(self) -> tuple[ToolDefinition, ...]: + """Return all tool definitions as a sorted tuple, ordered by name. + + Returns: + Sorted tuple of tool definitions for LLM providers. + """ + return tuple(self._tools[name].to_definition() for name in sorted(self._tools)) + + def __contains__(self, name: object) -> bool: + """Check whether a tool name is registered.""" + if not isinstance(name, str): + logger.debug( + TOOL_REGISTRY_CONTAINS_TYPE_ERROR, + name_type=type(name).__name__, + ) + return False + return name in self._tools + + def __len__(self) -> int: + """Return the number of registered tools.""" + return len(self._tools) diff --git a/tests/unit/tools/__init__.py b/tests/unit/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/tools/conftest.py b/tests/unit/tools/conftest.py new file mode 100644 index 0000000000..33a5bfd448 --- /dev/null +++ b/tests/unit/tools/conftest.py @@ -0,0 +1,282 @@ +"""Unit test fixtures for the tool system.""" + +from typing import Any + +import pytest + +from ai_company.providers.models import ToolCall +from ai_company.tools.base import BaseTool, ToolExecutionResult +from ai_company.tools.invoker import ToolInvoker +from ai_company.tools.registry import ToolRegistry + +# ── Concrete test tools (private to tests) ──────────────────────── + + +class _EchoTestTool(BaseTool): + """Returns arguments as content.""" + + def __init__(self) -> None: + super().__init__( + name="echo_test", + description="Echoes arguments back", + parameters_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + "additionalProperties": False, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult(content=arguments.get("message", "")) + + +class _FailingTool(BaseTool): + """Always raises RuntimeError in execute.""" + + def __init__(self) -> None: + super().__init__( + name="failing", + description="Always fails", + parameters_schema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + msg = "tool execution failed" + raise RuntimeError(msg) + + +class _NoSchemaTool(BaseTool): + """Tool with no parameters schema.""" + + def __init__(self) -> None: + super().__init__( + name="no_schema", + description="Accepts anything", + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult(content="ok") + + +class _StrictSchemaTool(BaseTool): + """Tool with strict schema: requires query + limit, no extras.""" + + def __init__(self) -> None: + super().__init__( + name="strict", + description="Strict parameters", + parameters_schema={ + "type": "object", + "properties": { + "query": {"type": "string"}, + "limit": {"type": "integer"}, + }, + "required": ["query", "limit"], + "additionalProperties": False, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult( + content=f"query={arguments['query']} limit={arguments['limit']}", + ) + + +class _SoftErrorTool(BaseTool): + """Returns is_error=True without raising an exception.""" + + def __init__(self) -> None: + super().__init__( + name="soft_error", + description="Reports a soft error", + parameters_schema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult(content="soft fail", is_error=True) + + +class _RecursionTool(BaseTool): + """Raises RecursionError in execute.""" + + def __init__(self) -> None: + super().__init__( + name="recursion", + description="Raises RecursionError", + parameters_schema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + msg = "maximum recursion depth" + raise RecursionError(msg) + + +class _InvalidSchemaTool(BaseTool): + """Tool with an invalid JSON Schema (properties is not a dict).""" + + def __init__(self) -> None: + super().__init__( + name="invalid_schema", + description="Has invalid schema", + parameters_schema={"type": "object", "properties": "not_a_dict"}, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult(content="ok") + + +class _EmptyErrorTool(BaseTool): + """Raises exception with empty string message.""" + + def __init__(self) -> None: + super().__init__( + name="empty_error", + description="Raises with empty message", + parameters_schema={ + "type": "object", + "properties": {"input": {"type": "string"}}, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + msg = "" + raise ValueError(msg) + + +class _RemoteRefTool(BaseTool): + """Tool with a remote ``$ref`` in its schema (for SSRF testing).""" + + def __init__(self) -> None: + super().__init__( + name="remote_ref", + description="Has remote ref in schema", + parameters_schema={ + "type": "object", + "properties": { + "data": {"$ref": "http://evil.example.com/schema.json"}, + }, + }, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult(content="ok") + + +# ── Fixtures ────────────────────────────────────────────────────── + + +@pytest.fixture +def echo_test_tool() -> _EchoTestTool: + return _EchoTestTool() + + +@pytest.fixture +def failing_tool() -> _FailingTool: + return _FailingTool() + + +@pytest.fixture +def no_schema_tool() -> _NoSchemaTool: + return _NoSchemaTool() + + +@pytest.fixture +def strict_schema_tool() -> _StrictSchemaTool: + return _StrictSchemaTool() + + +@pytest.fixture +def soft_error_tool() -> _SoftErrorTool: + return _SoftErrorTool() + + +@pytest.fixture +def sample_registry( + echo_test_tool: _EchoTestTool, + failing_tool: _FailingTool, + no_schema_tool: _NoSchemaTool, + strict_schema_tool: _StrictSchemaTool, + soft_error_tool: _SoftErrorTool, +) -> ToolRegistry: + return ToolRegistry( + [ + echo_test_tool, + failing_tool, + no_schema_tool, + strict_schema_tool, + soft_error_tool, + ], + ) + + +@pytest.fixture +def sample_invoker(sample_registry: ToolRegistry) -> ToolInvoker: + return ToolInvoker(sample_registry) + + +@pytest.fixture +def sample_tool_call() -> ToolCall: + return ToolCall( + id="call_001", + name="echo_test", + arguments={"message": "hello"}, + ) + + +@pytest.fixture +def extended_invoker() -> ToolInvoker: + """Invoker with additional edge-case tools for advanced tests.""" + tools = [ + _EchoTestTool(), + _RecursionTool(), + _InvalidSchemaTool(), + _EmptyErrorTool(), + _RemoteRefTool(), + ] + return ToolInvoker(ToolRegistry(tools)) diff --git a/tests/unit/tools/test_base.py b/tests/unit/tools/test_base.py new file mode 100644 index 0000000000..0aec649c3c --- /dev/null +++ b/tests/unit/tools/test_base.py @@ -0,0 +1,145 @@ +"""Tests for BaseTool ABC and ToolExecutionResult.""" + +from typing import Any + +import pytest +from pydantic import ValidationError + +from ai_company.providers.models import ToolDefinition +from ai_company.tools.base import BaseTool, ToolExecutionResult + +pytestmark = pytest.mark.timeout(30) + + +# ── ToolExecutionResult ────────────────────────────────────────── + + +@pytest.mark.unit +class TestToolExecutionResult: + """Tests for ToolExecutionResult model.""" + + def test_defaults(self) -> None: + result = ToolExecutionResult(content="output") + assert result.content == "output" + assert result.is_error is False + assert result.metadata == {} + + def test_custom_values(self) -> None: + result = ToolExecutionResult( + content="error output", + is_error=True, + metadata={"code": 42}, + ) + assert result.content == "error output" + assert result.is_error is True + assert result.metadata == {"code": 42} + + def test_frozen(self) -> None: + result = ToolExecutionResult(content="output") + with pytest.raises(ValidationError): + result.content = "modified" # type: ignore[misc] + + +# ── BaseTool ───────────────────────────────────────────────────── + + +class _ConcreteTool(BaseTool): + """Minimal concrete tool for testing BaseTool.""" + + def __init__( + self, + *, + name: str = "test_tool", + description: str = "A test tool", + parameters_schema: dict[str, Any] | None = None, + ) -> None: + super().__init__( + name=name, + description=description, + parameters_schema=parameters_schema, + ) + + async def execute( + self, + *, + arguments: dict[str, Any], + ) -> ToolExecutionResult: + return ToolExecutionResult(content="executed") + + +@pytest.mark.unit +class TestBaseTool: + """Tests for BaseTool ABC.""" + + def test_properties(self) -> None: + schema = {"type": "object", "properties": {"x": {"type": "string"}}} + tool = _ConcreteTool( + name="my_tool", + description="desc", + parameters_schema=schema, + ) + assert tool.name == "my_tool" + assert tool.description == "desc" + assert tool.parameters_schema == schema + + def test_blank_name_rejected(self) -> None: + with pytest.raises(ValueError, match="must not be empty"): + _ConcreteTool(name="") + + def test_whitespace_name_rejected(self) -> None: + with pytest.raises(ValueError, match="must not be empty"): + _ConcreteTool(name=" ") + + def test_default_description_empty(self) -> None: + tool = _ConcreteTool(name="t", description="") + assert tool.description == "" + + def test_default_schema_none(self) -> None: + tool = _ConcreteTool(name="t") + assert tool.parameters_schema is None + + def test_schema_deep_copied_on_construction(self) -> None: + props: dict[str, Any] = {"x": {"type": "string"}} + schema: dict[str, Any] = {"type": "object", "properties": props} + tool = _ConcreteTool(name="t", parameters_schema=schema) + props["y"] = {"type": "integer"} + assert tool.parameters_schema is not None + assert "y" not in tool.parameters_schema["properties"] + + def test_schema_property_returns_copy(self) -> None: + schema: dict[str, Any] = { + "type": "object", + "properties": {"x": {"type": "string"}}, + } + tool = _ConcreteTool(name="t", parameters_schema=schema) + returned = tool.parameters_schema + assert returned is not None + returned["injected"] = True + assert tool.parameters_schema is not None + assert "injected" not in tool.parameters_schema + + def test_to_definition(self) -> None: + schema = {"type": "object", "properties": {"x": {"type": "string"}}} + tool = _ConcreteTool( + name="my_tool", + description="desc", + parameters_schema=schema, + ) + defn = tool.to_definition() + assert isinstance(defn, ToolDefinition) + assert defn.name == "my_tool" + assert defn.description == "desc" + assert defn.parameters_schema == schema + + def test_to_definition_no_schema(self) -> None: + tool = _ConcreteTool(name="t") + defn = tool.to_definition() + assert defn.parameters_schema == {} + + def test_execute_is_abstract(self) -> None: + assert getattr(BaseTool.execute, "__isabstractmethod__", False) is True + + async def test_execute_runs(self) -> None: + tool = _ConcreteTool(name="t") + result = await tool.execute(arguments={}) + assert result.content == "executed" diff --git a/tests/unit/tools/test_echo.py b/tests/unit/tools/test_echo.py new file mode 100644 index 0000000000..1e1fe7c6f7 --- /dev/null +++ b/tests/unit/tools/test_echo.py @@ -0,0 +1,77 @@ +"""Tests for EchoTool.""" + +import pytest + +from ai_company.providers.models import ToolCall, ToolDefinition +from ai_company.tools.examples.echo import EchoTool +from ai_company.tools.invoker import ToolInvoker +from ai_company.tools.registry import ToolRegistry + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestEchoToolProperties: + """Tests for EchoTool name, description, schema.""" + + def test_name(self) -> None: + tool = EchoTool() + assert tool.name == "echo" + + def test_description(self) -> None: + tool = EchoTool() + assert tool.description == "Echoes the input message back" + + def test_schema(self) -> None: + tool = EchoTool() + schema = tool.parameters_schema + assert schema is not None + assert schema["type"] == "object" + assert "message" in schema["properties"] + assert schema["required"] == ["message"] + assert schema["additionalProperties"] is False + + def test_to_definition(self) -> None: + tool = EchoTool() + defn = tool.to_definition() + assert isinstance(defn, ToolDefinition) + assert defn.name == "echo" + + +@pytest.mark.unit +class TestEchoToolExecution: + """Tests for EchoTool execute method.""" + + async def test_echoes_message(self) -> None: + tool = EchoTool() + result = await tool.execute(arguments={"message": "hello world"}) + assert result.content == "hello world" + + async def test_is_not_error(self) -> None: + tool = EchoTool() + result = await tool.execute(arguments={"message": "test"}) + assert result.is_error is False + + async def test_metadata_empty(self) -> None: + tool = EchoTool() + result = await tool.execute(arguments={"message": "test"}) + assert result.metadata == {} + + +@pytest.mark.unit +class TestEchoToolIntegration: + """Integration test: registry -> invoker -> invoke with ToolCall.""" + + async def test_full_pipeline(self) -> None: + tool = EchoTool() + registry = ToolRegistry([tool]) + invoker = ToolInvoker(registry) + call = ToolCall( + id="call_echo_001", + name="echo", + arguments={"message": "integration test"}, + ) + result = await invoker.invoke(call) + assert result.tool_call_id == "call_echo_001" + assert result.content == "integration test" + assert result.is_error is False diff --git a/tests/unit/tools/test_errors.py b/tests/unit/tools/test_errors.py new file mode 100644 index 0000000000..8f72452b6d --- /dev/null +++ b/tests/unit/tools/test_errors.py @@ -0,0 +1,98 @@ +"""Tests for tool error hierarchy.""" + +import pytest + +from ai_company.tools.errors import ( + ToolError, + ToolExecutionError, + ToolNotFoundError, + ToolParameterError, +) + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestToolError: + """Tests for the base ToolError.""" + + def test_message_stored(self) -> None: + err = ToolError("something broke") + assert err.message == "something broke" + + def test_context_defaults_to_empty(self) -> None: + err = ToolError("oops") + assert err.context == {} + + def test_context_stored(self) -> None: + ctx = {"tool": "echo", "detail": "missing arg"} + err = ToolError("oops", context=ctx) + assert err.context == ctx + + def test_str_without_context(self) -> None: + err = ToolError("broken") + assert str(err) == "broken" + + def test_str_with_context(self) -> None: + err = ToolError("broken", context={"key": "val"}) + assert "broken" in str(err) + assert "key='val'" in str(err) + + def test_is_exception(self) -> None: + assert issubclass(ToolError, Exception) + + +@pytest.mark.unit +class TestErrorHierarchy: + """Tests for all typed error subclasses.""" + + def test_all_subclass_tool_error(self) -> None: + subclasses = [ + ToolNotFoundError, + ToolParameterError, + ToolExecutionError, + ] + for cls in subclasses: + assert issubclass(cls, ToolError) + + def test_catchable_as_tool_error(self) -> None: + err = ToolNotFoundError("missing") + with pytest.raises(ToolError): + raise err + + def test_catchable_as_exception(self) -> None: + err = ToolParameterError("bad param") + with pytest.raises(Exception, match="bad param"): + raise err + + +@pytest.mark.unit +class TestContextImmutability: + """Tests for context immutability guarantees.""" + + def test_context_is_immutable(self) -> None: + err = ToolError("oops", context={"key": "val"}) + with pytest.raises(TypeError): + err.context["new_key"] = "new_val" # type: ignore[index] + + def test_original_dict_mutation_does_not_affect_error(self) -> None: + ctx = {"tool": "echo"} + err = ToolError("oops", context=ctx) + ctx["extra"] = "injected" + assert "extra" not in err.context + + +@pytest.mark.unit +class TestErrorFormatting: + """Tests for __str__ formatting across error types.""" + + def test_all_errors_include_message_in_str(self) -> None: + for cls in (ToolNotFoundError, ToolParameterError, ToolExecutionError): + err = cls("test msg", context={"tool": "echo"}) + result = str(err) + assert "test msg" in result + assert "tool='echo'" in result + + def test_no_context_just_message(self) -> None: + err = ToolExecutionError("boom") + assert str(err) == "boom" diff --git a/tests/unit/tools/test_invoker.py b/tests/unit/tools/test_invoker.py new file mode 100644 index 0000000000..cc6df2ffd2 --- /dev/null +++ b/tests/unit/tools/test_invoker.py @@ -0,0 +1,336 @@ +"""Tests for ToolInvoker.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.providers.models import ToolCall, ToolResult + +if TYPE_CHECKING: + from ai_company.tools.invoker import ToolInvoker + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestInvokeSuccess: + """Tests for successful tool invocation.""" + + async def test_invoke_returns_tool_result( + self, + sample_invoker: ToolInvoker, + sample_tool_call: ToolCall, + ) -> None: + result = await sample_invoker.invoke(sample_tool_call) + assert isinstance(result, ToolResult) + assert result.content == "hello" + assert result.is_error is False + + async def test_tool_call_id_matches( + self, + sample_invoker: ToolInvoker, + sample_tool_call: ToolCall, + ) -> None: + result = await sample_invoker.invoke(sample_tool_call) + assert result.tool_call_id == sample_tool_call.id + + +@pytest.mark.unit +class TestInvokeNotFound: + """Tests for tool-not-found handling.""" + + async def test_not_found_returns_error_result( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall(id="call_x", name="nonexistent", arguments={}) + result = await sample_invoker.invoke(call) + assert result.is_error is True + assert result.tool_call_id == "call_x" + assert "not registered" in result.content + + async def test_not_found_does_not_raise( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall(id="call_x", name="nonexistent", arguments={}) + result = await sample_invoker.invoke(call) + assert isinstance(result, ToolResult) + + +@pytest.mark.unit +class TestInvokeParameterValidation: + """Tests for parameter schema validation.""" + + async def test_valid_params_accepted( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_strict", + name="strict", + arguments={"query": "hello", "limit": 10}, + ) + result = await sample_invoker.invoke(call) + assert result.is_error is False + assert "query=hello" in result.content + + async def test_invalid_params_returns_error( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_bad", + name="strict", + arguments={"query": "hello", "limit": "not_a_number"}, + ) + result = await sample_invoker.invoke(call) + assert result.is_error is True + assert result.tool_call_id == "call_bad" + + async def test_missing_required_params_returns_error( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_missing", + name="strict", + arguments={"query": "hello"}, + ) + result = await sample_invoker.invoke(call) + assert result.is_error is True + + async def test_extra_params_returns_error( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_extra", + name="echo_test", + arguments={"message": "hi", "extra": "nope"}, + ) + result = await sample_invoker.invoke(call) + assert result.is_error is True + + async def test_empty_schema_skips_validation( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_noschema", + name="no_schema", + arguments={"anything": "goes"}, + ) + result = await sample_invoker.invoke(call) + assert result.is_error is False + + +@pytest.mark.unit +class TestInvokeSoftError: + """Tests for tool-reported soft errors (is_error=True without exception).""" + + async def test_soft_error_propagated( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_soft", + name="soft_error", + arguments={"input": "test"}, + ) + result = await sample_invoker.invoke(call) + assert result.is_error is True + assert result.content == "soft fail" + assert result.tool_call_id == "call_soft" + + +@pytest.mark.unit +class TestInvokeExecutionError: + """Tests for execution error handling.""" + + async def test_execution_error_caught( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_fail", + name="failing", + arguments={"input": "test"}, + ) + result = await sample_invoker.invoke(call) + assert result.is_error is True + assert result.tool_call_id == "call_fail" + assert "tool execution failed" in result.content + + async def test_execution_error_does_not_propagate( + self, + sample_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_fail2", + name="failing", + arguments={"input": "test"}, + ) + result = await sample_invoker.invoke(call) + assert isinstance(result, ToolResult) + + +@pytest.mark.unit +class TestInvokeAll: + """Tests for invoke_all method.""" + + async def test_invoke_all_empty( + self, + sample_invoker: ToolInvoker, + ) -> None: + results = await sample_invoker.invoke_all([]) + assert results == () + + async def test_invoke_all_multiple( + self, + sample_invoker: ToolInvoker, + ) -> None: + calls = [ + ToolCall(id="c1", name="echo_test", arguments={"message": "a"}), + ToolCall(id="c2", name="echo_test", arguments={"message": "b"}), + ] + results = await sample_invoker.invoke_all(calls) + assert len(results) == 2 + assert results[0].content == "a" + assert results[1].content == "b" + + async def test_invoke_all_mixed_success_and_error( + self, + sample_invoker: ToolInvoker, + ) -> None: + calls = [ + ToolCall(id="c1", name="echo_test", arguments={"message": "ok"}), + ToolCall(id="c2", name="failing", arguments={"input": "x"}), + ToolCall(id="c3", name="echo_test", arguments={"message": "also ok"}), + ] + results = await sample_invoker.invoke_all(calls) + assert len(results) == 3 + assert results[0].is_error is False + assert results[1].is_error is True + assert results[2].is_error is False + + async def test_invoke_all_preserves_order( + self, + sample_invoker: ToolInvoker, + ) -> None: + calls = [ + ToolCall(id="c1", name="echo_test", arguments={"message": "first"}), + ToolCall(id="c2", name="echo_test", arguments={"message": "second"}), + ToolCall(id="c3", name="echo_test", arguments={"message": "third"}), + ] + results = await sample_invoker.invoke_all(calls) + assert results[0].tool_call_id == "c1" + assert results[1].tool_call_id == "c2" + assert results[2].tool_call_id == "c3" + + +@pytest.mark.unit +class TestInvokeNonRecoverableErrors: + """Tests for MemoryError/RecursionError re-raise behavior.""" + + async def test_recursion_error_propagates( + self, + extended_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_recursion", + name="recursion", + arguments={"input": "test"}, + ) + with pytest.raises(RecursionError, match="maximum recursion depth"): + await extended_invoker.invoke(call) + + async def test_recursion_error_not_swallowed_as_tool_result( + self, + extended_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_recursion2", + name="recursion", + arguments={"input": "test"}, + ) + with pytest.raises(RecursionError): + await extended_invoker.invoke(call) + + +@pytest.mark.unit +class TestInvokeSchemaError: + """Tests for invalid tool schema (SchemaError) handling.""" + + async def test_invalid_schema_returns_error( + self, + extended_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_bad_schema", + name="invalid_schema", + arguments={"data": "test"}, + ) + result = await extended_invoker.invoke(call) + assert result.is_error is True + assert result.tool_call_id == "call_bad_schema" + + async def test_invalid_schema_does_not_raise( + self, + extended_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_bad_schema2", + name="invalid_schema", + arguments={"data": "test"}, + ) + result = await extended_invoker.invoke(call) + assert isinstance(result, ToolResult) + + +@pytest.mark.unit +class TestInvokeSsrfProtection: + """Tests for SSRF prevention via blocked remote $ref resolution.""" + + async def test_remote_ref_blocked( + self, + extended_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_ssrf", + name="remote_ref", + arguments={"data": "test"}, + ) + result = await extended_invoker.invoke(call) + assert result.is_error is True + assert result.tool_call_id == "call_ssrf" + + async def test_remote_ref_does_not_raise( + self, + extended_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_ssrf2", + name="remote_ref", + arguments={"data": "test"}, + ) + result = await extended_invoker.invoke(call) + assert isinstance(result, ToolResult) + + +@pytest.mark.unit +class TestInvokeEmptyErrorMessage: + """Tests for empty exception message fallback.""" + + async def test_empty_error_message_fallback( + self, + extended_invoker: ToolInvoker, + ) -> None: + call = ToolCall( + id="call_empty_err", + name="empty_error", + arguments={"input": "test"}, + ) + result = await extended_invoker.invoke(call) + assert result.is_error is True + assert "ValueError (no message)" in result.content diff --git a/tests/unit/tools/test_registry.py b/tests/unit/tools/test_registry.py new file mode 100644 index 0000000000..b337e4187d --- /dev/null +++ b/tests/unit/tools/test_registry.py @@ -0,0 +1,121 @@ +"""Tests for ToolRegistry.""" + +from typing import TYPE_CHECKING + +import pytest + +from ai_company.providers.models import ToolDefinition +from ai_company.tools.errors import ToolNotFoundError +from ai_company.tools.registry import ToolRegistry + +if TYPE_CHECKING: + from ai_company.tools.base import BaseTool + +pytestmark = pytest.mark.timeout(30) + + +@pytest.mark.unit +class TestToolRegistryEmpty: + """Tests for an empty registry.""" + + def test_empty_registry_len(self) -> None: + registry = ToolRegistry([]) + assert len(registry) == 0 + + def test_empty_registry_list_tools(self) -> None: + registry = ToolRegistry([]) + assert registry.list_tools() == () + + def test_empty_registry_to_definitions(self) -> None: + registry = ToolRegistry([]) + assert registry.to_definitions() == () + + def test_empty_registry_get_raises(self) -> None: + registry = ToolRegistry([]) + with pytest.raises(ToolNotFoundError, match="not registered"): + registry.get("missing") + + def test_empty_registry_get_shows_none_available(self) -> None: + registry = ToolRegistry([]) + with pytest.raises(ToolNotFoundError, match=r"\(none\)") as exc_info: + registry.get("missing") + assert exc_info.value.context["tool"] == "missing" + + +@pytest.mark.unit +class TestToolRegistrySingle: + """Tests for a registry with one tool.""" + + def test_len(self, echo_test_tool: BaseTool) -> None: + registry = ToolRegistry([echo_test_tool]) + assert len(registry) == 1 + + def test_get_success(self, echo_test_tool: BaseTool) -> None: + registry = ToolRegistry([echo_test_tool]) + assert registry.get("echo_test") is echo_test_tool + + def test_contains(self, echo_test_tool: BaseTool) -> None: + registry = ToolRegistry([echo_test_tool]) + assert "echo_test" in registry + + def test_not_contains(self, echo_test_tool: BaseTool) -> None: + registry = ToolRegistry([echo_test_tool]) + assert "missing" not in registry + + +@pytest.mark.unit +class TestToolRegistryMultiple: + """Tests for a registry with multiple tools.""" + + def test_len(self, sample_registry: ToolRegistry) -> None: + assert len(sample_registry) == 5 + + def test_list_tools_sorted(self, sample_registry: ToolRegistry) -> None: + names = sample_registry.list_tools() + assert names == tuple(sorted(names)) + assert len(names) == 5 + + def test_get_each(self, sample_registry: ToolRegistry) -> None: + for name in sample_registry.list_tools(): + tool = sample_registry.get(name) + assert tool.name == name + + def test_get_not_found(self, sample_registry: ToolRegistry) -> None: + with pytest.raises(ToolNotFoundError, match="not registered"): + sample_registry.get("nonexistent") + + def test_get_not_found_context(self, sample_registry: ToolRegistry) -> None: + with pytest.raises(ToolNotFoundError) as exc_info: + sample_registry.get("nonexistent") + assert exc_info.value.context["tool"] == "nonexistent" + + def test_to_definitions(self, sample_registry: ToolRegistry) -> None: + defs = sample_registry.to_definitions() + assert len(defs) == 5 + assert all(isinstance(d, ToolDefinition) for d in defs) + names = [d.name for d in defs] + assert names == sorted(names) + + def test_contains_non_string(self, sample_registry: ToolRegistry) -> None: + assert 42 not in sample_registry + + def test_contains_unhashable_type(self, sample_registry: ToolRegistry) -> None: + assert [1, 2] not in sample_registry + + +@pytest.mark.unit +class TestToolRegistryDuplicate: + """Tests for duplicate tool name rejection.""" + + def test_duplicate_rejected(self, echo_test_tool: BaseTool) -> None: + with pytest.raises(ValueError, match="Duplicate tool name"): + ToolRegistry([echo_test_tool, echo_test_tool]) + + +@pytest.mark.unit +class TestToolRegistryImmutability: + """Tests for registry immutability.""" + + def test_tools_mapping_immutable(self, sample_registry: ToolRegistry) -> None: + with pytest.raises(TypeError): + sample_registry._tools["hack"] = None # type: ignore[index] diff --git a/uv.lock b/uv.lock index eb5286ea7c..fdc37ea07e 100644 --- a/uv.lock +++ b/uv.lock @@ -7,6 +7,7 @@ name = "ai-company" source = { editable = "." } dependencies = [ { name = "jinja2" }, + { name = "jsonschema" }, { name = "litellm" }, { name = "pydantic" }, { name = "pyyaml" }, @@ -44,6 +45,7 @@ test = [ [package.metadata] requires-dist = [ { name = "jinja2", specifier = "==3.1.6" }, + { name = "jsonschema", specifier = "==4.26.0" }, { name = "litellm", specifier = "==1.82.0" }, { name = "pydantic", specifier = "==2.12.5" }, { name = "pyyaml", specifier = "==6.0.3" },