From ef16f73d752eab5086d05e7da94facd479246e13 Mon Sep 17 00:00:00 2001 From: Cagatay Cali Date: Sun, 18 May 2025 20:40:37 -0400 Subject: [PATCH 1/6] chore: bump version from 0.1.0 to 0.1.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6c143de3..3fd0bd23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents-tools" -version = "0.1.0" +version = "0.1.1" description = "A collection of specialized tools for Strands Agents" readme = "README.md" requires-python = ">=3.10" From ad98a625c071086fc99f5ac3451e2828a0565dff Mon Sep 17 00:00:00 2001 From: Cagatay Cali Date: Mon, 19 May 2025 00:12:04 -0400 Subject: [PATCH 2/6] fix: improve tool interoperability and remove dependencies This commit enhances multiple tool components to better work together: - feat(think): inherit parent agent's traces and tools to maintain context - fix(load_tool): remove unnecessary hot_reload_tools dependency check - fix(use_llm): properly pass trace_attributes from parent agent to new instances - style(mem0_memory): improve code formatting and readability - test: update tests to match new implementation patterns --- src/strands_tools/load_tool.py | 6 +-- src/strands_tools/mem0_memory.py | 37 +++++------------- src/strands_tools/think.py | 66 +++++++++++++++++++------------- src/strands_tools/use_llm.py | 9 ++--- tests/test_think.py | 50 +++++++++++++++--------- tests/test_use_llm.py | 8 +++- 6 files changed, 92 insertions(+), 84 deletions(-) diff --git a/src/strands_tools/load_tool.py b/src/strands_tools/load_tool.py index fd72f695..0bf52207 100644 --- a/src/strands_tools/load_tool.py +++ b/src/strands_tools/load_tool.py @@ -79,7 +79,6 @@ def load_tool(path: str, name: str, agent=None) -> Dict[str, Any]: Tool Loading Process: ------------------- - - First, checks if dynamic loading is permitted (hot_reload_tools=True) - Expands the path to handle user paths with tilde (~) - Validates that the file exists at the specified path - Uses the tool_registry's load_tool_from_filepath method to: @@ -175,7 +174,6 @@ def my_custom_tool(tool: ToolUse, **kwargs: Any) -> ToolResult: Notes: - The tool loading can be disabled via STRANDS_DISABLE_LOAD_TOOL=true environment variable - - The Agent instance must have hot_reload_tools=True to enable dynamic loading - Python files in the cwd()/tools/ directory are automatically hot reloaded without requiring explicit calls to load_tool - When using the load_tool function, ensure your tool files have proper docstrings as they are @@ -187,8 +185,8 @@ def my_custom_tool(tool: ToolUse, **kwargs: Any) -> ToolResult: current_agent = agent try: - # Check if dynamic tool loading is disabled via environment variable or agent.hot_reload_tools. - if not current_agent.hot_reload_tools or os.environ.get("STRANDS_DISABLE_LOAD_TOOL", "").lower() == "true": + # Check if dynamic tool loading is disabled via environment variable. + if os.environ.get("STRANDS_DISABLE_LOAD_TOOL", "").lower() == "true": logger.warning("Dynamic tool loading is disabled via STRANDS_DISABLE_LOAD_TOOL=true") return {"status": "error", "content": [{"text": "⚠️ Dynamic tool loading is disabled in production mode."}]} diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py index 41d94b9d..212d85cf 100644 --- a/src/strands_tools/mem0_memory.py +++ b/src/strands_tools/mem0_memory.py @@ -141,19 +141,10 @@ "required": ["action"], "allOf": [ { - "if": { - "properties": { - "action": {"enum": ["store", "list", "retrieve"]} - } - }, - "then": { - "oneOf": [ - {"required": ["user_id"]}, - {"required": ["agent_id"]} - ] - } + "if": {"properties": {"action": {"enum": ["store", "list", "retrieve"]}}}, + "then": {"oneOf": [{"required": ["user_id"]}, {"required": ["agent_id"]}]}, } - ] + ], } }, } @@ -536,7 +527,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: return ToolResult( toolUseId=tool_use_id, status="success", - content=[ToolResultContent(text=f"Successfully stored {len(results.get('results', []))} memories")] + content=[ToolResultContent(text=f"Successfully stored {len(results.get('results', []))} memories")], ) elif action == "get": @@ -547,9 +538,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: panel = format_get_response(memory) console.print(panel) return ToolResult( - toolUseId=tool_use_id, - status="success", - content=[ToolResultContent(text=json.dumps(memory, indent=2))] + toolUseId=tool_use_id, status="success", content=[ToolResultContent(text=json.dumps(memory, indent=2))] ) elif action == "list": @@ -559,7 +548,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: return ToolResult( toolUseId=tool_use_id, status="success", - content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))] + content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))], ) elif action == "retrieve": @@ -576,7 +565,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: return ToolResult( toolUseId=tool_use_id, status="success", - content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))] + content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))], ) elif action == "delete": @@ -589,7 +578,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: return ToolResult( toolUseId=tool_use_id, status="success", - content=[ToolResultContent(text=f"Memory {tool_input['memory_id']} deleted successfully")] + content=[ToolResultContent(text=f"Memory {tool_input['memory_id']} deleted successfully")], ) elif action == "history": @@ -600,9 +589,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: panel = format_history_response(history) console.print(panel) return ToolResult( - toolUseId=tool_use_id, - status="success", - content=[ToolResultContent(text=json.dumps(history, indent=2))] + toolUseId=tool_use_id, status="success", content=[ToolResultContent(text=json.dumps(history, indent=2))] ) else: @@ -615,8 +602,4 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: border_style="red", ) console.print(error_panel) - return ToolResult( - toolUseId=tool_use_id, - status="error", - content=[ToolResultContent(text=f"Error: {str(e)}")] - ) + return ToolResult(toolUseId=tool_use_id, status="error", content=[ToolResultContent(text=f"Error: {str(e)}")]) diff --git a/src/strands_tools/think.py b/src/strands_tools/think.py index a1c210e3..ec3ab5a5 100644 --- a/src/strands_tools/think.py +++ b/src/strands_tools/think.py @@ -9,9 +9,9 @@ Usage with Strands Agent: ```python from strands import Agent -from strands_tools import think +from strands_tools import think, stop -agent = Agent(tools=[think]) +agent = Agent(tools=[think, stop]) # Basic usage with default system prompt result = agent.tool.think( @@ -32,13 +32,15 @@ See the think function docstring for more details on configuration options and parameters. """ +import logging import traceback import uuid from typing import Any, Dict -from strands import tool +from strands import Agent, tool +from strands.telemetry.metrics import metrics_to_string -from strands_tools.use_llm import use_llm +logger = logging.getLogger(__name__) class ThoughtProcessor: @@ -77,36 +79,46 @@ def process_cycle( ) -> str: """Process a single thinking cycle.""" + logger.debug(f"🧠 Thinking Cycle {cycle}/{total_cycles}: Processing cycle...") print(f"🧠 Thinking Cycle {cycle}/{total_cycles}: Processing cycle...") # Create cycle-specific prompt prompt = self.create_thinking_prompt(thought, cycle, total_cycles) - # Use LLM for processing - result = use_llm( - { - "name": "use_llm", - "toolUseId": self.tool_use_id, - "input": { - "system_prompt": custom_system_prompt, - "prompt": prompt, - }, - }, - **kwargs, - ) - - # Extract and return response - cycle_response = "" - if result.get("status") == "success": - for content in result.get("content", []): - if content.get("text"): - cycle_response += content["text"] + "\n" - - return cycle_response.strip() + # Display input prompt + logger.debug(f"\n--- Input Prompt ---\n{prompt}\n") + + # Get tools from parent agent if available + tools = [] + trace_attributes = {} + parent_agent = kwargs.get("agent") + if parent_agent: + tools = list(parent_agent.tool_registry.registry.values()) + trace_attributes = parent_agent.trace_attributes + + # Initialize the new Agent with provided parameters + agent = Agent(messages=[], tools=tools, system_prompt=custom_system_prompt, trace_attributes=trace_attributes) + + # Run the agent with the provided prompt + result = agent(prompt) + + # Extract response + assistant_response = str(result) + + # Display assistant response + logger.debug(f"\n--- Assistant Response ---\n{assistant_response.strip()}\n") + + # Print metrics if available + if result.metrics: + metrics = result.metrics + metrics_text = metrics_to_string(metrics) + logger.debug(metrics_text) + + return assistant_response.strip() @tool -def think(thought: str, cycle_count: int, system_prompt: str, **kwargs: Any) -> Dict[str, Any]: +def think(thought: str, cycle_count: int, system_prompt: str, agent: Any) -> Dict[str, Any]: """ Recursive thinking tool for sophisticated thought generation, learning, and self-reflection. @@ -172,7 +184,7 @@ def think(thought: str, cycle_count: int, system_prompt: str, **kwargs: Any) -> custom_system_prompt = ( "You are an expert analytical thinker. Process the thought deeply and provide clear insights." ) - + kwargs = {"agent": agent} # Create thought processor instance with the available context processor = ThoughtProcessor(kwargs) diff --git a/src/strands_tools/use_llm.py b/src/strands_tools/use_llm.py index a294efad..0807b052 100644 --- a/src/strands_tools/use_llm.py +++ b/src/strands_tools/use_llm.py @@ -117,9 +117,12 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult: tool_system_prompt = tool_input.get("system_prompt") tools = [] + trace_attributes = {} + parent_agent = kwargs.get("agent") if parent_agent: tools = list(parent_agent.tool_registry.registry.values()) + trace_attributes = parent_agent.trace_attributes # Display input prompt logger.debug(f"\n--- Input Prompt ---\n{prompt}\n") @@ -128,11 +131,7 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult: logger.debug("🔄 Creating new LLM instance...") # Initialize the new Agent with provided parameters - agent = Agent( - messages=[], - tools=tools, - system_prompt=tool_system_prompt, - ) + agent = Agent(messages=[], tools=tools, system_prompt=tool_system_prompt, trace_attributes=trace_attributes) # Run the agent with the provided prompt result = agent(prompt) diff --git a/tests/test_think.py b/tests/test_think.py index a33e25e9..77882b06 100644 --- a/tests/test_think.py +++ b/tests/test_think.py @@ -2,8 +2,9 @@ Tests for the think tool using the Agent interface. """ -from unittest.mock import patch +from unittest.mock import MagicMock, patch +from strands.agent import AgentResult from strands_tools import think from strands_tools.think import ThoughtProcessor @@ -28,13 +29,17 @@ def test_think_tool_direct(): }, } - # Mock use_llm function since we don't want to actually call the LLM - with patch("strands_tools.think.use_llm") as mock_use_llm: - # Setup mock response - mock_use_llm.return_value = { - "status": "success", - "content": [{"text": "This is a mock analysis of quantum computing."}], - } + # Mock Agent class since we don't want to actually call the LLM + with patch("strands_tools.think.Agent") as mock_agent_class: + # Setup mock agent and response + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "This is a mock analysis of quantum computing."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result # Call the think function directly tool_input = tool_use.get("input", {}) @@ -42,6 +47,7 @@ def test_think_tool_direct(): thought=tool_input.get("thought"), cycle_count=tool_input.get("cycle_count"), system_prompt=tool_input.get("system_prompt"), + agent=None, ) # Verify the result has the expected structure @@ -49,8 +55,8 @@ def test_think_tool_direct(): assert "Cycle 1/2" in result["content"][0]["text"] assert "Cycle 2/2" in result["content"][0]["text"] - # Verify use_llm was called twice (once for each cycle) - assert mock_use_llm.call_count == 2 + # Verify Agent was called twice (once for each cycle) + assert mock_agent.call_count == 2 def test_think_one_cycle(): @@ -65,22 +71,27 @@ def test_think_one_cycle(): }, } - with patch("strands_tools.think.use_llm") as mock_use_llm: - mock_use_llm.return_value = { - "status": "success", - "content": [{"text": "Analysis for single cycle."}], - } + with patch("strands_tools.think.Agent") as mock_agent_class: + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "Analysis for single cycle."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result tool_input = tool_use.get("input", {}) result = think.think( thought=tool_input.get("thought"), cycle_count=tool_input.get("cycle_count"), system_prompt=tool_input.get("system_prompt"), + agent=None, ) assert result["status"] == "success" assert "Cycle 1/1" in result["content"][0]["text"] - assert mock_use_llm.call_count == 1 + assert mock_agent.call_count == 1 def test_think_error_handling(): @@ -95,15 +106,16 @@ def test_think_error_handling(): }, } - with patch("strands_tools.think.use_llm") as mock_use_llm: - # Make use_llm raise an exception - mock_use_llm.side_effect = Exception("Test error") + with patch("strands_tools.think.Agent") as mock_agent_class: + # Make Agent raise an exception + mock_agent_class.side_effect = Exception("Test error") tool_input = tool_use.get("input", {}) result = think.think( thought=tool_input.get("thought"), cycle_count=tool_input.get("cycle_count"), system_prompt=tool_input.get("system_prompt"), + agent=None, ) assert result["status"] == "error" diff --git a/tests/test_use_llm.py b/tests/test_use_llm.py index c978c781..ccb6eb8c 100644 --- a/tests/test_use_llm.py +++ b/tests/test_use_llm.py @@ -59,7 +59,9 @@ def test_use_llm_tool_direct(mock_agent_response): assert "This is a test response from the LLM" in str(result) # Verify the Agent was created with the correct parameters - MockAgent.assert_called_once_with(messages=[], tools=[], system_prompt="You are a helpful test assistant") + MockAgent.assert_called_once_with( + messages=[], tools=[], system_prompt="You are a helpful test assistant", trace_attributes={} + ) def test_use_llm_with_custom_system_prompt(mock_agent_response): @@ -82,7 +84,9 @@ def test_use_llm_with_custom_system_prompt(mock_agent_response): result = use_llm.use_llm(tool=tool_use) # Verify agent was created with correct system prompt - MockAgent.assert_called_once_with(messages=[], tools=[], system_prompt="You are a specialized test assistant") + MockAgent.assert_called_once_with( + messages=[], tools=[], system_prompt="You are a specialized test assistant", trace_attributes={} + ) assert result["status"] == "success" assert "Custom response" in result["content"][0]["text"] From 6ec271365ae7bc41872ecef6c6589da7cc537916 Mon Sep 17 00:00:00 2001 From: Cagatay Cali Date: Mon, 19 May 2025 00:51:29 -0400 Subject: [PATCH 3/6] fix(slack): add missing trace_attributes to Agent initialization --- src/strands_tools/slack.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/strands_tools/slack.py b/src/strands_tools/slack.py index 5dccb562..2f3cc879 100644 --- a/src/strands_tools/slack.py +++ b/src/strands_tools/slack.py @@ -351,12 +351,14 @@ def _process_message(self, event): return tools = list(self.agent.tool_registry.registry.values()) + trace_attributes = self.agent.trace_attributes agent = Agent( messages=[], system_prompt=f"{self.agent.system_prompt}\n{SLACK_SYSTEM_PROMPT}", tools=tools, callback_handler=None, + trace_attributes=trace_attributes, ) channel_id = event.get("channel") From 2d0a00983e3f8b424468e0170935c3052413f1ad Mon Sep 17 00:00:00 2001 From: Cagatay Cali Date: Tue, 20 May 2025 08:48:19 -0400 Subject: [PATCH 4/6] docs(memory): update examples to use agent.tool.memory syntax --- src/strands_tools/memory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/strands_tools/memory.py b/src/strands_tools/memory.py index b4989932..585906c7 100644 --- a/src/strands_tools/memory.py +++ b/src/strands_tools/memory.py @@ -42,7 +42,7 @@ agent = Agent(tools=[memory]) # Store content in Knowledge Base -agent.memory( +agent.tool.memory( action="store", content="Important information to remember", title="Meeting Notes", @@ -50,7 +50,7 @@ ) # Retrieve content using semantic search -agent.memory( +agent.tool.memory( action="retrieve", query="meeting information", min_score=0.7, @@ -58,7 +58,7 @@ ) # List all documents -agent.memory( +agent.tool.memory( action="list", max_results=50, STRANDS_KNOWLEDGE_BASE_ID="my1234kb" From 116c2ae0ce77d7bdbb85d0996d9120e3775f72a7 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Tue, 20 May 2025 17:23:18 -0400 Subject: [PATCH 5/6] fix: Rewrite environment tests to properly mock dependencies (#19) Mock os.environ for test_environment by using a fixture, eliminating the need to worry about the real os environment. Properly mock the get_user_input function by using a fixture as well and having environment import user_input as a module rather than importing the function directly. This is the more an important change as previously the user input wasn't being mocked in the tests - all tests were passing as the code paths didn't actually need "y". Co-authored-by: Mackenzie Zastrow --- src/strands_tools/environment.py | 7 +- tests/test_environment.py | 270 ++++++++++++++----------------- 2 files changed, 125 insertions(+), 152 deletions(-) diff --git a/src/strands_tools/environment.py b/src/strands_tools/environment.py index 7d0499c1..ea0562b7 100644 --- a/src/strands_tools/environment.py +++ b/src/strands_tools/environment.py @@ -71,8 +71,7 @@ from rich.text import Text from strands.types.tools import ToolResult, ToolResultContent, ToolUse -from strands_tools.utils import console_util -from strands_tools.utils.user_input import get_user_input +from strands_tools.utils import console_util, user_input TOOL_SPEC = { "name": "environment", @@ -584,7 +583,7 @@ def environment(tool: ToolUse, **kwargs: Any) -> ToolResult: ) # Ask for confirmation - confirm = get_user_input( + confirm = user_input.get_user_input( "\nDo you want to proceed with setting this environment variable? [y/*]" ) # For tests, 'y' should be recognized even with extra spaces or newlines @@ -706,7 +705,7 @@ def environment(tool: ToolUse, **kwargs: Any) -> ToolResult: ) # Ask for confirmation - confirm = get_user_input( + confirm = user_input.get_user_input( "\nDo you want to proceed with deleting this environment variable? [y/*]" ) # For tests, 'y' should be recognized even with extra spaces or newlines diff --git a/tests/test_environment.py b/tests/test_environment.py index 911ef75d..921276e3 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -3,28 +3,33 @@ """ import os +from unittest import mock import pytest from strands import Agent from strands_tools import environment +from strands_tools.utils import user_input @pytest.fixture def agent(): """Create an agent with the environment tool loaded.""" - return Agent(tools=[environment]) + return Agent(tools=[environment], load_tools_from_directory=False) -@pytest.fixture -def test_env_var(): - """Create and clean up a test environment variable.""" - var_name = "TEST_ENV_VAR" - var_value = "test_value" - os.environ[var_name] = var_value - yield var_name, var_value - # Clean up: remove test variable if it exists - if var_name in os.environ: - del os.environ[var_name] +@pytest.fixture(autouse=True) +def get_user_input(): + with mock.patch.object(user_input, "get_user_input") as mocked_user_input: + # By default all tests will return deny + mocked_user_input.return_value = "n" + yield mocked_user_input + + +@pytest.fixture(autouse=True) +def os_environment(): + mock_env = {} + with mock.patch.object(os, "environ", mock_env): + yield mock_env def extract_result_text(result): @@ -42,18 +47,24 @@ def test_direct_list_action(agent): assert len(extract_result_text(result)) > 0 -def test_direct_list_with_prefix(agent, test_env_var): +def test_direct_list_with_prefix(agent, os_environment): """Test listing environment variables with a specific prefix.""" - var_name, _ = test_env_var + var_name = "TEST_ENV_VAR" + var_value = "test_value" + os_environment[var_name] = var_value + result = agent.tool.environment(action="list", prefix=var_name[:4]) assert result["status"] == "success" # Verify our test variable is in the result assert var_name in extract_result_text(result) -def test_direct_get_existing_var(agent, test_env_var): +def test_direct_get_existing_var(agent, os_environment): """Test getting an existing environment variable.""" - var_name, var_value = test_env_var + var_name = "TEST_ENV_VAR" + var_value = "test_value" + os_environment[var_name] = var_value + result = agent.tool.environment(action="get", name=var_name) assert result["status"] == "success" assert var_name in extract_result_text(result) @@ -67,10 +78,9 @@ def test_direct_get_nonexistent_var(agent): assert "not found" in extract_result_text(result) -def test_direct_set_protected_var(agent, monkeypatch): +def test_direct_set_protected_var(agent, os_environment): """Test attempting to set a protected environment variable.""" - # Mock get_user_input to always return 'y' for confirmation - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y") + os_environment["PATH"] = "/original/path" # Try to modify PATH which is in PROTECTED_VARS result = agent.tool.environment(action="set", name="PATH", value="/bad/path") @@ -79,103 +89,91 @@ def test_direct_set_protected_var(agent, monkeypatch): assert os.environ["PATH"] != "/bad/path" -@pytest.mark.skipif(os.name == "nt", reason="Need to mock console output for Windows (see #17)") -def test_direct_set_var_cancelled(agent, monkeypatch): - """Test cancelling setting an environment variable.""" - # Mock get_user_input to return 'n' to cancel - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "n") +def test_direct_set_var_allowed(get_user_input, agent): + """Test attempting to set a protected environment variable and allowing it.""" + get_user_input.return_value = "y" var_name = "CANCELLED_VAR" var_value = "cancelled_value" - # Clean up in case the variable exists - if var_name in os.environ: - del os.environ[var_name] + result = agent.tool.environment(action="set", name=var_name, value=var_value) + assert result["status"] == "success" + assert var_name in os.environ + assert get_user_input.call_count == 1 + + +def test_direct_set_var_cancelled(agent): + var_name = "CANCELLED_VAR" + var_value = "cancelled_value" - try: - result = agent.tool.environment(action="set", name=var_name, value=var_value) - assert result["status"] == "error" - assert "cancelled" in extract_result_text(result).lower() - # Verify variable was not set - assert var_name not in os.environ - finally: - # Clean up - if var_name in os.environ: - del os.environ[var_name] + result = agent.tool.environment(action="set", name=var_name, value=var_value) + assert result["status"] == "error" + assert "cancelled" in extract_result_text(result).lower() + # Verify variable was not set + assert var_name not in os.environ -def test_direct_delete_nonexistent_var(agent, monkeypatch): +def test_direct_delete_nonexistent_var(agent): """Test attempting to delete a non-existent variable.""" - # Mock get_user_input to always return 'y' for confirmation - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y") - var_name = "NONEXISTENT_VAR_FOR_DELETE_TEST" - # Make sure the variable doesn't exist - if var_name in os.environ: - del os.environ[var_name] - result = agent.tool.environment(action="delete", name=var_name) assert result["status"] == "error" assert "not found" in extract_result_text(result).lower() -def test_direct_delete_protected_var(agent, monkeypatch): +def test_direct_delete_protected_var(agent, os_environment): """Test attempting to delete a protected environment variable.""" - # Mock get_user_input to always return 'y' for confirmation - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y") - # Try to delete PATH which is in PROTECTED_VARS - original_path = os.environ.get("PATH", "") - try: - result = agent.tool.environment(action="delete", name="PATH") - assert result["status"] == "error" - # Verify PATH still exists - assert "PATH" in os.environ - finally: - # Restore PATH if somehow it got deleted - if "PATH" not in os.environ: - os.environ["PATH"] = original_path - - -@pytest.mark.skipif(os.name == "nt", reason="Need to mock console output for Windows (see #17)") -def test_direct_delete_var_cancelled(agent, monkeypatch): + unchanging_value = "/original/path" + os_environment["PATH"] = unchanging_value + + result = agent.tool.environment(action="delete", name="PATH") + assert result["status"] == "error" + # Verify PATH still exists + assert os_environment["PATH"] == unchanging_value + + +def test_direct_delete_var_cancelled(agent, os_environment): """Test cancelling deletion of an environment variable.""" - # Mock get_user_input to return 'n' to cancel - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "n") + var_name = "CANCEL_DELETE_VAR" + var_value = "cancel_delete_value" + + # Set up the variable + os_environment[var_name] = var_value - # Ensure DEV mode is disabled to force confirmation - current_dev = os.environ.get("DEV", None) - if current_dev: - os.environ.pop("DEV") + result = agent.tool.environment(action="delete", name=var_name) + assert result["status"] == "error" + assert "cancelled" in extract_result_text(result).lower() + # Verify variable still exists + assert var_name in os.environ + assert os_environment[var_name] == var_value + + +def test_direct_delete_var_allowed(agent, get_user_input, os_environment): + """Test allowing deletion of an environment variable.""" + get_user_input.return_value = "y" var_name = "CANCEL_DELETE_VAR" var_value = "cancel_delete_value" # Set up the variable - os.environ[var_name] = var_value - - try: - result = agent.tool.environment(action="delete", name=var_name) - assert result["status"] == "error" - assert "cancelled" in extract_result_text(result).lower() - # Verify variable still exists - assert var_name in os.environ - assert os.environ[var_name] == var_value - finally: - # Clean up - if var_name in os.environ: - del os.environ[var_name] - # Restore DEV mode if it was set - if current_dev: - os.environ["DEV"] = current_dev - if var_name in os.environ: - del os.environ[var_name] - - -def test_direct_validate_existing_var(agent, test_env_var): + os_environment[var_name] = var_value + + result = agent.tool.environment(action="delete", name=var_name) + assert result["status"] == "success" + assert "deleted environment variable" in extract_result_text(result).lower() + assert os_environment.get(var_name) is None + + +def test_direct_validate_existing_var(agent, os_environment): """Test validating an existing environment variable.""" - var_name, _ = test_env_var + var_name = "CANCEL_DELETE_VAR" + var_value = "cancel_delete_value" + + # Set up the variable + os_environment[var_name] = var_value + result = agent.tool.environment(action="validate", name=var_name) assert result["status"] == "success" assert "valid" in extract_result_text(result).lower() @@ -203,76 +201,52 @@ def test_direct_missing_parameters(agent): assert result["status"] == "error" -def test_environment_dev_mode_delete(agent): +def test_environment_dev_mode_delete(agent, os_environment): """Test the environment tool in DEV mode with delete action.""" # Set DEV mode - original_dev = os.environ.get("DEV") - os.environ["DEV"] = "true" + os_environment["DEV"] = "true" var_name = "DEV_MODE_DELETE_VAR" var_value = "dev_mode_delete_value" - try: - # Set up the variable - os.environ[var_name] = var_value - - result = agent.tool.environment(action="delete", name=var_name) - assert result["status"] == "success" - assert var_name not in os.environ - finally: - # Clean up - if var_name in os.environ: - del os.environ[var_name] + # Set up the variable + os_environment[var_name] = var_value - # Restore original DEV value - if original_dev is None: - if "DEV" in os.environ: - del os.environ["DEV"] - else: - os.environ["DEV"] = original_dev + result = agent.tool.environment(action="delete", name=var_name) + assert result["status"] == "success" + assert var_name not in os_environment -def test_environment_dev_mode_protected_var(agent, monkeypatch): +def test_environment_dev_mode_protected_var(agent, os_environment): """Test that protected variables are still protected in DEV mode.""" # Set DEV mode - original_dev = os.environ.get("DEV") - os.environ["DEV"] = "true" - - try: - # Try to modify PATH which is protected - result = agent.tool.environment(action="set", name="PATH", value="/bad/path") - assert result["status"] == "error" - # Verify PATH was not changed - assert os.environ["PATH"] != "/bad/path" - finally: - # Restore original DEV value - if original_dev is None: - if "DEV" in os.environ: - del os.environ["DEV"] - else: - os.environ["DEV"] = original_dev - - -def test_environment_masked_values(agent, test_env_var): + os_environment["DEV"] = True + + unchanging_value = "/original/path" + os_environment["PATH"] = unchanging_value + + # Try to modify PATH which is protected + result = agent.tool.environment(action="set", name="PATH", value="/bad/path") + assert result["status"] == "error" + # Verify PATH was not changed + assert os_environment != unchanging_value + + +def test_environment_masked_values(agent, os_environment): """Test that sensitive values are masked in output.""" # Create a sensitive looking variable sensitive_name = "TEST_TOKEN_SECRET" sensitive_value = "abcd1234efgh5678" - os.environ[sensitive_name] = sensitive_value - - try: - # Test with masking enabled (default) - result = agent.tool.environment(action="get", name=sensitive_name) - assert result["status"] == "success" - # The full value should not appear in the output - assert sensitive_value not in extract_result_text(result) - - # Test with masking disabled - result = agent.tool.environment(action="get", name=sensitive_name, masked=False) - assert result["status"] == "success" - # Now the full value should appear - assert sensitive_value in extract_result_text(result) - finally: - # Clean up - if sensitive_name in os.environ: - del os.environ[sensitive_name] + os_environment[sensitive_name] = sensitive_value + + # Test with masking enabled (default) + result = agent.tool.environment(action="get", name=sensitive_name) + assert result["status"] == "success" + # The full value should not appear in the output + assert sensitive_value not in extract_result_text(result) + + # Test with masking disabled + result = agent.tool.environment(action="get", name=sensitive_name, masked=False) + assert result["status"] == "success" + # Now the full value should appear + assert sensitive_value in extract_result_text(result) From 2f72c2ab6736f947f1f8371985f4a1aeca68c7f0 Mon Sep 17 00:00:00 2001 From: Strands Agent Date: Thu, 5 Jun 2025 15:51:45 -0400 Subject: [PATCH 6/6] feat: add tool filtering to use_llm and think meta-tools - Add 'tools' parameter to use_llm and think functions for parent tool filtering - Implement tool registry filtering to prevent infinite recursion - Allow parent agents to specify subset of tools for child agents - Add comprehensive tests for tool filtering scenarios - Update documentation with tool filtering examples - Fix infinite recursion issue when use_llm calls use_llm or think calls think This enhancement allows better control over meta-tool capabilities and prevents recursive loops while maintaining backward compatibility. --- src/strands_tools/memory.py | 2 +- src/strands_tools/think.py | 38 ++++- src/strands_tools/use_llm.py | 45 +++++- tests/test_think.py | 266 +++++++++++++++++++++++++++++++++++ tests/test_use_llm.py | 198 ++++++++++++++++++++++++++ 5 files changed, 539 insertions(+), 10 deletions(-) diff --git a/src/strands_tools/memory.py b/src/strands_tools/memory.py index 5ae5770c..f57e6d95 100644 --- a/src/strands_tools/memory.py +++ b/src/strands_tools/memory.py @@ -583,7 +583,7 @@ def memory( next_token: Token for pagination in 'list' or 'retrieve' action (optional). query: The search query for semantic search (required for 'retrieve' action). min_score: Minimum relevance score threshold (0.0-1.0) for 'retrieve' action. Default is 0.4. - region_name: Optional AWS region name. If not provided, will use the AWS_REGION env variable. + region_name: Optional AWS region name. If not provided, will use the AWS_REGION env variable. If AWS_REGION is not specified, it will default to us-west-2. Returns: diff --git a/src/strands_tools/think.py b/src/strands_tools/think.py index 47e3b798..97964016 100644 --- a/src/strands_tools/think.py +++ b/src/strands_tools/think.py @@ -13,19 +13,28 @@ agent = Agent(tools=[think, stop]) -# Basic usage with default system prompt +# Basic usage with default system prompt (inherits all parent tools) result = agent.tool.think( thought="How might we improve renewable energy storage solutions?", cycle_count=3, system_prompt="You are an expert energy systems analyst." ) -# Advanced usage with custom system prompt +# Usage with specific tools filtered from parent agent +result = agent.tool.think( + thought="Calculate energy efficiency and analyze the data", + cycle_count=3, + system_prompt="You are an expert energy systems analyst.", + tools=["calculator", "file_read", "python_repl"] +) + +# Usage with mixed tool filtering from parent agent result = agent.tool.think( thought="Analyze the implications of quantum computing on cryptography.", cycle_count=5, system_prompt="You are a specialist in quantum computing and cryptography. Analyze this topic deeply, - considering both technical and practical aspects." + considering both technical and practical aspects.", + tools=["retrieve", "calculator", "http_request"] ) ``` @@ -79,6 +88,7 @@ def process_cycle( cycle: int, total_cycles: int, custom_system_prompt: str, + specified_tools=None, **kwargs: Any, ) -> str: """Process a single thinking cycle.""" @@ -97,9 +107,21 @@ def process_cycle( trace_attributes = {} parent_agent = kwargs.get("agent") if parent_agent: - tools = list(parent_agent.tool_registry.registry.values()) trace_attributes = parent_agent.trace_attributes + # If specific tools are provided, filter parent tools; otherwise inherit all tools from parent + if specified_tools is not None: + # Filter parent agent tools to only include specified tool names + filtered_tools = [] + for tool_name in specified_tools: + if tool_name in parent_agent.tool_registry.registry: + filtered_tools.append(parent_agent.tool_registry.registry[tool_name]) + else: + logger.warning(f"Tool '{tool_name}' not found in parent agent's tool registry") + tools = filtered_tools + else: + tools = list(parent_agent.tool_registry.registry.values()) + # Initialize the new Agent with provided parameters agent = Agent(messages=[], tools=tools, system_prompt=custom_system_prompt, trace_attributes=trace_attributes) @@ -122,7 +144,7 @@ def process_cycle( @tool -def think(thought: str, cycle_count: int, system_prompt: str, agent: Any) -> Dict[str, Any]: +def think(thought: str, cycle_count: int, system_prompt: str, tools: list = None, agent: Any = None) -> Dict[str, Any]: """ Recursive thinking tool for sophisticated thought generation, learning, and self-reflection. @@ -162,7 +184,10 @@ def think(thought: str, cycle_count: int, system_prompt: str, agent: Any) -> Dic provide a good balance of depth and efficiency. system_prompt: Custom system prompt to use for the LLM thinking process. This should specify the expertise domain and thinking approach for processing the thought. - **kwargs: Additional keyword arguments passed to the underlying LLM processing. + tools: List of tool names to make available to the nested agent. Tool names must + exist in the parent agent's tool registry. Examples: ["calculator", "file_read", "retrieve"] + If not provided, inherits all tools from the parent agent. + agent: The parent agent (automatically passed by Strands framework) Returns: Dict containing status and response content in the format: @@ -210,6 +235,7 @@ def think(thought: str, cycle_count: int, system_prompt: str, agent: Any) -> Dic cycle, cycle_count, custom_system_prompt, + specified_tools=tools, **cycle_kwargs, ) diff --git a/src/strands_tools/use_llm.py b/src/strands_tools/use_llm.py index 897dfda1..f2edb553 100644 --- a/src/strands_tools/use_llm.py +++ b/src/strands_tools/use_llm.py @@ -15,12 +15,26 @@ agent = Agent(tools=[use_llm]) -# Basic usage with just a prompt and system prompt +# Basic usage with just a prompt and system prompt (inherits all parent tools) result = agent.tool.use_llm( prompt="Tell me about the advantages of tool-building in AI agents", system_prompt="You are a helpful AI assistant specializing in AI development concepts." ) +# Usage with specific tools filtered from parent agent +result = agent.tool.use_llm( + prompt="Calculate 2 + 2 and retrieve some information", + system_prompt="You are a helpful assistant.", + tools=["calculator", "retrieve"] +) + +# Usage with mixed tool filtering from parent agent +result = agent.tool.use_llm( + prompt="Analyze this data file", + system_prompt="You are a data analyst.", + tools=["file_read", "calculator", "python_repl"] +) + # The response is available in the returned object print(result["content"][0]["text"]) # Prints the response text ``` @@ -52,6 +66,13 @@ "type": "string", "description": "System prompt for the new event loop", }, + "tools": { + "type": "array", + "description": "List of tool names to make available to the nested agent" + + "Tool names must exist in the parent agent's tool registry." + + "If not provided, inherits all tools from parent agent.", + "items": {"type": "string"}, + }, }, "required": ["prompt", "system_prompt"], } @@ -92,7 +113,11 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult: Args: tool (ToolUse): Tool use object containing the following: - prompt (str): The prompt to process with the new agent instance - - system_prompt (str, optional): Custom system prompt for the agent + - system_prompt (str): Custom system prompt for the agent + - tools (List[str], optional): List of tool names to make available to the nested agent. + Tool names must exist in the parent agent's tool registry. + Examples: ["calculator", "file_read", "retrieve"] + If not provided, inherits all tools from the parent agent. **kwargs (Any): Additional keyword arguments Returns: @@ -116,6 +141,7 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult: prompt = tool_input["prompt"] tool_system_prompt = tool_input.get("system_prompt") + specified_tools = tool_input.get("tools") tools = [] trace_attributes = {} @@ -123,9 +149,22 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult: extra_kwargs = {} parent_agent = kwargs.get("agent") if parent_agent: - tools = list(parent_agent.tool_registry.registry.values()) trace_attributes = parent_agent.trace_attributes extra_kwargs["callback_handler"] = parent_agent.callback_handler + + # If specific tools are provided, filter parent tools; otherwise inherit all tools from parent + if specified_tools is not None: + # Filter parent agent tools to only include specified tool names + filtered_tools = [] + for tool_name in specified_tools: + if tool_name in parent_agent.tool_registry.registry: + filtered_tools.append(parent_agent.tool_registry.registry[tool_name]) + else: + logger.warning(f"Tool '{tool_name}' not found in parent agent's tool registry") + tools = filtered_tools + else: + tools = list(parent_agent.tool_registry.registry.values()) + if "callback_handler" in kwargs: extra_kwargs["callback_handler"] = kwargs["callback_handler"] diff --git a/tests/test_think.py b/tests/test_think.py index e6cfd138..edcf6c54 100644 --- a/tests/test_think.py +++ b/tests/test_think.py @@ -132,3 +132,269 @@ def test_thought_processor(): assert "Test thought" in prompt assert "Current Cycle: 1/3" in prompt assert "DO NOT call the think tool again" in prompt + + +def test_think_with_tool_filtering(): + """Test think tool with specific tools filtering from parent agent.""" + # Create mock tools for the parent agent + mock_calculator_tool = MagicMock(name="calculator_tool") + mock_file_read_tool = MagicMock(name="file_read_tool") + mock_other_tool = MagicMock(name="other_tool") + + # Create a mock parent agent with multiple tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry = { + "calculator": mock_calculator_tool, + "file_read": mock_file_read_tool, + "other_tool": mock_other_tool, + } + mock_parent_agent.trace_attributes = {} + + with patch("strands_tools.think.Agent") as mock_agent_class: + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "Analysis with filtered tools."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result + + # Call think with tool filtering + result = think.think( + thought="Test thought with tool filtering", + cycle_count=1, + system_prompt="You are an expert analytical thinker.", + tools=["calculator", "file_read"], + agent=mock_parent_agent, + ) + + # Verify the Agent was created with only the specified tools + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args.kwargs + + # Should only include calculator and file_read tools, not other_tool + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 2 + assert mock_calculator_tool in passed_tools + assert mock_file_read_tool in passed_tools + assert mock_other_tool not in passed_tools + + # Verify the result + assert result["status"] == "success" + assert "Cycle 1/1" in result["content"][0]["text"] + + +def test_think_with_nonexistent_tool_filtering(): + """Test think tool with tool filtering that includes non-existent tools.""" + # Create mock tools for the parent agent (missing nonexistent_tool) + mock_calculator_tool = MagicMock(name="calculator_tool") + mock_file_read_tool = MagicMock(name="file_read_tool") + + # Create a mock parent agent with limited tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry = {"calculator": mock_calculator_tool, "file_read": mock_file_read_tool} + mock_parent_agent.trace_attributes = {} + + with patch("strands_tools.think.Agent") as mock_agent_class, patch("strands_tools.think.logger") as mock_logger: + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "Analysis with missing tool."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result + + # Call think with tool filtering including non-existent tool + result = think.think( + thought="Test thought with non-existent tool", + cycle_count=1, + system_prompt="You are an expert analytical thinker.", + tools=["calculator", "nonexistent_tool", "file_read"], + agent=mock_parent_agent, + ) + + # Verify warning was logged for non-existent tool + mock_logger.warning.assert_called_once_with("Tool 'nonexistent_tool' not found in parent agent's tool registry") + + # Verify the Agent was created with only the existing tools + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args.kwargs + + # Should only include existing tools (calculator and file_read) + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 2 + assert mock_calculator_tool in passed_tools + assert mock_file_read_tool in passed_tools + + # Verify the result + assert result["status"] == "success" + + +def test_think_with_empty_tool_filtering(): + """Test think tool with empty tools list (should result in no tools).""" + # Create a mock parent agent with tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry = {"calculator": MagicMock(), "file_read": MagicMock()} + mock_parent_agent.trace_attributes = {} + + with patch("strands_tools.think.Agent") as mock_agent_class: + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "Analysis with no tools."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result + + # Call think with empty tools list + result = think.think( + thought="Test thought with no tools", + cycle_count=1, + system_prompt="You are an expert analytical thinker.", + tools=[], + agent=mock_parent_agent, + ) + + # Verify the Agent was created with no tools + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args.kwargs + + # Should be empty tools list + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 0 + + # Verify the result + assert result["status"] == "success" + + +def test_think_without_tool_filtering_inherits_all(): + """Test think tool without tools parameter inherits all parent tools.""" + # Create mock tools for the parent agent + mock_calculator_tool = MagicMock(name="calculator_tool") + mock_file_read_tool = MagicMock(name="file_read_tool") + mock_other_tool = MagicMock(name="other_tool") + + # Create a mock parent agent with multiple tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry.values.return_value = [ + mock_calculator_tool, + mock_file_read_tool, + mock_other_tool, + ] + mock_parent_agent.trace_attributes = {} + + with patch("strands_tools.think.Agent") as mock_agent_class: + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "Analysis with all tools."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result + + # Call think without tools parameter (should inherit all) + result = think.think( + thought="Test thought inheriting all tools", + cycle_count=1, + system_prompt="You are an expert analytical thinker.", + agent=mock_parent_agent, + ) + + # Verify the Agent was created with all parent tools + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args.kwargs + + # Should include all tools from parent + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 3 + assert mock_calculator_tool in passed_tools + assert mock_file_read_tool in passed_tools + assert mock_other_tool in passed_tools + + # Verify the result + assert result["status"] == "success" + + +def test_think_tool_filtering_with_multiple_cycles(): + """Test think tool with tool filtering across multiple cycles.""" + # Create mock tools for the parent agent + mock_calculator_tool = MagicMock(name="calculator_tool") + mock_file_read_tool = MagicMock(name="file_read_tool") + + # Create a mock parent agent with tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry = {"calculator": mock_calculator_tool, "file_read": mock_file_read_tool} + mock_parent_agent.trace_attributes = {} + + with patch("strands_tools.think.Agent") as mock_agent_class: + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "Multi-cycle analysis with filtered tools."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result + + # Call think with tool filtering and multiple cycles + result = think.think( + thought="Multi-cycle thought with tool filtering", + cycle_count=3, + system_prompt="You are an expert analytical thinker.", + tools=["calculator"], + agent=mock_parent_agent, + ) + + # Verify the Agent was created with filtered tools for each cycle + assert mock_agent_class.call_count == 3 # One for each cycle + + # Check that all calls used the filtered tools + for call in mock_agent_class.call_args_list: + call_kwargs = call.kwargs + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 1 + assert mock_calculator_tool in passed_tools + assert mock_file_read_tool not in passed_tools + + # Verify the result contains all cycles + assert result["status"] == "success" + assert "Cycle 1/3" in result["content"][0]["text"] + assert "Cycle 2/3" in result["content"][0]["text"] + assert "Cycle 3/3" in result["content"][0]["text"] + + +def test_think_no_parent_agent_with_tools_parameter(): + """Test think tool with tools parameter but no parent agent (should use empty tools).""" + with patch("strands_tools.think.Agent") as mock_agent_class: + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "Analysis without parent agent."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result + + # Call think with tools parameter but no parent agent + result = think.think( + thought="Test thought without parent agent", + cycle_count=1, + system_prompt="You are an expert analytical thinker.", + tools=["calculator", "file_read"], + agent=None, + ) + + # Verify the Agent was created with empty tools (since no parent to filter from) + mock_agent_class.assert_called_once() + call_kwargs = mock_agent_class.call_args.kwargs + + # Should be empty tools list since no parent agent + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 0 + + # Verify the result + assert result["status"] == "success" diff --git a/tests/test_use_llm.py b/tests/test_use_llm.py index b4f2a628..a50d82fb 100644 --- a/tests/test_use_llm.py +++ b/tests/test_use_llm.py @@ -264,3 +264,201 @@ def test_use_llm_with_explicit_callback(): # Verify the result assert result["status"] == "success" assert "Test response with explicit callback" in result["content"][0]["text"] + + +def test_use_llm_with_tool_filtering(): + """Test use_llm with specific tools filtering from parent agent.""" + tool_use = { + "toolUseId": "test-tool-filtering", + "input": { + "prompt": "Test with tool filtering", + "system_prompt": "Test system prompt", + "tools": ["calculator", "file_read"], + }, + } + + # Create mock tools for the parent agent + mock_calculator_tool = MagicMock(name="calculator_tool") + mock_file_read_tool = MagicMock(name="file_read_tool") + mock_other_tool = MagicMock(name="other_tool") + + # Create a mock parent agent with multiple tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry = { + "calculator": mock_calculator_tool, + "file_read": mock_file_read_tool, + "other_tool": mock_other_tool, + } + mock_parent_agent.trace_attributes = {} + mock_parent_agent.callback_handler = MagicMock() + + # Create a mock response + mock_response = MagicMock() + mock_response.metrics = None + mock_response.__str__.return_value = "Test response with filtered tools" + + with patch("strands_tools.use_llm.Agent") as MockAgent: + # Configure the mock agent + mock_instance = MockAgent.return_value + mock_instance.return_value = mock_response + + # Call use_llm with tool filtering + result = use_llm.use_llm(tool=tool_use, agent=mock_parent_agent) + + # Verify the Agent was created with only the specified tools + MockAgent.assert_called_once() + call_kwargs = MockAgent.call_args.kwargs + + # Should only include calculator and file_read tools, not other_tool + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 2 + assert mock_calculator_tool in passed_tools + assert mock_file_read_tool in passed_tools + assert mock_other_tool not in passed_tools + + # Verify the result + assert result["status"] == "success" + assert "Test response with filtered tools" in result["content"][0]["text"] + + +def test_use_llm_with_nonexistent_tool_filtering(): + """Test use_llm with tool filtering that includes non-existent tools.""" + tool_use = { + "toolUseId": "test-nonexistent-tool", + "input": { + "prompt": "Test with non-existent tool", + "system_prompt": "Test system prompt", + "tools": ["calculator", "nonexistent_tool", "file_read"], + }, + } + + # Create mock tools for the parent agent (missing nonexistent_tool) + mock_calculator_tool = MagicMock(name="calculator_tool") + mock_file_read_tool = MagicMock(name="file_read_tool") + + # Create a mock parent agent with limited tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry = {"calculator": mock_calculator_tool, "file_read": mock_file_read_tool} + mock_parent_agent.trace_attributes = {} + mock_parent_agent.callback_handler = MagicMock() + + # Create a mock response + mock_response = MagicMock() + mock_response.metrics = None + mock_response.__str__.return_value = "Test response with missing tool" + + with patch("strands_tools.use_llm.Agent") as MockAgent, patch("strands_tools.use_llm.logger") as mock_logger: + # Configure the mock agent + mock_instance = MockAgent.return_value + mock_instance.return_value = mock_response + + # Call use_llm with tool filtering including non-existent tool + result = use_llm.use_llm(tool=tool_use, agent=mock_parent_agent) + + # Verify warning was logged for non-existent tool + mock_logger.warning.assert_called_once_with("Tool 'nonexistent_tool' not found in parent agent's tool registry") + + # Verify the Agent was created with only the existing tools + MockAgent.assert_called_once() + call_kwargs = MockAgent.call_args.kwargs + + # Should only include existing tools (calculator and file_read) + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 2 + assert mock_calculator_tool in passed_tools + assert mock_file_read_tool in passed_tools + + # Verify the result + assert result["status"] == "success" + + +def test_use_llm_with_empty_tool_filtering(): + """Test use_llm with empty tools list (should result in no tools).""" + tool_use = { + "toolUseId": "test-empty-tools", + "input": {"prompt": "Test with empty tools list", "system_prompt": "Test system prompt", "tools": []}, + } + + # Create a mock parent agent with tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry = {"calculator": MagicMock(), "file_read": MagicMock()} + mock_parent_agent.trace_attributes = {} + mock_parent_agent.callback_handler = MagicMock() + + # Create a mock response + mock_response = MagicMock() + mock_response.metrics = None + mock_response.__str__.return_value = "Test response with no tools" + + with patch("strands_tools.use_llm.Agent") as MockAgent: + # Configure the mock agent + mock_instance = MockAgent.return_value + mock_instance.return_value = mock_response + + # Call use_llm with empty tools list + result = use_llm.use_llm(tool=tool_use, agent=mock_parent_agent) + + # Verify the Agent was created with no tools + MockAgent.assert_called_once() + call_kwargs = MockAgent.call_args.kwargs + + # Should be empty tools list + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 0 + + # Verify the result + assert result["status"] == "success" + + +def test_use_llm_without_tool_filtering_inherits_all(): + """Test use_llm without tools parameter inherits all parent tools.""" + tool_use = { + "toolUseId": "test-inherit-all-tools", + "input": { + "prompt": "Test inheriting all tools", + "system_prompt": "Test system prompt", + # No tools parameter - should inherit all + }, + } + + # Create mock tools for the parent agent + mock_calculator_tool = MagicMock(name="calculator_tool") + mock_file_read_tool = MagicMock(name="file_read_tool") + mock_other_tool = MagicMock(name="other_tool") + + # Create a mock parent agent with multiple tools + mock_parent_agent = MagicMock() + mock_parent_agent.tool_registry.registry.values.return_value = [ + mock_calculator_tool, + mock_file_read_tool, + mock_other_tool, + ] + mock_parent_agent.trace_attributes = {} + mock_parent_agent.callback_handler = MagicMock() + + # Create a mock response + mock_response = MagicMock() + mock_response.metrics = None + mock_response.__str__.return_value = "Test response with all tools" + + with patch("strands_tools.use_llm.Agent") as MockAgent: + # Configure the mock agent + mock_instance = MockAgent.return_value + mock_instance.return_value = mock_response + + # Call use_llm without tools parameter + result = use_llm.use_llm(tool=tool_use, agent=mock_parent_agent) + + # Verify the Agent was created with all parent tools + MockAgent.assert_called_once() + call_kwargs = MockAgent.call_args.kwargs + + # Should include all tools from parent + passed_tools = call_kwargs["tools"] + assert len(passed_tools) == 3 + assert mock_calculator_tool in passed_tools + assert mock_file_read_tool in passed_tools + assert mock_other_tool in passed_tools + + # Verify the result + assert result["status"] == "success"