Skip to content

Commit 993004a

Browse files
committed
add graph.execute and graph.execute_async functions
1 parent 9e2b41e commit 993004a

File tree

5 files changed

+96
-24
lines changed

5 files changed

+96
-24
lines changed

src/strands/multiagent/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ class MultiAgentBase(ABC):
7575

7676
@abstractmethod
7777
# TODO: for task - multi-modal input (Message), list of messages
78-
async def execute(self, task: str) -> MultiAgentResult:
79-
"""Execute task."""
78+
async def execute_async(self, task: str) -> MultiAgentResult:
79+
"""Execute task asynchronously."""
80+
raise NotImplementedError("execute_async not implemented")
81+
82+
@abstractmethod
83+
# TODO: for task - multi-modal input (Message), list of messages
84+
def execute(self, task: str) -> MultiAgentResult:
85+
"""Execute task synchronously."""
8086
raise NotImplementedError("execute not implemented")

src/strands/multiagent/graph.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
- Supports nested graphs (Graph as a node in another Graph)
1515
"""
1616

17+
import asyncio
1718
import logging
1819
import time
20+
from concurrent.futures import ThreadPoolExecutor
1921
from dataclasses import dataclass, field
2022
from typing import Any, Callable, Tuple
2123

@@ -229,8 +231,18 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
229231
self.entry_points = entry_points
230232
self.state = GraphState()
231233

232-
async def execute(self, task: str) -> GraphResult:
233-
"""Execute the graph."""
234+
def execute(self, task: str) -> GraphResult:
235+
"""Execute task synchronously."""
236+
237+
def execute() -> GraphResult:
238+
return asyncio.run(self.execute_async(task))
239+
240+
with ThreadPoolExecutor() as executor:
241+
future = executor.submit(execute)
242+
return future.result()
243+
244+
async def execute_async(self, task: str) -> GraphResult:
245+
"""Execute the graph asynchronously."""
234246
logger.debug("task=<%s> | starting graph execution", task)
235247

236248
# Initialize state
@@ -319,7 +331,7 @@ async def _execute_node(self, node: "GraphNode") -> None:
319331

320332
# Execute based on node type and create unified NodeResult
321333
if isinstance(node.executor, MultiAgentBase):
322-
multi_agent_result = await node.executor.execute(node_input)
334+
multi_agent_result = await node.executor.execute_async(node_input)
323335

324336
# Create NodeResult with MultiAgentResult directly
325337
node_result = NodeResult(
@@ -332,7 +344,13 @@ async def _execute_node(self, node: "GraphNode") -> None:
332344
)
333345

334346
elif isinstance(node.executor, Agent):
335-
agent_response = node.executor(node_input)
347+
agent_response = None # Initialize with None to handle case where no result is yielded
348+
async for event in node.executor.stream_async(node_input):
349+
if "result" in event:
350+
agent_response = event["result"]
351+
352+
if not agent_response:
353+
raise ValueError(f"Node '{node.node_id}' did not return a result")
336354

337355
# Extract metrics from agent response
338356
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)

tests-integ/test_multiagent_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def proceed_to_second_summary(state):
108108
"Calculate 15 + 27 and 8 * 6, analyze both results, perform additional calculations, validate everything, "
109109
"and provide a comprehensive summary"
110110
)
111-
result = await graph.execute(task)
111+
result = await graph.execute_async(task)
112112

113113
# Verify results
114114
assert result.status.value == "completed"

tests/strands/multiagent/test_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,10 @@ class IncompleteMultiAgent(MultiAgentBase):
138138

139139
# Test that complete implementations can be instantiated
140140
class CompleteMultiAgent(MultiAgentBase):
141-
async def execute(self, task: str) -> MultiAgentResult:
141+
async def execute_async(self, task: str) -> MultiAgentResult:
142+
return MultiAgentResult(results={})
143+
144+
def execute(self, task: str) -> MultiAgentResult:
142145
return MultiAgentResult(results={})
143146

144147
# Should not raise an exception

tests/strands/multiagent/test_graph.py

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import AsyncMock, Mock
1+
from unittest.mock import AsyncMock, MagicMock, Mock
22

33
import pytest
44

@@ -25,8 +25,15 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
2525
state={},
2626
metrics=metrics,
2727
)
28+
2829
agent.return_value = mock_result
2930
agent.__call__ = Mock(return_value=mock_result)
31+
32+
async def mock_stream_async(*args, **kwargs):
33+
yield {"result": mock_result}
34+
35+
agent.stream_async = MagicMock(side_effect=mock_stream_async)
36+
3037
return agent
3138

3239

@@ -51,7 +58,8 @@ def create_mock_multi_agent(name, response_text="Multi-agent response"):
5158
execution_count=1,
5259
execution_time=150,
5360
)
54-
multi_agent.execute = AsyncMock(return_value=mock_result)
61+
multi_agent.execute_async = AsyncMock(return_value=mock_result)
62+
multi_agent.execute = Mock(return_value=mock_result)
5563
return multi_agent
5664

5765

@@ -158,7 +166,7 @@ async def test_graph_execution(mock_graph, mock_agents, string_content_agent):
158166
start_node = mock_graph.nodes["start_agent"]
159167
assert conditional_edge.should_traverse(GraphState(completed_nodes={start_node}))
160168

161-
result = await mock_graph.execute("Test comprehensive execution")
169+
result = await mock_graph.execute_async("Test comprehensive execution")
162170

163171
# Verify execution results
164172
assert result.status == Status.COMPLETED
@@ -169,14 +177,14 @@ async def test_graph_execution(mock_graph, mock_agents, string_content_agent):
169177
assert result.execution_order[0].node_id == "start_agent"
170178

171179
# Verify agent calls
172-
mock_agents["start_agent"].assert_called_once()
173-
mock_agents["multi_agent"].execute.assert_called_once()
174-
mock_agents["conditional_agent"].assert_called_once()
175-
mock_agents["final_agent"].assert_called_once()
176-
mock_agents["no_metrics_agent"].assert_called_once()
177-
mock_agents["partial_metrics_agent"].assert_called_once()
178-
string_content_agent.assert_called_once()
179-
mock_agents["blocked_agent"].assert_not_called()
180+
mock_agents["start_agent"].stream_async.assert_called_once()
181+
mock_agents["multi_agent"].execute_async.assert_called_once()
182+
mock_agents["conditional_agent"].stream_async.assert_called_once()
183+
mock_agents["final_agent"].stream_async.assert_called_once()
184+
mock_agents["no_metrics_agent"].stream_async.assert_called_once()
185+
mock_agents["partial_metrics_agent"].stream_async.assert_called_once()
186+
string_content_agent.stream_async.assert_called_once()
187+
mock_agents["blocked_agent"].stream_async.assert_not_called()
180188

181189
# Verify metrics aggregation
182190
assert result.accumulated_usage["totalTokens"] > 0
@@ -219,7 +227,7 @@ class UnsupportedExecutor:
219227
graph = builder.build()
220228

221229
with pytest.raises(ValueError, match="Node 'unsupported_node' of type.*is not supported"):
222-
await graph.execute("test task")
230+
await graph.execute_async("test task")
223231

224232

225233
@pytest.mark.asyncio
@@ -228,9 +236,15 @@ async def test_graph_execution_with_failures():
228236
failing_agent = Mock(spec=Agent)
229237
failing_agent.name = "failing_agent"
230238
failing_agent.id = "fail_node"
231-
failing_agent.side_effect = Exception("Simulated failure")
232239
failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure"))
233240

241+
# Create a proper failing async generator for stream_async
242+
async def mock_stream_failure(*args, **kwargs):
243+
raise Exception("Simulated failure")
244+
yield # This will never be reached
245+
246+
failing_agent.stream_async = mock_stream_failure
247+
234248
success_agent = create_mock_agent("success_agent", "Success")
235249

236250
builder = GraphBuilder()
@@ -242,7 +256,7 @@ async def test_graph_execution_with_failures():
242256
graph = builder.build()
243257

244258
with pytest.raises(Exception, match="Simulated failure"):
245-
await graph.execute("Test error handling")
259+
await graph.execute_async("Test error handling")
246260

247261
assert graph.state.status == Status.FAILED
248262
assert any(node.node_id == "fail_node" for node in graph.state.failed_nodes)
@@ -259,10 +273,10 @@ async def test_graph_edge_cases():
259273
builder.add_node(entry_agent, "entry_only")
260274
graph = builder.build()
261275

262-
result = await graph.execute("Original task")
276+
result = await graph.execute_async("Original task")
263277

264278
# Verify entry node was called with original task
265-
entry_agent.assert_called_once_with("Original task")
279+
entry_agent.stream_async.assert_called_once_with("Original task")
266280
assert result.status == Status.COMPLETED
267281

268282

@@ -415,6 +429,37 @@ def test_condition(state):
415429
assert len(node.dependencies) == 0
416430

417431

432+
def test_graph_synchronous_execution(mock_agents):
433+
"""Test synchronous graph execution using execute method."""
434+
builder = GraphBuilder()
435+
builder.add_node(mock_agents["start_agent"], "start_agent")
436+
builder.add_node(mock_agents["final_agent"], "final_agent")
437+
builder.add_edge("start_agent", "final_agent")
438+
builder.set_entry_point("start_agent")
439+
440+
graph = builder.build()
441+
442+
# Test synchronous execution
443+
result = graph.execute("Test synchronous execution")
444+
445+
# Verify execution results
446+
assert result.status == Status.COMPLETED
447+
assert result.total_nodes == 2
448+
assert result.completed_nodes == 2
449+
assert result.failed_nodes == 0
450+
assert len(result.execution_order) == 2
451+
assert result.execution_order[0].node_id == "start_agent"
452+
assert result.execution_order[1].node_id == "final_agent"
453+
454+
# Verify agent calls
455+
mock_agents["start_agent"].stream_async.assert_called_once()
456+
mock_agents["final_agent"].stream_async.assert_called_once()
457+
458+
# Verify return type is GraphResult
459+
assert isinstance(result, GraphResult)
460+
assert isinstance(result, MultiAgentResult)
461+
462+
418463
def test_graph_result_str_representation():
419464
"""Test GraphResult string representation."""
420465
mock_agent = create_mock_agent("test_agent")

0 commit comments

Comments
 (0)