Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 150 additions & 24 deletions python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, asyncio.Future[None]]] | None = None
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
self._lifecycle_owner_task: asyncio.Task[None] | None = None
self.session = session
self.request_timeout = request_timeout
Expand All @@ -265,6 +265,11 @@ def __init__(
self.is_connected: bool = False
self._tools_loaded: bool = False
self._prompts_loaded: bool = False
self._server_capabilities: types.ServerCapabilities | None = None
self._supports_tools: bool = True
self._supports_prompts: bool = True
self._supports_logging: bool | None = None
self._ping_available: bool = True
self._pending_reload_tasks: set[asyncio.Task[None]] = set()

def __str__(self) -> str:
Expand Down Expand Up @@ -566,11 +571,11 @@ async def _run_lifecycle_owner(self) -> None:
stop_error: BaseException | None = None
try:
while True:
action, reset, future = await queue.get()
action, reset, load_configured, future = await queue.get()

try:
if action == "connect":
await self._connect_on_owner(reset=reset)
await self._connect_on_owner(reset=reset, load_configured=load_configured)
elif action == "close":
await self._close_on_owner()
else:
Expand All @@ -595,7 +600,7 @@ async def _run_lifecycle_owner(self) -> None:
finally:
while True:
try:
_, _, future = queue.get_nowait()
_, _, _, future = queue.get_nowait()
except asyncio.QueueEmpty:
break
if not future.done():
Expand All @@ -608,12 +613,18 @@ def _is_lifecycle_owner_task(self) -> bool:
owner_task = self._lifecycle_owner_task
return owner_task is not None and asyncio.current_task() is owner_task

async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) -> None:
async def _run_on_lifecycle_owner(
self,
action: str,
*,
reset: bool = False,
load_configured: bool = True,
) -> None:
await self._ensure_lifecycle_owner()

if self._is_lifecycle_owner_task():
if action == "connect":
await self._connect_on_owner(reset=reset)
await self._connect_on_owner(reset=reset, load_configured=load_configured)
elif action == "close":
await self._close_on_owner()
else:
Expand All @@ -625,7 +636,7 @@ async def _run_on_lifecycle_owner(self, action: str, *, reset: bool = False) ->
raise RuntimeError("MCP lifecycle owner is not available.")

future = asyncio.get_running_loop().create_future()
await queue.put((action, reset, future))
await queue.put((action, reset, load_configured, future))
await future

async def _safe_close_exit_stack(self) -> None:
Expand Down Expand Up @@ -656,6 +667,32 @@ async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
await self._safe_close_exit_stack()
return _should_propagate_cancelled_error(ex)

def _reset_session_state(self) -> None:
self._server_capabilities = None
self._supports_tools = True
self._supports_prompts = True
self._supports_logging = None
self._ping_available = True

def _set_server_capabilities(self, capabilities: types.ServerCapabilities | None) -> None:
self._server_capabilities = capabilities
if capabilities is None:
self._supports_tools = False
self._supports_prompts = False
self._supports_logging = False
return

self._supports_tools = getattr(capabilities, "tools", None) is not None
self._supports_prompts = getattr(capabilities, "prompts", None) is not None
Comment thread
giles17 marked this conversation as resolved.
self._supports_logging = getattr(capabilities, "logging", None) is not None

async def _reconnect_without_loading(self) -> None:
if self._is_lifecycle_owner_task():
await self._connect_on_owner(reset=True, load_configured=False)
return

await self._run_on_lifecycle_owner("connect", reset=True, load_configured=False)

async def connect(self, *, reset: bool = False) -> None:
if self._is_lifecycle_owner_task():
await self._connect_on_owner(reset=reset)
Expand All @@ -664,14 +701,15 @@ async def connect(self, *, reset: bool = False) -> None:
async with self._lifecycle_request_lock:
await self._run_on_lifecycle_owner("connect", reset=reset)

