Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/strands_tools/load_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def load_tool(path: str, name: str, agent=None) -> Dict[str, Any]:

Tool Loading Process:
-------------------
- First, checks if dynamic loading is permitted (hot_reload_tools=True)
- Expands the path to handle user paths with tilde (~)
- Validates that the file exists at the specified path
- Uses the tool_registry's load_tool_from_filepath method to:
Expand Down Expand Up @@ -175,7 +174,6 @@ def my_custom_tool(tool: ToolUse, **kwargs: Any) -> ToolResult:

Notes:
- The tool loading can be disabled via STRANDS_DISABLE_LOAD_TOOL=true environment variable
- The Agent instance must have hot_reload_tools=True to enable dynamic loading
- Python files in the cwd()/tools/ directory are automatically hot reloaded without
requiring explicit calls to load_tool
- When using the load_tool function, ensure your tool files have proper docstrings as they are
Expand All @@ -187,8 +185,8 @@ def my_custom_tool(tool: ToolUse, **kwargs: Any) -> ToolResult:
current_agent = agent

try:
# Check if dynamic tool loading is disabled via environment variable or agent.hot_reload_tools.
if not current_agent.hot_reload_tools or os.environ.get("STRANDS_DISABLE_LOAD_TOOL", "").lower() == "true":
# Check if dynamic tool loading is disabled via environment variable.
if os.environ.get("STRANDS_DISABLE_LOAD_TOOL", "").lower() == "true":
logger.warning("Dynamic tool loading is disabled via STRANDS_DISABLE_LOAD_TOOL=true")
return {"status": "error", "content": [{"text": "⚠️ Dynamic tool loading is disabled in production mode."}]}

Expand Down
37 changes: 10 additions & 27 deletions src/strands_tools/mem0_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,10 @@
"required": ["action"],
"allOf": [
{
"if": {
"properties": {
"action": {"enum": ["store", "list", "retrieve"]}
}
},
"then": {
"oneOf": [
{"required": ["user_id"]},
{"required": ["agent_id"]}
]
}
"if": {"properties": {"action": {"enum": ["store", "list", "retrieve"]}}},
"then": {"oneOf": [{"required": ["user_id"]}, {"required": ["agent_id"]}]},
}
]
],
}
},
}
Expand Down Expand Up @@ -536,7 +527,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
return ToolResult(
toolUseId=tool_use_id,
status="success",
content=[ToolResultContent(text=f"Successfully stored {len(results.get('results', []))} memories")]
content=[ToolResultContent(text=f"Successfully stored {len(results.get('results', []))} memories")],
)

elif action == "get":
Expand All @@ -547,9 +538,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
panel = format_get_response(memory)
console.print(panel)
return ToolResult(
toolUseId=tool_use_id,
status="success",
content=[ToolResultContent(text=json.dumps(memory, indent=2))]
toolUseId=tool_use_id, status="success", content=[ToolResultContent(text=json.dumps(memory, indent=2))]
)

elif action == "list":
Expand All @@ -559,7 +548,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
return ToolResult(
toolUseId=tool_use_id,
status="success",
content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))]
content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))],
)

elif action == "retrieve":
Expand All @@ -576,7 +565,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
return ToolResult(
toolUseId=tool_use_id,
status="success",
content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))]
content=[ToolResultContent(text=json.dumps(memories.get("results", []), indent=2))],
)

elif action == "delete":
Expand All @@ -589,7 +578,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
return ToolResult(
toolUseId=tool_use_id,
status="success",
content=[ToolResultContent(text=f"Memory {tool_input['memory_id']} deleted successfully")]
content=[ToolResultContent(text=f"Memory {tool_input['memory_id']} deleted successfully")],
)

elif action == "history":
Expand All @@ -600,9 +589,7 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
panel = format_history_response(history)
console.print(panel)
return ToolResult(
toolUseId=tool_use_id,
status="success",
content=[ToolResultContent(text=json.dumps(history, indent=2))]
toolUseId=tool_use_id, status="success", content=[ToolResultContent(text=json.dumps(history, indent=2))]
)

else:
Expand All @@ -615,8 +602,4 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
border_style="red",
)
console.print(error_panel)
return ToolResult(
toolUseId=tool_use_id,
status="error",
content=[ToolResultContent(text=f"Error: {str(e)}")]
)
return ToolResult(toolUseId=tool_use_id, status="error", content=[ToolResultContent(text=f"Error: {str(e)}")])
2 changes: 2 additions & 0 deletions src/strands_tools/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,14 @@ def _process_message(self, event):
return

