Skip to content

Commit 60f16b9

Browse files
author
Murat Kaan Meral
committed
fix: Add integ tests
1 parent ca59221 commit 60f16b9

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

tests_integ/test_multiagent_graph.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, AsyncIterator
2+
13
import pytest
24

35
from strands import Agent, tool
@@ -9,6 +11,7 @@
911
BeforeModelCallEvent,
1012
MessageAddedEvent,
1113
)
14+
from strands.multiagent.base import MultiAgentBase, MultiAgentResult, NodeResult, Status
1215
from strands.multiagent.graph import GraphBuilder
1316
from strands.types.content import ContentBlock
1417
from tests.fixtures.mock_hook_provider import MockHookProvider
@@ -218,3 +221,182 @@ async def test_graph_execution_with_image(image_analysis_agent, summary_agent, y
218221

219222
assert hook_provider.extract_for(image_analysis_agent).event_types_received == expected_hook_events
220223
assert hook_provider.extract_for(summary_agent).event_types_received == expected_hook_events
224+
225+
226+
class CustomStreamingNode(MultiAgentBase):
227+
"""Custom node that wraps an agent and adds custom streaming events."""
228+
229+
def __init__(self, agent: Agent, name: str):
230+
self.agent = agent
231+
self.name = name
232+
233+
async def invoke_async(
234+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
235+
) -> MultiAgentResult:
236+
result = await self.agent.invoke_async(task, **kwargs)
237+
node_result = NodeResult(result=result, status=Status.COMPLETED)
238+
return MultiAgentResult(status=Status.COMPLETED, results={self.name: node_result})
239+
240+
async def stream_async(
241+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
242+
) -> AsyncIterator[dict[str, Any]]:
243+
yield {"custom_event": "start", "node": self.name}
244+
result = await self.agent.invoke_async(task, **kwargs)
245+
yield {"custom_event": "agent_complete", "node": self.name}
246+
node_result = NodeResult(result=result, status=Status.COMPLETED)
247+
yield {"result": MultiAgentResult(status=Status.COMPLETED, results={self.name: node_result})}
248+
249+
250+
@pytest.mark.asyncio
251+
async def test_graph_streaming_with_agents():
252+
"""Test that Graph properly streams events from agent nodes."""
253+
math_agent = Agent(
254+
name="math",
255+
model="us.amazon.nova-pro-v1:0",
256+
system_prompt="You are a math assistant.",
257+
tools=[calculate_sum],
258+
)
259+
summary_agent = Agent(
260+
name="summary",
261+
model="us.amazon.nova-lite-v1:0",
262+
system_prompt="You are a summary assistant.",
263+
)
264+
265+
builder = GraphBuilder()
266+
builder.add_node(math_agent, "math")
267+
builder.add_node(summary_agent, "summary")
268+
builder.add_edge("math", "summary")
269+
builder.set_entry_point("math")
270+
graph = builder.build()
271+
272+
# Collect events
273+
events = []
274+
async for event in graph.stream_async("Calculate 5 + 3 and summarize the result"):
275+
events.append(event)
276+
277+
# Count event categories
278+
node_start_events = [e for e in events if e.get("multi_agent_node_start")]
279+
node_stream_events = [e for e in events if e.get("multi_agent_node_stream")]
280+
node_complete_events = [e for e in events if e.get("multi_agent_node_complete")]
281+
result_events = [e for e in events if "result" in e and "multi_agent_node_start" not in e]
282+
283+
# Verify we got multiple events of each type
284+
assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}"
285+
assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}"
286+
assert len(node_complete_events) >= 2, f"Expected at least 2 node_complete events, got {len(node_complete_events)}"
287+
assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}"
288+
289+
# Verify we have events for both nodes
290+
math_events = [e for e in events if e.get("node_id") == "math"]
291+
summary_events = [e for e in events if e.get("node_id") == "summary"]
292+
assert len(math_events) > 0, "Expected events from math node"
293+
assert len(summary_events) > 0, "Expected events from summary node"
294+
295+
296+
@pytest.mark.asyncio
297+
async def test_graph_streaming_with_custom_node():
298+
"""Test that Graph properly streams events from custom MultiAgentBase nodes."""
299+
math_agent = Agent(
300+
name="math",
301+
model="us.amazon.nova-pro-v1:0",
302+
system_prompt="You are a math assistant.",
303+
tools=[calculate_sum],
304+
)
305+
summary_agent = Agent(
306+
name="summary",
307+
model="us.amazon.nova-lite-v1:0",
308+
system_prompt="You are a summary assistant.",
309+
)
310+
311+
# Create a custom node
312+
custom_node = CustomStreamingNode(summary_agent, "custom_summary")
313+
314+
builder = GraphBuilder()
315+
builder.add_node(math_agent, "math")
316+
builder.add_node(custom_node, "custom_summary")
317+
builder.add_edge("math", "custom_summary")
318+
builder.set_entry_point("math")
319+
graph = builder.build()
320+
321+
# Collect events
322+
events = []
323+
async for event in graph.stream_async("Calculate 5 + 3 and summarize the result"):
324+
events.append(event)
325+
326+
# Count event categories
327+
node_start_events = [e for e in events if e.get("multi_agent_node_start")]
328+
node_stream_events = [e for e in events if e.get("multi_agent_node_stream")]
329+
custom_events = [e for e in events if e.get("custom_event")]
330+
result_events = [e for e in events if "result" in e and "multi_agent_node_start" not in e]
331+
332+
# Verify we got multiple events of each type
333+
assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}"
334+
assert len(node_stream_events) > 5, f"Expected many node_stream events, got {len(node_stream_events)}"
335+
assert len(custom_events) >= 2, f"Expected at least 2 custom events (start, complete), got {len(custom_events)}"
336+
assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}"
337+
338+
# Verify custom events are properly structured
339+
custom_start = [e for e in custom_events if e.get("custom_event") == "start"]
340+
custom_complete = [e for e in custom_events if e.get("custom_event") == "agent_complete"]
341+
342+
assert len(custom_start) >= 1, "Expected at least 1 custom start event"
343+
assert len(custom_complete) >= 1, "Expected at least 1 custom complete event"
344+
345+
346+
@pytest.mark.asyncio
347+
async def test_nested_graph_streaming():
348+
"""Test that nested graphs properly propagate streaming events."""
349+
math_agent = Agent(
350+
name="math",
351+
model="us.amazon.nova-pro-v1:0",
352+
system_prompt="You are a math assistant.",
353+
tools=[calculate_sum],
354+
)
355+
analysis_agent = Agent(
356+
name="analysis",
357+
model="us.amazon.nova-lite-v1:0",
358+
system_prompt="You are an analysis assistant.",
359+
)
360+
361+
# Create nested graph
362+
nested_builder = GraphBuilder()
363+
nested_builder.add_node(math_agent, "calculator")
364+
nested_builder.add_node(analysis_agent, "analyzer")
365+
nested_builder.add_edge("calculator", "analyzer")
366+
nested_builder.set_entry_point("calculator")
367+
nested_graph = nested_builder.build()
368+
369+
# Create outer graph with nested graph
370+
summary_agent = Agent(
371+
name="summary",
372+
model="us.amazon.nova-lite-v1:0",
373+
system_prompt="You are a summary assistant.",
374+
)
375+
376+
outer_builder = GraphBuilder()
377+
outer_builder.add_node(nested_graph, "computation")
378+
outer_builder.add_node(summary_agent, "summary")
379+
outer_builder.add_edge("computation", "summary")
380+
outer_builder.set_entry_point("computation")
381+
outer_graph = outer_builder.build()
382+
383+
# Collect events
384+
events = []
385+
async for event in outer_graph.stream_async("Calculate 7 + 8 and provide a summary"):
386+
events.append(event)
387+
388+
# Count event categories
389+
node_start_events = [e for e in events if e.get("multi_agent_node_start")]
390+
node_stream_events = [e for e in events if e.get("multi_agent_node_stream")]
391+
result_events = [e for e in events if "result" in e and "multi_agent_node_start" not in e]
392+
393+
# Verify we got multiple events
394+
assert len(node_start_events) >= 2, f"Expected at least 2 node_start events, got {len(node_start_events)}"
395+
assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}"
396+
assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}"
397+
398+
# Verify we have events from nested nodes
399+
computation_events = [e for e in events if e.get("node_id") == "computation"]
400+
summary_events = [e for e in events if e.get("node_id") == "summary"]
401+
assert len(computation_events) > 0, "Expected events from computation (nested graph) node"
402+
assert len(summary_events) > 0, "Expected events from summary node"

