Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/strands/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from .decorator import tool
from .structured_output import convert_pydantic_to_tool_spec
from .thread_pool_executor import ThreadPoolExecutorWrapper
from .tools import FunctionTool, InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec
from .tools import InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec

__all__ = [
"tool",
"FunctionTool",
"PythonAgentTool",
"InvalidToolUseNameException",
"normalize_schema",
Expand Down
131 changes: 1 addition & 130 deletions src/strands/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
Python module-based tools, as well as utilities for validating tool uses and normalizing tool schemas.
"""

import inspect
import logging
import re
from typing import Any, Callable, Dict, Optional, cast

from typing_extensions import Unpack
from typing import Any, Callable, Dict

from ..types.tools import AgentTool, ToolResult, ToolSpec, ToolUse

Expand Down Expand Up @@ -144,132 +141,6 @@ def normalize_tool_spec(tool_spec: ToolSpec) -> ToolSpec:
return normalized


class FunctionTool(AgentTool):
"""Tool implementation for function-based tools created with @tool.

This class adapts Python functions decorated with @tool to the AgentTool interface.
"""

def __new__(cls, *args: Any, **kwargs: Any) -> Any:
"""Compatability shim to allow callers to continue working after the introduction of DecoratedFunctionTool."""
if isinstance(args[0], AgentTool):
return args[0]

return super().__new__(cls)

def __init__(self, func: Callable[[ToolUse, Unpack[Any]], ToolResult], tool_name: Optional[str] = None) -> None:
"""Initialize a function-based tool.

Args:
func: The decorated function.
tool_name: Optional tool name (defaults to function name).

Raises:
ValueError: If func is not decorated with @tool.
"""
super().__init__()

self._func = func

# Get TOOL_SPEC from the decorated function
if hasattr(func, "TOOL_SPEC") and isinstance(func.TOOL_SPEC, dict):
self._tool_spec = cast(ToolSpec, func.TOOL_SPEC)
# Use name from tool spec if available, otherwise use function name or passed tool_name
name = self._tool_spec.get("name", tool_name or func.__name__)
if isinstance(name, str):
self._name = name
else:
raise ValueError(f"Tool name must be a string, got {type(name)}")
else:
raise ValueError(f"Function {func.__name__} is not decorated with @tool")

@property
def tool_name(self) -> str:
"""Get the name of the tool.

Returns:
The name of the tool.
"""
return self._name

@property
def tool_spec(self) -> ToolSpec:
"""Get the tool specification for this function-based tool.

Returns:
The tool specification.
"""
return self._tool_spec

@property
def tool_type(self) -> str:
"""Get the type of the tool.

Returns:
The string "function" indicating this is a function-based tool.
"""
return "function"

@property
def supports_hot_reload(self) -> bool:
"""Check if this tool supports automatic reloading when modified.

Returns:
Always true for function-based tools.
"""
return True

def invoke(self, tool: ToolUse, *args: Any, **kwargs: Any) -> ToolResult:
"""Execute the function with the given tool use request.

Args:
tool: The tool use request containing the tool name, ID, and input parameters.
*args: Additional positional arguments to pass to the function.
**kwargs: Additional keyword arguments to pass to the function.

Returns:
A ToolResult containing the status and content from the function execution.
"""
# Make sure to pass through all kwargs, including 'agent' if provided
try:
# Check if the function accepts agent as a keyword argument
sig = inspect.signature(self._func)
if "agent" in sig.parameters:
# Pass agent if function accepts it
return self._func(tool, **kwargs)
else:
# Skip passing agent if function doesn't accept it
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "agent"}
return self._func(tool, **filtered_kwargs)
except Exception as e:
return {
"toolUseId": tool.get("toolUseId", "unknown"),
"status": "error",
"content": [{"text": f"Error executing function: {str(e)}"}],
}

@property
def original_function(self) -> Callable:
"""Get the original function (without wrapper).

Returns:
Undecorated function.
"""
if hasattr(self._func, "original_function"):
return cast(Callable, self._func.original_function)
return self._func

def get_display_properties(self) -> dict[str, str]:
"""Get properties to display in UI representations.

Returns:
Function properties (e.g., function name).
"""
properties = super().get_display_properties()
properties["Function"] = self.original_function.__name__
return properties


class PythonAgentTool(AgentTool):
"""Tool implementation for Python-based tools.

Expand Down
6 changes: 2 additions & 4 deletions tests-integ/test_model_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import strands
from strands import Agent, tool
from strands.models import BedrockModel

if "OPENAI_API_KEY" not in os.environ:
pytest.skip(allow_module_level=True, reason="OPENAI_API_KEY environment variable missing")
Expand Down Expand Up @@ -67,9 +68,6 @@ class Weather(BaseModel):
assert result.weather == "sunny"


@pytest.skip(
reason="OpenAI provider cannot use tools that return images - https://github.com/strands-agents/sdk-python/issues/320"
)
def test_tool_returning_images(model, test_image_path):
@tool
def tool_with_image_return():
Expand All @@ -88,7 +86,7 @@ def tool_with_image_return():
],
}

