Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ def tcall() -> ToolResult:

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
self._agent._record_tool_execution(
tool_use, tool_result, user_message_override, self._agent.messages
)
self._agent._record_tool_execution(tool_use, tool_result, user_message_override)

# Apply window management
self._agent.conversation_manager.apply_management(self._agent)
Expand Down Expand Up @@ -602,7 +600,6 @@ def _record_tool_execution(
tool: ToolUse,
tool_result: ToolResult,
user_message_override: Optional[str],
messages: Messages,
) -> None:
"""Record a tool execution in the message history.

Expand All @@ -617,11 +614,12 @@ def _record_tool_execution(
tool: The tool call information.
tool_result: The result returned by the tool.
user_message_override: Optional custom message to include.
messages: The message history to append to.
"""
# Create user message describing the tool call
input_parameters = json.dumps(tool["input"], default=lambda o: f"<<non-serializable: {type(o).__qualname__}>>")

user_msg_content: list[ContentBlock] = [
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
{"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")}
]

# Add override message if provided
Expand All @@ -643,7 +641,7 @@ def _record_tool_execution(
}
assistant_msg: Message = {
"role": "assistant",
"content": [{"text": f"agent.{tool['name']} was called"}],
"content": [{"text": f"agent.tool.{tool['name']} was called."}],
}

# Add to message history
Expand Down
184 changes: 184 additions & 0 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,3 +1576,187 @@ def test_agent_with_session_and_conversation_manager():
assert agent.messages == agent_2.messages
# Asser the conversation manager was initialized properly
assert agent.conversation_manager.removed_message_count == agent_2.conversation_manager.removed_message_count


def test_agent_tool_non_serializable_parameter_filtering(agent, mock_randint):
"""Test that non-serializable objects in tool parameters are properly filtered during tool call recording."""
mock_randint.return_value = 42

# Create a non-serializable object (Agent instance)
another_agent = Agent()

# This should not crash even though we're passing non-serializable objects
result = agent.tool.tool_decorated(
random_string="test_value",
non_serializable_agent=another_agent, # This would previously cause JSON serialization error
user_message_override="Testing non-serializable parameter filtering",
)

# Verify the tool executed successfully
expected_result = {
"content": [{"text": "test_value"}],
"status": "success",
"toolUseId": "tooluse_tool_decorated_42",
}
assert result == expected_result

# The key test: this should not crash during execution
# Check that we have messages recorded (exact count may vary)
assert len(agent.messages) > 0

# Check user message with filtered parameters - this is the main test for the bug fix
user_message = agent.messages[0]
assert user_message["role"] == "user"
assert len(user_message["content"]) == 2

# Check override message
assert user_message["content"][0]["text"] == "Testing non-serializable parameter filtering\n"

# Check tool call description with filtered parameters - this is where JSON serialization would fail
tool_call_text = user_message["content"][1]["text"]
assert "agent.tool.tool_decorated direct tool call." in tool_call_text
assert '"random_string": "test_value"' in tool_call_text
assert '"non_serializable_agent": "<<non-serializable: Agent>>"' in tool_call_text


def test_agent_tool_multiple_non_serializable_types(agent, mock_randint):
"""Test filtering of various non-serializable object types."""
mock_randint.return_value = 123

# Create various non-serializable objects
class CustomClass:
def __init__(self, value):
self.value = value

non_serializable_objects = {
"agent": Agent(),
"custom_object": CustomClass("test"),
"function": lambda x: x,
"set_object": {1, 2, 3},
"complex_number": 3 + 4j,
"serializable_string": "this_should_remain",
"serializable_number": 42,
"serializable_list": [1, 2, 3],
"serializable_dict": {"key": "value"},
}

# This should not crash
result = agent.tool.tool_decorated(random_string="test_filtering", **non_serializable_objects)

# Verify tool executed successfully
expected_result = {
"content": [{"text": "test_filtering"}],
"status": "success",
"toolUseId": "tooluse_tool_decorated_123",
}
assert result == expected_result

