Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 1 addition & 84 deletions src/strands/tools/loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Tool loading utilities."""

import importlib
import inspect
import logging
import os
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, cast
from typing import cast

from ..types.tools import AgentTool
from .decorator import DecoratedFunctionTool
Expand All @@ -15,88 +14,6 @@
logger = logging.getLogger(__name__)


def load_function_tool(func: Any) -> Optional[DecoratedFunctionTool]:
"""Load a function as a tool if it's decorated with @tool.

Args:
func: The function to load.

Returns:
FunctionTool if successful, None otherwise.
"""
logger.warning(
"issue=<%s> | load_function_tool will be removed in a future version",
"https://github.com/strands-agents/sdk-python/pull/258",
)

if isinstance(func, DecoratedFunctionTool):
return func
else:
return None


def scan_module_for_tools(module: Any) -> List[DecoratedFunctionTool]:
"""Scan a module for function-based tools.

Args:
module: The module to scan.

Returns:
List of FunctionTool instances found in the module.
"""
tools = []

for name, obj in inspect.getmembers(module):
if isinstance(obj, DecoratedFunctionTool):
# Create a function tool with correct name
try:
tools.append(obj)
except Exception as e:
logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e)

return tools


def scan_directory_for_tools(directory: Path) -> Dict[str, DecoratedFunctionTool]:
"""Scan a directory for Python modules containing function-based tools.

Args:
directory: The directory to scan.

Returns:
Dictionary mapping tool names to FunctionTool instances.
"""
tools: Dict[str, DecoratedFunctionTool] = {}

if not directory.exists() or not directory.is_dir():
return tools

for file_path in directory.glob("*.py"):
if file_path.name.startswith("_"):
continue

try:
# Dynamically import the module
module_name = file_path.stem
spec = importlib.util.spec_from_file_location(module_name, file_path)
if not spec or not spec.loader:
continue

module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

# Find tools in the module
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, DecoratedFunctionTool):
tools[attr.tool_name] = attr

except Exception as e:
logger.warning("tool_path=<%s> | failed to load tools under path | %s", file_path, e)

return tools


class ToolLoader:
"""Handles loading of tools from different sources."""

Expand Down
31 changes: 27 additions & 4 deletions src/strands/tools/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

from typing_extensions import TypedDict, cast

from strands.tools.decorator import DecoratedFunctionTool

from ..types.tools import AgentTool, Tool, ToolChoice, ToolChoiceAuto, ToolConfig, ToolSpec
from .loader import scan_module_for_tools
from .tools import PythonAgentTool, normalize_schema, normalize_tool_spec

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -84,7 +85,7 @@ def process_tools(self, tools: List[Any]) -> List[str]:
self.load_tool_from_filepath(tool_name=tool_name, tool_path=module_path)
tool_names.append(tool_name)
else:
function_tools = scan_module_for_tools(tool)
function_tools = self._scan_module_for_tools(tool)
for function_tool in function_tools:
self.register_tool(function_tool)
tool_names.append(function_tool.tool_name)
Expand Down Expand Up @@ -313,7 +314,7 @@ def reload_tool(self, tool_name: str) -> None:

# Look for function-based tools first
try:
function_tools = scan_module_for_tools(module)
function_tools = self._scan_module_for_tools(module)

if function_tools:
for function_tool in function_tools:
Expand Down Expand Up @@ -400,7 +401,7 @@ def initialize_tools(self, load_tools_from_directory: bool = True) -> None:
if tool_path.suffix == ".py":
# Check for decorated function tools first
try:
function_tools = scan_module_for_tools(module)
function_tools = self._scan_module_for_tools(module)

if function_tools:
for function_tool in function_tools:
Expand Down Expand Up @@ -592,3 +593,25 @@ def _update_tool_config(self, tool_config: Dict[str, Any], new_tool: NewToolDict
else:
tool_config["tools"].append(new_tool_entry)
logger.debug("tool_name=<%s> | added new tool", new_tool_name)

def _scan_module_for_tools(self, module: Any) -> List[AgentTool]:
"""Scan a module for function-based tools.

Args:
module: The module to scan.

