Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def my_tool(param1: str, param2: int = 42) -> dict:
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
overload,
)
Expand Down Expand Up @@ -103,6 +105,35 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
doc_str = inspect.getdoc(func) or ""
self.doc = docstring_parser.parse(doc_str)

def _contains_tool_context(tp: Any) -> bool:
"""Return True if the annotation `tp` (possibly Union/Optional) includes ToolContext."""
if tp is None:
return False
origin = get_origin(tp)
if origin is Union:
return any(_contains_tool_context(a) for a in get_args(tp))
# Handle direct ToolContext type
return tp is ToolContext

for param in self.signature.parameters.values():
# Prefer resolved type hints (handles forward refs); fall back to annotation
ann = self.type_hints.get(param.name, param.annotation)
if ann is inspect._empty:
continue

if _contains_tool_context(ann):
# If decorator didn't opt-in to context injection, complain
if self._context_param is None:
raise TypeError(
f"Parameter '{param.name}' is of type 'ToolContext' but '@tool(context=True)' is missing."
)
# If decorator specified a different param name, complain
if param.name != self._context_param:
raise TypeError(
f"Parameter '{param.name}' is of type 'ToolContext' but has the wrong name. "
f"It should be named '{self._context_param}'."
)

# Get parameter descriptions from parsed docstring
self.param_descriptions = {
param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params
Expand Down
41 changes: 41 additions & 0 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,3 +1363,44 @@ async def async_generator() -> AsyncGenerator:
]

assert act_results == exp_results


def test_tool_with_mismatched_tool_context_param_name_raises_error():
"""Verify that a TypeError is raised for a mismatched tool_context parameter name."""
with pytest.raises(TypeError) as excinfo:

@strands.tool(context=True)
def my_tool(context: ToolContext):
pass

assert (
"Parameter 'context' is of type 'ToolContext' but has the wrong name. It should be named 'tool_context'."
in str(excinfo.value)
)


def test_tool_with_tool_context_but_no_context_flag_raises_error():
"""Verify that a TypeError is raised if ToolContext is used without context=True."""
with pytest.raises(TypeError) as excinfo:

@strands.tool
def my_tool(tool_context: ToolContext):
pass

assert "Parameter 'tool_context' is of type 'ToolContext' but '@tool(context=True)' is missing." in str(
excinfo.value
)


def test_tool_with_tool_context_named_custom_context_raises_error_if_mismatched():
"""Verify that a TypeError is raised when context param name doesn't match the decorator value."""
with pytest.raises(TypeError) as excinfo:

@strands.tool(context="my_context")
def my_tool(tool_context: ToolContext):
pass

assert (
"Parameter 'tool_context' is of type 'ToolContext' but has the wrong name. It should be named 'my_context'."
in str(excinfo.value)
)