-
Notifications
You must be signed in to change notification settings - Fork 445
feat(multiagent): add FunctionNode to improve DX for deterministic cases #991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,157 @@ | ||||||
"""FunctionNode implementation for executing deterministic Python functions as graph nodes. | ||||||
|
||||||
This module provides the FunctionNode class that extends MultiAgentBase to execute | ||||||
regular Python functions while maintaining compatibility with the existing graph | ||||||
execution framework, proper error handling, metrics collection, and result formatting. | ||||||
""" | ||||||
|
||||||
import logging | ||||||
import time | ||||||
from typing import Any, Protocol, Union | ||||||
|
||||||
from opentelemetry import trace as trace_api | ||||||
|
||||||
from ..agent import AgentResult | ||||||
from ..telemetry import get_tracer | ||||||
from ..telemetry.metrics import EventLoopMetrics | ||||||
from ..types.content import ContentBlock, Message | ||||||
from ..types.event_loop import Metrics, Usage | ||||||
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status | ||||||
|
||||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
|
||||||
class FunctionNodeCallable(Protocol): | ||||||
"""Protocol for functions that can be executed within FunctionNode.""" | ||||||
|
||||||
def __call__( | ||||||
self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any | ||||||
) -> str | list[ContentBlock] | Message: | ||||||
"""Execute deterministic logic within the multiagent system.""" | ||||||
... | ||||||
|
||||||
|
||||||
class FunctionNode(MultiAgentBase): | ||||||
"""Execute deterministic Python functions as graph nodes. | ||||||
|
||||||
FunctionNode wraps any callable Python function and executes it within the | ||||||
established multiagent framework, handling input conversion, error management, | ||||||
metrics collection, and result formatting automatically. | ||||||
|
||||||
Args: | ||||||
func: The callable function to wrap and execute | ||||||
name: Required name for the node | ||||||
""" | ||||||
|
||||||
def __init__(self, func: FunctionNodeCallable, name: str): | ||||||
"""Initialize FunctionNode with a callable function and required name. | ||||||
|
||||||
Args: | ||||||
func: The callable function to wrap and execute | ||||||
name: Required name for the node | ||||||
""" | ||||||
self.func = func | ||||||
self.name = name | ||||||
self.tracer = get_tracer() | ||||||
|
||||||
async def invoke_async( | ||||||
self, task: Union[str, list[ContentBlock]], invocation_state: dict[str, Any] | None = None, **kwargs: Any | ||||||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
) -> MultiAgentResult: | ||||||
"""Execute the wrapped function and return formatted results. | ||||||
|
||||||
Args: | ||||||
task: The task input (string or ContentBlock list) to pass to the function | ||||||
invocation_state: Additional state/context (preserved for interface compatibility) | ||||||
**kwargs: Additional keyword arguments (preserved for future extensibility) | ||||||
|
||||||
Returns: | ||||||
MultiAgentResult containing the function execution results and metadata | ||||||
""" | ||||||
if invocation_state is None: | ||||||
invocation_state = {} | ||||||
|
||||||
logger.debug("task=<%s> | starting function node execution", task) | ||||||
logger.debug("function_name=<%s> | executing function", self.name) | ||||||
|
||||||
span = self.tracer.start_multiagent_span(task, "function_node") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if we should be using this to record a function node: sdk-python/src/strands/telemetry/tracer.py Line 622 in 776fd93
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? We use it in both swarm and graph with the name of the node? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Huh, I flagged this because we are setting a field in this function named sdk-python/src/strands/telemetry/tracer.py Line 662 in 26862e4
|
||||||
with trace_api.use_span(span, end_on_exit=True): | ||||||
try: | ||||||
start_time = time.time() | ||||||
# Execute the wrapped function with proper parameters | ||||||
function_result = self.func(task, invocation_state, **kwargs) | ||||||
# Calculate execution time | ||||||
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds | ||||||
logger.debug( | ||||||
"function_result=<%s>, execution_time=<%dms> | function executed successfully", | ||||||
function_result, | ||||||
execution_time, | ||||||
) | ||||||
|
||||||
# Convert function result to Message based on type | ||||||
if isinstance(function_result, dict) and "role" in function_result and "content" in function_result: | ||||||
# Already a Message | ||||||
message = function_result | ||||||
elif isinstance(function_result, list): | ||||||
# List of ContentBlocks | ||||||
message = Message(role="assistant", content=function_result) | ||||||
else: | ||||||
# String or other type - convert to string | ||||||
message = Message(role="assistant", content=[ContentBlock(text=str(function_result))]) | ||||||
agent_result = AgentResult( | ||||||
stop_reason="end_turn", # "Normal completion of the response" - function executed successfully | ||||||
message=message, | ||||||
metrics=EventLoopMetrics(), | ||||||
state={}, | ||||||
) | ||||||
|
||||||
# Create NodeResult for this function execution | ||||||
node_result = NodeResult( | ||||||
result=agent_result, # type is AgentResult | ||||||
execution_time=execution_time, | ||||||
status=Status.COMPLETED, | ||||||
execution_count=1, | ||||||
) | ||||||
|
||||||
# Create MultiAgentResult with the NodeResult | ||||||
multi_agent_result = MultiAgentResult( | ||||||
status=Status.COMPLETED, | ||||||
results={self.name: node_result}, | ||||||
execution_count=1, | ||||||
execution_time=execution_time, | ||||||
) | ||||||
|
||||||
logger.debug( | ||||||
"function_name=<%s>, execution_time=<%dms> | function node completed successfully", | ||||||
self.name, | ||||||
execution_time, | ||||||
) | ||||||
|
||||||
return multi_agent_result | ||||||
|
||||||
except Exception as e: | ||||||
# Calculate execution time even for failed executions | ||||||
execution_time = int((time.time() - start_time) * 1000) # Convert to milliseconds | ||||||
|
||||||
logger.error("function_name=<%s>, error=<%s> | function node failed", self.name, e) | ||||||
|
||||||
# Create failed NodeResult with exception | ||||||
node_result = NodeResult( | ||||||
result=e, | ||||||
execution_time=execution_time, | ||||||
status=Status.FAILED, | ||||||
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), | ||||||
accumulated_metrics=Metrics(latencyMs=execution_time), | ||||||
execution_count=1, | ||||||
) | ||||||
|
||||||
# Create failed MultiAgentResult | ||||||
multi_agent_result = MultiAgentResult( | ||||||
status=Status.FAILED, | ||||||
results={self.name: node_result}, | ||||||
accumulated_usage=Usage(inputTokens=0, outputTokens=0, totalTokens=0), | ||||||
accumulated_metrics=Metrics(latencyMs=execution_time), | ||||||
execution_count=1, | ||||||
execution_time=execution_time, | ||||||
) | ||||||
|
||||||
return multi_agent_result |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
"""Unit tests for FunctionNode implementation.""" | ||
|
||
from unittest.mock import Mock, patch | ||
|
||
import pytest | ||
|
||
from strands.multiagent.base import Status | ||
from strands.multiagent.function_node import FunctionNode | ||
from strands.types.content import ContentBlock, Message | ||
|
||
|
||
@pytest.fixture | ||
def mock_tracer(): | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Create a mock tracer for testing.""" | ||
tracer = Mock() | ||
span = Mock() | ||
span.__enter__ = Mock(return_value=span) | ||
span.__exit__ = Mock(return_value=None) | ||
tracer.start_multiagent_span.return_value = span | ||
return tracer | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invoke_async_string_input_success(mock_tracer): | ||
"""Test successful function execution with string input.""" | ||
|
||
def test_function(task, invocation_state=None, **kwargs): | ||
return f"Processed: {task}" | ||
|
||
node = FunctionNode(test_function, "string_test") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test input") | ||
|
||
assert result.status == Status.COMPLETED | ||
assert "string_test" in result.results | ||
assert result.results["string_test"].status == Status.COMPLETED | ||
assert result.accumulated_usage["inputTokens"] == 0 | ||
assert result.accumulated_usage["outputTokens"] == 0 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invoke_async_content_block_input_success(mock_tracer): | ||
"""Test successful function execution with ContentBlock input.""" | ||
|
||
def test_function(task, invocation_state=None, **kwargs): | ||
return "ContentBlock processed" | ||
|
||
node = FunctionNode(test_function, "content_block_test") | ||
content_blocks = [ContentBlock(text="First block"), ContentBlock(text="Second block")] | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async(content_blocks) | ||
|
||
assert result.status == Status.COMPLETED | ||
assert "content_block_test" in result.results | ||
node_result = result.results["content_block_test"] | ||
assert node_result.status == Status.COMPLETED | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invoke_async_with_kwargs(mock_tracer): | ||
"""Test function execution with additional kwargs.""" | ||
|
||
def test_function(task, invocation_state=None, **kwargs): | ||
extra_param = kwargs.get("extra_param", "none") | ||
return f"Extra: {extra_param}" | ||
|
||
node = FunctionNode(test_function, "kwargs_test") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test", None, extra_param="test_value") | ||
|
||
assert result.status == Status.COMPLETED | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_invoke_async_function_exception(mock_tracer): | ||
"""Test proper exception handling when function raises an error.""" | ||
|
||
def failing_function(task, invocation_state=None, **kwargs): | ||
raise ValueError("Test exception") | ||
|
||
node = FunctionNode(failing_function, "exception_test") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test input") | ||
|
||
assert result.status == Status.FAILED | ||
assert "exception_test" in result.results | ||
node_result = result.results["exception_test"] | ||
assert node_result.status == Status.FAILED | ||
assert isinstance(node_result.result, ValueError) | ||
assert str(node_result.result) == "Test exception" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_function_returns_string(mock_tracer): | ||
"""Test function returning string.""" | ||
|
||
def string_function(task, invocation_state=None, **kwargs): | ||
return "Hello World" | ||
|
||
node = FunctionNode(string_function, "string_node") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test") | ||
|
||
agent_result = result.results["string_node"].result | ||
assert agent_result.message["content"][0]["text"] == "Hello World" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_function_returns_content_blocks(mock_tracer): | ||
"""Test function returning list of ContentBlocks.""" | ||
|
||
def content_block_function(task, invocation_state=None, **kwargs): | ||
return [ContentBlock(text="Block 1"), ContentBlock(text="Block 2")] | ||
|
||
node = FunctionNode(content_block_function, "content_node") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test") | ||
|
||
agent_result = result.results["content_node"].result | ||
assert len(agent_result.message["content"]) == 2 | ||
assert agent_result.message["content"][0]["text"] == "Block 1" | ||
assert agent_result.message["content"][1]["text"] == "Block 2" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_function_returns_message(mock_tracer): | ||
"""Test function returning Message.""" | ||
|
||
def message_function(task, invocation_state=None, **kwargs): | ||
return Message(role="user", content=[ContentBlock(text="Custom message")]) | ||
|
||
node = FunctionNode(message_function, "message_node") | ||
|
||
with patch.object(node, "tracer", mock_tracer): | ||
result = await node.invoke_async("test") | ||
|
||
agent_result = result.results["message_node"].result | ||
assert agent_result.message["role"] == "user" | ||
assert agent_result.message["content"][0]["text"] == "Custom message" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""Integration tests for FunctionNode with multiagent systems.""" | ||
|
||
import pytest | ||
|
||
from strands import Agent | ||
from strands.multiagent.base import Status | ||
from strands.multiagent.function_node import FunctionNode | ||
from strands.multiagent.graph import GraphBuilder | ||
|
||
# Global variable to test function execution | ||
test_global_var = None | ||
|
||
|
||
def set_global_var(task, invocation_state=None, **kwargs): | ||
"""Simple function that sets a global variable.""" | ||
global test_global_var | ||
test_global_var = f"Function executed with: {task}" | ||
return "Global variable set" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_agent_with_function_node(): | ||
"""Test graph with agent and function node.""" | ||
global test_global_var | ||
test_global_var = None | ||
|
||
# Create nodes | ||
agent = Agent() | ||
function_node = FunctionNode(set_global_var, "setter") | ||
dbschmigelski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Build graph | ||
builder = GraphBuilder() | ||
builder.add_node(agent, "agent") | ||
builder.add_node(function_node, "setter") | ||
builder.add_edge("agent", "setter") | ||
builder.set_entry_point("agent") | ||
graph = builder.build() | ||
|
||
# Execute | ||
result = await graph.invoke_async("Say hello") | ||
|
||
# Verify function was called | ||
assert "Function executed with:" in test_global_var | ||
assert result.status == Status.COMPLETED |
Uh oh!
There was an error while loading. Please reload this page.