diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py index d88152beb5..0f86516448 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py @@ -3,59 +3,85 @@ """Tool handling helpers.""" import logging -from typing import Any +from typing import TYPE_CHECKING, Any -from agent_framework import BaseChatClient, ChatAgent +from agent_framework import BaseChatClient + +if TYPE_CHECKING: + from agent_framework import AgentProtocol logger = logging.getLogger(__name__) -def collect_server_tools(agent: Any) -> list[Any]: - """Collect server tools from ChatAgent or duck-typed agent.""" - if isinstance(agent, ChatAgent): - tools_from_agent = agent.default_options.get("tools") - server_tools = list(tools_from_agent) if tools_from_agent else [] - logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools") - for tool in server_tools: - tool_name = getattr(tool, "name", "unknown") - approval_mode = getattr(tool, "approval_mode", None) - logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}") - return server_tools - - try: - default_options_attr = getattr(agent, "default_options", None) - if default_options_attr is not None: - if isinstance(default_options_attr, dict): - return default_options_attr.get("tools") or [] - return getattr(default_options_attr, "tools", None) or [] - except AttributeError: +def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]: + """Extract functions from connected MCP tools. + + Args: + mcp_tools: List of MCP tool instances. + + Returns: + List of functions from connected MCP tools. + """ + functions: list[Any] = [] + for mcp_tool in mcp_tools: + if getattr(mcp_tool, "is_connected", False) and hasattr(mcp_tool, "functions"): + functions.extend(mcp_tool.functions) + return functions + + +def collect_server_tools(agent: "AgentProtocol") -> list[Any]: + """Collect server tools from an agent. + + This includes both regular tools from default_options and MCP tools. + MCP tools are stored separately for lifecycle management but their + functions need to be included for tool execution during approval flows. + + Args: + agent: Agent instance to collect tools from. Works with ChatAgent + or any agent with default_options and optional mcp_tools attributes. + + Returns: + List of tools including both regular tools and connected MCP tool functions. + """ + # Get tools from default_options + default_options = getattr(agent, "default_options", None) + if default_options is None: return [] - return [] + tools_from_agent = default_options.get("tools") if isinstance(default_options, dict) else None + server_tools = list(tools_from_agent) if tools_from_agent else [] + + # Include functions from connected MCP tools (only available on ChatAgent) + mcp_tools = getattr(agent, "mcp_tools", None) + if mcp_tools: + server_tools.extend(_collect_mcp_tool_functions(mcp_tools)) + + logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools") + for tool in server_tools: + tool_name = getattr(tool, "name", "unknown") + approval_mode = getattr(tool, "approval_mode", None) + logger.info(f"[TOOLS] - {tool_name}: approval_mode={approval_mode}") + return server_tools -def register_additional_client_tools(agent: Any, client_tools: list[Any] | None) -> None: - """Register client tools as additional declaration-only tools to avoid server execution.""" + +def register_additional_client_tools(agent: "AgentProtocol", client_tools: list[Any] | None) -> None: + """Register client tools as additional declaration-only tools to avoid server execution. + + Args: + agent: Agent instance to register tools on. Works with ChatAgent + or any agent with a chat_client attribute. + client_tools: List of client tools to register. + """ if not client_tools: return - if isinstance(agent, ChatAgent): - chat_client = agent.chat_client - if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None: - chat_client.function_invocation_configuration.additional_tools = client_tools - logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)") + chat_client = getattr(agent, "chat_client", None) + if chat_client is None: return - try: - chat_client_attr = getattr(agent, "chat_client", None) - if chat_client_attr is not None: - fic = getattr(chat_client_attr, "function_invocation_configuration", None) - if fic is not None: - fic.additional_tools = client_tools # type: ignore[attr-defined] - logger.debug( - f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)" - ) - except AttributeError: - return + if isinstance(chat_client, BaseChatClient) and chat_client.function_invocation_configuration is not None: + chat_client.function_invocation_configuration.additional_tools = client_tools + logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)") def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None: diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index dc9420f15a..3e08f4d061 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -3,10 +3,17 @@ """Tests for AG-UI orchestrators.""" from collections.abc import AsyncGenerator -from types import SimpleNamespace from typing import Any +from unittest.mock import MagicMock -from agent_framework import AgentResponseUpdate, FunctionInvocationConfiguration, TextContent, ai_function +from agent_framework import ( + AgentResponseUpdate, + BaseChatClient, + ChatAgent, + FunctionInvocationConfiguration, + TextContent, + ai_function, +) from agent_framework_ag_ui._agent import AgentConfig from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext @@ -18,56 +25,53 @@ def server_tool() -> str: return "server" -class DummyAgent: - """Minimal agent stub to capture run_stream parameters.""" - - def __init__(self) -> None: - self.default_options: dict[str, Any] = {"tools": [server_tool], "response_format": None} - self.tools = [server_tool] - self.chat_client = SimpleNamespace( - function_invocation_configuration=FunctionInvocationConfiguration(), - ) - self.seen_tools: list[Any] | None = None +def _create_mock_chat_agent( + tools: list[Any] | None = None, + response_format: Any = None, + capture_tools: list[Any] | None = None, + capture_messages: list[Any] | None = None, +) -> ChatAgent: + """Create a ChatAgent with mocked chat client for testing. + + Args: + tools: Tools to configure on the agent. + response_format: Response format to configure. + capture_tools: If provided, tools passed to run_stream will be appended here. + capture_messages: If provided, messages passed to run_stream will be appended here. + """ + mock_chat_client = MagicMock(spec=BaseChatClient) + mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration() + + agent = ChatAgent( + chat_client=mock_chat_client, + tools=tools or [server_tool], + response_format=response_format, + ) - async def run_stream( - self, + # Create a mock run_stream that captures parameters and yields a simple response + async def mock_run_stream( messages: list[Any], *, - thread: Any, + thread: Any = None, tools: list[Any] | None = None, **kwargs: Any, ) -> AsyncGenerator[AgentResponseUpdate, None]: - self.seen_tools = tools + if capture_tools is not None and tools is not None: + capture_tools.extend(tools) + if capture_messages is not None: + capture_messages.extend(messages) yield AgentResponseUpdate(contents=[TextContent(text="ok")], role="assistant") + # Patch the run_stream method + agent.run_stream = mock_run_stream # type: ignore[method-assign] -class RecordingAgent: - """Agent stub that captures messages passed to run_stream.""" - - def __init__(self) -> None: - self.chat_options = SimpleNamespace(tools=[], response_format=None) - self.tools: list[Any] = [] - self.chat_client = SimpleNamespace( - function_invocation_configuration=FunctionInvocationConfiguration(), - ) - self.seen_messages: list[Any] | None = None - - async def run_stream( - self, - messages: list[Any], - *, - thread: Any, - tools: list[Any] | None = None, - **kwargs: Any, - ) -> AsyncGenerator[AgentResponseUpdate, None]: - self.seen_messages = messages - yield AgentResponseUpdate(contents=[TextContent(text="ok")], role="assistant") + return agent async def test_default_orchestrator_merges_client_tools() -> None: """Client tool declarations are merged with server tools before running agent.""" - - agent = DummyAgent() + captured_tools: list[Any] = [] + agent = _create_mock_chat_agent(tools=[server_tool], capture_tools=captured_tools) orchestrator = DefaultOrchestrator() input_data = { @@ -100,8 +104,8 @@ async def test_default_orchestrator_merges_client_tools() -> None: async for event in orchestrator.run(context): events.append(event) - assert agent.seen_tools is not None - tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools] + assert len(captured_tools) > 0 + tool_names = [getattr(tool, "name", "?") for tool in captured_tools] assert "server_tool" in tool_names assert "get_weather" in tool_names assert agent.chat_client.function_invocation_configuration.additional_tools @@ -109,8 +113,7 @@ async def test_default_orchestrator_merges_client_tools() -> None: async def test_default_orchestrator_with_camel_case_ids() -> None: """Client tool is able to extract camelCase IDs.""" - - agent = DummyAgent() + agent = _create_mock_chat_agent() orchestrator = DefaultOrchestrator() input_data = { @@ -143,8 +146,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None: async def test_default_orchestrator_with_snake_case_ids() -> None: """Client tool is able to extract snake_case IDs.""" - - agent = DummyAgent() + agent = _create_mock_chat_agent() orchestrator = DefaultOrchestrator() input_data = { @@ -177,8 +179,8 @@ async def test_default_orchestrator_with_snake_case_ids() -> None: async def test_state_context_injected_when_tool_call_state_mismatch() -> None: """State context should be injected when current state differs from tool call args.""" - - agent = RecordingAgent() + captured_messages: list[Any] = [] + agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages) orchestrator = DefaultOrchestrator() tool_recipe = {"title": "Salad", "special_preferences": []} @@ -215,9 +217,9 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None: async for _event in orchestrator.run(context): pass - assert agent.seen_messages is not None + assert len(captured_messages) > 0 state_messages = [] - for msg in agent.seen_messages: + for msg in captured_messages: role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) if role_value != "system": continue @@ -230,8 +232,8 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None: async def test_state_context_not_injected_when_tool_call_matches_state() -> None: """State context should be skipped when tool call args match current state.""" - - agent = RecordingAgent() + captured_messages: list[Any] = [] + agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages) orchestrator = DefaultOrchestrator() input_data = { @@ -264,9 +266,9 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None async for _event in orchestrator.run(context): pass - assert agent.seen_messages is not None + assert len(captured_messages) > 0 state_messages = [] - for msg in agent.seen_messages: + for msg in captured_messages: role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) if role_value != "system": continue diff --git a/python/packages/ag-ui/tests/test_tooling.py b/python/packages/ag-ui/tests/test_tooling.py index b802d654c6..23d82dda90 100644 --- a/python/packages/ag-ui/tests/test_tooling.py +++ b/python/packages/ag-ui/tests/test_tooling.py @@ -1,8 +1,14 @@ # Copyright (c) Microsoft. All rights reserved. -from types import SimpleNamespace +from unittest.mock import MagicMock -from agent_framework_ag_ui._orchestration._tooling import merge_tools, register_additional_client_tools +from agent_framework import ChatAgent, ai_function + +from agent_framework_ag_ui._orchestration._tooling import ( + collect_server_tools, + merge_tools, + register_additional_client_tools, +) class DummyTool: @@ -11,6 +17,30 @@ def __init__(self, name: str) -> None: self.declaration_only = True +class MockMCPTool: + """Mock MCP tool that simulates connected MCP tool with functions.""" + + def __init__(self, functions: list[DummyTool], is_connected: bool = True) -> None: + self.functions = functions + self.is_connected = is_connected + + +@ai_function +def regular_tool() -> str: + """Regular tool for testing.""" + return "result" + + +def _create_chat_agent_with_tool(tool_name: str = "regular_tool") -> ChatAgent: + """Create a ChatAgent with a mocked chat client and a simple tool. + + Note: tool_name parameter is kept for API compatibility but the tool + will always be named 'regular_tool' since ai_function uses the function name. + """ + mock_chat_client = MagicMock() + return ChatAgent(chat_client=mock_chat_client, tools=[regular_tool]) + + def test_merge_tools_filters_duplicates() -> None: server = [DummyTool("a"), DummyTool("b")] client = [DummyTool("b"), DummyTool("c")] @@ -23,14 +53,79 @@ def test_merge_tools_filters_duplicates() -> None: def test_register_additional_client_tools_assigns_when_configured() -> None: - class Fic: - def __init__(self) -> None: - self.additional_tools = None + """register_additional_client_tools should set additional_tools on the chat client.""" + from agent_framework import BaseChatClient, FunctionInvocationConfiguration + + mock_chat_client = MagicMock(spec=BaseChatClient) + mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration() - holder = SimpleNamespace(function_invocation_configuration=Fic()) - agent = SimpleNamespace(chat_client=holder) + agent = ChatAgent(chat_client=mock_chat_client) tools = [DummyTool("x")] register_additional_client_tools(agent, tools) - assert holder.function_invocation_configuration.additional_tools == tools + assert mock_chat_client.function_invocation_configuration.additional_tools == tools + + +def test_collect_server_tools_includes_mcp_tools_when_connected() -> None: + """MCP tool functions should be included when the MCP tool is connected.""" + mcp_function1 = DummyTool("mcp_function_1") + mcp_function2 = DummyTool("mcp_function_2") + mock_mcp = MockMCPTool([mcp_function1, mcp_function2], is_connected=True) + + agent = _create_chat_agent_with_tool("regular_tool") + agent.mcp_tools = [mock_mcp] + + tools = collect_server_tools(agent) + + names = [getattr(t, "name", None) for t in tools] + assert "regular_tool" in names + assert "mcp_function_1" in names + assert "mcp_function_2" in names + assert len(tools) == 3 + + +def test_collect_server_tools_excludes_mcp_tools_when_not_connected() -> None: + """MCP tool functions should be excluded when the MCP tool is not connected.""" + mcp_function = DummyTool("mcp_function") + mock_mcp = MockMCPTool([mcp_function], is_connected=False) + + agent = _create_chat_agent_with_tool("regular_tool") + agent.mcp_tools = [mock_mcp] + + tools = collect_server_tools(agent) + + names = [getattr(t, "name", None) for t in tools] + assert "regular_tool" in names + assert "mcp_function" not in names + assert len(tools) == 1 + + +def test_collect_server_tools_works_with_no_mcp_tools() -> None: + """collect_server_tools should work when there are no MCP tools.""" + agent = _create_chat_agent_with_tool("regular_tool") + + tools = collect_server_tools(agent) + + names = [getattr(t, "name", None) for t in tools] + assert "regular_tool" in names + assert len(tools) == 1 + + +def test_collect_server_tools_with_mcp_tools_via_public_property() -> None: + """collect_server_tools should access MCP tools via the public mcp_tools property.""" + mcp_function = DummyTool("mcp_function") + mock_mcp = MockMCPTool([mcp_function], is_connected=True) + + agent = _create_chat_agent_with_tool("regular_tool") + agent.mcp_tools = [mock_mcp] + + # Verify the public property works + assert agent.mcp_tools == [mock_mcp] + + tools = collect_server_tools(agent) + + names = [getattr(t, "name", None) for t in tools] + assert "regular_tool" in names + assert "mcp_function" in names + assert len(tools) == 2 diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a2a7e71a64..628ac7fb17 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -678,7 +678,7 @@ def __init__( normalized_tools: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = ( # type:ignore[reportUnknownVariableType] [] if tools_ is None else tools_ if isinstance(tools_, list) else [tools_] # type: ignore[list-item] ) - self._local_mcp_tools = [tool for tool in normalized_tools if isinstance(tool, MCPTool)] + self.mcp_tools: list[MCPTool] = [tool for tool in normalized_tools if isinstance(tool, MCPTool)] agent_tools = [tool for tool in normalized_tools if not isinstance(tool, MCPTool)] # Build chat options dict @@ -720,7 +720,7 @@ async def __aenter__(self) -> "Self": Returns: The ChatAgent instance. """ - for context_manager in chain([self.chat_client], self._local_mcp_tools): + for context_manager in chain([self.chat_client], self.mcp_tools): if isinstance(context_manager, AbstractAsyncContextManager): await self._async_exit_stack.enter_async_context(context_manager) return self @@ -817,7 +817,7 @@ async def run( else: final_tools.append(tool) # type: ignore - for mcp_server in self._local_mcp_tools: + for mcp_server in self.mcp_tools: if not mcp_server.is_connected: await self._async_exit_stack.enter_async_context(mcp_server) final_tools.extend(mcp_server.functions) @@ -944,7 +944,7 @@ async def run_stream( else: final_tools.append(tool) - for mcp_server in self._local_mcp_tools: + for mcp_server in self.mcp_tools: if not mcp_server.is_connected: await self._async_exit_stack.enter_async_context(mcp_server) final_tools.extend(mcp_server.functions) diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 9e9e115bd4..59e158cb0e 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -275,13 +275,13 @@ def _clone_chat_agent(self, agent: ChatAgent) -> ChatAgent: middleware = list(agent.middleware or []) # Reconstruct the original tools list by combining regular tools with MCP tools. - # ChatAgent.__init__ separates MCP tools into _local_mcp_tools during initialization, + # ChatAgent.__init__ separates MCP tools during initialization, # so we need to recombine them here to pass the complete tools list to the constructor. # This makes sure MCP tools are preserved when cloning agents for handoff workflows. tools_from_options = options.get("tools") all_tools = list(tools_from_options) if tools_from_options else [] - if agent._local_mcp_tools: # type: ignore - all_tools.extend(agent._local_mcp_tools) # type: ignore + if agent.mcp_tools: + all_tools.extend(agent.mcp_tools) logit_bias = options.get("logit_bias") metadata = options.get("metadata") diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index e63dd014fe..585036bef9 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -114,10 +114,10 @@ async def _ensure_mcp_connections(self, agent: Any) -> None: Args: agent: Agent object that may have MCP tools """ - if not hasattr(agent, "_local_mcp_tools"): + if not hasattr(agent, "mcp_tools"): return - for mcp_tool in agent._local_mcp_tools: + for mcp_tool in agent.mcp_tools: if not getattr(mcp_tool, "is_connected", False): continue diff --git a/python/packages/devui/agent_framework_devui/_server.py b/python/packages/devui/agent_framework_devui/_server.py index 146db9b33d..6393f23b4a 100644 --- a/python/packages/devui/agent_framework_devui/_server.py +++ b/python/packages/devui/agent_framework_devui/_server.py @@ -248,9 +248,9 @@ async def _cleanup_entities(self) -> None: except Exception as e: logger.warning(f"Error closing credential for {entity_info.id}: {e}") - # Close MCP tools (framework tracks them in _local_mcp_tools) - if entity_obj and hasattr(entity_obj, "_local_mcp_tools"): - for mcp_tool in entity_obj._local_mcp_tools: + # Close MCP tools (framework tracks them in mcp_tools) + if entity_obj and hasattr(entity_obj, "mcp_tools"): + for mcp_tool in entity_obj.mcp_tools: if hasattr(mcp_tool, "close") and callable(mcp_tool.close): try: if inspect.iscoroutinefunction(mcp_tool.close):