Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions src/strands/multiagent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,35 @@ class MultiAgentBase(ABC):
"""

@abstractmethod
async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
"""Invoke asynchronously."""
async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
"""Invoke asynchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Additional keyword arguments passed to underlying agents.
"""
raise NotImplementedError("invoke_async not implemented")

def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> MultiAgentResult:
"""Invoke synchronously."""
def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
"""Invoke synchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Additional keyword arguments passed to underlying agents.
"""
if invocation_state is None:
invocation_state = {}

def execute() -> MultiAgentResult:
return asyncio.run(self.invoke_async(task, **kwargs))
return asyncio.run(self.invoke_async(task, invocation_state, **kwargs))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
Expand Down
50 changes: 37 additions & 13 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,18 +385,42 @@ def __init__(
self.state = GraphState()
self.tracer = get_tracer()

def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
"""Invoke the graph synchronously."""
def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> GraphResult:
"""Invoke the graph synchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
if invocation_state is None:
invocation_state = {}

def execute() -> GraphResult:
return asyncio.run(self.invoke_async(task))
return asyncio.run(self.invoke_async(task, invocation_state))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()

async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> GraphResult:
"""Invoke the graph asynchronously."""
async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> GraphResult:
"""Invoke the graph asynchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues - a new empty dict
is created if None is provided.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
if invocation_state is None:
invocation_state = {}

logger.debug("task=<%s> | starting graph execution", task)

# Initialize state
Expand All @@ -420,7 +444,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> G
self.node_timeout or "None",
)

await self._execute_graph()
await self._execute_graph(invocation_state)

# Set final status based on execution results
if self.state.failed_nodes:
Expand Down Expand Up @@ -450,7 +474,7 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
# Validate Agent-specific constraints for each node
_validate_node_executor(node.executor)

async def _execute_graph(self) -> None:
async def _execute_graph(self, invocation_state: dict[str, Any]) -> None:
"""Unified execution flow with conditional routing."""
ready_nodes = list(self.entry_points)

Expand All @@ -469,7 +493,7 @@ async def _execute_graph(self) -> None:
ready_nodes.clear()

# Execute current batch of ready nodes concurrently
tasks = [asyncio.create_task(self._execute_node(node)) for node in current_batch]
tasks = [asyncio.create_task(self._execute_node(node, invocation_state)) for node in current_batch]

for task in tasks:
await task
Expand Down Expand Up @@ -506,7 +530,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[
)
return False

async def _execute_node(self, node: GraphNode) -> None:
async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> None:
"""Execute a single node with error handling and timeout protection."""
# Reset the node's state if reset_on_revisit is enabled and it's being revisited
if self.reset_on_revisit and node in self.state.completed_nodes:
Expand All @@ -529,11 +553,11 @@ async def _execute_node(self, node: GraphNode) -> None:
if isinstance(node.executor, MultiAgentBase):
if self.node_timeout is not None:
multi_agent_result = await asyncio.wait_for(
node.executor.invoke_async(node_input),
node.executor.invoke_async(node_input, invocation_state),
timeout=self.node_timeout,
)
else:
multi_agent_result = await node.executor.invoke_async(node_input)
multi_agent_result = await node.executor.invoke_async(node_input, invocation_state)

# Create NodeResult with MultiAgentResult directly
node_result = NodeResult(
Expand All @@ -548,11 +572,11 @@ async def _execute_node(self, node: GraphNode) -> None:
elif isinstance(node.executor, Agent):
if self.node_timeout is not None:
agent_response = await asyncio.wait_for(
node.executor.invoke_async(node_input),
node.executor.invoke_async(node_input, **invocation_state),
timeout=self.node_timeout,
)
else:
agent_response = await node.executor.invoke_async(node_input)
agent_response = await node.executor.invoke_async(node_input, **invocation_state)

# Extract metrics from agent response
usage = Usage(inputTokens=0, outputTokens=0, totalTokens=0)
Expand Down
47 changes: 37 additions & 10 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,18 +237,42 @@ def __init__(
self._setup_swarm(nodes)
self._inject_swarm_tools()

def __call__(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult:
"""Invoke the swarm synchronously."""
def __call__(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> SwarmResult:
"""Invoke the swarm synchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
if invocation_state is None:
invocation_state = {}

def execute() -> SwarmResult:
return asyncio.run(self.invoke_async(task))
return asyncio.run(self.invoke_async(task, invocation_state))

with ThreadPoolExecutor() as executor:
future = executor.submit(execute)
return future.result()

async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> SwarmResult:
"""Invoke the swarm asynchronously."""
async def invoke_async(
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> SwarmResult:
"""Invoke the swarm asynchronously.

Args:
task: The task to execute
invocation_state: Additional state/context passed to underlying agents.
Defaults to None to avoid mutable default argument issues - a new empty dict
is created if None is provided.
**kwargs: Keyword arguments allowing backward compatible future changes.
"""
if invocation_state is None:
invocation_state = {}

logger.debug("starting swarm execution")

# Initialize swarm state with configuration
Expand All @@ -272,7 +296,7 @@ async def invoke_async(self, task: str | list[ContentBlock], **kwargs: Any) -> S
self.execution_timeout,
)

await self._execute_swarm()
await self._execute_swarm(invocation_state)
except Exception:
logger.exception("swarm execution failed")
self.state.completion_status = Status.FAILED
Expand Down Expand Up @@ -483,7 +507,7 @@ def _build_node_input(self, target_node: SwarmNode) -> str:

return context_text

async def _execute_swarm(self) -> None:
async def _execute_swarm(self, invocation_state: dict[str, Any]) -> None:
"""Shared execution logic used by execute_async."""
try:
# Main execution loop
Expand Down Expand Up @@ -522,7 +546,7 @@ async def _execute_swarm(self) -> None:
# TODO: Implement cancellation token to stop _execute_node from continuing
try:
await asyncio.wait_for(
self._execute_node(current_node, self.state.task),
self._execute_node(current_node, self.state.task, invocation_state),
timeout=self.node_timeout,
)

Expand Down Expand Up @@ -563,7 +587,9 @@ async def _execute_swarm(self) -> None:
f"{elapsed_time:.2f}",
)

async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -> AgentResult:
async def _execute_node(
self, node: SwarmNode, task: str | list[ContentBlock], invocation_state: dict[str, Any]
) -> AgentResult:
"""Execute swarm node."""
start_time = time.time()
node_name = node.node_id
Expand All @@ -583,7 +609,8 @@ async def _execute_node(self, node: SwarmNode, task: str | list[ContentBlock]) -
# Execute node
result = None
node.reset_executor_state()
result = await node.executor.invoke_async(node_input)
# Unpacking since this is the agent class. Other executors should not unpack
result = await node.executor.invoke_async(node_input, **invocation_state)

execution_time = round((time.time() - start_time) * 1000)

Expand Down
2 changes: 1 addition & 1 deletion tests/strands/multiagent/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(self):
self.received_task = None
self.received_kwargs = None

async def invoke_async(self, task, **kwargs):
async def invoke_async(self, task, invocation_state, **kwargs):
self.invoke_async_called = True
self.received_task = task
self.received_kwargs = kwargs
Expand Down
52 changes: 52 additions & 0 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,55 @@ def multi_loop_condition(state: GraphState) -> bool:
assert result.status == Status.COMPLETED
assert len(result.execution_order) >= 2
assert multi_agent.invoke_async.call_count >= 2


@pytest.mark.asyncio
async def test_graph_kwargs_passing_agent(mock_strands_tracer, mock_use_span):
"""Test that kwargs are passed through to underlying Agent nodes."""
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async)

builder = GraphBuilder()
builder.add_node(kwargs_agent, "kwargs_node")
graph = builder.build()

test_invocation_state = {"custom_param": "test_value", "another_param": 42}
result = await graph.invoke_async("Test kwargs passing", test_invocation_state)

kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing"}], **test_invocation_state)
assert result.status == Status.COMPLETED


@pytest.mark.asyncio
async def test_graph_kwargs_passing_multiagent(mock_strands_tracer, mock_use_span):
"""Test that kwargs are passed through to underlying MultiAgentBase nodes."""
kwargs_multiagent = create_mock_multi_agent("kwargs_multiagent", "MultiAgent response with kwargs")
kwargs_multiagent.invoke_async = Mock(side_effect=kwargs_multiagent.invoke_async)

builder = GraphBuilder()
builder.add_node(kwargs_multiagent, "multiagent_node")
graph = builder.build()

test_invocation_state = {"custom_param": "test_value", "another_param": 42}
result = await graph.invoke_async("Test kwargs passing to multiagent", test_invocation_state)

kwargs_multiagent.invoke_async.assert_called_once_with(
[{"text": "Test kwargs passing to multiagent"}], test_invocation_state
)
assert result.status == Status.COMPLETED


def test_graph_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
"""Test that kwargs are passed through to underlying nodes in sync execution."""
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async)

builder = GraphBuilder()
builder.add_node(kwargs_agent, "kwargs_node")
graph = builder.build()

test_invocation_state = {"custom_param": "test_value", "another_param": 42}
result = graph("Test kwargs passing sync", test_invocation_state)

kwargs_agent.invoke_async.assert_called_once_with([{"text": "Test kwargs passing sync"}], **test_invocation_state)
assert result.status == Status.COMPLETED
29 changes: 29 additions & 0 deletions tests/strands/multiagent/test_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,32 @@ def test_swarm_validate_unsupported_features():

with pytest.raises(ValueError, match="Session persistence is not supported for Swarm agents yet"):
Swarm([agent_with_session])


@pytest.mark.asyncio
async def test_swarm_kwargs_passing(mock_strands_tracer, mock_use_span):
"""Test that kwargs are passed through to underlying agents."""
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async)

swarm = Swarm(nodes=[kwargs_agent])

test_kwargs = {"custom_param": "test_value", "another_param": 42}
result = await swarm.invoke_async("Test kwargs passing", test_kwargs)

assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs
assert result.status == Status.COMPLETED


def test_swarm_kwargs_passing_sync(mock_strands_tracer, mock_use_span):
"""Test that kwargs are passed through to underlying agents in sync execution."""
kwargs_agent = create_mock_agent("kwargs_agent", "Response with kwargs")
kwargs_agent.invoke_async = Mock(side_effect=kwargs_agent.invoke_async)

swarm = Swarm(nodes=[kwargs_agent])

test_kwargs = {"custom_param": "test_value", "another_param": 42}
result = swarm("Test kwargs passing sync", test_kwargs)

assert kwargs_agent.invoke_async.call_args.kwargs == test_kwargs
assert result.status == Status.COMPLETED
Loading