Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/strands/multiagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
"""

from .base import MultiAgentBase, MultiAgentResult
from .function_node import FunctionNode
from .graph import GraphBuilder, GraphResult
from .swarm import Swarm, SwarmResult

__all__ = [
"FunctionNode",
"GraphBuilder",
"GraphResult",
"MultiAgentBase",
Expand Down
142 changes: 142 additions & 0 deletions src/strands/multiagent/function_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""FunctionNode implementation for executing deterministic Python functions as graph nodes.

This module provides the FunctionNode class that extends MultiAgentBase to execute
regular Python functions while maintaining compatibility with the existing graph
execution framework, proper error handling, metrics collection, and result formatting.
"""

import logging
import time
from typing import Any, Callable, Union

from opentelemetry import trace as trace_api

from ..agent import AgentResult
from ..telemetry import get_tracer
from ..telemetry.metrics import EventLoopMetrics
from ..types.content import ContentBlock, Message
from ..types.event_loop import Metrics, Usage
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status

logger = logging.getLogger(__name__)


FunctionNodeCallable = Callable[[Union[str, list[ContentBlock]], dict[str, Any] | None], str]


class FunctionNode(MultiAgentBase):
"""Execute deterministic Python functions as graph nodes.

FunctionNode wraps any callable Python function and executes it within the
established multiagent framework, handling input conversion, error management,
metrics collection, and result formatting automatically.

Args:
func: The callable function to wrap and execute
name: Required name for the node
"""

def __init__(self, func: FunctionNodeCallable, name: str):
"""Initialize FunctionNode with a callable function and required name.

Args:
func: The callable function to wrap and execute
name: Required name for the node
"""
self.func = func
self.name = name
self.tracer = get_tracer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO for later probably: streaming callables 😅 #961

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point let me take a look for the streaming


async def invoke_async(
self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any
) -> MultiAgentResult:
"""Execute the wrapped function and return formatted results.

Args:
task: The task input (string or ContentBlock list) to pass to the function
invocation_state: Additional state/context (preserved for interface compatibility)
**kwargs: Additional keyword arguments (preserved for future extensibility)

Returns:
MultiAgentResult containing the function execution results and metadata
"""
if invocation_state is None:
invocation_state = {}

logger.debug("task=<%s> | starting function node execution", task)
logger.debug("function_name=<%s> | executing function", self.name)

span = self.tracer.start_multiagent_span(task, "function_node")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we should be using this to record a function node:

def start_multiagent_span(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@Unshure Unshure Oct 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, I flagged this because we are setting a field in this function named gen_ai.agent.name, and this is almost always not an agent. We should probably revisit how we do this in the other multi-agent cases, but not a blocker for this pr

"gen_ai.agent.name": instance,

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh sorry misunderstood the callout.

Looking at it now, honeslty use of gen_ai.agent.name seems inappropriate for graph and swarm too at the top level. It seems like these should all be using gen_ai.operation.name.

So I guess we can

  1. Keep gen_ai.agent.name for consistency
  2. Remove entirely until we need to add back
  3. introduce a start_multiagent_function_node_span

I'm leaning towards 1 for consistency, seems like you are too if I'm not mistaken

Copy link
Member

@Unshure Unshure Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not block this pr, im fine with keeping it consistent for now. But lets create an issue to track the proper fix of this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with trace_api.use_span(span, end_on_exit=True):
try:
start_time = time.time()
# Execute the wrapped function with proper parameters
function_result = self.func(task, invocation_state, **kwargs)
# Calculate execution time
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds
logger.debug(
"function_result=<%s>, execution_time=<%dms> | function executed successfully",
function_result,
execution_time,
)

# Convert function result to Message
message = Message(role="assistant", content=[ContentBlock(text=str(function_result))])
agent_result = AgentResult(
stop_reason="end_turn", # "Normal completion of the response" - function executed successfully
message=message,
metrics=EventLoopMetrics(),
state={},
)

# Create NodeResult for this function execution
node_result = NodeResult(
result=agent_result, # type is AgentResult
execution_time=execution_time,
status=Status.COMPLETED,
execution_count=1,
)

# Create MultiAgentResult with the NodeResult
multi_agent_result = MultiAgentResult(
status=Status.COMPLETED,
results={self.name: node_result},
execution_count=1,
execution_time=execution_time,
)

logger.debug(
"function_name=<%s>, execution_time=<%dms> | function node completed successfully",
self.name,
execution_time,
)

return multi_agent_result

except Exception as e:
# Calculate execution time even for failed executions
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds

logger.error("function_name=<%s>, error=<%s> | function node failed", self.name, e)

# Create failed NodeResult with exception
node_result = NodeResult(
result=e,
execution_time=execution_time,
status=Status.FAILED,
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
accumulated_metrics=Metrics(latencyMs=execution_time),
execution_count=1,
)

# Create failed MultiAgentResult
multi_agent_result = MultiAgentResult(
status=Status.FAILED,
results={self.name: node_result},
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0),
accumulated_metrics=Metrics(latencyMs=execution_time),
execution_count=1,
execution_time=execution_time,
)