Returns:
List of FunctionTool instances found in the module.
"""
tools: List[AgentTool] = []

for name, obj in inspect.getmembers(module):
if isinstance(obj, DecoratedFunctionTool):
# Create a function tool with correct name
try:
# Cast as AgentTool for mypy
tools.append(cast(AgentTool, obj))
except Exception as e:
logger.warning("tool_name=<%s> | failed to create function tool | %s", name, e)

return tools
124 changes: 0 additions & 124 deletions tests/strands/tools/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,14 @@
import os
import pathlib
import re
import textwrap
import unittest.mock

import pytest

import strands
from strands.tools.decorator import DecoratedFunctionTool
from strands.tools.loader import ToolLoader
from strands.tools.tools import PythonAgentTool


def test_load_function_tool():
@strands.tools.tool
def tool_function(a):
return a

tool = strands.tools.loader.load_function_tool(tool_function)

assert isinstance(tool, DecoratedFunctionTool)


def test_load_function_tool_no_function():
tool = strands.tools.loader.load_function_tool("no_function")

assert tool is None


def test_load_function_tool_no_spec():
def tool_function(a):
return a

tool = strands.tools.loader.load_function_tool(tool_function)

assert tool is None


def test_load_function_tool_invalid():
def tool_function(a):
return a

tool_function.TOOL_SPEC = "invalid"

tool = strands.tools.loader.load_function_tool(tool_function)

assert tool is None


def test_scan_module_for_tools():
@strands.tools.tool
def tool_function_1(a):
return a

@strands.tools.tool
def tool_function_2(b):
return b

def tool_function_3(c):
return c

def tool_function_4(d):
return d

tool_function_4.tool_spec = "invalid"

mock_module = unittest.mock.MagicMock()
mock_module.tool_function_1 = tool_function_1
mock_module.tool_function_2 = tool_function_2
mock_module.tool_function_3 = tool_function_3
mock_module.tool_function_4 = tool_function_4

tools = strands.tools.loader.scan_module_for_tools(mock_module)

assert len(tools) == 2
assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools)


def test_scan_directory_for_tools(tmp_path):
tool_definition_1 = textwrap.dedent("""
import strands

@strands.tools.tool
def tool_function_1(a):
return a
""")
tool_definition_2 = textwrap.dedent("""
import strands

@strands.tools.tool
def tool_function_2(b):
return b
""")
tool_definition_3 = textwrap.dedent("""
def tool_function_3(c):
return c
""")
tool_definition_4 = textwrap.dedent("""
def tool_function_4(d):
return d
""")
tool_definition_5 = ""
tool_definition_6 = "**invalid**"

tool_path_1 = tmp_path / "tool_1.py"
tool_path_2 = tmp_path / "tool_2.py"
tool_path_3 = tmp_path / "tool_3.py"
tool_path_4 = tmp_path / "tool_4.py"
tool_path_5 = tmp_path / "_tool_5.py"
tool_path_6 = tmp_path / "tool_6.py"

tool_path_1.write_text(tool_definition_1)
tool_path_2.write_text(tool_definition_2)
tool_path_3.write_text(tool_definition_3)
tool_path_4.write_text(tool_definition_4)
tool_path_5.write_text(tool_definition_5)
tool_path_6.write_text(tool_definition_6)

tools = strands.tools.loader.scan_directory_for_tools(tmp_path)

tru_tool_names = sorted(tools.keys())
exp_tool_names = ["tool_function_1", "tool_function_2"]

assert tru_tool_names == exp_tool_names
assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools.values())


def test_scan_directory_for_tools_does_not_exist():
tru_tools = strands.tools.loader.scan_directory_for_tools(pathlib.Path("does_not_exist"))
exp_tools = {}

assert tru_tools == exp_tools


@pytest.fixture
def tool_path(request, tmp_path, monkeypatch):
definition = request.param
Expand Down
32 changes: 32 additions & 0 deletions tests/strands/tools/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from strands.tools import PythonAgentTool
from strands.tools.decorator import DecoratedFunctionTool, tool
from strands.tools.registry import ToolRegistry


Expand Down Expand Up @@ -43,3 +44,34 @@ def test_register_tool_with_similar_name_raises():
str(err.value) == "Tool name 'tool_like_this' already exists as 'tool-like-this'. "
"Cannot add a duplicate tool which differs by a '-' or '_'"
)


def test_scan_module_for_tools():
@tool
def tool_function_1(a):
return a

@tool
def tool_function_2(b):
return b

def tool_function_3(c):
return c

def tool_function_4(d):
return d

tool_function_4.tool_spec = "invalid"

mock_module = MagicMock()
mock_module.tool_function_1 = tool_function_1
mock_module.tool_function_2 = tool_function_2
mock_module.tool_function_3 = tool_function_3
mock_module.tool_function_4 = tool_function_4

tool_registry = ToolRegistry()

tools = tool_registry._scan_module_for_tools(mock_module)

assert len(tools) == 2
assert all(isinstance(tool, DecoratedFunctionTool) for tool in tools)