# Check the recorded message for proper parameter filtering
assert len(agent.messages) > 0
user_message = agent.messages[0]
tool_call_text = user_message["content"][0]["text"]

# Verify serializable objects remain unchanged
assert '"serializable_string": "this_should_remain"' in tool_call_text
assert '"serializable_number": 42' in tool_call_text
assert '"serializable_list": [1, 2, 3]' in tool_call_text
assert '"serializable_dict": {"key": "value"}' in tool_call_text

# Verify non-serializable objects are replaced with descriptive strings
assert '"agent": "<<non-serializable: Agent>>"' in tool_call_text
assert (
'"custom_object": "<<non-serializable: test_agent_tool_multiple_non_serializable_types.<locals>.CustomClass>>"'
in tool_call_text
)
assert '"function": "<<non-serializable: function>>"' in tool_call_text
assert '"set_object": "<<non-serializable: set>>"' in tool_call_text
assert '"complex_number": "<<non-serializable: complex>>"' in tool_call_text


def test_agent_tool_serialization_edge_cases(agent, mock_randint):
"""Test edge cases in parameter serialization filtering."""
mock_randint.return_value = 999

# Test with None values, empty containers, and nested structures
edge_case_params = {
"none_value": None,
"empty_list": [],
"empty_dict": {},
"nested_list_with_non_serializable": [1, 2, Agent()], # This should be filtered out
"nested_dict_serializable": {"nested": {"key": "value"}}, # This should remain
}

result = agent.tool.tool_decorated(random_string="edge_cases", **edge_case_params)

# Verify successful execution
expected_result = {
"content": [{"text": "edge_cases"}],
"status": "success",
"toolUseId": "tooluse_tool_decorated_999",
}
assert result == expected_result

# Check parameter filtering in recorded message
assert len(agent.messages) > 0
user_message = agent.messages[0]
tool_call_text = user_message["content"][0]["text"]

# Verify serializable values remain
assert '"none_value": null' in tool_call_text
assert '"empty_list": []' in tool_call_text
assert '"empty_dict": {}' in tool_call_text
assert '"nested_dict_serializable": {"nested": {"key": "value"}}' in tool_call_text

# Verify non-serializable nested structure is replaced
assert '"nested_list_with_non_serializable": [1, 2, "<<non-serializable: Agent>>"]' in tool_call_text


def test_agent_tool_no_non_serializable_parameters(agent, mock_randint):
"""Test that normal tool calls with only serializable parameters work unchanged."""
mock_randint.return_value = 555

# Call with only serializable parameters
result = agent.tool.tool_decorated(random_string="normal_call", user_message_override="Normal tool call test")

# Verify successful execution
expected_result = {
"content": [{"text": "normal_call"}],
"status": "success",
"toolUseId": "tooluse_tool_decorated_555",
}
assert result == expected_result

# Check message recording works normally
assert len(agent.messages) > 0
user_message = agent.messages[0]
tool_call_text = user_message["content"][1]["text"]

# Verify normal parameter serialization (no filtering needed)
assert "agent.tool.tool_decorated direct tool call." in tool_call_text
assert '"random_string": "normal_call"' in tool_call_text
# Should not contain any "<<non-serializable:" strings
assert "<<non-serializable:" not in tool_call_text


def test_agent_tool_record_direct_tool_call_disabled_with_non_serializable(agent, mock_randint):
"""Test that when record_direct_tool_call is disabled, non-serializable parameters don't cause issues."""
mock_randint.return_value = 777

# Disable tool call recording
agent.record_direct_tool_call = False

# This should work fine even with non-serializable parameters since recording is disabled
result = agent.tool.tool_decorated(
random_string="no_recording", non_serializable_agent=Agent(), user_message_override="This shouldn't be recorded"
)

# Verify successful execution
expected_result = {
"content": [{"text": "no_recording"}],
"status": "success",
"toolUseId": "tooluse_tool_decorated_777",
}
assert result == expected_result

# Verify no messages were recorded
assert len(agent.messages) == 0
Loading