Skip to content

Commit a307f37

Browse files
author
Murat Kaan Meral
committed
refactor(multiagent): improve streaming event handling and documentation
- Update docstrings to match Agent's minimal style (use dict keys instead of class names) - Add isinstance checks for result event detection for type safety - Improve _stream_with_timeout to handle None timeout case - Add MultiAgentResultEvent for consistency with Agent pattern - Yield TypedEvent objects internally, convert to dict at API boundary - All 154 tests passing
1 parent 60f16b9 commit a307f37

File tree

3 files changed

+172
-164
lines changed

3 files changed

+172
-164
lines changed

src/strands/multiagent/graph.py

Lines changed: 107 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
MultiAgentNodeCompleteEvent,
3232
MultiAgentNodeStartEvent,
3333
MultiAgentNodeStreamEvent,
34+
MultiAgentResultEvent,
3435
)
3536
from ..types.content import ContentBlock, Messages
3637
from ..types.event_loop import Metrics, Usage
@@ -443,11 +444,11 @@ async def stream_async(
443444
**kwargs: Keyword arguments allowing backward compatible future changes.
444445
445446
Yields:
446-
Dictionary events containing graph execution information including:
447-
- MultiAgentNodeStartEvent: When a node begins execution
448-
- MultiAgentNodeStreamEvent: Forwarded agent events with node context
449-
- MultiAgentNodeCompleteEvent: When a node completes execution
450-
- Final result event
447+
Dictionary events during graph execution, such as:
448+
- multi_agent_node_start: When a node begins execution
449+
- multi_agent_node_stream: Forwarded agent/multi-agent events with node context
450+
- multi_agent_node_complete: When a node completes execution
451+
- result: Final graph result
451452
"""
452453
if invocation_state is None:
453454
invocation_state = {}
@@ -476,7 +477,7 @@ async def stream_async(
476477
)
477478

478479
async for event in self._execute_graph(invocation_state):
479-
yield event
480+
yield event.as_dict()
480481

481482
# Set final status based on execution results
482483
if self.state.failed_nodes:
@@ -490,7 +491,7 @@ async def stream_async(
490491
result = self._build_result()
491492

492493
# Use the same event format as Agent for consistency
493-
yield {"result": result}
494+
yield MultiAgentResultEvent(result=result).as_dict()
494495

495496
except Exception:
496497
logger.exception("graph execution failed")
@@ -500,17 +501,35 @@ async def stream_async(
500501
self.state.execution_time = round((time.time() - start_time) * 1000)
501502

502503
async def _stream_with_timeout(
503-
self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str
504-
) -> AsyncIterator[dict[str, Any]]:
505-
"""Wrap an async generator with timeout functionality."""
506-
while True:
507-
try:
508-
event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout)
504+
self, async_generator: AsyncIterator[Any], timeout: float | None, timeout_message: str
505+
) -> AsyncIterator[Any]:
506+
"""Wrap an async generator with timeout functionality.
507+
508+
Args:
509+
async_generator: The generator to wrap
510+
timeout: Timeout in seconds, or None for no timeout
511+
timeout_message: Message to include in timeout exception
512+
513+
Yields:
514+
Events from the wrapped generator
515+
516+
Raises:
517+
Exception: If timeout is exceeded (same as original behavior)
518+
"""
519+
if timeout is None:
520+
# No timeout - just pass through
521+
async for event in async_generator:
509522
yield event
510-
except StopAsyncIteration:
511-
break
512-
except asyncio.TimeoutError:
513-
raise Exception(timeout_message) from None
523+
else:
524+
# Apply timeout to each event
525+
while True:
526+
try:
527+
event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout)
528+
yield event
529+
except StopAsyncIteration:
530+
break
531+
except asyncio.TimeoutError:
532+
raise Exception(timeout_message) from None
514533

515534
def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
516535
"""Validate graph nodes for duplicate instances."""
@@ -524,8 +543,8 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
524543
# Validate Agent-specific constraints for each node
525544
_validate_node_executor(node.executor)
526545

527-
async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]:
528-
"""Execute graph and yield events."""
546+
async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
547+
"""Execute graph and yield TypedEvent objects."""
529548
ready_nodes = list(self.entry_points)
530549

531550
while ready_nodes:
@@ -557,14 +576,14 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
557576

558577
async def _execute_nodes_parallel(
559578
self, nodes: list["GraphNode"], invocation_state: dict[str, Any]
560-
) -> AsyncIterator[dict[str, Any]]:
579+
) -> AsyncIterator[Any]:
561580
"""Execute multiple nodes in parallel and merge their event streams in real-time.
562581
563582
Uses a shared queue where each node's stream runs independently and pushes events
564583
as they occur, enabling true real-time event propagation without round-robin delays.
565584
"""
566585
# Create a shared queue for all events
567-
event_queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue()
586+
event_queue: asyncio.Queue[Any | None] = asyncio.Queue()
568587

569588
# Track active node tasks
570589
active_tasks: set[asyncio.Task[None]] = set()
@@ -637,8 +656,8 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
637656
)
638657
return False
639658

640-
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[dict[str, Any]]:
641-
"""Execute a single node."""
659+
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
660+
"""Execute a single node and yield TypedEvent objects."""
642661
# Reset the node's state if reset_on_revisit is enabled and it's being revisited
643662
if self.reset_on_revisit and node in self.state.completed_nodes:
644663
logger.debug("node_id=<%s> | resetting node state for revisit", node.node_id)
@@ -652,109 +671,79 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
652671
start_event = MultiAgentNodeStartEvent(
653672
node_id=node.node_id, node_type="agent" if isinstance(node.executor, Agent) else "multiagent"
654673
)
655-
yield start_event.as_dict()
674+
yield start_event
656675

657676
start_time = time.time()
658677
try:
659678
# Build node input from satisfied dependencies
660679
node_input = self._build_node_input(node)
661680

662681
# Execute with timeout protection and stream events
663-
try:
664-
if isinstance(node.executor, MultiAgentBase):
665-
# For nested multi-agent systems, stream their events and collect result
666-
multi_agent_result = None
667-
if self.node_timeout is not None:
668-
# Implement timeout for async generator streaming
669-
async for event in self._stream_with_timeout(
670-
node.executor.stream_async(node_input, invocation_state),
671-
self.node_timeout,
672-
f"Node '{node.node_id}' execution timed out after {self.node_timeout}s",
673-
):
674-
# Forward nested multi-agent events with node context
675-
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
676-
yield wrapped_event.as_dict()
677-
# Capture the final result event
678-
if "result" in event:
679-
multi_agent_result = event["result"]
680-
else:
681-
async for event in node.executor.stream_async(node_input, invocation_state):
682-
# Forward nested multi-agent events with node context
683-
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
684-
yield wrapped_event.as_dict()
685-
# Capture the final result event
686-
if "result" in event:
687-
multi_agent_result = event["result"]
688-
689-
# Use the captured result from streaming (no double execution)
690-
if multi_agent_result is None:
691-
raise ValueError(f"Node '{node.node_id}' did not produce a result event")
692-
693-
node_result = NodeResult(
694-
result=multi_agent_result,
695-
execution_time=multi_agent_result.execution_time,
696-
status=Status.COMPLETED,
697-
accumulated_usage=multi_agent_result.accumulated_usage,
698-
accumulated_metrics=multi_agent_result.accumulated_metrics,
699-
execution_count=multi_agent_result.execution_count,
700-
)
701-
702-
elif isinstance(node.executor, Agent):
703-
# For agents, stream their events and collect result
704-
agent_response = None
705-
if self.node_timeout is not None:
706-
# Implement timeout for async generator streaming
707-
async for event in self._stream_with_timeout(
708-
node.executor.stream_async(node_input, **invocation_state),
709-
self.node_timeout,
710-
f"Node '{node.node_id}' execution timed out after {self.node_timeout}s",
711-
):
712-
# Forward agent events with node context
713-
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
714-
yield wrapped_event.as_dict()
715-
# Capture the final result event
716-
if "result" in event:
717-
agent_response = event["result"]
718-
else:
719-
async for event in node.executor.stream_async(node_input, **invocation_state):
720-
# Forward agent events with node context
721-
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
722-
yield wrapped_event.as_dict()
723-
# Capture the final result event
724-
if "result" in event:
725-
agent_response = event["result"]
726-
727-
# Use the captured result from streaming (no double execution)
728-
if agent_response is None:
729-
raise ValueError(f"Node '{node.node_id}' did not produce a result event")
730-
731-
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
732-
metrics = Metrics(latencyMs=0)
733-
if hasattr(agent_response, "metrics") and agent_response.metrics:
734-
if hasattr(agent_response.metrics, "accumulated_usage"):
735-
usage = agent_response.metrics.accumulated_usage
736-
if hasattr(agent_response.metrics, "accumulated_metrics"):
737-
metrics = agent_response.metrics.accumulated_metrics
738-
739-
node_result = NodeResult(
740-
result=agent_response,
741-
execution_time=round((time.time() - start_time) * 1000),
742-
status=Status.COMPLETED,
743-
accumulated_usage=usage,
744-
accumulated_metrics=metrics,
745-
execution_count=1,
746-
)
747-
else:
748-
raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")
682+
if isinstance(node.executor, MultiAgentBase):
683+
# For nested multi-agent systems, stream their events and collect result
684+
multi_agent_result = None
685+
async for event in self._stream_with_timeout(
686+
node.executor.stream_async(node_input, invocation_state),
687+
self.node_timeout,
688+
f"Node '{node.node_id}' execution timed out after {self.node_timeout}s",
689+
):
690+
# Forward nested multi-agent events with node context
691+
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
692+
yield wrapped_event
693+
# Capture the final result event
694+
if isinstance(event, dict) and "result" in event:
695+
multi_agent_result = event["result"]
696+
697+
# Use the captured result from streaming (no double execution)
698+
if multi_agent_result is None:
699+
raise ValueError(f"Node '{node.node_id}' did not produce a result event")
700+
701+
node_result = NodeResult(
702+
result=multi_agent_result,
703+
execution_time=multi_agent_result.execution_time,
704+
status=Status.COMPLETED,
705+
accumulated_usage=multi_agent_result.accumulated_usage,
706+
accumulated_metrics=multi_agent_result.accumulated_metrics,
707+
execution_count=multi_agent_result.execution_count,
708+
)
749709

750-
except asyncio.TimeoutError:
751-
timeout_msg = f"Node '{node.node_id}' execution timed out after {self.node_timeout}s"
752-
logger.exception(
753-
"node=<%s>, timeout=<%s>s | node execution timed out",
754-
node.node_id,
710+
elif isinstance(node.executor, Agent):
711+
# For agents, stream their events and collect result
712+
agent_response = None
713+
async for event in self._stream_with_timeout(
714+
node.executor.stream_async(node_input, **invocation_state),
755715
self.node_timeout,
716+
f"Node '{node.node_id}' execution timed out after {self.node_timeout}s",
717+
):
718+
# Forward agent events with node context
719+
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
720+
yield wrapped_event
721+
# Capture the final result event
722+
if isinstance(event, dict) and "result" in event:
723+
agent_response = event["result"]
724+
725+
# Use the captured result from streaming (no double execution)
726+
if agent_response is None:
727+
raise ValueError(f"Node '{node.node_id}' did not produce a result event")
728+
729+
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
730+
metrics = Metrics(latencyMs=0)
731+
if hasattr(agent_response, "metrics") and agent_response.metrics:
732+
if hasattr(agent_response.metrics, "accumulated_usage"):
733+
usage = agent_response.metrics.accumulated_usage
734+
if hasattr(agent_response.metrics, "accumulated_metrics"):
735+
metrics = agent_response.metrics.accumulated_metrics
736+
737+
node_result = NodeResult(
738+
result=agent_response,
739+
execution_time=round((time.time() - start_time) * 1000),
740+
status=Status.COMPLETED,
741+
accumulated_usage=usage,
742+
accumulated_metrics=metrics,
743+
execution_count=1,
756744
)
757-
raise Exception(timeout_msg) from None
745+
else:
746+
raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")
758747

759748
# Mark as completed
760749
node.execution_status = Status.COMPLETED
@@ -769,7 +758,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
769758

770759
# Emit node complete event
771760
complete_event = MultiAgentNodeCompleteEvent(node_id=node.node_id, execution_time=node.execution_time)
772-
yield complete_event.as_dict()
761+
yield complete_event
773762

774763
logger.debug(
775764
"node_id=<%s>, execution_time=<%dms> | node completed successfully",
@@ -799,7 +788,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
799788

800789
# Still emit complete event even for failures
801790
complete_event = MultiAgentNodeCompleteEvent(node_id=node.node_id, execution_time=execution_time)
802-
yield complete_event.as_dict()
791+
yield complete_event
803792

804793
raise
805794

0 commit comments

Comments
 (0)