Skip to content

Commit e5499db

Browse files
committed
fix: only include parameters that defined in tool spec
1 parent e00f3fd commit e5499db

File tree

2 files changed

+64
-98
lines changed

2 files changed

+64
-98
lines changed

src/strands/agent/agent.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -617,8 +617,11 @@ def _record_tool_execution(
617617
tool_result: The result returned by the tool.
618618
user_message_override: Optional custom message to include.
619619
"""
620+
# Filter tool input parameters to only include those defined in tool spec
621+
filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"])
622+
620623
# Create user message describing the tool call
621-
input_parameters = json.dumps(tool["input"], default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
624+
input_parameters = json.dumps(filtered_input, default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")
622625

623626
user_msg_content: list[ContentBlock] = [
624627
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
@@ -628,9 +631,12 @@ def _record_tool_execution(
628631
if user_message_override:
629632
user_msg_content.insert(0, {"text": f"{user_message_override}\n"})
630633

631-
sanitized_input = json.loads(input_parameters)
632-
sanitized_tool = tool.copy()
633-
sanitized_tool["input"] = sanitized_input
634+
# Create filtered tool use for message history
635+
filtered_tool: ToolUse = {
636+
"toolUseId": tool["toolUseId"],
637+
"name": tool["name"],
638+
"input": filtered_input,
639+
}
634640

635641
# Create the message sequence
636642
user_msg: Message = {
@@ -639,7 +645,7 @@ def _record_tool_execution(
639645
}
640646
tool_use_msg: Message = {
641647
"role": "assistant",
642-
"content": [{"toolUse": sanitized_tool}],
648+
"content": [{"toolUse": filtered_tool}],
643649
}
644650
tool_result_msg: Message = {
645651
"role": "user",
@@ -696,6 +702,25 @@ def _end_agent_trace_span(
696702

697703
self.tracer.end_agent_span(**trace_attributes)
698704

705+
def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]:
706+
"""Filter input parameters to only include those defined in the tool specification.
707+
708+
Args:
709+
tool_name: Name of the tool to get specification for
710+
input_params: Original input parameters
711+
712+
Returns:
713+
Filtered parameters containing only those defined in tool spec
714+
"""
715+
all_tools_config = self.tool_registry.get_all_tools_config()
716+
tool_spec = all_tools_config.get(tool_name)
717+
718+
if not tool_spec or "inputSchema" not in tool_spec:
719+
return input_params.copy()
720+
721+
properties = tool_spec["inputSchema"]["json"]["properties"]
722+
return {k: v for k, v in input_params.items() if k in properties}
723+
699724
def _append_message(self, message: Message) -> None:
700725
"""Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent."""
701726
self.messages.append(message)

tests/strands/agent/test_agent.py

Lines changed: 34 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,99 +1687,7 @@ def test_agent_tool_non_serializable_parameter_filtering(agent, mock_randint):
16871687
tool_call_text = user_message["content"][1]["text"]
16881688
assert "agent.tool.tool_decorated direct tool call." in tool_call_text
16891689
assert '"random_string": "test_value"' in tool_call_text
1690-
assert '"non_serializable_agent": "<<non-serializable: Agent>>"' in tool_call_text
1691-
1692-
1693-
def test_agent_tool_multiple_non_serializable_types(agent, mock_randint):
1694-
"""Test filtering of various non-serializable object types."""
1695-
mock_randint.return_value = 123
1696-
1697-
# Create various non-serializable objects
1698-
class CustomClass:
1699-
def __init__(self, value):
1700-
self.value = value
1701-
1702-
non_serializable_objects = {
1703-
"agent": Agent(),
1704-
"custom_object": CustomClass("test"),
1705-
"function": lambda x: x,
1706-
"set_object": {1, 2, 3},
1707-
"complex_number": 3 + 4j,
1708-
"serializable_string": "this_should_remain",
1709-
"serializable_number": 42,
1710-
"serializable_list": [1, 2, 3],
1711-
"serializable_dict": {"key": "value"},
1712-
}
1713-
1714-
# This should not crash
1715-
result = agent.tool.tool_decorated(random_string="test_filtering", **non_serializable_objects)
1716-
1717-
# Verify tool executed successfully
1718-
expected_result = {
1719-
"content": [{"text": "test_filtering"}],
1720-
"status": "success",
1721-
"toolUseId": "tooluse_tool_decorated_123",
1722-
}
1723-
assert result == expected_result
1724-
1725-
# Check the recorded message for proper parameter filtering
1726-
assert len(agent.messages) > 0
1727-
user_message = agent.messages[0]
1728-
tool_call_text = user_message["content"][0]["text"]
1729-
1730-
# Verify serializable objects remain unchanged
1731-
assert '"serializable_string": "this_should_remain"' in tool_call_text
1732-
assert '"serializable_number": 42' in tool_call_text
1733-
assert '"serializable_list": [1, 2, 3]' in tool_call_text
1734-
assert '"serializable_dict": {"key": "value"}' in tool_call_text
1735-
1736-
# Verify non-serializable objects are replaced with descriptive strings
1737-
assert '"agent": "<<non-serializable: Agent>>"' in tool_call_text
1738-
assert (
1739-
'"custom_object": "<<non-serializable: test_agent_tool_multiple_non_serializable_types.<locals>.CustomClass>>"'
1740-
in tool_call_text
1741-
)
1742-
assert '"function": "<<non-serializable: function>>"' in tool_call_text
1743-
assert '"set_object": "<<non-serializable: set>>"' in tool_call_text
1744-
assert '"complex_number": "<<non-serializable: complex>>"' in tool_call_text
1745-
1746-
1747-
def test_agent_tool_serialization_edge_cases(agent, mock_randint):
1748-
"""Test edge cases in parameter serialization filtering."""
1749-
mock_randint.return_value = 999
1750-
1751-
# Test with None values, empty containers, and nested structures
1752-
edge_case_params = {
1753-
"none_value": None,
1754-
"empty_list": [],
1755-
"empty_dict": {},
1756-
"nested_list_with_non_serializable": [1, 2, Agent()], # This should be filtered out
1757-
"nested_dict_serializable": {"nested": {"key": "value"}}, # This should remain
1758-
}
1759-
1760-
result = agent.tool.tool_decorated(random_string="edge_cases", **edge_case_params)
1761-
1762-
# Verify successful execution
1763-
expected_result = {
1764-
"content": [{"text": "edge_cases"}],
1765-
"status": "success",
1766-
"toolUseId": "tooluse_tool_decorated_999",
1767-
}
1768-
assert result == expected_result
1769-
1770-
# Check parameter filtering in recorded message
1771-
assert len(agent.messages) > 0
1772-
user_message = agent.messages[0]
1773-
tool_call_text = user_message["content"][0]["text"]
1774-
1775-
# Verify serializable values remain
1776-
assert '"none_value": null' in tool_call_text
1777-
assert '"empty_list": []' in tool_call_text
1778-
assert '"empty_dict": {}' in tool_call_text
1779-
assert '"nested_dict_serializable": {"nested": {"key": "value"}}' in tool_call_text
1780-
1781-
# Verify non-serializable nested structure is replaced
1782-
assert '"nested_list_with_non_serializable": [1, 2, "<<non-serializable: Agent>>"]' in tool_call_text
1690+
assert '"non_serializable_agent": "<<non-serializable: Agent>>"' not in tool_call_text
17831691

17841692

17851693
def test_agent_tool_no_non_serializable_parameters(agent, mock_randint):
@@ -1831,3 +1739,36 @@ def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent
18311739

18321740
# Verify no messages were recorded
18331741
assert len(agent.messages) == 0
1742+
1743+
1744+
def test_agent_tool_call_parameter_filtering_integration(mock_randint):
1745+
"""Test that tool calls properly filter parameters in message recording."""
1746+
mock_randint.return_value = 42
1747+
1748+
@strands.tool
1749+
def test_tool(action: str) -> str:
1750+
"""Test tool with single parameter."""
1751+
return action
1752+
1753+
agent = Agent(tools=[test_tool])
1754+
1755+
# Call tool with extra non-spec parameters
1756+
result = agent.tool.test_tool(
1757+
action="test_value",
1758+
agent=agent, # Should be filtered out
1759+
extra_param="filtered", # Should be filtered out
1760+
)
1761+
1762+
# Verify tool executed successfully
1763+
assert result["status"] == "success"
1764+
assert result["content"] == [{"text": "test_value"}]
1765+
1766+
# Check that only spec parameters are recorded in message history
1767+
assert len(agent.messages) > 0
1768+
user_message = agent.messages[0]
1769+
tool_call_text = user_message["content"][0]["text"]
1770+
1771+
# Should only contain the 'action' parameter
1772+
assert '"action": "test_value"' in tool_call_text
1773+
assert '"agent"' not in tool_call_text
1774+
assert '"extra_param"' not in tool_call_text

0 commit comments

Comments
 (0)