Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 138 additions & 45 deletions src/fastmcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import secrets
import uuid
import weakref
from contextlib import AsyncExitStack, asynccontextmanager
from collections.abc import Coroutine
from contextlib import AsyncExitStack, asynccontextmanager, suppress
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Generic, Literal, TypeVar, cast, overload
Expand Down Expand Up @@ -94,6 +95,7 @@
logger = get_logger(__name__)

T = TypeVar("T", bound="ClientTransport")
ResultT = TypeVar("ResultT")


@dataclass
Expand Down Expand Up @@ -655,6 +657,69 @@ async def _session_runner(self):
# Ensure ready event is set even if context manager entry fails
self._session_state.ready_event.set()

async def _await_with_session_monitoring(
self, coro: Coroutine[Any, Any, ResultT]
) -> ResultT:
"""Await a coroutine while monitoring the session task for errors.

When using HTTP transports, server errors (4xx/5xx) are raised in the
background session task, not in the coroutine waiting for a response.
This causes the client to hang indefinitely since the response never
arrives. This method monitors the session task and propagates any
exceptions that occur, preventing the client from hanging.

Args:
coro: The coroutine to await (typically a session method call)

Returns:
The result of the coroutine

Raises:
The exception from the session task if it fails, or RuntimeError
if the session task completes unexpectedly without an exception.
"""
session_task = self._session_state.session_task

# If no session task, just await directly
if session_task is None:
return await coro

# If session task already failed, raise immediately
if session_task.done():
exc = session_task.exception()
if exc:
raise exc
raise RuntimeError("Session task completed unexpectedly")

# Create task for our call
call_task = asyncio.create_task(coro)

try:
done, _ = await asyncio.wait(
{call_task, session_task},
return_when=asyncio.FIRST_COMPLETED,
)

if session_task in done:
# Session task completed (likely errored) before our call finished
call_task.cancel()
with anyio.CancelScope(shield=True), suppress(asyncio.CancelledError):
await call_task

# Raise the session task exception
exc = session_task.exception()
if exc:
raise exc
raise RuntimeError("Session task completed unexpectedly")

# Our call completed first - get the result
return call_task.result()
except asyncio.CancelledError:
call_task.cancel()
with anyio.CancelScope(shield=True), suppress(asyncio.CancelledError):
await call_task
raise

def _handle_task_status_notification(
self, notification: TaskStatusNotification
) -> None:
Expand Down Expand Up @@ -685,7 +750,7 @@ async def close(self):

async def ping(self) -> bool:
"""Send a ping request."""
result = await self.session.send_ping()
result = await self._await_with_session_monitoring(self.session.send_ping())
return isinstance(result, mcp.types.EmptyResult)

async def cancel(
Expand Down Expand Up @@ -719,7 +784,7 @@ async def progress(

async def set_logging_level(self, level: mcp.types.LoggingLevel) -> None:
"""Send a logging/setLevel request."""
await self.session.set_logging_level(level)
await self._await_with_session_monitoring(self.session.set_logging_level(level))

async def send_roots_list_changed(self) -> None:
"""Send a roots/list_changed notification."""
Expand All @@ -740,7 +805,9 @@ async def list_resources_mcp(self) -> mcp.types.ListResourcesResult:
"""
logger.debug(f"[{self.name}] called list_resources")

result = await self.session.list_resources()
result = await self._await_with_session_monitoring(
self.session.list_resources()
)
return result

async def list_resources(self) -> list[mcp.types.Resource]:
Expand Down Expand Up @@ -771,7 +838,9 @@ async def list_resource_templates_mcp(
"""
logger.debug(f"[{self.name}] called list_resource_templates")

result = await self.session.list_resource_templates()
result = await self._await_with_session_monitoring(
self.session.list_resource_templates()
)
return result

async def list_resource_templates(
Expand Down Expand Up @@ -822,12 +891,16 @@ async def read_resource_mcp(
else None, # SEP-1686: task as direct param (spec-compliant)
)
)
result = await self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=mcp.types.ReadResourceResult,
result = await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=mcp.types.ReadResourceResult,
)
)
else:
result = await self.session.read_resource(uri)
result = await self._await_with_session_monitoring(
self.session.read_resource(uri)
)
return result

@overload
Expand Down Expand Up @@ -921,9 +994,11 @@ async def _read_resource_as_task(
TaskResponseUnion = RootModel[
mcp.types.CreateTaskResult | mcp.types.ReadResourceResult
]
wrapped_result = await self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=TaskResponseUnion, # type: ignore[arg-type]
wrapped_result = await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=TaskResponseUnion, # type: ignore[arg-type]
)
)
raw_result = wrapped_result.root