return multi_agent_result
94 changes: 94 additions & 0 deletions tests/strands/multiagent/test_function_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Unit tests for FunctionNode implementation."""

from unittest.mock import Mock, patch

import pytest

from strands.multiagent.base import Status
from strands.multiagent.function_node import FunctionNode
from strands.types.content import ContentBlock


@pytest.fixture
def mock_tracer():
"""Create a mock tracer for testing."""
tracer = Mock()
span = Mock()
span.__enter__ = Mock(return_value=span)
span.__exit__ = Mock(return_value=None)
tracer.start_multiagent_span.return_value = span
return tracer


@pytest.mark.asyncio
async def test_invoke_async_string_input_success(mock_tracer):
"""Test successful function execution with string input."""

def test_function(task, invocation_state=None, **kwargs):
return f"Processed: {task}"

node = FunctionNode(test_function, "string_test")

with patch.object(node, "tracer", mock_tracer):
result = await node.invoke_async("test input")

assert result.status == Status.COMPLETED
assert "string_test" in result.results
assert result.results["string_test"].status == Status.COMPLETED
assert result.accumulated_usage["inputTokens"] == 0
assert result.accumulated_usage["outputTokens"] == 0


@pytest.mark.asyncio
async def test_invoke_async_content_block_input_success(mock_tracer):
"""Test successful function execution with ContentBlock input."""

def test_function(task, invocation_state=None, **kwargs):
return "ContentBlock processed"

node = FunctionNode(test_function, "content_block_test")
content_blocks = [ContentBlock(text="First block"), ContentBlock(text="Second block")]

with patch.object(node, "tracer", mock_tracer):
result = await node.invoke_async(content_blocks)

assert result.status == Status.COMPLETED
assert "content_block_test" in result.results
node_result = result.results["content_block_test"]
assert node_result.status == Status.COMPLETED


@pytest.mark.asyncio
async def test_invoke_async_with_kwargs(mock_tracer):
"""Test function execution with additional kwargs."""

def test_function(task, invocation_state=None, **kwargs):
extra_param = kwargs.get("extra_param", "none")
return f"Extra: {extra_param}"

node = FunctionNode(test_function, "kwargs_test")

with patch.object(node, "tracer", mock_tracer):
result = await node.invoke_async("test", None, extra_param="test_value")

assert result.status == Status.COMPLETED


@pytest.mark.asyncio
async def test_invoke_async_function_exception(mock_tracer):
"""Test proper exception handling when function raises an error."""

def failing_function(task, invocation_state=None, **kwargs):
raise ValueError("Test exception")

node = FunctionNode(failing_function, "exception_test")

with patch.object(node, "tracer", mock_tracer):
result = await node.invoke_async("test input")

assert result.status == Status.FAILED
assert "exception_test" in result.results
node_result = result.results["exception_test"]
assert node_result.status == Status.FAILED
assert isinstance(node_result.result, ValueError)
assert str(node_result.result) == "Test exception"
44 changes: 44 additions & 0 deletions tests_integ/test_multiagent_function_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Integration tests for FunctionNode with multiagent systems."""

import pytest

from strands import Agent
from strands.multiagent.base import Status
from strands.multiagent.function_node import FunctionNode
from strands.multiagent.graph import GraphBuilder

# Global variable to test function execution
test_global_var = None


def set_global_var(task, invocation_state=None, **kwargs):
"""Simple function that sets a global variable."""
global test_global_var
test_global_var = f"Function executed with: {task}"
return "Global variable set"


@pytest.mark.asyncio
async def test_agent_with_function_node():
"""Test graph with agent and function node."""
global test_global_var
test_global_var = None

# Create nodes
agent = Agent()
function_node = FunctionNode(set_global_var, "setter")

# Build graph
builder = GraphBuilder()
builder.add_node(agent, "agent")
builder.add_node(function_node, "setter")
builder.add_edge("agent", "setter")
builder.set_entry_point("agent")
graph = builder.build()

# Execute
result = await graph.invoke_async("Say hello")

# Verify function was called
assert "Function executed with:" in test_global_var
assert result.status == Status.COMPLETED
Loading