diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint-pr.yml index 0e013471..f071044f 100644 --- a/.github/workflows/test-lint-pr.yml +++ b/.github/workflows/test-lint-pr.yml @@ -48,8 +48,7 @@ jobs: return true; unit-test: - name: Run Tests on Python ${{ matrix.python-version }} - runs-on: ubuntu-latest + name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} needs: check-approval permissions: contents: read @@ -57,8 +56,39 @@ jobs: if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' strategy: matrix: - python-version: [ "3.10", "3.11", "3.12", "3.13" ] + include: + # Linux + - os: ubuntu-latest + os-name: linux + python-version: "3.10" + - os: ubuntu-latest + os-name: linux + python-version: "3.11" + - os: ubuntu-latest + os-name: linux + python-version: "3.12" + - os: ubuntu-latest + os-name: linux + python-version: "3.13" + # Windows + - os: windows-latest + os-name: windows + python-version: "3.10" + - os: windows-latest + os-name: windows + python-version: "3.11" + - os: windows-latest + os-name: windows + python-version: "3.12" + - os: windows-latest + os-name: windows + python-version: "3.13" + # MacOS - latest only; not enough runners for MacOS + - os: macos-latest + os-name: macos + python-version: "3.13" fail-fast: false + runs-on: ${{ matrix.os }} steps: - name: Checkout code uses: actions/checkout@v4 @@ -78,7 +108,7 @@ jobs: continue-on-error: false lint: - name: Run Lint + name: Lint runs-on: ubuntu-latest needs: check-approval permissions: diff --git a/pyproject.toml b/pyproject.toml index 3fd0bd23..c44f0094 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents-tools" -version = "0.1.1" +version = "0.1.2" description = "A collection of specialized tools for Strands Agents" readme = "README.md" requires-python = ">=3.10" @@ -40,6 +40,8 @@ dependencies = [ "slack_bolt>=1.23.0,<2.0.0", "mem0ai>=0.1.99,<1.0.0", "opensearch-py>=2.8.0,<3.0.0", + # Note: Always want the latest tzdata + "tzdata ; platform_system == 'Windows'", ] [tool.hatch.build.targets.wheel] diff --git a/src/strands_tools/environment.py b/src/strands_tools/environment.py index 7d0499c1..ea0562b7 100644 --- a/src/strands_tools/environment.py +++ b/src/strands_tools/environment.py @@ -71,8 +71,7 @@ from rich.text import Text from strands.types.tools import ToolResult, ToolResultContent, ToolUse -from strands_tools.utils import console_util -from strands_tools.utils.user_input import get_user_input +from strands_tools.utils import console_util, user_input TOOL_SPEC = { "name": "environment", @@ -584,7 +583,7 @@ def environment(tool: ToolUse, **kwargs: Any) -> ToolResult: ) # Ask for confirmation - confirm = get_user_input( + confirm = user_input.get_user_input( "\nDo you want to proceed with setting this environment variable? [y/*]" ) # For tests, 'y' should be recognized even with extra spaces or newlines @@ -706,7 +705,7 @@ def environment(tool: ToolUse, **kwargs: Any) -> ToolResult: ) # Ask for confirmation - confirm = get_user_input( + confirm = user_input.get_user_input( "\nDo you want to proceed with deleting this environment variable? [y/*]" ) # For tests, 'y' should be recognized even with extra spaces or newlines diff --git a/src/strands_tools/load_tool.py b/src/strands_tools/load_tool.py index fd72f695..0bf52207 100644 --- a/src/strands_tools/load_tool.py +++ b/src/strands_tools/load_tool.py @@ -79,7 +79,6 @@ def load_tool(path: str, name: str, agent=None) -> Dict[str, Any]: Tool Loading Process: ------------------- - - First, checks if dynamic loading is permitted (hot_reload_tools=True) - Expands the path to handle user paths with tilde (~) - Validates that the file exists at the specified path - Uses the tool_registry's load_tool_from_filepath method to: @@ -175,7 +174,6 @@ def my_custom_tool(tool: ToolUse, **kwargs: Any) -> ToolResult: Notes: - The tool loading can be disabled via STRANDS_DISABLE_LOAD_TOOL=true environment variable - - The Agent instance must have hot_reload_tools=True to enable dynamic loading - Python files in the cwd()/tools/ directory are automatically hot reloaded without requiring explicit calls to load_tool - When using the load_tool function, ensure your tool files have proper docstrings as they are @@ -187,8 +185,8 @@ def my_custom_tool(tool: ToolUse, **kwargs: Any) -> ToolResult: current_agent = agent try: - # Check if dynamic tool loading is disabled via environment variable or agent.hot_reload_tools. - if not current_agent.hot_reload_tools or os.environ.get("STRANDS_DISABLE_LOAD_TOOL", "").lower() == "true": + # Check if dynamic tool loading is disabled via environment variable. + if os.environ.get("STRANDS_DISABLE_LOAD_TOOL", "").lower() == "true": logger.warning("Dynamic tool loading is disabled via STRANDS_DISABLE_LOAD_TOOL=true") return {"status": "error", "content": [{"text": "⚠️ Dynamic tool loading is disabled in production mode."}]} diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py index 41d94b9d..212d85cf 100644 --- a/src/strands_tools/mem0_memory.py +++ b/src/strands_tools/mem0_memory.py @@ -141,19 +141,10 @@ "required": ["action"], "allOf": [ { - "if": { - "properties": { - "action": {"enum": ["store", "list", "retrieve"]} - } - }, - "then": { - "oneOf": [ - {"required": ["user_id"]}, - {"required": ["agent_id"]} - ] - } + "if": {"properties": {"action": {"enum": ["store", "list", "retrieve"]}}}, + "then": {"oneOf": [{"required": ["user_id"]}, {"required": ["agent_id"]}]}, } - ] + ], } }, } @@ -536,7 +527,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: return ToolResult( toolUseId=tool_use_id, status="success", - content=[ToolResultContent(text=f"Successfully stored {len(results.get('results', []))} memories")] + content=[ToolResultContent(text=f"Successfully stored {len(results.get('results', []))} memories")], ) elif action == "get": @@ -547,9 +538,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: panel = format_get_response(memory) console.print(panel) return ToolResult( - toolUseId=tool_use_id, - status="success", - content=[ToolResultContent(text=json.dumps(memory, indent=2))] + toolUseId=tool_use_id, status="success", content=[ToolResultContent(text=json.dumps(memory, indent=2))] ) elif action == "list": @@ -559,7 +548,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: return ToolResult( toolUseId=tool_use_id, status="success", - content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))] + content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))], ) elif action == "retrieve": @@ -576,7 +565,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: return ToolResult( toolUseId=tool_use_id, status="success", - content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))] + content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))], ) elif action == "delete": @@ -589,7 +578,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: return ToolResult( toolUseId=tool_use_id, status="success", - content=[ToolResultContent(text=f"Memory {tool_input['memory_id']} deleted successfully")] + content=[ToolResultContent(text=f"Memory {tool_input['memory_id']} deleted successfully")], ) elif action == "history": @@ -600,9 +589,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: panel = format_history_response(history) console.print(panel) return ToolResult( - toolUseId=tool_use_id, - status="success", - content=[ToolResultContent(text=json.dumps(history, indent=2))] + toolUseId=tool_use_id, status="success", content=[ToolResultContent(text=json.dumps(history, indent=2))] ) else: @@ -615,8 +602,4 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult: border_style="red", ) console.print(error_panel) - return ToolResult( - toolUseId=tool_use_id, - status="error", - content=[ToolResultContent(text=f"Error: {str(e)}")] - ) + return ToolResult(toolUseId=tool_use_id, status="error", content=[ToolResultContent(text=f"Error: {str(e)}")]) diff --git a/src/strands_tools/memory.py b/src/strands_tools/memory.py index b4989932..585906c7 100644 --- a/src/strands_tools/memory.py +++ b/src/strands_tools/memory.py @@ -42,7 +42,7 @@ agent = Agent(tools=[memory]) # Store content in Knowledge Base -agent.memory( +agent.tool.memory( action="store", content="Important information to remember", title="Meeting Notes", @@ -50,7 +50,7 @@ ) # Retrieve content using semantic search -agent.memory( +agent.tool.memory( action="retrieve", query="meeting information", min_score=0.7, @@ -58,7 +58,7 @@ ) # List all documents -agent.memory( +agent.tool.memory( action="list", max_results=50, STRANDS_KNOWLEDGE_BASE_ID="my1234kb" diff --git a/src/strands_tools/slack.py b/src/strands_tools/slack.py index 5dccb562..2f3cc879 100644 --- a/src/strands_tools/slack.py +++ b/src/strands_tools/slack.py @@ -351,12 +351,14 @@ def _process_message(self, event): return tools = list(self.agent.tool_registry.registry.values()) + trace_attributes = self.agent.trace_attributes agent = Agent( messages=[], system_prompt=f"{self.agent.system_prompt}\n{SLACK_SYSTEM_PROMPT}", tools=tools, callback_handler=None, + trace_attributes=trace_attributes, ) channel_id = event.get("channel") diff --git a/src/strands_tools/think.py b/src/strands_tools/think.py index a1c210e3..ec3ab5a5 100644 --- a/src/strands_tools/think.py +++ b/src/strands_tools/think.py @@ -9,9 +9,9 @@ Usage with Strands Agent: ```python from strands import Agent -from strands_tools import think +from strands_tools import think, stop -agent = Agent(tools=[think]) +agent = Agent(tools=[think, stop]) # Basic usage with default system prompt result = agent.tool.think( @@ -32,13 +32,15 @@ See the think function docstring for more details on configuration options and parameters. """ +import logging import traceback import uuid from typing import Any, Dict -from strands import tool +from strands import Agent, tool +from strands.telemetry.metrics import metrics_to_string -from strands_tools.use_llm import use_llm +logger = logging.getLogger(__name__) class ThoughtProcessor: @@ -77,36 +79,46 @@ def process_cycle( ) -> str: """Process a single thinking cycle.""" + logger.debug(f"🧠 Thinking Cycle {cycle}/{total_cycles}: Processing cycle...") print(f"🧠 Thinking Cycle {cycle}/{total_cycles}: Processing cycle...") # Create cycle-specific prompt prompt = self.create_thinking_prompt(thought, cycle, total_cycles) - # Use LLM for processing - result = use_llm( - { - "name": "use_llm", - "toolUseId": self.tool_use_id, - "input": { - "system_prompt": custom_system_prompt, - "prompt": prompt, - }, - }, - **kwargs, - ) - - # Extract and return response - cycle_response = "" - if result.get("status") == "success": - for content in result.get("content", []): - if content.get("text"): - cycle_response += content["text"] + "\n" - - return cycle_response.strip() + # Display input prompt + logger.debug(f"\n--- Input Prompt ---\n{prompt}\n") + + # Get tools from parent agent if available + tools = [] + trace_attributes = {} + parent_agent = kwargs.get("agent") + if parent_agent: + tools = list(parent_agent.tool_registry.registry.values()) + trace_attributes = parent_agent.trace_attributes + + # Initialize the new Agent with provided parameters + agent = Agent(messages=[], tools=tools, system_prompt=custom_system_prompt, trace_attributes=trace_attributes) + + # Run the agent with the provided prompt + result = agent(prompt) + + # Extract response + assistant_response = str(result) + + # Display assistant response + logger.debug(f"\n--- Assistant Response ---\n{assistant_response.strip()}\n") + + # Print metrics if available + if result.metrics: + metrics = result.metrics + metrics_text = metrics_to_string(metrics) + logger.debug(metrics_text) + + return assistant_response.strip() @tool -def think(thought: str, cycle_count: int, system_prompt: str, **kwargs: Any) -> Dict[str, Any]: +def think(thought: str, cycle_count: int, system_prompt: str, agent: Any) -> Dict[str, Any]: """ Recursive thinking tool for sophisticated thought generation, learning, and self-reflection. @@ -172,7 +184,7 @@ def think(thought: str, cycle_count: int, system_prompt: str, **kwargs: Any) -> custom_system_prompt = ( "You are an expert analytical thinker. Process the thought deeply and provide clear insights." ) - + kwargs = {"agent": agent} # Create thought processor instance with the available context processor = ThoughtProcessor(kwargs) diff --git a/src/strands_tools/use_llm.py b/src/strands_tools/use_llm.py index a294efad..0807b052 100644 --- a/src/strands_tools/use_llm.py +++ b/src/strands_tools/use_llm.py @@ -117,9 +117,12 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult: tool_system_prompt = tool_input.get("system_prompt") tools = [] + trace_attributes = {} + parent_agent = kwargs.get("agent") if parent_agent: tools = list(parent_agent.tool_registry.registry.values()) + trace_attributes = parent_agent.trace_attributes # Display input prompt logger.debug(f"\n--- Input Prompt ---\n{prompt}\n") @@ -128,11 +131,7 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult: logger.debug("🔄 Creating new LLM instance...") # Initialize the new Agent with provided parameters - agent = Agent( - messages=[], - tools=tools, - system_prompt=tool_system_prompt, - ) + agent = Agent(messages=[], tools=tools, system_prompt=tool_system_prompt, trace_attributes=trace_attributes) # Run the agent with the provided prompt result = agent(prompt) diff --git a/src/strands_tools/utils/user_input.py b/src/strands_tools/utils/user_input.py index c1e895cd..16eebb0e 100644 --- a/src/strands_tools/utils/user_input.py +++ b/src/strands_tools/utils/user_input.py @@ -8,10 +8,13 @@ from prompt_toolkit import HTML, PromptSession from prompt_toolkit.patch_stdout import patch_stdout -session: PromptSession = PromptSession() +# Lazy initialize to avoid import errors for tests on windows without a terminal +session: PromptSession | None = None async def get_user_input_async(prompt: str, default: str = "n") -> str: + global session + """ Asynchronously get user input with prompt_toolkit's features (history, arrow keys, styling, etc). @@ -25,6 +28,9 @@ async def get_user_input_async(prompt: str, default: str = "n") -> str: try: with patch_stdout(raw=True): + if session is None: + session = PromptSession() + response = await session.prompt_async(HTML(f"{prompt} ")) if not response: diff --git a/tests/test_environment.py b/tests/test_environment.py index 72d9ce74..921276e3 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -3,28 +3,33 @@ """ import os +from unittest import mock import pytest from strands import Agent from strands_tools import environment +from strands_tools.utils import user_input @pytest.fixture def agent(): """Create an agent with the environment tool loaded.""" - return Agent(tools=[environment]) + return Agent(tools=[environment], load_tools_from_directory=False) -@pytest.fixture -def test_env_var(): - """Create and clean up a test environment variable.""" - var_name = "TEST_ENV_VAR" - var_value = "test_value" - os.environ[var_name] = var_value - yield var_name, var_value - # Clean up: remove test variable if it exists - if var_name in os.environ: - del os.environ[var_name] +@pytest.fixture(autouse=True) +def get_user_input(): + with mock.patch.object(user_input, "get_user_input") as mocked_user_input: + # By default all tests will return deny + mocked_user_input.return_value = "n" + yield mocked_user_input + + +@pytest.fixture(autouse=True) +def os_environment(): + mock_env = {} + with mock.patch.object(os, "environ", mock_env): + yield mock_env def extract_result_text(result): @@ -42,18 +47,24 @@ def test_direct_list_action(agent): assert len(extract_result_text(result)) > 0 -def test_direct_list_with_prefix(agent, test_env_var): +def test_direct_list_with_prefix(agent, os_environment): """Test listing environment variables with a specific prefix.""" - var_name, _ = test_env_var + var_name = "TEST_ENV_VAR" + var_value = "test_value" + os_environment[var_name] = var_value + result = agent.tool.environment(action="list", prefix=var_name[:4]) assert result["status"] == "success" # Verify our test variable is in the result assert var_name in extract_result_text(result) -def test_direct_get_existing_var(agent, test_env_var): +def test_direct_get_existing_var(agent, os_environment): """Test getting an existing environment variable.""" - var_name, var_value = test_env_var + var_name = "TEST_ENV_VAR" + var_value = "test_value" + os_environment[var_name] = var_value + result = agent.tool.environment(action="get", name=var_name) assert result["status"] == "success" assert var_name in extract_result_text(result) @@ -67,10 +78,9 @@ def test_direct_get_nonexistent_var(agent): assert "not found" in extract_result_text(result) -def test_direct_set_protected_var(agent, monkeypatch): +def test_direct_set_protected_var(agent, os_environment): """Test attempting to set a protected environment variable.""" - # Mock get_user_input to always return 'y' for confirmation - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y") + os_environment["PATH"] = "/original/path" # Try to modify PATH which is in PROTECTED_VARS result = agent.tool.environment(action="set", name="PATH", value="/bad/path") @@ -79,101 +89,91 @@ def test_direct_set_protected_var(agent, monkeypatch): assert os.environ["PATH"] != "/bad/path" -def test_direct_set_var_cancelled(agent, monkeypatch): - """Test cancelling setting an environment variable.""" - # Mock get_user_input to return 'n' to cancel - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "n") +def test_direct_set_var_allowed(get_user_input, agent): + """Test attempting to set a protected environment variable and allowing it.""" + get_user_input.return_value = "y" var_name = "CANCELLED_VAR" var_value = "cancelled_value" - # Clean up in case the variable exists - if var_name in os.environ: - del os.environ[var_name] + result = agent.tool.environment(action="set", name=var_name, value=var_value) + assert result["status"] == "success" + assert var_name in os.environ + assert get_user_input.call_count == 1 + + +def test_direct_set_var_cancelled(agent): + var_name = "CANCELLED_VAR" + var_value = "cancelled_value" - try: - result = agent.tool.environment(action="set", name=var_name, value=var_value) - assert result["status"] == "error" - assert "cancelled" in extract_result_text(result).lower() - # Verify variable was not set - assert var_name not in os.environ - finally: - # Clean up - if var_name in os.environ: - del os.environ[var_name] + result = agent.tool.environment(action="set", name=var_name, value=var_value) + assert result["status"] == "error" + assert "cancelled" in extract_result_text(result).lower() + # Verify variable was not set + assert var_name not in os.environ -def test_direct_delete_nonexistent_var(agent, monkeypatch): +def test_direct_delete_nonexistent_var(agent): """Test attempting to delete a non-existent variable.""" - # Mock get_user_input to always return 'y' for confirmation - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y") - var_name = "NONEXISTENT_VAR_FOR_DELETE_TEST" - # Make sure the variable doesn't exist - if var_name in os.environ: - del os.environ[var_name] - result = agent.tool.environment(action="delete", name=var_name) assert result["status"] == "error" assert "not found" in extract_result_text(result).lower() -def test_direct_delete_protected_var(agent, monkeypatch): +def test_direct_delete_protected_var(agent, os_environment): """Test attempting to delete a protected environment variable.""" - # Mock get_user_input to always return 'y' for confirmation - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y") - # Try to delete PATH which is in PROTECTED_VARS - original_path = os.environ.get("PATH", "") - try: - result = agent.tool.environment(action="delete", name="PATH") - assert result["status"] == "error" - # Verify PATH still exists - assert "PATH" in os.environ - finally: - # Restore PATH if somehow it got deleted - if "PATH" not in os.environ: - os.environ["PATH"] = original_path - - -def test_direct_delete_var_cancelled(agent, monkeypatch): + unchanging_value = "/original/path" + os_environment["PATH"] = unchanging_value + + result = agent.tool.environment(action="delete", name="PATH") + assert result["status"] == "error" + # Verify PATH still exists + assert os_environment["PATH"] == unchanging_value + + +def test_direct_delete_var_cancelled(agent, os_environment): """Test cancelling deletion of an environment variable.""" - # Mock get_user_input to return 'n' to cancel - monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "n") + var_name = "CANCEL_DELETE_VAR" + var_value = "cancel_delete_value" + + # Set up the variable + os_environment[var_name] = var_value - # Ensure DEV mode is disabled to force confirmation - current_dev = os.environ.get("DEV", None) - if current_dev: - os.environ.pop("DEV") + result = agent.tool.environment(action="delete", name=var_name) + assert result["status"] == "error" + assert "cancelled" in extract_result_text(result).lower() + # Verify variable still exists + assert var_name in os.environ + assert os_environment[var_name] == var_value + + +def test_direct_delete_var_allowed(agent, get_user_input, os_environment): + """Test allowing deletion of an environment variable.""" + get_user_input.return_value = "y" var_name = "CANCEL_DELETE_VAR" var_value = "cancel_delete_value" # Set up the variable - os.environ[var_name] = var_value - - try: - result = agent.tool.environment(action="delete", name=var_name) - assert result["status"] == "error" - assert "cancelled" in extract_result_text(result).lower() - # Verify variable still exists - assert var_name in os.environ - assert os.environ[var_name] == var_value - finally: - # Clean up - if var_name in os.environ: - del os.environ[var_name] - # Restore DEV mode if it was set - if current_dev: - os.environ["DEV"] = current_dev - if var_name in os.environ: - del os.environ[var_name] - - -def test_direct_validate_existing_var(agent, test_env_var): + os_environment[var_name] = var_value + + result = agent.tool.environment(action="delete", name=var_name) + assert result["status"] == "success" + assert "deleted environment variable" in extract_result_text(result).lower() + assert os_environment.get(var_name) is None + + +def test_direct_validate_existing_var(agent, os_environment): """Test validating an existing environment variable.""" - var_name, _ = test_env_var + var_name = "CANCEL_DELETE_VAR" + var_value = "cancel_delete_value" + + # Set up the variable + os_environment[var_name] = var_value + result = agent.tool.environment(action="validate", name=var_name) assert result["status"] == "success" assert "valid" in extract_result_text(result).lower() @@ -201,76 +201,52 @@ def test_direct_missing_parameters(agent): assert result["status"] == "error" -def test_environment_dev_mode_delete(agent): +def test_environment_dev_mode_delete(agent, os_environment): """Test the environment tool in DEV mode with delete action.""" # Set DEV mode - original_dev = os.environ.get("DEV") - os.environ["DEV"] = "true" + os_environment["DEV"] = "true" var_name = "DEV_MODE_DELETE_VAR" var_value = "dev_mode_delete_value" - try: - # Set up the variable - os.environ[var_name] = var_value - - result = agent.tool.environment(action="delete", name=var_name) - assert result["status"] == "success" - assert var_name not in os.environ - finally: - # Clean up - if var_name in os.environ: - del os.environ[var_name] + # Set up the variable + os_environment[var_name] = var_value - # Restore original DEV value - if original_dev is None: - if "DEV" in os.environ: - del os.environ["DEV"] - else: - os.environ["DEV"] = original_dev + result = agent.tool.environment(action="delete", name=var_name) + assert result["status"] == "success" + assert var_name not in os_environment -def test_environment_dev_mode_protected_var(agent, monkeypatch): +def test_environment_dev_mode_protected_var(agent, os_environment): """Test that protected variables are still protected in DEV mode.""" # Set DEV mode - original_dev = os.environ.get("DEV") - os.environ["DEV"] = "true" - - try: - # Try to modify PATH which is protected - result = agent.tool.environment(action="set", name="PATH", value="/bad/path") - assert result["status"] == "error" - # Verify PATH was not changed - assert os.environ["PATH"] != "/bad/path" - finally: - # Restore original DEV value - if original_dev is None: - if "DEV" in os.environ: - del os.environ["DEV"] - else: - os.environ["DEV"] = original_dev - - -def test_environment_masked_values(agent, test_env_var): + os_environment["DEV"] = True + + unchanging_value = "/original/path" + os_environment["PATH"] = unchanging_value + + # Try to modify PATH which is protected + result = agent.tool.environment(action="set", name="PATH", value="/bad/path") + assert result["status"] == "error" + # Verify PATH was not changed + assert os_environment != unchanging_value + + +def test_environment_masked_values(agent, os_environment): """Test that sensitive values are masked in output.""" # Create a sensitive looking variable sensitive_name = "TEST_TOKEN_SECRET" sensitive_value = "abcd1234efgh5678" - os.environ[sensitive_name] = sensitive_value - - try: - # Test with masking enabled (default) - result = agent.tool.environment(action="get", name=sensitive_name) - assert result["status"] == "success" - # The full value should not appear in the output - assert sensitive_value not in extract_result_text(result) - - # Test with masking disabled - result = agent.tool.environment(action="get", name=sensitive_name, masked=False) - assert result["status"] == "success" - # Now the full value should appear - assert sensitive_value in extract_result_text(result) - finally: - # Clean up - if sensitive_name in os.environ: - del os.environ[sensitive_name] + os_environment[sensitive_name] = sensitive_value + + # Test with masking enabled (default) + result = agent.tool.environment(action="get", name=sensitive_name) + assert result["status"] == "success" + # The full value should not appear in the output + assert sensitive_value not in extract_result_text(result) + + # Test with masking disabled + result = agent.tool.environment(action="get", name=sensitive_name, masked=False) + assert result["status"] == "success" + # Now the full value should appear + assert sensitive_value in extract_result_text(result) diff --git a/tests/test_file_write.py b/tests/test_file_write.py index e72a4a7d..8a7ce5cf 100644 --- a/tests/test_file_write.py +++ b/tests/test_file_write.py @@ -148,8 +148,6 @@ def test_file_write_error_handling(mock_user_input, temp_file): # This should fail on most systems if os.name == "posix": # Unix/Linux/Mac invalid_path = "/root/test_no_permission.txt" - elif os.name == "nt": # Windows - invalid_path = "C:\\Windows\\System32\\config\\nopermission.txt" else: # Fallback - create a path that's too long invalid_path = os.path.join(temp_file, "a" * 1000 + ".txt") diff --git a/tests/test_python_repl.py b/tests/test_python_repl.py index 672e14cc..6eddd1f1 100644 --- a/tests/test_python_repl.py +++ b/tests/test_python_repl.py @@ -12,6 +12,10 @@ import dill import pytest from strands import Agent + +if os.name == "nt": + pytest.skip("skipping on windows until issue #17 is resolved", allow_module_level=True) + from strands_tools import python_repl diff --git a/tests/test_shell.py b/tests/test_shell.py index d270f93f..4e6dba8c 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -5,12 +5,17 @@ import os import signal import sys -import termios from unittest.mock import MagicMock, patch import pytest from rich.panel import Panel from strands import Agent + +if os.name == "nt": + pytest.skip("skipping on windows until issue #17 is resolved", allow_module_level=True) + +import termios + from strands_tools import shell diff --git a/tests/test_think.py b/tests/test_think.py index a33e25e9..77882b06 100644 --- a/tests/test_think.py +++ b/tests/test_think.py @@ -2,8 +2,9 @@ Tests for the think tool using the Agent interface. """ -from unittest.mock import patch +from unittest.mock import MagicMock, patch +from strands.agent import AgentResult from strands_tools import think from strands_tools.think import ThoughtProcessor @@ -28,13 +29,17 @@ def test_think_tool_direct(): }, } - # Mock use_llm function since we don't want to actually call the LLM - with patch("strands_tools.think.use_llm") as mock_use_llm: - # Setup mock response - mock_use_llm.return_value = { - "status": "success", - "content": [{"text": "This is a mock analysis of quantum computing."}], - } + # Mock Agent class since we don't want to actually call the LLM + with patch("strands_tools.think.Agent") as mock_agent_class: + # Setup mock agent and response + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "This is a mock analysis of quantum computing."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result # Call the think function directly tool_input = tool_use.get("input", {}) @@ -42,6 +47,7 @@ def test_think_tool_direct(): thought=tool_input.get("thought"), cycle_count=tool_input.get("cycle_count"), system_prompt=tool_input.get("system_prompt"), + agent=None, ) # Verify the result has the expected structure @@ -49,8 +55,8 @@ def test_think_tool_direct(): assert "Cycle 1/2" in result["content"][0]["text"] assert "Cycle 2/2" in result["content"][0]["text"] - # Verify use_llm was called twice (once for each cycle) - assert mock_use_llm.call_count == 2 + # Verify Agent was called twice (once for each cycle) + assert mock_agent.call_count == 2 def test_think_one_cycle(): @@ -65,22 +71,27 @@ def test_think_one_cycle(): }, } - with patch("strands_tools.think.use_llm") as mock_use_llm: - mock_use_llm.return_value = { - "status": "success", - "content": [{"text": "Analysis for single cycle."}], - } + with patch("strands_tools.think.Agent") as mock_agent_class: + mock_agent = mock_agent_class.return_value + mock_result = AgentResult( + message={"content": [{"text": "Analysis for single cycle."}]}, + stop_reason="end_turn", + metrics=None, + state=MagicMock(), + ) + mock_agent.return_value = mock_result tool_input = tool_use.get("input", {}) result = think.think( thought=tool_input.get("thought"), cycle_count=tool_input.get("cycle_count"), system_prompt=tool_input.get("system_prompt"), + agent=None, ) assert result["status"] == "success" assert "Cycle 1/1" in result["content"][0]["text"] - assert mock_use_llm.call_count == 1 + assert mock_agent.call_count == 1 def test_think_error_handling(): @@ -95,15 +106,16 @@ def test_think_error_handling(): }, } - with patch("strands_tools.think.use_llm") as mock_use_llm: - # Make use_llm raise an exception - mock_use_llm.side_effect = Exception("Test error") + with patch("strands_tools.think.Agent") as mock_agent_class: + # Make Agent raise an exception + mock_agent_class.side_effect = Exception("Test error") tool_input = tool_use.get("input", {}) result = think.think( thought=tool_input.get("thought"), cycle_count=tool_input.get("cycle_count"), system_prompt=tool_input.get("system_prompt"), + agent=None, ) assert result["status"] == "error" diff --git a/tests/test_use_llm.py b/tests/test_use_llm.py index c978c781..ccb6eb8c 100644 --- a/tests/test_use_llm.py +++ b/tests/test_use_llm.py @@ -59,7 +59,9 @@ def test_use_llm_tool_direct(mock_agent_response): assert "This is a test response from the LLM" in str(result) # Verify the Agent was created with the correct parameters - MockAgent.assert_called_once_with(messages=[], tools=[], system_prompt="You are a helpful test assistant") + MockAgent.assert_called_once_with( + messages=[], tools=[], system_prompt="You are a helpful test assistant", trace_attributes={} + ) def test_use_llm_with_custom_system_prompt(mock_agent_response): @@ -82,7 +84,9 @@ def test_use_llm_with_custom_system_prompt(mock_agent_response): result = use_llm.use_llm(tool=tool_use) # Verify agent was created with correct system prompt - MockAgent.assert_called_once_with(messages=[], tools=[], system_prompt="You are a specialized test assistant") + MockAgent.assert_called_once_with( + messages=[], tools=[], system_prompt="You are a specialized test assistant", trace_attributes={} + ) assert result["status"] == "success" assert "Custom response" in result["content"][0]["text"]