Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 66 additions & 40 deletions python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,85 @@
"""Tool handling helpers."""

import logging
from typing import Any
from typing import TYPE_CHECKING, Any

from agent_framework import BaseChatClient, ChatAgent
from agent_framework import BaseChatClient

if TYPE_CHECKING:
from agent_framework import AgentProtocol

logger = logging.getLogger(__name__)


def collect_server_tools(agent: Any) -> list[Any]:
"""Collect server tools from ChatAgent or duck-typed agent."""
if isinstance(agent, ChatAgent):
tools_from_agent = agent.default_options.get("tools")
server_tools = list(tools_from_agent) if tools_from_agent else []
logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
for tool in server_tools:
tool_name = getattr(tool, "name", "unknown")
approval_mode = getattr(tool, "approval_mode", None)
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
return server_tools

try:
default_options_attr = getattr(agent, "default_options", None)
if default_options_attr is not None:
if isinstance(default_options_attr, dict):
return default_options_attr.get("tools") or []
return getattr(default_options_attr, "tools", None) or []
except AttributeError:
def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]:
"""Extract functions from connected MCP tools.

Args:
mcp_tools: List of MCP tool instances.

Returns:
List of functions from connected MCP tools.
"""
functions: list[Any] = []
for mcp_tool in mcp_tools:
if getattr(mcp_tool, "is_connected", False) and hasattr(mcp_tool, "functions"):
functions.extend(mcp_tool.functions)
return functions


def collect_server_tools(agent: "AgentProtocol") -> list[Any]:
"""Collect server tools from an agent.

This includes both regular tools from default_options and MCP tools.
MCP tools are stored separately for lifecycle management but their
functions need to be included for tool execution during approval flows.

Args:
agent: Agent instance to collect tools from. Works with ChatAgent
or any agent with default_options and optional mcp_tools attributes.

Returns:
List of tools including both regular tools and connected MCP tool functions.
"""
# Get tools from default_options
default_options = getattr(agent, "default_options", None)
if default_options is None:
return []
return []

tools_from_agent = default_options.get("tools") if isinstance(default_options, dict) else None
server_tools = list(tools_from_agent) if tools_from_agent else []

# Include functions from connected MCP tools (only available on ChatAgent)
mcp_tools = getattr(agent, "mcp_tools", None)
if mcp_tools:
server_tools.extend(_collect_mcp_tool_functions(mcp_tools))

logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools")
for tool in server_tools:
tool_name = getattr(tool, "name", "unknown")
approval_mode = getattr(tool, "approval_mode", None)
logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}")
return server_tools

def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) -> None:
"""Register client tools as additional declaration-only tools to avoid server execution."""

def register_additional_client_tools(agent: "AgentProtocol", client_tools: list[Any] | None) -> None:
"""Register client tools as additional declaration-only tools to avoid server execution.

Args:
agent: Agent instance to register tools on. Works with ChatAgent
or any agent with a chat_client attribute.
client_tools: List of client tools to register.
"""
if not client_tools:
return

if isinstance(agent, ChatAgent):
chat_client = agent.chat_client
if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None:
chat_client.function_invocation_configuration.additional_tools = client_tools
logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)")
chat_client = getattr(agent, "chat_client", None)
if chat_client is None:
return

try:
chat_client_attr = getattr(agent, "chat_client", None)
if chat_client_attr is not None:
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
if fic is not None:
fic.additional_tools = client_tools # type: ignore[attr-defined]
logger.debug(
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
)
except AttributeError:
return
if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None:
chat_client.function_invocation_configuration.additional_tools = client_tools
logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)")


def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None:
Expand Down
108 changes: 55 additions & 53 deletions python/packages/ag-ui/tests/test_orchestrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@
"""Tests for AG-UI orchestrators."""

from collections.abc import AsyncGenerator
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock

from agent_framework import AgentResponseUpdate, FunctionInvocationConfiguration, TextContent, ai_function
from agent_framework import (
AgentResponseUpdate,
BaseChatClient,
ChatAgent,
FunctionInvocationConfiguration,
TextContent,
ai_function,
)

from agent_framework_ag_ui._agent import AgentConfig
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext
Expand All @@ -18,56 +25,53 @@ def server_tool() -> str:
return "server"


class DummyAgent:
"""Minimal agent stub to capture run_stream parameters."""

