diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint-pr.yml
index 0e013471..f071044f 100644
--- a/.github/workflows/test-lint-pr.yml
+++ b/.github/workflows/test-lint-pr.yml
@@ -48,8 +48,7 @@ jobs:
return true;
unit-test:
- name: Run Tests on Python ${{ matrix.python-version }}
- runs-on: ubuntu-latest
+ name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }}
needs: check-approval
permissions:
contents: read
@@ -57,8 +56,39 @@ jobs:
if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true'
strategy:
matrix:
- python-version: [ "3.10", "3.11", "3.12", "3.13" ]
+ include:
+ # Linux
+ - os: ubuntu-latest
+ os-name: linux
+ python-version: "3.10"
+ - os: ubuntu-latest
+ os-name: linux
+ python-version: "3.11"
+ - os: ubuntu-latest
+ os-name: linux
+ python-version: "3.12"
+ - os: ubuntu-latest
+ os-name: linux
+ python-version: "3.13"
+ # Windows
+ - os: windows-latest
+ os-name: windows
+ python-version: "3.10"
+ - os: windows-latest
+ os-name: windows
+ python-version: "3.11"
+ - os: windows-latest
+ os-name: windows
+ python-version: "3.12"
+ - os: windows-latest
+ os-name: windows
+ python-version: "3.13"
+ # MacOS - latest only; not enough runners for MacOS
+ - os: macos-latest
+ os-name: macos
+ python-version: "3.13"
fail-fast: false
+ runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -78,7 +108,7 @@ jobs:
continue-on-error: false
lint:
- name: Run Lint
+ name: Lint
runs-on: ubuntu-latest
needs: check-approval
permissions:
diff --git a/pyproject.toml b/pyproject.toml
index 3fd0bd23..c44f0094 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "strands-agents-tools"
-version = "0.1.1"
+version = "0.1.2"
description = "A collection of specialized tools for Strands Agents"
readme = "README.md"
requires-python = ">=3.10"
@@ -40,6 +40,8 @@ dependencies = [
"slack_bolt>=1.23.0,<2.0.0",
"mem0ai>=0.1.99,<1.0.0",
"opensearch-py>=2.8.0,<3.0.0",
+ # Note: Always want the latest tzdata
+ "tzdata ; platform_system == 'Windows'",
]
[tool.hatch.build.targets.wheel]
diff --git a/src/strands_tools/environment.py b/src/strands_tools/environment.py
index 7d0499c1..ea0562b7 100644
--- a/src/strands_tools/environment.py
+++ b/src/strands_tools/environment.py
@@ -71,8 +71,7 @@
from rich.text import Text
from strands.types.tools import ToolResult, ToolResultContent, ToolUse
-from strands_tools.utils import console_util
-from strands_tools.utils.user_input import get_user_input
+from strands_tools.utils import console_util, user_input
TOOL_SPEC = {
"name": "environment",
@@ -584,7 +583,7 @@ def environment(tool: ToolUse, **kwargs: Any) -> ToolResult:
)
# Ask for confirmation
- confirm = get_user_input(
+ confirm = user_input.get_user_input(
"\nDo you want to proceed with setting this environment variable? [y/*]"
)
# For tests, 'y' should be recognized even with extra spaces or newlines
@@ -706,7 +705,7 @@ def environment(tool: ToolUse, **kwargs: Any) -> ToolResult:
)
# Ask for confirmation
- confirm = get_user_input(
+ confirm = user_input.get_user_input(
"\nDo you want to proceed with deleting this environment variable? [y/*]"
)
# For tests, 'y' should be recognized even with extra spaces or newlines
diff --git a/src/strands_tools/load_tool.py b/src/strands_tools/load_tool.py
index fd72f695..0bf52207 100644
--- a/src/strands_tools/load_tool.py
+++ b/src/strands_tools/load_tool.py
@@ -79,7 +79,6 @@ def load_tool(path: str, name: str, agent=None) -> Dict[str, Any]:
Tool Loading Process:
-------------------
- - First, checks if dynamic loading is permitted (hot_reload_tools=True)
- Expands the path to handle user paths with tilde (~)
- Validates that the file exists at the specified path
- Uses the tool_registry's load_tool_from_filepath method to:
@@ -175,7 +174,6 @@ def my_custom_tool(tool: ToolUse, **kwargs: Any) -> ToolResult:
Notes:
- The tool loading can be disabled via STRANDS_DISABLE_LOAD_TOOL=true environment variable
- - The Agent instance must have hot_reload_tools=True to enable dynamic loading
- Python files in the cwd()/tools/ directory are automatically hot reloaded without
requiring explicit calls to load_tool
- When using the load_tool function, ensure your tool files have proper docstrings as they are
@@ -187,8 +185,8 @@ def my_custom_tool(tool: ToolUse, **kwargs: Any) -> ToolResult:
current_agent = agent
try:
- # Check if dynamic tool loading is disabled via environment variable or agent.hot_reload_tools.
- if not current_agent.hot_reload_tools or os.environ.get("STRANDS_DISABLE_LOAD_TOOL", "").lower() == "true":
+ # Check if dynamic tool loading is disabled via environment variable.
+ if os.environ.get("STRANDS_DISABLE_LOAD_TOOL", "").lower() == "true":
logger.warning("Dynamic tool loading is disabled via STRANDS_DISABLE_LOAD_TOOL=true")
return {"status": "error", "content": [{"text": "⚠️ Dynamic tool loading is disabled in production mode."}]}
diff --git a/src/strands_tools/mem0_memory.py b/src/strands_tools/mem0_memory.py
index 41d94b9d..212d85cf 100644
--- a/src/strands_tools/mem0_memory.py
+++ b/src/strands_tools/mem0_memory.py
@@ -141,19 +141,10 @@
"required": ["action"],
"allOf": [
{
- "if": {
- "properties": {
- "action": {"enum": ["store", "list", "retrieve"]}
- }
- },
- "then": {
- "oneOf": [
- {"required": ["user_id"]},
- {"required": ["agent_id"]}
- ]
- }
+ "if": {"properties": {"action": {"enum": ["store", "list", "retrieve"]}}},
+ "then": {"oneOf": [{"required": ["user_id"]}, {"required": ["agent_id"]}]},
}
- ]
+ ],
}
},
}
@@ -536,7 +527,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
return ToolResult(
toolUseId=tool_use_id,
status="success",
- content=[ToolResultContent(text=f"Successfully stored {len(results.get('results', []))} memories")]
+ content=[ToolResultContent(text=f"Successfully stored {len(results.get('results', []))} memories")],
)
elif action == "get":
@@ -547,9 +538,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
panel = format_get_response(memory)
console.print(panel)
return ToolResult(
- toolUseId=tool_use_id,
- status="success",
- content=[ToolResultContent(text=json.dumps(memory, indent=2))]
+ toolUseId=tool_use_id, status="success", content=[ToolResultContent(text=json.dumps(memory, indent=2))]
)
elif action == "list":
@@ -559,7 +548,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
return ToolResult(
toolUseId=tool_use_id,
status="success",
- content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))]
+ content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))],
)
elif action == "retrieve":
@@ -576,7 +565,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
return ToolResult(
toolUseId=tool_use_id,
status="success",
- content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))]
+ content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))],
)
elif action == "delete":
@@ -589,7 +578,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
return ToolResult(
toolUseId=tool_use_id,
status="success",
- content=[ToolResultContent(text=f"Memory {tool_input['memory_id']} deleted successfully")]
+ content=[ToolResultContent(text=f"Memory {tool_input['memory_id']} deleted successfully")],
)
elif action == "history":
@@ -600,9 +589,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
panel = format_history_response(history)
console.print(panel)
return ToolResult(
- toolUseId=tool_use_id,
- status="success",
- content=[ToolResultContent(text=json.dumps(history, indent=2))]
+ toolUseId=tool_use_id, status="success", content=[ToolResultContent(text=json.dumps(history, indent=2))]
)
else:
@@ -615,8 +602,4 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
border_style="red",
)
console.print(error_panel)
- return ToolResult(
- toolUseId=tool_use_id,
- status="error",
- content=[ToolResultContent(text=f"Error: {str(e)}")]
- )
+ return ToolResult(toolUseId=tool_use_id, status="error", content=[ToolResultContent(text=f"Error: {str(e)}")])
diff --git a/src/strands_tools/memory.py b/src/strands_tools/memory.py
index b4989932..585906c7 100644
--- a/src/strands_tools/memory.py
+++ b/src/strands_tools/memory.py
@@ -42,7 +42,7 @@
agent = Agent(tools=[memory])
# Store content in Knowledge Base
-agent.memory(
+agent.tool.memory(
action="store",
content="Important information to remember",
title="Meeting Notes",
@@ -50,7 +50,7 @@
)
# Retrieve content using semantic search
-agent.memory(
+agent.tool.memory(
action="retrieve",
query="meeting information",
min_score=0.7,
@@ -58,7 +58,7 @@
)
# List all documents
-agent.memory(
+agent.tool.memory(
action="list",
max_results=50,
STRANDS_KNOWLEDGE_BASE_ID="my1234kb"
diff --git a/src/strands_tools/slack.py b/src/strands_tools/slack.py
index 5dccb562..2f3cc879 100644
--- a/src/strands_tools/slack.py
+++ b/src/strands_tools/slack.py
@@ -351,12 +351,14 @@ def _process_message(self, event):
return
tools = list(self.agent.tool_registry.registry.values())
+ trace_attributes = self.agent.trace_attributes
agent = Agent(
messages=[],
system_prompt=f"{self.agent.system_prompt}\n{SLACK_SYSTEM_PROMPT}",
tools=tools,
callback_handler=None,
+ trace_attributes=trace_attributes,
)
channel_id = event.get("channel")
diff --git a/src/strands_tools/think.py b/src/strands_tools/think.py
index a1c210e3..ec3ab5a5 100644
--- a/src/strands_tools/think.py
+++ b/src/strands_tools/think.py
@@ -9,9 +9,9 @@
Usage with Strands Agent:
```python
from strands import Agent
-from strands_tools import think
+from strands_tools import think, stop
-agent = Agent(tools=[think])
+agent = Agent(tools=[think, stop])
# Basic usage with default system prompt
result = agent.tool.think(
@@ -32,13 +32,15 @@
See the think function docstring for more details on configuration options and parameters.
"""
+import logging
import traceback
import uuid
from typing import Any, Dict
-from strands import tool
+from strands import Agent, tool
+from strands.telemetry.metrics import metrics_to_string
-from strands_tools.use_llm import use_llm
+logger = logging.getLogger(__name__)
class ThoughtProcessor:
@@ -77,36 +79,46 @@ def process_cycle(
) -> str:
"""Process a single thinking cycle."""
+ logger.debug(f"🧠 Thinking Cycle {cycle}/{total_cycles}: Processing cycle...")
print(f"🧠 Thinking Cycle {cycle}/{total_cycles}: Processing cycle...")
# Create cycle-specific prompt
prompt = self.create_thinking_prompt(thought, cycle, total_cycles)
- # Use LLM for processing
- result = use_llm(
- {
- "name": "use_llm",
- "toolUseId": self.tool_use_id,
- "input": {
- "system_prompt": custom_system_prompt,
- "prompt": prompt,
- },
- },
- **kwargs,
- )
-
- # Extract and return response
- cycle_response = ""
- if result.get("status") == "success":
- for content in result.get("content", []):
- if content.get("text"):
- cycle_response += content["text"] + "\n"
-
- return cycle_response.strip()
+ # Display input prompt
+ logger.debug(f"\n--- Input Prompt ---\n{prompt}\n")
+
+ # Get tools from parent agent if available
+ tools = []
+ trace_attributes = {}
+ parent_agent = kwargs.get("agent")
+ if parent_agent:
+ tools = list(parent_agent.tool_registry.registry.values())
+ trace_attributes = parent_agent.trace_attributes
+
+ # Initialize the new Agent with provided parameters
+ agent = Agent(messages=[], tools=tools, system_prompt=custom_system_prompt, trace_attributes=trace_attributes)
+
+ # Run the agent with the provided prompt
+ result = agent(prompt)
+
+ # Extract response
+ assistant_response = str(result)
+
+ # Display assistant response
+ logger.debug(f"\n--- Assistant Response ---\n{assistant_response.strip()}\n")
+
+ # Print metrics if available
+ if result.metrics:
+ metrics = result.metrics
+ metrics_text = metrics_to_string(metrics)
+ logger.debug(metrics_text)
+
+ return assistant_response.strip()
@tool
-def think(thought: str, cycle_count: int, system_prompt: str, **kwargs: Any) -> Dict[str, Any]:
+def think(thought: str, cycle_count: int, system_prompt: str, agent: Any) -> Dict[str, Any]:
"""
Recursive thinking tool for sophisticated thought generation, learning, and self-reflection.
@@ -172,7 +184,7 @@ def think(thought: str, cycle_count: int, system_prompt: str, **kwargs: Any) ->
custom_system_prompt = (
"You are an expert analytical thinker. Process the thought deeply and provide clear insights."
)
-
+ kwargs = {"agent": agent}
# Create thought processor instance with the available context
processor = ThoughtProcessor(kwargs)
diff --git a/src/strands_tools/use_llm.py b/src/strands_tools/use_llm.py
index a294efad..0807b052 100644
--- a/src/strands_tools/use_llm.py
+++ b/src/strands_tools/use_llm.py
@@ -117,9 +117,12 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult:
tool_system_prompt = tool_input.get("system_prompt")
tools = []
+ trace_attributes = {}
+
parent_agent = kwargs.get("agent")
if parent_agent:
tools = list(parent_agent.tool_registry.registry.values())
+ trace_attributes = parent_agent.trace_attributes
# Display input prompt
logger.debug(f"\n--- Input Prompt ---\n{prompt}\n")
@@ -128,11 +131,7 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult:
logger.debug("🔄 Creating new LLM instance...")
# Initialize the new Agent with provided parameters
- agent = Agent(
- messages=[],
- tools=tools,
- system_prompt=tool_system_prompt,
- )
+ agent = Agent(messages=[], tools=tools, system_prompt=tool_system_prompt, trace_attributes=trace_attributes)
# Run the agent with the provided prompt
result = agent(prompt)
diff --git a/src/strands_tools/utils/user_input.py b/src/strands_tools/utils/user_input.py
index c1e895cd..16eebb0e 100644
--- a/src/strands_tools/utils/user_input.py
+++ b/src/strands_tools/utils/user_input.py
@@ -8,10 +8,13 @@
from prompt_toolkit import HTML, PromptSession
from prompt_toolkit.patch_stdout import patch_stdout
-session: PromptSession = PromptSession()
+# Lazy initialize to avoid import errors for tests on windows without a terminal
+session: PromptSession | None = None
async def get_user_input_async(prompt: str, default: str = "n") -> str:
+ global session
+
"""
Asynchronously get user input with prompt_toolkit's features (history, arrow keys, styling, etc).
@@ -25,6 +28,9 @@ async def get_user_input_async(prompt: str, default: str = "n") -> str:
try:
with patch_stdout(raw=True):
+ if session is None:
+ session = PromptSession()
+
response = await session.prompt_async(HTML(f"{prompt} "))
if not response:
diff --git a/tests/test_environment.py b/tests/test_environment.py
index 72d9ce74..921276e3 100644
--- a/tests/test_environment.py
+++ b/tests/test_environment.py
@@ -3,28 +3,33 @@
"""
import os
+from unittest import mock
import pytest
from strands import Agent
from strands_tools import environment
+from strands_tools.utils import user_input
@pytest.fixture
def agent():
"""Create an agent with the environment tool loaded."""
- return Agent(tools=[environment])
+ return Agent(tools=[environment], load_tools_from_directory=False)
-@pytest.fixture
-def test_env_var():
- """Create and clean up a test environment variable."""
- var_name = "TEST_ENV_VAR"
- var_value = "test_value"
- os.environ[var_name] = var_value
- yield var_name, var_value
- # Clean up: remove test variable if it exists
- if var_name in os.environ:
- del os.environ[var_name]
+@pytest.fixture(autouse=True)
+def get_user_input():
+ with mock.patch.object(user_input, "get_user_input") as mocked_user_input:
+ # By default all tests will return deny
+ mocked_user_input.return_value = "n"
+ yield mocked_user_input
+
+
+@pytest.fixture(autouse=True)
+def os_environment():
+ mock_env = {}
+ with mock.patch.object(os, "environ", mock_env):
+ yield mock_env
def extract_result_text(result):
@@ -42,18 +47,24 @@ def test_direct_list_action(agent):
assert len(extract_result_text(result)) > 0
-def test_direct_list_with_prefix(agent, test_env_var):
+def test_direct_list_with_prefix(agent, os_environment):
"""Test listing environment variables with a specific prefix."""
- var_name, _ = test_env_var
+ var_name = "TEST_ENV_VAR"
+ var_value = "test_value"
+ os_environment[var_name] = var_value
+
result = agent.tool.environment(action="list", prefix=var_name[:4])
assert result["status"] == "success"
# Verify our test variable is in the result
assert var_name in extract_result_text(result)
-def test_direct_get_existing_var(agent, test_env_var):
+def test_direct_get_existing_var(agent, os_environment):
"""Test getting an existing environment variable."""
- var_name, var_value = test_env_var
+ var_name = "TEST_ENV_VAR"
+ var_value = "test_value"
+ os_environment[var_name] = var_value
+
result = agent.tool.environment(action="get", name=var_name)
assert result["status"] == "success"
assert var_name in extract_result_text(result)
@@ -67,10 +78,9 @@ def test_direct_get_nonexistent_var(agent):
assert "not found" in extract_result_text(result)
-def test_direct_set_protected_var(agent, monkeypatch):
+def test_direct_set_protected_var(agent, os_environment):
"""Test attempting to set a protected environment variable."""
- # Mock get_user_input to always return 'y' for confirmation
- monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y")
+ os_environment["PATH"] = "/original/path"
# Try to modify PATH which is in PROTECTED_VARS
result = agent.tool.environment(action="set", name="PATH", value="/bad/path")
@@ -79,101 +89,91 @@ def test_direct_set_protected_var(agent, monkeypatch):
assert os.environ["PATH"] != "/bad/path"
-def test_direct_set_var_cancelled(agent, monkeypatch):
- """Test cancelling setting an environment variable."""
- # Mock get_user_input to return 'n' to cancel
- monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "n")
+def test_direct_set_var_allowed(get_user_input, agent):
+ """Test attempting to set a protected environment variable and allowing it."""
+ get_user_input.return_value = "y"
var_name = "CANCELLED_VAR"
var_value = "cancelled_value"
- # Clean up in case the variable exists
- if var_name in os.environ:
- del os.environ[var_name]
+ result = agent.tool.environment(action="set", name=var_name, value=var_value)
+ assert result["status"] == "success"
+ assert var_name in os.environ
+ assert get_user_input.call_count == 1
+
+
+def test_direct_set_var_cancelled(agent):
+ var_name = "CANCELLED_VAR"
+ var_value = "cancelled_value"
- try:
- result = agent.tool.environment(action="set", name=var_name, value=var_value)
- assert result["status"] == "error"
- assert "cancelled" in extract_result_text(result).lower()
- # Verify variable was not set
- assert var_name not in os.environ
- finally:
- # Clean up
- if var_name in os.environ:
- del os.environ[var_name]
+ result = agent.tool.environment(action="set", name=var_name, value=var_value)
+ assert result["status"] == "error"
+ assert "cancelled" in extract_result_text(result).lower()
+ # Verify variable was not set
+ assert var_name not in os.environ
-def test_direct_delete_nonexistent_var(agent, monkeypatch):
+def test_direct_delete_nonexistent_var(agent):
"""Test attempting to delete a non-existent variable."""
- # Mock get_user_input to always return 'y' for confirmation
- monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y")
-
var_name = "NONEXISTENT_VAR_FOR_DELETE_TEST"
- # Make sure the variable doesn't exist
- if var_name in os.environ:
- del os.environ[var_name]
-
result = agent.tool.environment(action="delete", name=var_name)
assert result["status"] == "error"
assert "not found" in extract_result_text(result).lower()
-def test_direct_delete_protected_var(agent, monkeypatch):
+def test_direct_delete_protected_var(agent, os_environment):
"""Test attempting to delete a protected environment variable."""
- # Mock get_user_input to always return 'y' for confirmation
- monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "y")
-
# Try to delete PATH which is in PROTECTED_VARS
- original_path = os.environ.get("PATH", "")
- try:
- result = agent.tool.environment(action="delete", name="PATH")
- assert result["status"] == "error"
- # Verify PATH still exists
- assert "PATH" in os.environ
- finally:
- # Restore PATH if somehow it got deleted
- if "PATH" not in os.environ:
- os.environ["PATH"] = original_path
-
-
-def test_direct_delete_var_cancelled(agent, monkeypatch):
+ unchanging_value = "/original/path"
+ os_environment["PATH"] = unchanging_value
+
+ result = agent.tool.environment(action="delete", name="PATH")
+ assert result["status"] == "error"
+ # Verify PATH still exists
+ assert os_environment["PATH"] == unchanging_value
+
+
+def test_direct_delete_var_cancelled(agent, os_environment):
"""Test cancelling deletion of an environment variable."""
- # Mock get_user_input to return 'n' to cancel
- monkeypatch.setattr("strands_tools.utils.user_input.get_user_input", lambda _: "n")
+ var_name = "CANCEL_DELETE_VAR"
+ var_value = "cancel_delete_value"
+
+ # Set up the variable
+ os_environment[var_name] = var_value
- # Ensure DEV mode is disabled to force confirmation
- current_dev = os.environ.get("DEV", None)
- if current_dev:
- os.environ.pop("DEV")
+ result = agent.tool.environment(action="delete", name=var_name)
+ assert result["status"] == "error"
+ assert "cancelled" in extract_result_text(result).lower()
+ # Verify variable still exists
+ assert var_name in os.environ
+ assert os_environment[var_name] == var_value
+
+
+def test_direct_delete_var_allowed(agent, get_user_input, os_environment):
+ """Test allowing deletion of an environment variable."""
+ get_user_input.return_value = "y"
var_name = "CANCEL_DELETE_VAR"
var_value = "cancel_delete_value"
# Set up the variable
- os.environ[var_name] = var_value
-
- try:
- result = agent.tool.environment(action="delete", name=var_name)
- assert result["status"] == "error"
- assert "cancelled" in extract_result_text(result).lower()
- # Verify variable still exists
- assert var_name in os.environ
- assert os.environ[var_name] == var_value
- finally:
- # Clean up
- if var_name in os.environ:
- del os.environ[var_name]
- # Restore DEV mode if it was set
- if current_dev:
- os.environ["DEV"] = current_dev
- if var_name in os.environ:
- del os.environ[var_name]
-
-
-def test_direct_validate_existing_var(agent, test_env_var):
+ os_environment[var_name] = var_value
+
+ result = agent.tool.environment(action="delete", name=var_name)
+ assert result["status"] == "success"
+ assert "deleted environment variable" in extract_result_text(result).lower()
+ assert os_environment.get(var_name) is None
+
+
+def test_direct_validate_existing_var(agent, os_environment):
"""Test validating an existing environment variable."""
- var_name, _ = test_env_var
+ var_name = "CANCEL_DELETE_VAR"
+ var_value = "cancel_delete_value"
+
+ # Set up the variable
+ os_environment[var_name] = var_value
+
result = agent.tool.environment(action="validate", name=var_name)
assert result["status"] == "success"
assert "valid" in extract_result_text(result).lower()
@@ -201,76 +201,52 @@ def test_direct_missing_parameters(agent):
assert result["status"] == "error"
-def test_environment_dev_mode_delete(agent):
+def test_environment_dev_mode_delete(agent, os_environment):
"""Test the environment tool in DEV mode with delete action."""
# Set DEV mode
- original_dev = os.environ.get("DEV")
- os.environ["DEV"] = "true"
+ os_environment["DEV"] = "true"
var_name = "DEV_MODE_DELETE_VAR"
var_value = "dev_mode_delete_value"
- try:
- # Set up the variable
- os.environ[var_name] = var_value
-
- result = agent.tool.environment(action="delete", name=var_name)
- assert result["status"] == "success"
- assert var_name not in os.environ
- finally:
- # Clean up
- if var_name in os.environ:
- del os.environ[var_name]
+ # Set up the variable
+ os_environment[var_name] = var_value
- # Restore original DEV value
- if original_dev is None:
- if "DEV" in os.environ:
- del os.environ["DEV"]
- else:
- os.environ["DEV"] = original_dev
+ result = agent.tool.environment(action="delete", name=var_name)
+ assert result["status"] == "success"
+ assert var_name not in os_environment
-def test_environment_dev_mode_protected_var(agent, monkeypatch):
+def test_environment_dev_mode_protected_var(agent, os_environment):
"""Test that protected variables are still protected in DEV mode."""
# Set DEV mode
- original_dev = os.environ.get("DEV")
- os.environ["DEV"] = "true"
-
- try:
- # Try to modify PATH which is protected
- result = agent.tool.environment(action="set", name="PATH", value="/bad/path")
- assert result["status"] == "error"
- # Verify PATH was not changed
- assert os.environ["PATH"] != "/bad/path"
- finally:
- # Restore original DEV value
- if original_dev is None:
- if "DEV" in os.environ:
- del os.environ["DEV"]
- else:
- os.environ["DEV"] = original_dev
-
-
-def test_environment_masked_values(agent, test_env_var):
+ os_environment["DEV"] = True
+
+ unchanging_value = "/original/path"
+ os_environment["PATH"] = unchanging_value
+
+ # Try to modify PATH which is protected
+ result = agent.tool.environment(action="set", name="PATH", value="/bad/path")
+ assert result["status"] == "error"
+ # Verify PATH was not changed
+ assert os_environment != unchanging_value
+
+
+def test_environment_masked_values(agent, os_environment):
"""Test that sensitive values are masked in output."""
# Create a sensitive looking variable
sensitive_name = "TEST_TOKEN_SECRET"
sensitive_value = "abcd1234efgh5678"
- os.environ[sensitive_name] = sensitive_value
-
- try:
- # Test with masking enabled (default)
- result = agent.tool.environment(action="get", name=sensitive_name)
- assert result["status"] == "success"
- # The full value should not appear in the output
- assert sensitive_value not in extract_result_text(result)
-
- # Test with masking disabled
- result = agent.tool.environment(action="get", name=sensitive_name, masked=False)
- assert result["status"] == "success"
- # Now the full value should appear
- assert sensitive_value in extract_result_text(result)
- finally:
- # Clean up
- if sensitive_name in os.environ:
- del os.environ[sensitive_name]
+ os_environment[sensitive_name] = sensitive_value
+
+ # Test with masking enabled (default)
+ result = agent.tool.environment(action="get", name=sensitive_name)
+ assert result["status"] == "success"
+ # The full value should not appear in the output
+ assert sensitive_value not in extract_result_text(result)
+
+ # Test with masking disabled
+ result = agent.tool.environment(action="get", name=sensitive_name, masked=False)
+ assert result["status"] == "success"
+ # Now the full value should appear
+ assert sensitive_value in extract_result_text(result)
diff --git a/tests/test_file_write.py b/tests/test_file_write.py
index e72a4a7d..8a7ce5cf 100644
--- a/tests/test_file_write.py
+++ b/tests/test_file_write.py
@@ -148,8 +148,6 @@ def test_file_write_error_handling(mock_user_input, temp_file):
# This should fail on most systems
if os.name == "posix": # Unix/Linux/Mac
invalid_path = "/root/test_no_permission.txt"
- elif os.name == "nt": # Windows
- invalid_path = "C:\\Windows\\System32\\config\\nopermission.txt"
else:
# Fallback - create a path that's too long
invalid_path = os.path.join(temp_file, "a" * 1000 + ".txt")
diff --git a/tests/test_python_repl.py b/tests/test_python_repl.py
index 672e14cc..6eddd1f1 100644
--- a/tests/test_python_repl.py
+++ b/tests/test_python_repl.py
@@ -12,6 +12,10 @@
import dill
import pytest
from strands import Agent
+
+if os.name == "nt":
+ pytest.skip("skipping on windows until issue #17 is resolved", allow_module_level=True)
+
from strands_tools import python_repl
diff --git a/tests/test_shell.py b/tests/test_shell.py
index d270f93f..4e6dba8c 100644
--- a/tests/test_shell.py
+++ b/tests/test_shell.py
@@ -5,12 +5,17 @@
import os
import signal
import sys
-import termios
from unittest.mock import MagicMock, patch
import pytest
from rich.panel import Panel
from strands import Agent
+
+if os.name == "nt":
+ pytest.skip("skipping on windows until issue #17 is resolved", allow_module_level=True)
+
+import termios
+
from strands_tools import shell
diff --git a/tests/test_think.py b/tests/test_think.py
index a33e25e9..77882b06 100644
--- a/tests/test_think.py
+++ b/tests/test_think.py
@@ -2,8 +2,9 @@
Tests for the think tool using the Agent interface.
"""
-from unittest.mock import patch
+from unittest.mock import MagicMock, patch
+from strands.agent import AgentResult
from strands_tools import think
from strands_tools.think import ThoughtProcessor
@@ -28,13 +29,17 @@ def test_think_tool_direct():
},
}
- # Mock use_llm function since we don't want to actually call the LLM
- with patch("strands_tools.think.use_llm") as mock_use_llm:
- # Setup mock response
- mock_use_llm.return_value = {
- "status": "success",
- "content": [{"text": "This is a mock analysis of quantum computing."}],
- }
+ # Mock Agent class since we don't want to actually call the LLM
+ with patch("strands_tools.think.Agent") as mock_agent_class:
+ # Setup mock agent and response
+ mock_agent = mock_agent_class.return_value
+ mock_result = AgentResult(
+ message={"content": [{"text": "This is a mock analysis of quantum computing."}]},
+ stop_reason="end_turn",
+ metrics=None,
+ state=MagicMock(),
+ )
+ mock_agent.return_value = mock_result
# Call the think function directly
tool_input = tool_use.get("input", {})
@@ -42,6 +47,7 @@ def test_think_tool_direct():
thought=tool_input.get("thought"),
cycle_count=tool_input.get("cycle_count"),
system_prompt=tool_input.get("system_prompt"),
+ agent=None,
)
# Verify the result has the expected structure
@@ -49,8 +55,8 @@ def test_think_tool_direct():
assert "Cycle 1/2" in result["content"][0]["text"]
assert "Cycle 2/2" in result["content"][0]["text"]
- # Verify use_llm was called twice (once for each cycle)
- assert mock_use_llm.call_count == 2
+ # Verify Agent was called twice (once for each cycle)
+ assert mock_agent.call_count == 2
def test_think_one_cycle():
@@ -65,22 +71,27 @@ def test_think_one_cycle():
},
}
- with patch("strands_tools.think.use_llm") as mock_use_llm:
- mock_use_llm.return_value = {
- "status": "success",
- "content": [{"text": "Analysis for single cycle."}],
- }
+ with patch("strands_tools.think.Agent") as mock_agent_class:
+ mock_agent = mock_agent_class.return_value
+ mock_result = AgentResult(
+ message={"content": [{"text": "Analysis for single cycle."}]},
+ stop_reason="end_turn",
+ metrics=None,
+ state=MagicMock(),
+ )
+ mock_agent.return_value = mock_result
tool_input = tool_use.get("input", {})
result = think.think(
thought=tool_input.get("thought"),
cycle_count=tool_input.get("cycle_count"),
system_prompt=tool_input.get("system_prompt"),
+ agent=None,
)
assert result["status"] == "success"
assert "Cycle 1/1" in result["content"][0]["text"]
- assert mock_use_llm.call_count == 1
+ assert mock_agent.call_count == 1
def test_think_error_handling():
@@ -95,15 +106,16 @@ def test_think_error_handling():
},
}
- with patch("strands_tools.think.use_llm") as mock_use_llm:
- # Make use_llm raise an exception
- mock_use_llm.side_effect = Exception("Test error")
+ with patch("strands_tools.think.Agent") as mock_agent_class:
+ # Make Agent raise an exception
+ mock_agent_class.side_effect = Exception("Test error")
tool_input = tool_use.get("input", {})
result = think.think(
thought=tool_input.get("thought"),
cycle_count=tool_input.get("cycle_count"),
system_prompt=tool_input.get("system_prompt"),
+ agent=None,
)
assert result["status"] == "error"
diff --git a/tests/test_use_llm.py b/tests/test_use_llm.py
index c978c781..ccb6eb8c 100644
--- a/tests/test_use_llm.py
+++ b/tests/test_use_llm.py
@@ -59,7 +59,9 @@ def test_use_llm_tool_direct(mock_agent_response):
assert "This is a test response from the LLM" in str(result)
# Verify the Agent was created with the correct parameters
- MockAgent.assert_called_once_with(messages=[], tools=[], system_prompt="You are a helpful test assistant")
+ MockAgent.assert_called_once_with(
+ messages=[], tools=[], system_prompt="You are a helpful test assistant", trace_attributes={}
+ )
def test_use_llm_with_custom_system_prompt(mock_agent_response):
@@ -82,7 +84,9 @@ def test_use_llm_with_custom_system_prompt(mock_agent_response):
result = use_llm.use_llm(tool=tool_use)
# Verify agent was created with correct system prompt
- MockAgent.assert_called_once_with(messages=[], tools=[], system_prompt="You are a specialized test assistant")
+ MockAgent.assert_called_once_with(
+ messages=[], tools=[], system_prompt="You are a specialized test assistant", trace_attributes={}
+ )
assert result["status"] == "success"
assert "Custom response" in result["content"][0]["text"]