Skip to content

Commit d4f5571

Browse files
author
Murat Kaan Meral
committed
fix: Fix double execution
1 parent 08141a0 commit d4f5571

File tree

4 files changed

+355
-76
lines changed

4 files changed

+355
-76
lines changed

src/strands/multiagent/graph.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
673673
# Execute with timeout protection and stream events
674674
try:
675675
if isinstance(node.executor, MultiAgentBase):
676-
# For nested multi-agent systems, stream their events
676+
# For nested multi-agent systems, stream their events and collect result
677+
multi_agent_result = None
677678
if self.node_timeout is not None:
678679
# Implement timeout for async generator streaming
679680
async for event in self._stream_with_timeout(
@@ -684,14 +685,22 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
684685
# Forward nested multi-agent events with node context
685686
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
686687
yield wrapped_event.as_dict()
688+
# Capture the final result event
689+
if "result" in event:
690+
multi_agent_result = event["result"]
687691
else:
688692
async for event in node.executor.stream_async(node_input, invocation_state):
689693
# Forward nested multi-agent events with node context
690694
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
691695
yield wrapped_event.as_dict()
696+
# Capture the final result event
697+
if "result" in event:
698+
multi_agent_result = event["result"]
699+
700+
# Use the captured result from streaming (no double execution)
701+
if multi_agent_result is None:
702+
raise ValueError(f"Node '{node.node_id}' did not produce a result event")
692703

693-
# Get the final result for metrics
694-
multi_agent_result = await node.executor.invoke_async(node_input, invocation_state)
695704
node_result = NodeResult(
696705
result=multi_agent_result,
697706
execution_time=multi_agent_result.execution_time,
@@ -702,7 +711,8 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
702711
)
703712

704713
elif isinstance(node.executor, Agent):
705-
# For agents, stream their events
714+
# For agents, stream their events and collect result
715+
agent_response = None
706716
if self.node_timeout is not None:
707717
# Implement timeout for async generator streaming
708718
async for event in self._stream_with_timeout(
@@ -713,14 +723,22 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
713723
# Forward agent events with node context
714724
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
715725
yield wrapped_event.as_dict()
726+
# Capture the final result event
727+
if "result" in event:
728+
agent_response = event["result"]
716729
else:
717730
async for event in node.executor.stream_async(node_input, **invocation_state):
718731
# Forward agent events with node context
719732
wrapped_event = MultiAgentNodeStreamEvent(node.node_id, event)
720733
yield wrapped_event.as_dict()
734+
# Capture the final result event
735+
if "result" in event:
736+
agent_response = event["result"]
737+
738+
# Use the captured result from streaming (no double execution)
739+
if agent_response is None:
740+
raise ValueError(f"Node '{node.node_id}' did not produce a result event")
721741

722-
# Get the final result for metrics
723-
agent_response = await node.executor.invoke_async(node_input, **invocation_state)
724742
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
725743
metrics = Metrics(latencyMs=0)
726744
if hasattr(agent_response, "metrics") and agent_response.metrics:

src/strands/multiagent/swarm.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,19 @@ async def stream_async(
350350
finally:
351351
self.state.execution_time = round((time.time() - start_time) * 1000)
352352

353+
async def _stream_with_timeout(
354+
self, async_generator: AsyncIterator[dict[str, Any]], timeout: float, timeout_message: str
355+
) -> AsyncIterator[dict[str, Any]]:
356+
"""Wrap an async generator with timeout functionality."""
357+
while True:
358+
try:
359+
event = await asyncio.wait_for(async_generator.__anext__(), timeout=timeout)
360+
yield event
361+
except StopAsyncIteration:
362+
break
363+
except asyncio.TimeoutError:
364+
raise Exception(timeout_message) from None
365+
353366
def _setup_swarm(self, nodes: list[Agent]) -> None:
354367
"""Initialize swarm configuration."""
355368
# Validate nodes before setup
@@ -610,9 +623,17 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
610623

611624
# Execute node with timeout protection
612625
try:
613-
# For now, execute without timeout for async generators
614-
# TODO: Implement proper timeout for async generators if needed
615-
async for event in self._execute_node(current_node, self.state.task, invocation_state):
626+
# Execute with timeout wrapper for async generator streaming
627+
node_stream = (
628+
self._stream_with_timeout(
629+
self._execute_node(current_node, self.state.task, invocation_state),
630+
self.node_timeout,
631+
f"Node '{current_node.node_id}' execution timed out after {self.node_timeout}s",
632+
)
633+
if self.node_timeout is not None
634+
else self._execute_node(current_node, self.state.task, invocation_state)
635+
)
636+
async for event in node_stream:
616637
yield event
617638

618639
self.state.node_history.append(current_node)
@@ -639,17 +660,16 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato
639660
self.state.completion_status = Status.COMPLETED
640661
break
641662

642-
except asyncio.TimeoutError:
643-
logger.exception(
644-
"node=<%s>, timeout=<%s>s | node execution timed out",
645-
current_node.node_id,
646-
self.node_timeout,
647-
)
648-
self.state.completion_status = Status.FAILED
649-
break
650-
651-
except Exception:
652-
logger.exception("node=<%s> | node execution failed", current_node.node_id)
663+
except Exception as e:
664+
# Check if this is a timeout exception
665+
if "timed out after" in str(e):
666+
logger.exception(
667+
"node=<%s>, timeout=<%s>s | node execution timed out",
668+
current_node.node_id,
669+
self.node_timeout,
670+
)
671+
else:
672+
logger.exception("node=<%s> | node execution failed", current_node.node_id)
653673
self.state.completion_status = Status.FAILED
654674
break
655675

@@ -691,14 +711,20 @@ async def _execute_node(
691711
# Execute node with streaming
692712
node.reset_executor_state()
693713

694-
# Stream agent events with node context
714+
# Stream agent events with node context and capture final result
715+
result = None
695716
async for event in node.executor.stream_async(node_input, **invocation_state):
696717
# Forward agent events with node context
697718
wrapped_event = MultiAgentNodeStreamEvent(node_name, event)
698719
yield wrapped_event.as_dict()
720+
# Capture the final result event
721+
if "result" in event:
722+
result = event["result"]
723+
724+
# Use the captured result from streaming to avoid double execution
725+
if result is None:
726+
raise ValueError(f"Node '{node_name}' did not produce a result event")
699727

700-
# Get the final result for metrics (we need to call invoke_async again for the result)
701-
result = await node.executor.invoke_async(node_input, **invocation_state)
702728
execution_time = round((time.time() - start_time) * 1000)
703729

704730
# Create NodeResult

0 commit comments

Comments
 (0)