Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/strands/types/_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,25 +275,20 @@ def is_callback_event(self) -> bool:
class ToolStreamEvent(TypedEvent):
"""Event emitted when a tool yields sub-events as part of tool execution."""

def __init__(self, tool_use: ToolUse, tool_sub_event: Any) -> None:
def __init__(self, tool_use: ToolUse, tool_stream_data: Any) -> None:
"""Initialize with tool streaming data.

Args:
tool_use: The tool invocation producing the stream
tool_sub_event: The yielded event from the tool execution
tool_stream_data: The yielded event from the tool execution
"""
super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_event": tool_sub_event})
super().__init__({"tool_stream_tool_use": tool_use, "tool_stream_data": tool_stream_data})

@property
def tool_use_id(self) -> str:
"""The toolUseId associated with this stream."""
return cast(str, cast(ToolUse, self.get("tool_stream_tool_use")).get("toolUseId"))

@property
@override
def is_callback_event(self) -> bool:
return False


class ModelMessageEvent(TypedEvent):
"""Event emitted when the model invocation has completed.
Expand Down
18 changes: 9 additions & 9 deletions tests/strands/agent/hooks/test_agent_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,18 +260,18 @@ async def test_stream_e2e_success(alist):
"role": "assistant",
}
},
{
"tool_stream_data": {"tool_streaming": True},
"tool_stream_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"},
},
{
"tool_stream_data": "Final result",
"tool_stream_tool_use": {"input": {}, "name": "streaming_tool", "toolUseId": "12345"},
},
{
"message": {
"content": [
{
"toolResult": {
# TODO update this text when we get tool streaming implemented; right now this
# TODO is of the form '<async_generator object streaming_tool at 0x107d18a00>'
"content": [{"text": ANY}],
"status": "success",
"toolUseId": "12345",
}
},
{"toolResult": {"content": [{"text": "Final result"}], "status": "success", "toolUseId": "12345"}}
],
"role": "user",
}
Expand Down
145 changes: 143 additions & 2 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
"""

from asyncio import Queue
from typing import Any, Dict, Optional, Union
from typing import Any, AsyncGenerator, Dict, Optional, Union
from unittest.mock import MagicMock

import pytest

import strands
from strands import Agent
from strands.types._events import ToolResultEvent
from strands.types._events import ToolResultEvent, ToolStreamEvent
from strands.types.tools import AgentTool, ToolContext, ToolUse


Expand Down Expand Up @@ -1222,3 +1222,144 @@ def context_tool(message: str, agent: Agent, tool_context: str) -> str:
"toolUseId": "test-id-2",
}
)


@pytest.mark.asyncio
async def test_tool_async_generator():
"""Test that async generators yield results appropriately."""

@strands.tool(context=False)
async def async_generator() -> AsyncGenerator:
"""Tool that expects tool_context as a regular string parameter."""
yield 0
yield "Value 1"
yield {"nested": "value"}
yield {
"status": "success",
"content": [{"text": "Looks like tool result"}],
"toolUseId": "test-id-2",
}
yield "final result"

tool: AgentTool = async_generator
tool_use: ToolUse = {
"toolUseId": "test-id-2",
"name": "context_tool",
"input": {"message": "some_message", "tool_context": "my_custom_context_string"},
}
generator = tool.stream(
tool_use=tool_use,
invocation_state={
"agent": Agent(name="test_agent"),
},
)
act_results = [value async for value in generator]
exp_results = [
ToolStreamEvent(tool_use, 0),
ToolStreamEvent(tool_use, "Value 1"),
ToolStreamEvent(tool_use, {"nested": "value"}),
ToolStreamEvent(
tool_use,
{
"status": "success",
"content": [{"text": "Looks like tool result"}],
"toolUseId": "test-id-2",
},
),
ToolStreamEvent(tool_use, "final result"),
ToolResultEvent(
{
"status": "success",
"content": [{"text": "final result"}],
"toolUseId": "test-id-2",
}
),
]

assert act_results == exp_results


@pytest.mark.asyncio
async def test_tool_async_generator_exceptions_result_in_error():
"""Test that async generators handle exceptions."""

@strands.tool(context=False)
async def async_generator() -> AsyncGenerator:
"""Tool that expects tool_context as a regular string parameter."""
yield 13
raise ValueError("It's an error!")

tool: AgentTool = async_generator
tool_use: ToolUse = {
"toolUseId": "test-id-2",
"name": "context_tool",
"input": {"message": "some_message", "tool_context": "my_custom_context_string"},
}
generator = tool.stream(
tool_use=tool_use,
invocation_state={
"agent": Agent(name="test_agent"),
},
)
act_results = [value async for value in generator]
exp_results = [
ToolStreamEvent(tool_use, 13),
ToolResultEvent(
{
"status": "error",
"content": [{"text": "Error: It's an error!"}],
"toolUseId": "test-id-2",
}
),
]

assert act_results == exp_results


@pytest.mark.asyncio
async def test_tool_async_generator_yield_object_result():
"""Test that async generators handle exceptions."""

@strands.tool(context=False)
async def async_generator() -> AsyncGenerator:
"""Tool that expects tool_context as a regular string parameter."""
yield 13
yield {
"status": "success",
"content": [{"text": "final result"}],
"toolUseId": "test-id-2",
}

tool: AgentTool = async_generator
tool_use: ToolUse = {
"toolUseId": "test-id-2",
"name": "context_tool",
"input": {"message": "some_message", "tool_context": "my_custom_context_string"},
}
generator = tool.stream(
tool_use=tool_use,
invocation_state={
"agent": Agent(name="test_agent"),
},
)
act_results = [value async for value in generator]
exp_results = [
ToolStreamEvent(tool_use, 13),
ToolStreamEvent(
tool_use,
{
"status": "success",
"content": [{"text": "final result"}],
"toolUseId": "test-id-2",
},
),
ToolResultEvent(
{
"status": "success",
"content": [{"text": "final result"}],
"toolUseId": "test-id-2",
}
),
]

assert act_results == exp_results
Loading