tests_integ/test_multiagent_swarm.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,41 @@ async def test_swarm_execution_with_image(researcher_agent, analyst_agent, write
134134

135135
# Verify agent history - at least one agent should have been used
136136
assert len(result.node_history) > 0
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_swarm_streaming():
141+
"""Test that Swarm properly streams events during execution."""
142+
researcher = Agent(
143+
name="researcher",
144+
model="us.amazon.nova-pro-v1:0",
145+
system_prompt="You are a researcher. When you need calculations, hand off to the analyst.",
146+
)
147+
analyst = Agent(
148+
name="analyst",
149+
model="us.amazon.nova-pro-v1:0",
150+
system_prompt="You are an analyst. Use tools to perform calculations.",
151+
tools=[calculate],
152+
)
153+
154+
swarm = Swarm([researcher, analyst])
155+
156+
# Collect events
157+
events = []
158+
async for event in swarm.stream_async("Calculate 10 + 5 and explain the result"):
159+
events.append(event)
160+
161+
# Count event categories
162+
node_start_events = [e for e in events if e.get("multi_agent_node_start")]
163+
node_stream_events = [e for e in events if e.get("multi_agent_node_stream")]
164+
result_events = [e for e in events if "result" in e and "multi_agent_node_start" not in e]
165+
166+
# Verify we got multiple events of each type
167+
assert len(node_start_events) >= 1, f"Expected at least 1 node_start event, got {len(node_start_events)}"
168+
assert len(node_stream_events) > 10, f"Expected many node_stream events, got {len(node_stream_events)}"
169+
assert len(result_events) >= 1, f"Expected at least 1 result event, got {len(result_events)}"
170+
171+
# Verify we have events from at least one agent
172+
researcher_events = [e for e in events if e.get("node_id") == "researcher"]
173+
analyst_events = [e for e in events if e.get("node_id") == "analyst"]
174+
assert len(researcher_events) > 0 or len(analyst_events) > 0, "Expected events from at least one agent"

0 commit comments

Comments
 (0)