1- from unittest .mock import AsyncMock , Mock
1+ from unittest .mock import AsyncMock , MagicMock , Mock
22
33import pytest
44
@@ -25,8 +25,15 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
2525 state = {},
2626 metrics = metrics ,
2727 )
28+
2829 agent .return_value = mock_result
2930 agent .__call__ = Mock (return_value = mock_result )
31+
32+ async def mock_stream_async (* args , ** kwargs ):
33+ yield {"result" : mock_result }
34+
35+ agent .stream_async = MagicMock (side_effect = mock_stream_async )
36+
3037 return agent
3138
3239
@@ -51,7 +58,8 @@ def create_mock_multi_agent(name, response_text="Multi-agent response"):
5158 execution_count = 1 ,
5259 execution_time = 150 ,
5360 )
54- multi_agent .execute = AsyncMock (return_value = mock_result )
61+ multi_agent .execute_async = AsyncMock (return_value = mock_result )
62+ multi_agent .execute = Mock (return_value = mock_result )
5563 return multi_agent
5664
5765
@@ -158,7 +166,7 @@ async def test_graph_execution(mock_graph, mock_agents, string_content_agent):
158166 start_node = mock_graph .nodes ["start_agent" ]
159167 assert conditional_edge .should_traverse (GraphState (completed_nodes = {start_node }))
160168
161- result = await mock_graph .execute ("Test comprehensive execution" )
169+ result = await mock_graph .execute_async ("Test comprehensive execution" )
162170
163171 # Verify execution results
164172 assert result .status == Status .COMPLETED
@@ -169,14 +177,14 @@ async def test_graph_execution(mock_graph, mock_agents, string_content_agent):
169177 assert result .execution_order [0 ].node_id == "start_agent"
170178
171179 # Verify agent calls
172- mock_agents ["start_agent" ].assert_called_once ()
173- mock_agents ["multi_agent" ].execute .assert_called_once ()
174- mock_agents ["conditional_agent" ].assert_called_once ()
175- mock_agents ["final_agent" ].assert_called_once ()
176- mock_agents ["no_metrics_agent" ].assert_called_once ()
177- mock_agents ["partial_metrics_agent" ].assert_called_once ()
178- string_content_agent .assert_called_once ()
179- mock_agents ["blocked_agent" ].assert_not_called ()
180+ mock_agents ["start_agent" ].stream_async . assert_called_once ()
181+ mock_agents ["multi_agent" ].execute_async .assert_called_once ()
182+ mock_agents ["conditional_agent" ].stream_async . assert_called_once ()
183+ mock_agents ["final_agent" ].stream_async . assert_called_once ()
184+ mock_agents ["no_metrics_agent" ].stream_async . assert_called_once ()
185+ mock_agents ["partial_metrics_agent" ].stream_async . assert_called_once ()
186+ string_content_agent .stream_async . assert_called_once ()
187+ mock_agents ["blocked_agent" ].stream_async . assert_not_called ()
180188
181189 # Verify metrics aggregation
182190 assert result .accumulated_usage ["totalTokens" ] > 0
@@ -219,7 +227,7 @@ class UnsupportedExecutor:
219227 graph = builder .build ()
220228
221229 with pytest .raises (ValueError , match = "Node 'unsupported_node' of type.*is not supported" ):
222- await graph .execute ("test task" )
230+ await graph .execute_async ("test task" )
223231
224232
225233@pytest .mark .asyncio
@@ -228,9 +236,15 @@ async def test_graph_execution_with_failures():
228236 failing_agent = Mock (spec = Agent )
229237 failing_agent .name = "failing_agent"
230238 failing_agent .id = "fail_node"
231- failing_agent .side_effect = Exception ("Simulated failure" )
232239 failing_agent .__call__ = Mock (side_effect = Exception ("Simulated failure" ))
233240
241+ # Create a proper failing async generator for stream_async
242+ async def mock_stream_failure (* args , ** kwargs ):
243+ raise Exception ("Simulated failure" )
244+ yield # This will never be reached
245+
246+ failing_agent .stream_async = mock_stream_failure
247+
234248 success_agent = create_mock_agent ("success_agent" , "Success" )
235249
236250 builder = GraphBuilder ()
@@ -242,7 +256,7 @@ async def test_graph_execution_with_failures():
242256 graph = builder .build ()
243257
244258 with pytest .raises (Exception , match = "Simulated failure" ):
245- await graph .execute ("Test error handling" )
259+ await graph .execute_async ("Test error handling" )
246260
247261 assert graph .state .status == Status .FAILED
248262 assert any (node .node_id == "fail_node" for node in graph .state .failed_nodes )
@@ -259,10 +273,10 @@ async def test_graph_edge_cases():
259273 builder .add_node (entry_agent , "entry_only" )
260274 graph = builder .build ()
261275
262- result = await graph .execute ("Original task" )
276+ result = await graph .execute_async ("Original task" )
263277
264278 # Verify entry node was called with original task
265- entry_agent .assert_called_once_with ("Original task" )
279+ entry_agent .stream_async . assert_called_once_with ("Original task" )
266280 assert result .status == Status .COMPLETED
267281
268282
@@ -415,6 +429,37 @@ def test_condition(state):
415429 assert len (node .dependencies ) == 0
416430
417431
432+ def test_graph_synchronous_execution (mock_agents ):
433+ """Test synchronous graph execution using execute method."""
434+ builder = GraphBuilder ()
435+ builder .add_node (mock_agents ["start_agent" ], "start_agent" )
436+ builder .add_node (mock_agents ["final_agent" ], "final_agent" )
437+ builder .add_edge ("start_agent" , "final_agent" )
438+ builder .set_entry_point ("start_agent" )
439+
440+ graph = builder .build ()
441+
442+ # Test synchronous execution
443+ result = graph .execute ("Test synchronous execution" )
444+
445+ # Verify execution results
446+ assert result .status == Status .COMPLETED
447+ assert result .total_nodes == 2
448+ assert result .completed_nodes == 2
449+ assert result .failed_nodes == 0
450+ assert len (result .execution_order ) == 2
451+ assert result .execution_order [0 ].node_id == "start_agent"
452+ assert result .execution_order [1 ].node_id == "final_agent"
453+
454+ # Verify agent calls
455+ mock_agents ["start_agent" ].stream_async .assert_called_once ()
456+ mock_agents ["final_agent" ].stream_async .assert_called_once ()
457+
458+ # Verify return type is GraphResult
459+ assert isinstance (result , GraphResult )
460+ assert isinstance (result , MultiAgentResult )
461+
462+
418463def test_graph_result_str_representation ():
419464 """Test GraphResult string representation."""
420465 mock_agent = create_mock_agent ("test_agent" )
0 commit comments