def __init__(self) -> None:
self.default_options: dict[str, Any] = {"tools": [server_tool], "response_format": None}
self.tools = [server_tool]
self.chat_client = SimpleNamespace(
function_invocation_configuration=FunctionInvocationConfiguration(),
)
self.seen_tools: list[Any] | None = None
def _create_mock_chat_agent(
tools: list[Any] | None = None,
response_format: Any = None,
capture_tools: list[Any] | None = None,
capture_messages: list[Any] | None = None,
) -> ChatAgent:
"""Create a ChatAgent with mocked chat client for testing.

Args:
tools: Tools to configure on the agent.
response_format: Response format to configure.
capture_tools: If provided, tools passed to run_stream will be appended here.
capture_messages: If provided, messages passed to run_stream will be appended here.
"""
mock_chat_client = MagicMock(spec=BaseChatClient)
mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration()

agent = ChatAgent(
chat_client=mock_chat_client,
tools=tools or [server_tool],
response_format=response_format,
)

async def run_stream(
self,
# Create a mock run_stream that captures parameters and yields a simple response
async def mock_run_stream(
messages: list[Any],
*,
thread: Any,
thread: Any = None,
tools: list[Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[AgentResponseUpdate, None]:
self.seen_tools = tools
if capture_tools is not None and tools is not None:
capture_tools.extend(tools)
if capture_messages is not None:
capture_messages.extend(messages)
yield AgentResponseUpdate(contents=[TextContent(text="ok")], role="assistant")

# Patch the run_stream method
agent.run_stream = mock_run_stream # type: ignore[method-assign]

class RecordingAgent:
"""Agent stub that captures messages passed to run_stream."""

def __init__(self) -> None:
self.chat_options = SimpleNamespace(tools=[], response_format=None)
self.tools: list[Any] = []
self.chat_client = SimpleNamespace(
function_invocation_configuration=FunctionInvocationConfiguration(),
)
self.seen_messages: list[Any] | None = None

async def run_stream(
self,
messages: list[Any],
*,
thread: Any,
tools: list[Any] | None = None,
**kwargs: Any,
) -> AsyncGenerator[AgentResponseUpdate, None]:
self.seen_messages = messages
yield AgentResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
return agent


async def test_default_orchestrator_merges_client_tools() -> None:
"""Client tool declarations are merged with server tools before running agent."""

agent = DummyAgent()
captured_tools: list[Any] = []
agent = _create_mock_chat_agent(tools=[server_tool], capture_tools=captured_tools)
orchestrator = DefaultOrchestrator()

input_data = {
Expand Down Expand Up @@ -100,17 +104,16 @@ async def test_default_orchestrator_merges_client_tools() -> None:
async for event in orchestrator.run(context):
events.append(event)

assert agent.seen_tools is not None
tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools]
assert len(captured_tools) > 0
tool_names = [getattr(tool, "name", "?") for tool in captured_tools]
assert "server_tool" in tool_names
assert "get_weather" in tool_names
assert agent.chat_client.function_invocation_configuration.additional_tools


async def test_default_orchestrator_with_camel_case_ids() -> None:
"""Client tool is able to extract camelCase IDs."""

agent = DummyAgent()
agent = _create_mock_chat_agent()
orchestrator = DefaultOrchestrator()

input_data = {
Expand Down Expand Up @@ -143,8 +146,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None:

async def test_default_orchestrator_with_snake_case_ids() -> None:
"""Client tool is able to extract snake_case IDs."""

agent = DummyAgent()
agent = _create_mock_chat_agent()
orchestrator = DefaultOrchestrator()

input_data = {
Expand Down Expand Up @@ -177,8 +179,8 @@ async def test_default_orchestrator_with_snake_case_ids() -> None:

async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
"""State context should be injected when current state differs from tool call args."""

agent = RecordingAgent()
captured_messages: list[Any] = []
agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages)
orchestrator = DefaultOrchestrator()

tool_recipe = {"title": "Salad", "special_preferences": []}
Expand Down Expand Up @@ -215,9 +217,9 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
async for _event in orchestrator.run(context):
pass

assert agent.seen_messages is not None
assert len(captured_messages) > 0
state_messages = []
for msg in agent.seen_messages:
for msg in captured_messages:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value != "system":
continue
Expand All @@ -230,8 +232,8 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:

async def test_state_context_not_injected_when_tool_call_matches_state() -> None:
"""State context should be skipped when tool call args match current state."""

agent = RecordingAgent()
captured_messages: list[Any] = []
agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages)
orchestrator = DefaultOrchestrator()

input_data = {
Expand Down Expand Up @@ -264,9 +266,9 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None
async for _event in orchestrator.run(context):
pass

assert agent.seen_messages is not None
assert len(captured_messages) > 0
state_messages = []
for msg in agent.seen_messages:
for msg in captured_messages:
role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
if role_value != "system":
continue
Expand Down
Loading
Loading