@@ -385,18 +385,42 @@ def __init__(
385
385
self .state = GraphState ()
386
386
self .tracer = get_tracer ()
387
387
388
- def __call__ (self , task : str | list [ContentBlock ], ** kwargs : Any ) -> GraphResult :
389
- """Invoke the graph synchronously."""
388
+ def __call__ (
389
+ self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
390
+ ) -> GraphResult :
391
+ """Invoke the graph synchronously.
392
+
393
+ Args:
394
+ task: The task to execute
395
+ invocation_state: Additional state/context passed to underlying agents.
396
+ Defaults to None to avoid mutable default argument issues.
397
+ **kwargs: Keyword arguments allowing backward compatible future changes.
398
+ """
399
+ if invocation_state is None :
400
+ invocation_state = {}
390
401
391
402
def execute () -> GraphResult :
392
- return asyncio .run (self .invoke_async (task ))
403
+ return asyncio .run (self .invoke_async (task , invocation_state ))
393
404
394
405
with ThreadPoolExecutor () as executor :
395
406
future = executor .submit (execute )
396
407
return future .result ()
397
408
398
- async def invoke_async (self , task : str | list [ContentBlock ], ** kwargs : Any ) -> GraphResult :
399
- """Invoke the graph asynchronously."""
409
+ async def invoke_async (
410
+ self , task : str | list [ContentBlock ], invocation_state : dict [str , Any ] | None = None , ** kwargs : Any
411
+ ) -> GraphResult :
412
+ """Invoke the graph asynchronously.
413
+
414
+ Args:
415
+ task: The task to execute
416
+ invocation_state: Additional state/context passed to underlying agents.
417
+ Defaults to None to avoid mutable default argument issues - a new empty dict
418
+ is created if None is provided.
419
+ **kwargs: Keyword arguments allowing backward compatible future changes.
420
+ """
421
+ if invocation_state is None :
422
+ invocation_state = {}
423
+
400
424
logger .debug ("task=<%s> | starting graph execution" , task )
401
425
402
426
# Initialize state
@@ -420,7 +444,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G
420
444
self .node_timeout or "None" ,
421
445
)
422
446
423
- await self ._execute_graph ()
447
+ await self ._execute_graph (invocation_state )
424
448
425
449
# Set final status based on execution results
426
450
if self .state .failed_nodes :
@@ -450,7 +474,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
450
474
# Validate Agent-specific constraints for each node
451
475
_validate_node_executor (node .executor )
452
476
453
- async def _execute_graph (self ) -> None :
477
+ async def _execute_graph (self , invocation_state : dict [ str , Any ] ) -> None :
454
478
"""Unified execution flow with conditional routing."""
455
479
ready_nodes = list (self .entry_points )
456
480
@@ -469,7 +493,7 @@ async def _execute_graph(self) -> None:
469
493
ready_nodes .clear ()
470
494
471
495
# Execute current batch of ready nodes concurrently
472
- tasks = [asyncio .create_task (self ._execute_node (node )) for node in current_batch ]
496
+ tasks = [asyncio .create_task (self ._execute_node (node , invocation_state )) for node in current_batch ]
473
497
474
498
for task in tasks :
475
499
await task
@@ -506,7 +530,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
506
530
)
507
531
return False
508
532
509
- async def _execute_node (self , node : GraphNode ) -> None :
533
+ async def _execute_node (self , node : GraphNode , invocation_state : dict [ str , Any ] ) -> None :
510
534
"""Execute a single node with error handling and timeout protection."""
511
535
# Reset the node's state if reset_on_revisit is enabled and it's being revisited
512
536
if self .reset_on_revisit and node in self .state .completed_nodes :
@@ -529,11 +553,11 @@ async def _execute_node(self, node: GraphNode) -> None:
529
553
if isinstance (node .executor , MultiAgentBase ):
530
554
if self .node_timeout is not None :
531
555
multi_agent_result = await asyncio .wait_for (
532
- node .executor .invoke_async (node_input ),
556
+ node .executor .invoke_async (node_input , invocation_state ),
533
557
timeout = self .node_timeout ,
534
558
)
535
559
else :
536
- multi_agent_result = await node .executor .invoke_async (node_input )
560
+ multi_agent_result = await node .executor .invoke_async (node_input , invocation_state )
537
561
538
562
# Create NodeResult with MultiAgentResult directly
539
563
node_result = NodeResult (
@@ -548,11 +572,11 @@ async def _execute_node(self, node: GraphNode) -> None:
548
572
elif isinstance (node .executor , Agent ):
549
573
if self .node_timeout is not None :
550
574
agent_response = await asyncio .wait_for (
551
- node .executor .invoke_async (node_input ),
575
+ node .executor .invoke_async (node_input , ** invocation_state ),
552
576
timeout = self .node_timeout ,
553
577
)
554
578
else :
555
- agent_response = await node .executor .invoke_async (node_input )
579
+ agent_response = await node .executor .invoke_async (node_input , ** invocation_state )
556
580
557
581
# Extract metrics from agent response
558
582
usage = Usage (inputTokens = 0 , outputTokens = 0 , totalTokens = 0 )
0 commit comments