Skip to content

Commit 7f58ce9

Browse files
dbschmigelskiUnshureaditya270520
authored
feat(multiagent): allow callers of swarm and graph to pass kwargs to executors (#816)
* feat(multiagent): allow callers of swarm and graph to pass kwargs to executors --------- Co-authored-by: Nick Clegg <[email protected]> Co-authored-by: Aditya Bhushan Sharma <[email protected]>
1 parent 001aa93 commit 7f58ce9

File tree

6 files changed

+181
-29
lines changed

6 files changed

+181
-29
lines changed

src/strands/multiagent/base.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,35 @@ class MultiAgentBase(ABC):
8484
"""
8585

8686
@abstractmethod
87-
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
88-
"""Invoke asynchronously."""
87+
async def invoke_async(
88+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
89+
) -> MultiAgentResult:
90+
"""Invoke asynchronously.
91+
92+
Args:
93+
task: The task to execute
94+
invocation_state: Additional state/context passed to underlying agents.
95+
Defaults to None to avoid mutable default argument issues.
96+
**kwargs: Additional keyword arguments passed to underlying agents.
97+
"""
8998
raise NotImplementedError("invoke_async not implemented")
9099

91-
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
92-
"""Invoke synchronously."""
100+
def __call__(
101+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
102+
) -> MultiAgentResult:
103+
"""Invoke synchronously.
104+
105+
Args:
106+
task: The task to execute
107+
invocation_state: Additional state/context passed to underlying agents.
108+
Defaults to None to avoid mutable default argument issues.
109+
**kwargs: Additional keyword arguments passed to underlying agents.
110+
"""
111+
if invocation_state is None:
112+
invocation_state = {}
93113

94114
def execute() -> MultiAgentResult:
95-
return asyncio.run(self.invoke_async(task, **kwargs))
115+
return asyncio.run(self.invoke_async(task, invocation_state, **kwargs))
96116

97117
with ThreadPoolExecutor() as executor:
98118
future = executor.submit(execute)

src/strands/multiagent/graph.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -385,18 +385,42 @@ def __init__(
385385
self.state = GraphState()
386386
self.tracer = get_tracer()
387387

388-
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
389-
"""Invoke the graph synchronously."""
388+
def __call__(
389+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
390+
) -> GraphResult:
391+
"""Invoke the graph synchronously.
392+
393+
Args:
394+
task: The task to execute
395+
invocation_state: Additional state/context passed to underlying agents.
396+
Defaults to None to avoid mutable default argument issues.
397+
**kwargs: Keyword arguments allowing backward compatible future changes.
398+
"""
399+
if invocation_state is None:
400+
invocation_state = {}
390401

391402
def execute() -> GraphResult:
392-
return asyncio.run(self.invoke_async(task))
403+
return asyncio.run(self.invoke_async(task, invocation_state))
393404

394405
with ThreadPoolExecutor() as executor:
395406
future = executor.submit(execute)
396407
return future.result()
397408

398-
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
399-
"""Invoke the graph asynchronously."""
409+
async def invoke_async(
410+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
411+
) -> GraphResult:
412+
"""Invoke the graph asynchronously.
413+
414+
Args:
415+
task: The task to execute
416+
invocation_state: Additional state/context passed to underlying agents.
417+
Defaults to None to avoid mutable default argument issues - a new empty dict
418+
is created if None is provided.
419+
**kwargs: Keyword arguments allowing backward compatible future changes.
420+
"""
421+
if invocation_state is None:
422+
invocation_state = {}
423+
400424
logger.debug("task=<%s> | starting graph execution", task)
401425

402426
# Initialize state
@@ -420,7 +444,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G
420444
self.node_timeout or "None",
421445
)
422446

423-
await self._execute_graph()
447+
await self._execute_graph(invocation_state)
424448

425449
# Set final status based on execution results
426450
if self.state.failed_nodes:
@@ -450,7 +474,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
450474
# Validate Agent-specific constraints for each node
451475
_validate_node_executor(node.executor)
452476

453-
async def _execute_graph(self) -> None:
477+
async def _execute_graph(self, invocation_state: dict[str, Any]) -> None:
454478
"""Unified execution flow with conditional routing."""
455479
ready_nodes = list(self.entry_points)
456480

@@ -469,7 +493,7 @@ async def _execute_graph(self) -> None:
469493
ready_nodes.clear()
470494

471495
# Execute current batch of ready nodes concurrently
472-
tasks = [asyncio.create_task(self._execute_node(node)) for node in current_batch]
496+
tasks = [asyncio.create_task(self._execute_node(node, invocation_state)) for node in current_batch]
473497

474498
for task in tasks:
475499
await task
@@ -506,7 +530,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
506530
)
507531
return False
508532

509-
async def _execute_node(self, node: GraphNode) -> None:
533+
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None:
510534
"""Execute a single node with error handling and timeout protection."""
511535
# Reset the node's state if reset_on_revisit is enabled and it's being revisited
512536
if self.reset_on_revisit and node in self.state.completed_nodes:
@@ -529,11 +553,11 @@ async def _execute_node(self, node: GraphNode) -> None:
529553
if isinstance(node.executor, MultiAgentBase):
530554
if self.node_timeout is not None:
531555
multi_agent_result = await asyncio.wait_for(
532-
node.executor.invoke_async(node_input),
556+
node.executor.invoke_async(node_input, invocation_state),
533557
timeout=self.node_timeout,
534558
)
535559
else:
536-
multi_agent_result = await node.executor.invoke_async(node_input)
560+
multi_agent_result = await node.executor.invoke_async(node_input, invocation_state)
537561

538562
# Create NodeResult with MultiAgentResult directly
539563
node_result = NodeResult(
@@ -548,11 +572,11 @@ async def _execute_node(self, node: GraphNode) -> None:
548572
elif isinstance(node.executor, Agent):
549573
if self.node_timeout is not None:
550574
agent_response = await asyncio.wait_for(
551-
node.executor.invoke_async(node_input),
575+
node.executor.invoke_async(node_input, **invocation_state),
552576
timeout=self.node_timeout,
553577
)
554578
else:
555-
agent_response = await node.executor.invoke_async(node_input)
579+
agent_response = await node.executor.invoke_async(node_input, **invocation_state)
556580

557581
# Extract metrics from agent response
558582
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)

src/strands/multiagent/swarm.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,42 @@ def __init__(
237237
self._setup_swarm(nodes)
238238
self._inject_swarm_tools()
239239

240-
def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult:
241-
"""Invoke the swarm synchronously."""
240+
def __call__(
241+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
242+
) -> SwarmResult:
243+
"""Invoke the swarm synchronously.
244+
245+
Args:
246+
task: The task to execute
247+
invocation_state: Additional state/context passed to underlying agents.
248+
Defaults to None to avoid mutable default argument issues.
249+
**kwargs: Keyword arguments allowing backward compatible future changes.
250+
"""
251+
if invocation_state is None:
252+
invocation_state = {}
242253

243254
def execute() -> SwarmResult:
244-
return asyncio.run(self.invoke_async(task))
255+
return asyncio.run(self.invoke_async(task, invocation_state))
245256

246257
with ThreadPoolExecutor() as executor:
247258
future = executor.submit(execute)
248259
return future.result()
249260

250-
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult:
251-
"""Invoke the swarm asynchronously."""
261+
async def invoke_async(
262+
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
263+
) -> SwarmResult:
264+
"""Invoke the swarm asynchronously.
265+
266+
Args:
267+
task: The task to execute
268+
invocation_state: Additional state/context passed to underlying agents.
269+
Defaults to None to avoid mutable default argument issues - a new empty dict
270+
is created if None is provided.
271+
**kwargs: Keyword arguments allowing backward compatible future changes.
272+
"""
273+
if invocation_state is None:
274+
invocation_state = {}
275+
252276
logger.debug("starting swarm execution")
253277

254278
# Initialize swarm state with configuration
@@ -272,7 +296,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S
272296
self.execution_timeout,
273297
)
274298

275-
await self._execute_swarm()
299+
await self._execute_swarm(invocation_state)
276300
except Exception:
277301
logger.exception("swarm execution failed")
278302
self.state.completion_status = Status.FAILED
@@ -483,7 +507,7 @@ def _build_node_input(self, target_node: SwarmNode) -> str:
483507

484508
return context_text
485509

486-
async def _execute_swarm(self) -> None:
510+
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
487511
"""Shared execution logic used by execute_async."""
488512
try:
489513
# Main execution loop
@@ -522,7 +546,7 @@ async def _execute_swarm(self) -> None:
522546
# TODO: Implement cancellation token to stop _execute_node from continuing
523547
try:
524548
await asyncio.wait_for(
525-
self._execute_node(current_node, self.state.task),
549+
self._execute_node(current_node, self.state.task, invocation_state),
526550
timeout=self.node_timeout,
527551
)
528552

@@ -563,7 +587,9 @@ async def _execute_swarm(self) -> None:
563587
f"{elapsed_time:.2f}",
564588
)
565589

566-
async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -> AgentResult:
590+
async def _execute_node(
591+
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
592+
) -> AgentResult:
567593
"""Execute swarm node."""
568594
start_time = time.time()
569595
node_name = node.node_id
@@ -583,7 +609,8 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -
583609
# Execute node
584610
result = None
585611
node.reset_executor_state()
586-
result = await node.executor.invoke_async(node_input)
612+
# Unpacking since this is the agent class. Other executors should not unpack
613+
result = await node.executor.invoke_async(node_input, **invocation_state)
587614

588615
execution_time = round((time.time() - start_time) * 1000)
589616

tests/strands/multiagent/test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __init__(self):
155155
self.received_task = None
156156
self.received_kwargs = None
157157

158-
async def invoke_async(self, task, **kwargs):
158+
async def invoke_async(self, task, invocation_state, **kwargs):
159159
self.invoke_async_called = True
160160
self.received_task = task
161161
self.received_kwargs = kwargs

tests/strands/multiagent/test_graph.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,3 +1285,55 @@ def multi_loop_condition(state: GraphState) -> bool:
12851285
assert result.status == Status.COMPLETED
12861286
assert len(result.execution_order) >= 2
12871287
assert multi_agent.invoke_async.call_count >= 2
1288+
1289+
1290+
@pytest.mark.asyncio
1291+
async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span):
1292+
"""Test that kwargs are passed through to underlying Agent nodes."""
1293+
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
1294+
kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async)
1295+
1296+
builder = GraphBuilder()
1297+
builder.add_node(kwargs_agent, "kwargs_node")
1298+
graph = builder.build()
1299+
1300+
test_invocation_state = {"custom_param": "test_value", "another_param": 42}
1301+
result = await graph.invoke_async("Test kwargs passing", test_invocation_state)
1302+
1303+
kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state)
1304+
assert result.status == Status.COMPLETED
1305+
1306+
1307+
@pytest.mark.asyncio
1308+
async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_span):
1309+
"""Test that kwargs are passed through to underlying MultiAgentBase nodes."""
1310+
kwargs_multiagent = create_mock_multi_agent("kwargs_multiagent", "MultiAgent response with kwargs")
1311+
kwargs_multiagent.invoke_async = Mock(side_effect=kwargs_multiagent.invoke_async)
1312+
1313+
builder = GraphBuilder()
1314+
builder.add_node(kwargs_multiagent, "multiagent_node")
1315+
graph = builder.build()
1316+
1317+
test_invocation_state = {"custom_param": "test_value", "another_param": 42}
1318+
result = await graph.invoke_async("Test kwargs passing to multiagent", test_invocation_state)
1319+
1320+
kwargs_multiagent.invoke_async.assert_called_once_with(
1321+
[{"text": "Test kwargs passing to multiagent"}], test_invocation_state
1322+
)
1323+
assert result.status == Status.COMPLETED
1324+
1325+
1326+
def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
1327+
"""Test that kwargs are passed through to underlying nodes in sync execution."""
1328+
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
1329+
kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async)
1330+
1331+
builder = GraphBuilder()
1332+
builder.add_node(kwargs_agent, "kwargs_node")
1333+
graph = builder.build()
1334+
1335+
test_invocation_state = {"custom_param": "test_value", "another_param": 42}
1336+
result = graph("Test kwargs passing sync", test_invocation_state)
1337+
1338+
kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state)
1339+
assert result.status == Status.COMPLETED

tests/strands/multiagent/test_swarm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,32 @@ def test_swarm_validate_unsupported_features():
469469

470470
with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"):
471471
Swarm([agent_with_session])
472+
473+
474+
@pytest.mark.asyncio
475+
async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span):
476+
"""Test that kwargs are passed through to underlying agents."""
477+
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
478+
kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async)
479+
480+
swarm = Swarm(nodes=[kwargs_agent])
481+
482+
test_kwargs = {"custom_param": "test_value", "another_param": 42}
483+
result = await swarm.invoke_async("Test kwargs passing", test_kwargs)
484+
485+
assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs
486+
assert result.status == Status.COMPLETED
487+
488+
489+
def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
490+
"""Test that kwargs are passed through to underlying agents in sync execution."""
491+
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
492+
kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async)
493+
494+
swarm = Swarm(nodes=[kwargs_agent])
495+
496+
test_kwargs = {"custom_param": "test_value", "another_param": 42}
497+
result = swarm("Test kwargs passing sync", test_kwargs)
498+
499+
assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs
500+
assert result.status == Status.COMPLETED

0 commit comments

Comments
 (0)