agent = Agent(model=model, tools=[tool_with_image_return])
agent = Agent(model=BedrockModel(), tools=[tool_with_image_return])
# NOTE - this currently fails with: "Invalid 'messages[3]'. Image URLs are only allowed for messages with role
# 'user', but this message with role 'tool' contains an image URL."
# See https://github.com/strands-agents/sdk-python/issues/320 for additional details
Expand Down
12 changes: 4 additions & 8 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,8 @@ def tool_imported():

@pytest.fixture
def tool(tool_decorated, tool_registry):
function_tool = strands.tools.tools.FunctionTool(tool_decorated, tool_name="tool_decorated")
tool_registry.register_tool(function_tool)

return function_tool
tool_registry.register_tool(tool_decorated)
return tool_decorated


@pytest.fixture
Expand Down Expand Up @@ -156,8 +154,7 @@ def agent(
# Only register the tool directly if tools wasn't parameterized
if not hasattr(request, "param") or request.param is None:
# Create a new function tool directly from the decorated function
function_tool = strands.tools.tools.FunctionTool(tool_decorated, tool_name="tool_decorated")
agent.tool_registry.register_tool(function_tool)
agent.tool_registry.register_tool(tool_decorated)

return agent

Expand Down Expand Up @@ -810,8 +807,7 @@ def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint):
def function(system_prompt: str) -> str:
return system_prompt

tool = strands.tools.tools.FunctionTool(function)
agent.tool_registry.register_tool(tool)
agent.tool_registry.register_tool(function)

mock_randint.return_value = 1

Expand Down
5 changes: 2 additions & 3 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,9 @@ def tool(tool_registry):
def tool_for_testing(random_string: str) -> str:
return random_string

function_tool = strands.tools.tools.FunctionTool(tool_for_testing)
tool_registry.register_tool(function_tool)
tool_registry.register_tool(tool_for_testing)

return function_tool
return tool_for_testing


@pytest.fixture
Expand Down
16 changes: 1 addition & 15 deletions tests/strands/handlers/test_tool_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,11 @@ def tool_use_identity(tool_registry):
def identity(a: int) -> int:
return a

identity_tool = strands.tools.tools.FunctionTool(identity)
tool_registry.register_tool(identity_tool)
tool_registry.register_tool(identity)

return {"toolUseId": "identity", "name": "identity", "input": {"a": 1}}


@pytest.fixture
def tool_use_error(tool_registry):
def error():
return

error.TOOL_SPEC = {"invalid": True}

error_tool = strands.tools.tools.FunctionTool(error)
tool_registry.register_tool(error_tool)

return {"toolUseId": "error", "name": "error", "input": {}}


def test_process(tool_handler, tool_use_identity):
tru_result = tool_handler.process(
tool_use_identity,
Expand Down
47 changes: 5 additions & 42 deletions tests/strands/tools/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import strands
from strands.tools.tools import (
FunctionTool,
InvalidToolUseNameException,
PythonAgentTool,
normalize_schema,
Expand Down Expand Up @@ -408,13 +407,9 @@ def identity(a: int) -> int:


@pytest.fixture
def tool_function(function):
return strands.tools.tool(function)


@pytest.fixture
def tool(tool_function):
return FunctionTool(tool_function, tool_name="identity")
def tool(function):
tool_function = strands.tools.tool(function)
return tool_function


def test__init__invalid_name():
Expand Down Expand Up @@ -476,9 +471,7 @@ def test_original_function_not_decorated():
def identity(a: int):
return a

identity.TOOL_SPEC = {}

tool = FunctionTool(identity, tool_name="identity")
tool = strands.tool(func=identity, name="identity")

tru_name = tool.original_function.__name__
exp_name = "identity"
Expand Down Expand Up @@ -509,39 +502,9 @@ def test_invoke_with_agent():
def identity(a: int, agent: dict = None):
return a, agent

tool = FunctionTool(identity, tool_name="identity")
# FunctionTool is a pass through for AgentTool instances until we remove it in a future release (#258)
assert tool == identity

exp_output = {"toolUseId": "unknown", "status": "success", "content": [{"text": "(2, {'state': 1})"}]}

tru_output = tool.invoke({"input": {"a": 2}}, agent={"state": 1})

assert tru_output == exp_output


def test_invoke_exception():
def identity(a: int):
return a

identity.TOOL_SPEC = {}

tool = FunctionTool(identity, tool_name="identity")

tru_output = tool.invoke({}, invalid=1)
exp_output = {
"toolUseId": "unknown",
"status": "error",
"content": [
{
"text": (
"Error executing function: "
"test_invoke_exception.<locals>.identity() "
"got an unexpected keyword argument 'invalid'"
)
}
],
}
tru_output = identity.invoke({"input": {"a": 2}}, agent={"state": 1})

assert tru_output == exp_output

Expand Down