tools = list(self.agent.tool_registry.registry.values())
trace_attributes = self.agent.trace_attributes

agent = Agent(
messages=[],
system_prompt=f"{self.agent.system_prompt}\n{SLACK_SYSTEM_PROMPT}",
tools=tools,
callback_handler=None,
trace_attributes=trace_attributes,
)

channel_id = event.get("channel")
Expand Down
66 changes: 39 additions & 27 deletions src/strands_tools/think.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
Usage with Strands Agent:
```python
from strands import Agent
from strands_tools import think
from strands_tools import think, stop

agent = Agent(tools=[think])
agent = Agent(tools=[think, stop])

# Basic usage with default system prompt
result = agent.tool.think(
Expand All @@ -32,13 +32,15 @@
See the think function docstring for more details on configuration options and parameters.
"""

import logging
import traceback
import uuid
from typing import Any, Dict

from strands import tool
from strands import Agent, tool
from strands.telemetry.metrics import metrics_to_string

from strands_tools.use_llm import use_llm
logger = logging.getLogger(__name__)


class ThoughtProcessor:
Expand Down Expand Up @@ -77,36 +79,46 @@ def process_cycle(
) -> str:
"""Process a single thinking cycle."""

logger.debug(f"🧠 Thinking Cycle {cycle}/{total_cycles}: Processing cycle...")
print(f"🧠 Thinking Cycle {cycle}/{total_cycles}: Processing cycle...")

# Create cycle-specific prompt
prompt = self.create_thinking_prompt(thought, cycle, total_cycles)

# Use LLM for processing
result = use_llm(
{
"name": "use_llm",
"toolUseId": self.tool_use_id,
"input": {
"system_prompt": custom_system_prompt,
"prompt": prompt,
},
},
**kwargs,
)

# Extract and return response
cycle_response = ""
if result.get("status") == "success":
for content in result.get("content", []):
if content.get("text"):
cycle_response += content["text"] + "\n"

return cycle_response.strip()
# Display input prompt
logger.debug(f"\n--- Input Prompt ---\n{prompt}\n")

# Get tools from parent agent if available
tools = []
trace_attributes = {}
parent_agent = kwargs.get("agent")
if parent_agent:
tools = list(parent_agent.tool_registry.registry.values())
trace_attributes = parent_agent.trace_attributes

# Initialize the new Agent with provided parameters
agent = Agent(messages=[], tools=tools, system_prompt=custom_system_prompt, trace_attributes=trace_attributes)

# Run the agent with the provided prompt
result = agent(prompt)

# Extract response
assistant_response = str(result)

# Display assistant response
logger.debug(f"\n--- Assistant Response ---\n{assistant_response.strip()}\n")

# Print metrics if available
if result.metrics:
metrics = result.metrics
metrics_text = metrics_to_string(metrics)
logger.debug(metrics_text)

return assistant_response.strip()


@tool
def think(thought: str, cycle_count: int, system_prompt: str, **kwargs: Any) -> Dict[str, Any]:
def think(thought: str, cycle_count: int, system_prompt: str, agent: Any) -> Dict[str, Any]:
"""
Recursive thinking tool for sophisticated thought generation, learning, and self-reflection.

Expand Down Expand Up @@ -172,7 +184,7 @@ def think(thought: str, cycle_count: int, system_prompt: str, **kwargs: Any) ->
custom_system_prompt = (
"You are an expert analytical thinker. Process the thought deeply and provide clear insights."
)

kwargs = {"agent": agent}
# Create thought processor instance with the available context
processor = ThoughtProcessor(kwargs)

Expand Down
9 changes: 4 additions & 5 deletions src/strands_tools/use_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult:
tool_system_prompt = tool_input.get("system_prompt")

tools = []
trace_attributes = {}

parent_agent = kwargs.get("agent")
if parent_agent:
tools = list(parent_agent.tool_registry.registry.values())
trace_attributes = parent_agent.trace_attributes

# Display input prompt
logger.debug(f"\n--- Input Prompt ---\n{prompt}\n")
Expand All @@ -128,11 +131,7 @@ def use_llm(tool: ToolUse, **kwargs: Any) -> ToolResult:
logger.debug("🔄 Creating new LLM instance...")

# Initialize the new Agent with provided parameters
agent = Agent(
messages=[],
tools=tools,
system_prompt=tool_system_prompt,
)
agent = Agent(messages=[], tools=tools, system_prompt=tool_system_prompt, trace_attributes=trace_attributes)
# Run the agent with the provided prompt
result = agent(prompt)

Expand Down
50 changes: 31 additions & 19 deletions tests/test_think.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Tests for the think tool using the Agent interface.
"""

from unittest.mock import patch
from unittest.mock import MagicMock, patch

from strands.agent import AgentResult
from strands_tools import think
from strands_tools.think import ThoughtProcessor

Expand All @@ -28,29 +29,34 @@ def test_think_tool_direct():
},
}

# Mock use_llm function since we don't want to actually call the LLM
with patch("strands_tools.think.use_llm") as mock_use_llm:
# Setup mock response
mock_use_llm.return_value = {
"status": "success",
"content": [{"text": "This is a mock analysis of quantum computing."}],
}
# Mock Agent class since we don't want to actually call the LLM
with patch("strands_tools.think.Agent") as mock_agent_class:
# Setup mock agent and response
mock_agent = mock_agent_class.return_value
mock_result = AgentResult(
message={"content": [{"text": "This is a mock analysis of quantum computing."}]},
stop_reason="end_turn",
metrics=None,
state=MagicMock(),
)
mock_agent.return_value = mock_result

# Call the think function directly
tool_input = tool_use.get("input", {})
result = think.think(
thought=tool_input.get("thought"),
cycle_count=tool_input.get("cycle_count"),
system_prompt=tool_input.get("system_prompt"),
agent=None,
)

# Verify the result has the expected structure
assert result["status"] == "success"
assert "Cycle 1/2" in result["content"][0]["text"]
assert "Cycle 2/2" in result["content"][0]["text"]

# Verify use_llm was called twice (once for each cycle)
assert mock_use_llm.call_count == 2
# Verify Agent was called twice (once for each cycle)
assert mock_agent.call_count == 2


def test_think_one_cycle():
Expand All @@ -65,22 +71,27 @@ def test_think_one_cycle():
},
}

with patch("strands_tools.think.use_llm") as mock_use_llm:
mock_use_llm.return_value = {
"status": "success",
"content": [{"text": "Analysis for single cycle."}],
}
with patch("strands_tools.think.Agent") as mock_agent_class:
mock_agent = mock_agent_class.return_value
mock_result = AgentResult(
message={"content": [{"text": "Analysis for single cycle."}]},
stop_reason="end_turn",
metrics=None,
state=MagicMock(),
)
mock_agent.return_value = mock_result

tool_input = tool_use.get("input", {})
result = think.think(
thought=tool_input.get("thought"),
cycle_count=tool_input.get("cycle_count"),
system_prompt=tool_input.get("system_prompt"),
agent=None,
)

assert result["status"] == "success"
assert "Cycle 1/1" in result["content"][0]["text"]
assert mock_use_llm.call_count == 1
assert mock_agent.call_count == 1


def test_think_error_handling():
Expand All @@ -95,15 +106,16 @@ def test_think_error_handling():
},
}

with patch("strands_tools.think.use_llm") as mock_use_llm:
# Make use_llm raise an exception
mock_use_llm.side_effect = Exception("Test error")
with patch("strands_tools.think.Agent") as mock_agent_class:
# Make Agent raise an exception
mock_agent_class.side_effect = Exception("Test error")

tool_input = tool_use.get("input", {})
result = think.think(
thought=tool_input.get("thought"),
cycle_count=tool_input.get("cycle_count"),
system_prompt=tool_input.get("system_prompt"),
agent=None,
)

assert result["status"] == "error"
Expand Down
8 changes: 6 additions & 2 deletions tests/test_use_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def test_use_llm_tool_direct(mock_agent_response):
assert "This is a test response from the LLM" in str(result)

# Verify the Agent was created with the correct parameters
MockAgent.assert_called_once_with(messages=[], tools=[], system_prompt="You are a helpful test assistant")
MockAgent.assert_called_once_with(
messages=[], tools=[], system_prompt="You are a helpful test assistant", trace_attributes={}
)


def test_use_llm_with_custom_system_prompt(mock_agent_response):
Expand All @@ -82,7 +84,9 @@ def test_use_llm_with_custom_system_prompt(mock_agent_response):
result = use_llm.use_llm(tool=tool_use)

# Verify agent was created with correct system prompt
MockAgent.assert_called_once_with(messages=[], tools=[], system_prompt="You are a specialized test assistant")
MockAgent.assert_called_once_with(
messages=[], tools=[], system_prompt="You are a specialized test assistant", trace_attributes={}
)

assert result["status"] == "success"
assert "Custom response" in result["content"][0]["text"]
Expand Down