Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ def inject_special_parameters(
invocation_state: Context for the tool invocation, including agent state.
"""
if self._context_param and self._context_param in self.signature.parameters:
tool_context = ToolContext(tool_use=tool_use, agent=invocation_state["agent"])
tool_context = ToolContext(
tool_use=tool_use, agent=invocation_state["agent"], invocation_state=invocation_state
)
validated_input[self._context_param] = tool_context

# Inject agent if requested (backward compatibility)
Expand Down
3 changes: 3 additions & 0 deletions src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class ToolContext:
tool_use: The complete ToolUse object containing tool invocation details.
agent: The Agent instance executing this tool, providing access to conversation history,
model configuration, and other agent state.
invocation_state: Keyword arguments passed to agent invocation methods (agent(), agent.invoke_async(), etc.).
Provides access to invocation-specific context and parameters.

Note:
This class is intended to be instantiated by the SDK. Direct construction by users
Expand All @@ -140,6 +142,7 @@ class ToolContext:

tool_use: ToolUse
agent: "Agent"
invocation_state: dict[str, Any]


ToolChoice = Union[
Expand Down
12 changes: 10 additions & 2 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def complex_schema_tool(union_param: Union[List[int], Dict[str, Any], str, None]
assert "NoneType: None" in result["content"][0]["text"]


async def _run_context_injection_test(context_tool: AgentTool):
async def _run_context_injection_test(context_tool: AgentTool, additional_context=None):
"""Common test logic for context injection tests."""
tool: AgentTool = context_tool
generator = tool.stream(
Expand All @@ -1052,6 +1052,7 @@ async def _run_context_injection_test(context_tool: AgentTool):
},
invocation_state={
"agent": Agent(name="test_agent"),
**(additional_context or {}),
},
)
tool_results = [value async for value in generator]
Expand Down Expand Up @@ -1081,6 +1082,8 @@ def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
tool_name = tool_context.tool_use["name"]
agent_from_tool_context = tool_context.agent

assert tool_context.invocation_state["new_value"] == 13

return {
"status": "success",
"content": [
Expand All @@ -1090,7 +1093,12 @@ def context_tool(message: str, agent: Agent, tool_context: ToolContext) -> dict:
],
}

await _run_context_injection_test(context_tool)
await _run_context_injection_test(
context_tool,
{
"new_value": 13,
},
)


@pytest.mark.asyncio
Expand Down
Loading