diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 99aa7e372..72109dbef 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -99,6 +99,8 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - self.type_hints = get_type_hints(func) self._context_param = context_param + self._validate_signature() + # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) @@ -111,6 +113,20 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - # Create a Pydantic model for validation self.input_model = self._create_input_model() + def _validate_signature(self) -> None: + """Verify that ToolContext is used correctly in the function signature.""" + for param in self.signature.parameters.values(): + if param.annotation is ToolContext: + if self._context_param is None: + raise ValueError("@tool(context) must be set if passing in ToolContext param") + + if param.name != self._context_param: + raise ValueError( + f"param_name=<{param.name}> | ToolContext param must be named '{self._context_param}'" + ) + # Found the parameter, no need to check further + break + def _create_input_model(self) -> Type[BaseModel]: """Create a Pydantic model from function signature for input validation. diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 5b4b5cdda..658a34052 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -1363,3 +1363,27 @@ async def async_generator() -> AsyncGenerator: ] assert act_results == exp_results + + +def test_function_tool_metadata_validate_signature_default_context_name_mismatch(): + with pytest.raises(ValueError, match=r"param_name= | ToolContext param must be named 'tool_context'"): + + @strands.tool(context=True) + def my_tool(context: ToolContext): + pass + + +def test_function_tool_metadata_validate_signature_custom_context_name_mismatch(): + with pytest.raises(ValueError, match=r"param_name= | ToolContext param must be named 'my_context'"): + + @strands.tool(context="my_context") + def my_tool(tool_context: ToolContext): + pass + + +def test_function_tool_metadata_validate_signature_missing_context_config(): + with pytest.raises(ValueError, match=r"@tool\(context\) must be set if passing in ToolContext param"): + + @strands.tool + def my_tool(tool_context: ToolContext): + pass