diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 35ccb1d58a..b2942de2a0 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -255,7 +255,7 @@ def __init__( self._exit_stack = AsyncExitStack() self._lifecycle_lock = asyncio.Lock() self._lifecycle_request_lock = asyncio.Lock() - self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None + self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None self._lifecycle_owner_task: asyncio.Task[None] | None = None self.session = session self.request_timeout = request_timeout @@ -265,6 +265,11 @@ def __init__( self.is_connected: bool = False self._tools_loaded: bool = False self._prompts_loaded: bool = False + self._server_capabilities: types.ServerCapabilities | None = None + self._supports_tools: bool = True + self._supports_prompts: bool = True + self._supports_logging: bool | None = None + self._ping_available: bool = True self._pending_reload_tasks: set[asyncio.Task[None]] = set() def __str__(self) -> str: @@ -566,11 +571,11 @@ async def _run_lifecycle_owner(self) -> None: stop_error: BaseException | None = None try: while True: - action, reset, future = await queue.get() + action, reset, load_configured, future = await queue.get() try: if action == "connect": - await self._connect_on_owner(reset=reset) + await self._connect_on_owner(reset=reset, load_configured=load_configured) elif action == "close": await self._close_on_owner() else: @@ -595,7 +600,7 @@ async def _run_lifecycle_owner(self) -> None: finally: while True: try: - _, _, future = queue.get_nowait() + _, _, _, future = queue.get_nowait() except asyncio.QueueEmpty: break if not future.done(): @@ -608,12 +613,18 @@ def _is_lifecycle_owner_task(self) -> bool: owner_task = self._lifecycle_owner_task return owner_task is not None and asyncio.current_task() is owner_task - async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) -> None: + async def _run_on_lifecycle_owner( + self, + action: str, + *, + reset: bool = False, + load_configured: bool = True, + ) -> None: await self._ensure_lifecycle_owner() if self._is_lifecycle_owner_task(): if action == "connect": - await self._connect_on_owner(reset=reset) + await self._connect_on_owner(reset=reset, load_configured=load_configured) elif action == "close": await self._close_on_owner() else: @@ -625,7 +636,7 @@ async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) -> raise RuntimeError("MCP lifecycle owner is not available.") future = asyncio.get_running_loop().create_future() - await queue.put((action, reset, future)) + await queue.put((action, reset, load_configured, future)) await future async def _safe_close_exit_stack(self) -> None: @@ -656,6 +667,32 @@ async def _close_and_check_cancelled(self, ex: BaseException) -> bool: await self._safe_close_exit_stack() return _should_propagate_cancelled_error(ex) + def _reset_session_state(self) -> None: + self._server_capabilities = None + self._supports_tools = True + self._supports_prompts = True + self._supports_logging = None + self._ping_available = True + + def _set_server_capabilities(self, capabilities: types.ServerCapabilities | None) -> None: + self._server_capabilities = capabilities + if capabilities is None: + self._supports_tools = False + self._supports_prompts = False + self._supports_logging = False + return + + self._supports_tools = getattr(capabilities, "tools", None) is not None + self._supports_prompts = getattr(capabilities, "prompts", None) is not None + self._supports_logging = getattr(capabilities, "logging", None) is not None + + async def _reconnect_without_loading(self) -> None: + if self._is_lifecycle_owner_task(): + await self._connect_on_owner(reset=True, load_configured=False) + return + + await self._run_on_lifecycle_owner("connect", reset=True, load_configured=False) + async def connect(self, *, reset: bool = False) -> None: if self._is_lifecycle_owner_task(): await self._connect_on_owner(reset=reset) @@ -664,7 +701,7 @@ async def connect(self, *, reset: bool = False) -> None: async with self._lifecycle_request_lock: await self._run_on_lifecycle_owner("connect", reset=reset) - async def _connect_on_owner(self, *, reset: bool = False) -> None: + async def _connect_on_owner(self, *, reset: bool = False, load_configured: bool = True) -> None: """Connect to the MCP server. Establishes a connection to the MCP server, initializes the session, @@ -672,6 +709,7 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None: Keyword Args: reset: If True, forces a reconnection even if already connected. + load_configured: If True, loads tools and prompts according to the constructor flags. Raises: ToolException: If connection or session initialization fails. @@ -680,6 +718,7 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None: await self._safe_close_exit_stack() self.session = None self.is_connected = False + self._reset_session_state() self._exit_stack = AsyncExitStack() if not self.session: try: @@ -741,7 +780,8 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None: inner_exception=ex if isinstance(ex, Exception) else None, ) from ex try: - await session.initialize() + initialize_result = await session.initialize() + self._set_server_capabilities(getattr(initialize_result, "capabilities", None)) except (Exception, asyncio.CancelledError) as ex: if await self._close_and_check_cancelled(ex): raise @@ -759,17 +799,22 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None: self.session = session elif self.session._request_id == 0: # type: ignore[attr-defined] # If the session is not initialized, we need to reinitialize it - await self.session.initialize() + initialize_result = await self.session.initialize() + self._set_server_capabilities(getattr(initialize_result, "capabilities", None)) + elif self._server_capabilities is None: + self._set_server_capabilities(getattr(self.session, "_server_capabilities", None)) logger.debug("Connected to MCP server: %s", self.session) self.is_connected = True - if self.load_tools_flag: - await self.load_tools() + if load_configured and self.load_tools_flag: + if self._supports_tools: + await self.load_tools() self._tools_loaded = True - if self.load_prompts_flag: - await self.load_prompts() + if load_configured and self.load_prompts_flag: + if self._supports_prompts: + await self.load_prompts() self._prompts_loaded = True - if logger.level != logging.NOTSET: + if logger.level != logging.NOTSET and self._supports_logging is not False: try: level_name = cast( Any, next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level) @@ -973,17 +1018,49 @@ async def load_prompts(self) -> None: Raises: ToolExecutionException: If the MCP server is not connected. """ + from anyio import ClosedResourceError from mcp import types + if not self._supports_prompts: + logger.debug("Skipping MCP prompt loading because the server did not advertise prompts support.") + return + # Track existing function names to prevent duplicates existing_names = {func.name for func in self._functions} params: types.PaginatedRequestParams | None = None while True: - # Ensure connection is still valid before each page request - await self._ensure_connected() + prompt_list: types.ListPromptsResult | None = None + for attempt in range(2): + try: + # Ensure connection is still valid before each page request + await self._ensure_connected() + if not self._supports_prompts: + logger.debug( + "Skipping MCP prompt loading because the server did not advertise prompts support." + ) + return + prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr] + break + except ClosedResourceError as cl_ex: + if attempt == 0: + logger.info("MCP connection closed unexpectedly while loading prompts. Reconnecting...") + try: + await self._reconnect_without_loading() + except Exception as reconn_ex: + raise ToolExecutionException( + "Failed to reconnect to MCP server.", + inner_exception=reconn_ex, + ) from reconn_ex + continue + logger.error("MCP connection closed unexpectedly after reconnection: %s", cl_ex) + raise ToolExecutionException( + "Failed to load prompts - connection lost.", + inner_exception=cl_ex, + ) from cl_ex - prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr] + if prompt_list is None: + raise ToolExecutionException("Failed to load prompts.") for prompt in prompt_list.prompts: normalized_name = _normalize_mcp_name(prompt.name) @@ -1010,7 +1087,7 @@ async def load_prompts(self) -> None: existing_names.add(local_name) # Check if there are more pages - if not prompt_list or not prompt_list.nextCursor: + if not prompt_list.nextCursor: break params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor) @@ -1023,18 +1100,48 @@ async def load_tools(self) -> None: Raises: ToolExecutionException: If the MCP server is not connected. """ + from anyio import ClosedResourceError from mcp import types + if not self._supports_tools: + logger.debug("Skipping MCP tool loading because the server did not advertise tools support.") + return + # Track existing function names to prevent duplicates existing_names = {func.name for func in self._functions} self._tool_call_meta_by_name.clear() params: types.PaginatedRequestParams | None = None while True: - # Ensure connection is still valid before each page request - await self._ensure_connected() + tool_list: types.ListToolsResult | None = None + for attempt in range(2): + try: + # Ensure connection is still valid before each page request + await self._ensure_connected() + if not self._supports_tools: + logger.debug("Skipping MCP tool loading because the server did not advertise tools support.") + return + tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr] + break + except ClosedResourceError as cl_ex: + if attempt == 0: + logger.info("MCP connection closed unexpectedly while loading tools. Reconnecting...") + try: + await self._reconnect_without_loading() + except Exception as reconn_ex: + raise ToolExecutionException( + "Failed to reconnect to MCP server.", + inner_exception=reconn_ex, + ) from reconn_ex + continue + logger.error("MCP connection closed unexpectedly after reconnection: %s", cl_ex) + raise ToolExecutionException( + "Failed to load tools - connection lost.", + inner_exception=cl_ex, + ) from cl_ex - tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr] + if tool_list is None: + raise ToolExecutionException("Failed to load tools.") for tool in tool_list.tools: if tool.meta is not None: @@ -1083,7 +1190,7 @@ async def _call_tool_with_runtime_kwargs( existing_names.add(local_name) # Check if there are more pages - if not tool_list or not tool_list.nextCursor: + if not tool_list.nextCursor: break params = types.PaginatedRequestParams(cursor=tool_list.nextCursor) @@ -1100,6 +1207,7 @@ async def _close_on_owner(self) -> None: self._exit_stack = AsyncExitStack() self.session = None self.is_connected = False + self._reset_session_state() async def close(self) -> None: """Disconnect from the MCP server. @@ -1131,12 +1239,30 @@ async def _ensure_connected(self) -> None: Raises: ToolExecutionException: If reconnection fails. """ + from mcp.shared.exceptions import McpError + + if not self._ping_available: + return + try: await self.session.send_ping() # type: ignore[union-attr] + except McpError as mcp_exc: + if mcp_exc.error.code == -32601: + self._ping_available = False + logger.debug("Skipping future MCP pings because the server does not support ping.") + return + logger.info("MCP connection invalid or closed. Reconnecting...") + try: + await self._reconnect_without_loading() + except Exception as ex: + raise ToolExecutionException( + "Failed to establish MCP connection.", + inner_exception=ex, + ) from ex except Exception: logger.info("MCP connection invalid or closed. Reconnecting...") try: - await self.connect(reset=True) + await self._reconnect_without_loading() except Exception as ex: raise ToolExecutionException( "Failed to establish MCP connection.", diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 0fc5867d79..817aab2fff 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -4031,14 +4031,102 @@ async def test_connect_reinitializes_existing_session_and_loads_tools_and_prompt assert tool._prompts_loaded is True +async def test_connect_skips_tools_and_prompts_when_server_does_not_advertise_capabilities() -> None: + tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) + tool.is_connected = True + tool.session = Mock() + tool.session._request_id = 0 + tool.session.initialize = AsyncMock( + return_value=types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ServerCapabilities(), + serverInfo=types.Implementation(name="test", version="1.0"), + ) + ) + tool.session.list_tools = AsyncMock() + tool.session.list_prompts = AsyncMock() + tool.session.set_logging_level = AsyncMock() + + with patch.object(logger, "level", logging.INFO): + await tool._connect_on_owner() + + tool.session.initialize.assert_awaited_once() + tool.session.list_tools.assert_not_called() + tool.session.list_prompts.assert_not_called() + tool.session.set_logging_level.assert_not_called() + assert tool.is_connected is True + assert tool._supports_tools is False + assert tool._supports_prompts is False + assert tool._supports_logging is False + assert tool._tools_loaded is True + assert tool._prompts_loaded is True + + +async def test_connect_treats_missing_capabilities_as_unsupported() -> None: + tool = MCPTool(name="test_tool", load_tools=True, load_prompts=True) + tool.is_connected = True + tool.session = Mock() + tool.session._request_id = 0 + tool.session.initialize = AsyncMock(return_value=Mock(capabilities=None)) + tool.session.list_tools = AsyncMock() + tool.session.list_prompts = AsyncMock() + + with patch.object(logger, "level", logging.NOTSET): + await tool._connect_on_owner() + + tool.session.list_tools.assert_not_called() + tool.session.list_prompts.assert_not_called() + assert tool._supports_tools is False + assert tool._supports_prompts is False + assert tool._supports_logging is False + + +async def test_connect_sets_logging_level_when_server_advertises_logging() -> None: + tool = MCPTool(name="test_tool", load_tools=False, load_prompts=False) + tool.is_connected = True + tool.session = Mock() + tool.session._request_id = 0 + tool.session.initialize = AsyncMock( + return_value=types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ServerCapabilities(logging=types.LoggingCapability()), + serverInfo=types.Implementation(name="test", version="1.0"), + ) + ) + tool.session.set_logging_level = AsyncMock() + + with patch.object(logger, "level", logging.INFO): + await tool._connect_on_owner() + + tool.session.set_logging_level.assert_awaited_once_with("info") + assert tool._supports_logging is True + + +async def test_ensure_connected_skips_future_pings_when_ping_is_not_available() -> None: + tool = MCPTool(name="test_tool") + tool.session = Mock( + send_ping=AsyncMock( + side_effect=McpError(types.ErrorData(code=-32601, message="Method 'ping' is not available.")) + ) + ) + + with patch.object(tool, "_reconnect_without_loading", AsyncMock()) as mock_reconnect: + await tool._ensure_connected() + await tool._ensure_connected() + + tool.session.send_ping.assert_awaited_once() + mock_reconnect.assert_not_awaited() + assert tool._ping_available is False + + async def test_ensure_connected_reconnects_on_failed_ping() -> None: tool = MCPTool(name="test_tool") tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed"))) - with patch.object(tool, "connect", AsyncMock()) as mock_connect: + with patch.object(tool, "_reconnect_without_loading", AsyncMock()) as mock_reconnect: await tool._ensure_connected() - mock_connect.assert_awaited_once_with(reset=True) + mock_reconnect.assert_awaited_once_with() async def test_ensure_connected_wraps_reconnect_failure() -> None: @@ -4046,12 +4134,70 @@ async def test_ensure_connected_wraps_reconnect_failure() -> None: tool.session = Mock(send_ping=AsyncMock(side_effect=RuntimeError("closed"))) with ( - patch.object(tool, "connect", AsyncMock(side_effect=RuntimeError("still closed"))), + patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=RuntimeError("still closed"))), pytest.raises(ToolExecutionException, match="Failed to establish MCP connection"), ): await tool._ensure_connected() +async def test_load_tools_reconnects_on_closed_resource_when_ping_is_unavailable() -> None: + from anyio import ClosedResourceError + + tool = MCPTool(name="test_tool", load_tools=True) + tool._ping_available = False + + first_session = Mock() + first_session.list_tools = AsyncMock(side_effect=ClosedResourceError()) + tool.session = first_session + + page = Mock() + page.tools = [] + page.nextCursor = None + + second_session = Mock() + second_session.list_tools = AsyncMock(return_value=page) + + async def reconnect() -> None: + tool.session = second_session + tool._supports_tools = True + + with patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=reconnect)) as mock_reconnect: + await tool.load_tools() + + first_session.list_tools.assert_awaited_once() + mock_reconnect.assert_awaited_once_with() + second_session.list_tools.assert_awaited_once() + + +async def test_load_prompts_reconnects_on_closed_resource_when_ping_is_unavailable() -> None: + from anyio import ClosedResourceError + + tool = MCPTool(name="test_tool", load_prompts=True) + tool._ping_available = False + + first_session = Mock() + first_session.list_prompts = AsyncMock(side_effect=ClosedResourceError()) + tool.session = first_session + + page = Mock() + page.prompts = [] + page.nextCursor = None + + second_session = Mock() + second_session.list_prompts = AsyncMock(return_value=page) + + async def reconnect() -> None: + tool.session = second_session + tool._supports_prompts = True + + with patch.object(tool, "_reconnect_without_loading", AsyncMock(side_effect=reconnect)) as mock_reconnect: + await tool.load_prompts() + + first_session.list_prompts.assert_awaited_once() + mock_reconnect.assert_awaited_once_with() + second_session.list_prompts.assert_awaited_once() + + async def test_mcp_tool_filters_framework_kwargs(): """Test that call_tool filters out framework-specific kwargs before calling MCP session.