async def _connect_on_owner(self, *, reset: bool = False) -> None:
async def _connect_on_owner(self, *, reset: bool = False, load_configured: bool = True) -> None:
"""Connect to the MCP server.

Establishes a connection to the MCP server, initializes the session,
and loads tools and prompts if configured to do so.

Keyword Args:
reset: If True, forces a reconnection even if already connected.
load_configured: If True, loads tools and prompts according to the constructor flags.

Raises:
ToolException: If connection or session initialization fails.
Expand All @@ -680,6 +718,7 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None:
await self._safe_close_exit_stack()
self.session = None
self.is_connected = False
self._reset_session_state()
self._exit_stack = AsyncExitStack()
if not self.session:
try:
Expand Down Expand Up @@ -741,7 +780,8 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None:
inner_exception=ex if isinstance(ex, Exception) else None,
) from ex
try:
await session.initialize()
initialize_result = await session.initialize()
self._set_server_capabilities(getattr(initialize_result, "capabilities", None))
except (Exception, asyncio.CancelledError) as ex:
if await self._close_and_check_cancelled(ex):
raise
Expand All @@ -759,17 +799,22 @@ async def _connect_on_owner(self, *, reset: bool = False) -> None:
self.session = session
elif self.session._request_id == 0: # type: ignore[attr-defined]
# If the session is not initialized, we need to reinitialize it
await self.session.initialize()
initialize_result = await self.session.initialize()
self._set_server_capabilities(getattr(initialize_result, "capabilities", None))
elif self._server_capabilities is None:
self._set_server_capabilities(getattr(self.session, "_server_capabilities", None))
logger.debug("Connected to MCP server: %s", self.session)
self.is_connected = True
if self.load_tools_flag:
await self.load_tools()
if load_configured and self.load_tools_flag:
if self._supports_tools:
await self.load_tools()
self._tools_loaded = True
if self.load_prompts_flag:
await self.load_prompts()
if load_configured and self.load_prompts_flag:
if self._supports_prompts:
await self.load_prompts()
self._prompts_loaded = True

if logger.level != logging.NOTSET:
if logger.level != logging.NOTSET and self._supports_logging is not False:
try:
level_name = cast(
Any, next(level for level, value in LOG_LEVEL_MAPPING.items() if value == logger.level)
Expand Down Expand Up @@ -973,17 +1018,49 @@ async def load_prompts(self) -> None:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
from anyio import ClosedResourceError
from mcp import types

if not self._supports_prompts:
logger.debug("Skipping MCP prompt loading because the server did not advertise prompts support.")
return

# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}

params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()
prompt_list: types.ListPromptsResult | None = None
for attempt in range(2):
try:
# Ensure connection is still valid before each page request
await self._ensure_connected()
if not self._supports_prompts:
logger.debug(
"Skipping MCP prompt loading because the server did not advertise prompts support."
)
return
prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
break
except ClosedResourceError as cl_ex:
if attempt == 0:
logger.info("MCP connection closed unexpectedly while loading prompts. Reconnecting...")
try:
await self._reconnect_without_loading()
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
continue
logger.error("MCP connection closed unexpectedly after reconnection: %s", cl_ex)
raise ToolExecutionException(
"Failed to load prompts - connection lost.",
inner_exception=cl_ex,
) from cl_ex

prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr]
if prompt_list is None:
raise ToolExecutionException("Failed to load prompts.")

for prompt in prompt_list.prompts:
normalized_name = _normalize_mcp_name(prompt.name)
Expand All @@ -1010,7 +1087,7 @@ async def load_prompts(self) -> None:
existing_names.add(local_name)

# Check if there are more pages
if not prompt_list or not prompt_list.nextCursor:
if not prompt_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=prompt_list.nextCursor)

Expand All @@ -1023,18 +1100,48 @@ async def load_tools(self) -> None:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
from anyio import ClosedResourceError
from mcp import types

if not self._supports_tools:
logger.debug("Skipping MCP tool loading because the server did not advertise tools support.")
return

# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
self._tool_call_meta_by_name.clear()

params: types.PaginatedRequestParams | None = None
while True:
# Ensure connection is still valid before each page request
await self._ensure_connected()
tool_list: types.ListToolsResult | None = None
for attempt in range(2):
try:
# Ensure connection is still valid before each page request
await self._ensure_connected()
if not self._supports_tools:
logger.debug("Skipping MCP tool loading because the server did not advertise tools support.")
return
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
break
except ClosedResourceError as cl_ex:
if attempt == 0:
logger.info("MCP connection closed unexpectedly while loading tools. Reconnecting...")
try:
await self._reconnect_without_loading()
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
continue
logger.error("MCP connection closed unexpectedly after reconnection: %s", cl_ex)
raise ToolExecutionException(
"Failed to load tools - connection lost.",
inner_exception=cl_ex,
) from cl_ex

tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
if tool_list is None:
raise ToolExecutionException("Failed to load tools.")

for tool in tool_list.tools:
if tool.meta is not None:
Expand Down Expand Up @@ -1083,7 +1190,7 @@ async def _call_tool_with_runtime_kwargs(
existing_names.add(local_name)

# Check if there are more pages
if not tool_list or not tool_list.nextCursor:
if not tool_list.nextCursor:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)

Expand All @@ -1100,6 +1207,7 @@ async def _close_on_owner(self) -> None:
self._exit_stack = AsyncExitStack()
self.session = None
self.is_connected = False
self._reset_session_state()

async def close(self) -> None:
"""Disconnect from the MCP server.
Expand Down Expand Up @@ -1131,12 +1239,30 @@ async def _ensure_connected(self) -> None:
Raises:
ToolExecutionException: If reconnection fails.
"""
from mcp.shared.exceptions import McpError

if not self._ping_available:
return

try:
await self.session.send_ping() # type: ignore[union-attr]
except McpError as mcp_exc:
if mcp_exc.error.code == -32601:
self._ping_available = False
logger.debug("Skipping future MCP pings because the server does not support ping.")
return
logger.info("MCP connection invalid or closed. Reconnecting...")
try:
await self._reconnect_without_loading()
except Exception as ex:
raise ToolExecutionException(
"Failed to establish MCP connection.",
inner_exception=ex,
) from ex
except Exception:
logger.info("MCP connection invalid or closed. Reconnecting...")
try:
await self.connect(reset=True)
await self._reconnect_without_loading()
except Exception as ex:
raise ToolExecutionException(
"Failed to establish MCP connection.",
Expand Down
Loading
Loading