Expand Down Expand Up @@ -974,7 +1049,7 @@ async def list_prompts_mcp(self) -> mcp.types.ListPromptsResult:
"""
logger.debug(f"[{self.name}] called list_prompts")

result = await self.session.list_prompts()
result = await self._await_with_session_monitoring(self.session.list_prompts())
return result

async def list_prompts(self) -> list[mcp.types.Prompt]:
Expand Down Expand Up @@ -1039,13 +1114,15 @@ async def get_prompt_mcp(
else None, # SEP-1686: task as direct param (spec-compliant)
)
)
result = await self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=mcp.types.GetPromptResult,
result = await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=mcp.types.GetPromptResult,
)
)
else:
result = await self.session.get_prompt(
name=name, arguments=serialized_arguments
result = await self._await_with_session_monitoring(
self.session.get_prompt(name=name, arguments=serialized_arguments)
)
return result

Expand Down Expand Up @@ -1146,9 +1223,11 @@ async def _get_prompt_as_task(
TaskResponseUnion = RootModel[
mcp.types.CreateTaskResult | mcp.types.GetPromptResult
]
wrapped_result = await self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=TaskResponseUnion, # type: ignore[arg-type]
wrapped_result = await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=TaskResponseUnion, # type: ignore[arg-type]
)
)
raw_result = wrapped_result.root

Expand Down Expand Up @@ -1195,8 +1274,10 @@ async def complete_mcp(
"""
logger.debug(f"[{self.name}] called complete: {ref}")

result = await self.session.complete(
ref=ref, argument=argument, context_arguments=context_arguments
result = await self._await_with_session_monitoring(
self.session.complete(
ref=ref, argument=argument, context_arguments=context_arguments
)
)
return result

Expand Down Expand Up @@ -1241,7 +1322,7 @@ async def list_tools_mcp(self) -> mcp.types.ListToolsResult:
"""
logger.debug(f"[{self.name}] called list_tools")

result = await self.session.list_tools()
result = await self._await_with_session_monitoring(self.session.list_tools())
return result

async def list_tools(self) -> list[mcp.types.Tool]:
Expand Down Expand Up @@ -1296,12 +1377,14 @@ async def call_tool_mcp(
if isinstance(timeout, int | float):
timeout = datetime.timedelta(seconds=float(timeout))

result = await self.session.call_tool(
name=name,
arguments=arguments,
read_timeout_seconds=timeout, # ty: ignore[invalid-argument-type]
progress_callback=progress_handler or self._progress_handler,
meta=meta,
result = await self._await_with_session_monitoring(
self.session.call_tool(
name=name,
arguments=arguments,
read_timeout_seconds=timeout, # ty: ignore[invalid-argument-type]
progress_callback=progress_handler or self._progress_handler,
meta=meta,
)
)
return result

Expand Down Expand Up @@ -1476,9 +1559,11 @@ async def _call_tool_as_task(
TaskResponseUnion = RootModel[
mcp.types.CreateTaskResult | mcp.types.CallToolResult
]
wrapped_result = await self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=TaskResponseUnion, # type: ignore[arg-type]
wrapped_result = await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=TaskResponseUnion, # type: ignore[arg-type]
)
)
raw_result = wrapped_result.root

Expand Down Expand Up @@ -1516,9 +1601,11 @@ async def get_task_status(self, task_id: str) -> GetTaskResult:
McpError: If the request results in a TimeoutError | JSONRPCError
"""
request = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))
return await self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=GetTaskResult, # type: ignore[arg-type]
return await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=GetTaskResult, # type: ignore[arg-type]
)
)

async def get_task_result(self, task_id: str) -> Any:
Expand All @@ -1541,9 +1628,11 @@ async def get_task_result(self, task_id: str) -> Any:
params=GetTaskPayloadRequestParams(taskId=task_id)
)
# Return raw result - Task classes handle type-specific parsing
result = await self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=GetTaskPayloadResult, # type: ignore[arg-type]
result = await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[arg-type]
result_type=GetTaskPayloadResult, # type: ignore[arg-type]
)
)
# Return as dict for compatibility with Task class parsing
return result.model_dump(exclude_none=True, by_alias=True)
Expand Down Expand Up @@ -1575,9 +1664,11 @@ async def list_tasks(
# Send protocol request
params = PaginatedRequestParams(cursor=cursor, limit=limit) # type: ignore[call-arg] # Optional field in MCP SDK
request = ListTasksRequest(params=params)
server_response = await self.session.send_request(
request=request, # type: ignore[invalid-argument-type]
result_type=mcp.types.ListTasksResult,
server_response = await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[invalid-argument-type]
result_type=mcp.types.ListTasksResult,
)
)

# If server returned tasks, use those
Expand Down Expand Up @@ -1613,9 +1704,11 @@ async def cancel_task(self, task_id: str) -> mcp.types.CancelTaskResult:
McpError: If the request results in a TimeoutError | JSONRPCError
"""
request = CancelTaskRequest(params=CancelTaskRequestParams(taskId=task_id))
return await self.session.send_request(
request=request, # type: ignore[invalid-argument-type]
result_type=mcp.types.CancelTaskResult,
return await self._await_with_session_monitoring(
self.session.send_request(
request=request, # type: ignore[invalid-argument-type]
result_type=mcp.types.CancelTaskResult,
)
)

@classmethod
Expand Down
Loading
Loading