diff --git a/src/strands/types/_events.py b/src/strands/types/_events.py index 1a7f48d4b..ccdab1846 100644 --- a/src/strands/types/_events.py +++ b/src/strands/types/_events.py @@ -275,24 +275,19 @@ def is_callback_event(self) -> bool: class ToolStreamEvent(TypedEvent): """Event emitted when a tool yields sub-events as part of tool execution.""" - def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None: + def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None: """Initialize with tool streaming data. Args: tool_use: The tool invocation producing the stream - tool_sub_event: The yielded event from the tool execution + tool_stream_data: The yielded event from the tool execution """ - super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event}) + super().__init__({"tool_stream_event": {"tool_use": tool_use, "data": tool_stream_data}}) @property def tool_use_id(self) -> str: """The toolUseId associated with this stream.""" - return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId")) - - @property - @override - def is_callback_event(self) -> bool: - return False + return cast(str, cast(ToolUse, cast(dict, self.get("tool_stream_event")).get("tool_use")).get("toolUseId")) class ModelMessageEvent(TypedEvent): diff --git a/tests/strands/agent/hooks/test_agent_events.py b/tests/strands/agent/hooks/test_agent_events.py index 04b832259..07f55b724 100644 --- a/tests/strands/agent/hooks/test_agent_events.py +++ b/tests/strands/agent/hooks/test_agent_events.py @@ -260,18 +260,22 @@ async def test_stream_e2e_success(alist): "role": "assistant", } }, + { + "tool_stream_event": { + "data": {"tool_streaming": True}, + "tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, + } + }, + { + "tool_stream_event": { + "data": "Final result", + "tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"}, + } + }, { "message": { "content": [ - { - "toolResult": { - # TODO update this text when we get tool streaming implemented; right now this - # TODO is of the form '' - "content": [{"text": ANY}], - "status": "success", - "toolUseId": "12345", - } - }, + {"toolResult": {"content": [{"text": "Final result"}], "status": "success", "toolUseId": "12345"}} ], "role": "user", } diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index a13c2833e..5b4b5cdda 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,14 +3,14 @@ """ from asyncio import Queue -from typing import Any, Dict, Optional, Union +from typing import Any, AsyncGenerator, Dict, Optional, Union from unittest.mock import MagicMock import pytest import strands from strands import Agent -from strands.types._events import ToolResultEvent +from strands.types._events import ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -1222,3 +1222,144 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str: "toolUseId": "test-id-2", } ) + + +@pytest.mark.asyncio +async def test_tool_async_generator(): + """Test that async generators yield results appropriately.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 0 + yield "Value 1" + yield {"nested": "value"} + yield { + "status": "success", + "content": [{"text": "Looks like tool result"}], + "toolUseId": "test-id-2", + } + yield "final result" + + tool: AgentTool = async_generator + tool_use: ToolUse = { + "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, + } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 0), + ToolStreamEvent(tool_use, "Value 1"), + ToolStreamEvent(tool_use, {"nested": "value"}), + ToolStreamEvent( + tool_use, + { + "status": "success", + "content": [{"text": "Looks like tool result"}], + "toolUseId": "test-id-2", + }, + ), + ToolStreamEvent(tool_use, "final result"), + ToolResultEvent( + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results + + +@pytest.mark.asyncio +async def test_tool_async_generator_exceptions_result_in_error(): + """Test that async generators handle exceptions.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 13 + raise ValueError("It's an error!") + + tool: AgentTool = async_generator + tool_use: ToolUse = { + "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, + } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 13), + ToolResultEvent( + { + "status": "error", + "content": [{"text": "Error: It's an error!"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results + + +@pytest.mark.asyncio +async def test_tool_async_generator_yield_object_result(): + """Test that async generators handle exceptions.""" + + @strands.tool(context=False) + async def async_generator() -> AsyncGenerator: + """Tool that expects tool_context as a regular string parameter.""" + yield 13 + yield { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + + tool: AgentTool = async_generator + tool_use: ToolUse = { + "toolUseId": "test-id-2", + "name": "context_tool", + "input": {"message": "some_message", "tool_context": "my_custom_context_string"}, + } + generator = tool.stream( + tool_use=tool_use, + invocation_state={ + "agent": Agent(name="test_agent"), + }, + ) + act_results = [value async for value in generator] + exp_results = [ + ToolStreamEvent(tool_use, 13), + ToolStreamEvent( + tool_use, + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + }, + ), + ToolResultEvent( + { + "status": "success", + "content": [{"text": "final result"}], + "toolUseId": "test-id-2", + } + ), + ] + + assert act_results == exp_results