diff --git a/src/fastmcp/client/client.py b/src/fastmcp/client/client.py index a0f12139c5..4b8488280f 100644 --- a/src/fastmcp/client/client.py +++ b/src/fastmcp/client/client.py @@ -6,7 +6,8 @@ import secrets import uuid import weakref -from contextlib import AsyncExitStack, asynccontextmanager +from collections.abc import Coroutine +from contextlib import AsyncExitStack, asynccontextmanager, suppress from dataclasses import dataclass, field from pathlib import Path from typing import Any, Generic, Literal, TypeVar, cast, overload @@ -94,6 +95,7 @@ logger = get_logger(__name__) T = TypeVar("T", bound="ClientTransport") +ResultT = TypeVar("ResultT") @dataclass @@ -655,6 +657,69 @@ async def _session_runner(self): # Ensure ready event is set even if context manager entry fails self._session_state.ready_event.set() + async def _await_with_session_monitoring( + self, coro: Coroutine[Any, Any, ResultT] + ) -> ResultT: + """Await a coroutine while monitoring the session task for errors. + + When using HTTP transports, server errors (4xx/5xx) are raised in the + background session task, not in the coroutine waiting for a response. + This causes the client to hang indefinitely since the response never + arrives. This method monitors the session task and propagates any + exceptions that occur, preventing the client from hanging. + + Args: + coro: The coroutine to await (typically a session method call) + + Returns: + The result of the coroutine + + Raises: + The exception from the session task if it fails, or RuntimeError + if the session task completes unexpectedly without an exception. + """ + session_task = self._session_state.session_task + + # If no session task, just await directly + if session_task is None: + return await coro + + # If session task already failed, raise immediately + if session_task.done(): + exc = session_task.exception() + if exc: + raise exc + raise RuntimeError("Session task completed unexpectedly") + + # Create task for our call + call_task = asyncio.create_task(coro) + + try: + done, _ = await asyncio.wait( + {call_task, session_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + if session_task in done: + # Session task completed (likely errored) before our call finished + call_task.cancel() + with anyio.CancelScope(shield=True), suppress(asyncio.CancelledError): + await call_task + + # Raise the session task exception + exc = session_task.exception() + if exc: + raise exc + raise RuntimeError("Session task completed unexpectedly") + + # Our call completed first - get the result + return call_task.result() + except asyncio.CancelledError: + call_task.cancel() + with anyio.CancelScope(shield=True), suppress(asyncio.CancelledError): + await call_task + raise + def _handle_task_status_notification( self, notification: TaskStatusNotification ) -> None: @@ -685,7 +750,7 @@ async def close(self): async def ping(self) -> bool: """Send a ping request.""" - result = await self.session.send_ping() + result = await self._await_with_session_monitoring(self.session.send_ping()) return isinstance(result, mcp.types.EmptyResult) async def cancel( @@ -719,7 +784,7 @@ async def progress( async def set_logging_level(self, level: mcp.types.LoggingLevel) -> None: """Send a logging/setLevel request.""" - await self.session.set_logging_level(level) + await self._await_with_session_monitoring(self.session.set_logging_level(level)) async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" @@ -740,7 +805,9 @@ async def list_resources_mcp(self) -> mcp.types.ListResourcesResult: """ logger.debug(f"[{self.name}] called list_resources") - result = await self.session.list_resources() + result = await self._await_with_session_monitoring( + self.session.list_resources() + ) return result async def list_resources(self) -> list[mcp.types.Resource]: @@ -771,7 +838,9 @@ async def list_resource_templates_mcp( """ logger.debug(f"[{self.name}] called list_resource_templates") - result = await self.session.list_resource_templates() + result = await self._await_with_session_monitoring( + self.session.list_resource_templates() + ) return result async def list_resource_templates( @@ -822,12 +891,16 @@ async def read_resource_mcp( else None, # SEP-1686: task as direct param (spec-compliant) ) ) - result = await self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=mcp.types.ReadResourceResult, + result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=mcp.types.ReadResourceResult, + ) ) else: - result = await self.session.read_resource(uri) + result = await self._await_with_session_monitoring( + self.session.read_resource(uri) + ) return result @overload @@ -921,9 +994,11 @@ async def _read_resource_as_task( TaskResponseUnion = RootModel[ mcp.types.CreateTaskResult | mcp.types.ReadResourceResult ] - wrapped_result = await self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=TaskResponseUnion, # type: ignore[arg-type] + wrapped_result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=TaskResponseUnion, # type: ignore[arg-type] + ) ) raw_result = wrapped_result.root @@ -974,7 +1049,7 @@ async def list_prompts_mcp(self) -> mcp.types.ListPromptsResult: """ logger.debug(f"[{self.name}] called list_prompts") - result = await self.session.list_prompts() + result = await self._await_with_session_monitoring(self.session.list_prompts()) return result async def list_prompts(self) -> list[mcp.types.Prompt]: @@ -1039,13 +1114,15 @@ async def get_prompt_mcp( else None, # SEP-1686: task as direct param (spec-compliant) ) ) - result = await self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=mcp.types.GetPromptResult, + result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=mcp.types.GetPromptResult, + ) ) else: - result = await self.session.get_prompt( - name=name, arguments=serialized_arguments + result = await self._await_with_session_monitoring( + self.session.get_prompt(name=name, arguments=serialized_arguments) ) return result @@ -1146,9 +1223,11 @@ async def _get_prompt_as_task( TaskResponseUnion = RootModel[ mcp.types.CreateTaskResult | mcp.types.GetPromptResult ] - wrapped_result = await self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=TaskResponseUnion, # type: ignore[arg-type] + wrapped_result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=TaskResponseUnion, # type: ignore[arg-type] + ) ) raw_result = wrapped_result.root @@ -1195,8 +1274,10 @@ async def complete_mcp( """ logger.debug(f"[{self.name}] called complete: {ref}") - result = await self.session.complete( - ref=ref, argument=argument, context_arguments=context_arguments + result = await self._await_with_session_monitoring( + self.session.complete( + ref=ref, argument=argument, context_arguments=context_arguments + ) ) return result @@ -1241,7 +1322,7 @@ async def list_tools_mcp(self) -> mcp.types.ListToolsResult: """ logger.debug(f"[{self.name}] called list_tools") - result = await self.session.list_tools() + result = await self._await_with_session_monitoring(self.session.list_tools()) return result async def list_tools(self) -> list[mcp.types.Tool]: @@ -1296,12 +1377,14 @@ async def call_tool_mcp( if isinstance(timeout, int | float): timeout = datetime.timedelta(seconds=float(timeout)) - result = await self.session.call_tool( - name=name, - arguments=arguments, - read_timeout_seconds=timeout, # ty: ignore[invalid-argument-type] - progress_callback=progress_handler or self._progress_handler, - meta=meta, + result = await self._await_with_session_monitoring( + self.session.call_tool( + name=name, + arguments=arguments, + read_timeout_seconds=timeout, # ty: ignore[invalid-argument-type] + progress_callback=progress_handler or self._progress_handler, + meta=meta, + ) ) return result @@ -1476,9 +1559,11 @@ async def _call_tool_as_task( TaskResponseUnion = RootModel[ mcp.types.CreateTaskResult | mcp.types.CallToolResult ] - wrapped_result = await self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=TaskResponseUnion, # type: ignore[arg-type] + wrapped_result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=TaskResponseUnion, # type: ignore[arg-type] + ) ) raw_result = wrapped_result.root @@ -1516,9 +1601,11 @@ async def get_task_status(self, task_id: str) -> GetTaskResult: McpError: If the request results in a TimeoutError | JSONRPCError """ request = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id)) - return await self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=GetTaskResult, # type: ignore[arg-type] + return await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=GetTaskResult, # type: ignore[arg-type] + ) ) async def get_task_result(self, task_id: str) -> Any: @@ -1541,9 +1628,11 @@ async def get_task_result(self, task_id: str) -> Any: params=GetTaskPayloadRequestParams(taskId=task_id) ) # Return raw result - Task classes handle type-specific parsing - result = await self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=GetTaskPayloadResult, # type: ignore[arg-type] + result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=GetTaskPayloadResult, # type: ignore[arg-type] + ) ) # Return as dict for compatibility with Task class parsing return result.model_dump(exclude_none=True, by_alias=True) @@ -1575,9 +1664,11 @@ async def list_tasks( # Send protocol request params = PaginatedRequestParams(cursor=cursor, limit=limit) # type: ignore[call-arg] # Optional field in MCP SDK request = ListTasksRequest(params=params) - server_response = await self.session.send_request( - request=request, # type: ignore[invalid-argument-type] - result_type=mcp.types.ListTasksResult, + server_response = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[invalid-argument-type] + result_type=mcp.types.ListTasksResult, + ) ) # If server returned tasks, use those @@ -1613,9 +1704,11 @@ async def cancel_task(self, task_id: str) -> mcp.types.CancelTaskResult: McpError: If the request results in a TimeoutError | JSONRPCError """ request = CancelTaskRequest(params=CancelTaskRequestParams(taskId=task_id)) - return await self.session.send_request( - request=request, # type: ignore[invalid-argument-type] - result_type=mcp.types.CancelTaskResult, + return await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[invalid-argument-type] + result_type=mcp.types.CancelTaskResult, + ) ) @classmethod diff --git a/tests/client/test_client.py b/tests/client/test_client.py index bb1f53da59..b0676561a0 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1307,3 +1307,133 @@ async def test_manual_initialize_can_call_tools(self, fastmcp_server): # Should be able to call tools after manual initialization result = await client.call_tool("greet", {"name": "World"}) assert "Hello, World!" in str(result.content) + + +class TestSessionTaskErrorPropagation: + """Tests for ensuring session task errors propagate to client calls. + + Regression tests for https://github.com/jlowin/fastmcp/issues/2595 + where the client would hang indefinitely when the session task failed + (e.g., due to HTTP 4xx/5xx errors) instead of raising an exception. + """ + + async def test_session_task_error_propagates_to_call(self, fastmcp_server): + """Test that errors in session task propagate to pending client calls. + + When the session task fails (e.g., due to HTTP errors), pending + client operations should immediately receive the exception rather + than hanging indefinitely. + """ + client = Client(fastmcp_server) + + async with client: + original_task = client._session_state.session_task + assert original_task is not None + + async def never_complete(): + """A coroutine that will never complete normally.""" + await asyncio.sleep(1000) + + async def failing_session(): + """Simulates a session task that raises an error.""" + raise ValueError("Simulated HTTP error") + + # Replace session_task with one that will fail + client._session_state.session_task = asyncio.create_task(failing_session()) + + # The monitoring should detect the session task failure + with pytest.raises(ValueError, match="Simulated HTTP error"): + await client._await_with_session_monitoring(never_complete()) + + # Restore original task for cleanup + client._session_state.session_task = original_task + + async def test_session_task_already_done_with_error(self, fastmcp_server): + """Test that if session task is already done with error, calls fail immediately.""" + client = Client(fastmcp_server) + + async with client: + original_task = client._session_state.session_task + + async def raise_error(): + raise ValueError("Session failed") + + # Replace session_task with one that has already failed + failed_task = asyncio.create_task(raise_error()) + try: + await failed_task + except ValueError: + pass # Expected + client._session_state.session_task = failed_task + + # New calls should fail immediately with the original error + async def simple_coro(): + return "should not reach" + + with pytest.raises(ValueError, match="Session failed"): + await client._await_with_session_monitoring(simple_coro()) + + # Restore original task for cleanup + client._session_state.session_task = original_task + + async def test_session_task_already_done_no_error_raises_runtime_error( + self, fastmcp_server + ): + """Test that if session task completes without error, raises RuntimeError.""" + client = Client(fastmcp_server) + + async with client: + original_task = client._session_state.session_task + + # Create a task that completes normally (unexpected for session task) + completed_task = asyncio.create_task(asyncio.sleep(0)) + await completed_task + client._session_state.session_task = completed_task + + async def simple_coro(): + return "should not reach" + + with pytest.raises( + RuntimeError, match="Session task completed unexpectedly" + ): + await client._await_with_session_monitoring(simple_coro()) + + # Restore original task for cleanup + client._session_state.session_task = original_task + + async def test_normal_operation_unaffected(self, fastmcp_server): + """Test that normal operation is unaffected by the monitoring.""" + client = Client(fastmcp_server) + + async with client: + # These should all work normally + tools = await client.list_tools() + assert len(tools) > 0 + + result = await client.call_tool("greet", {"name": "Test"}) + assert "Hello, Test!" in str(result.content) + + resources = await client.list_resources() + assert len(resources) > 0 + + prompts = await client.list_prompts() + assert len(prompts) > 0 + + async def test_no_session_task_falls_back_to_direct_await(self, fastmcp_server): + """Test that when no session task exists, it falls back to direct await.""" + client = Client(fastmcp_server) + + async with client: + # Temporarily remove session_task to test fallback + original_task = client._session_state.session_task + client._session_state.session_task = None + + # Should work via direct await + async def simple_coro(): + return "success" + + result = await client._await_with_session_monitoring(simple_coro()) + assert result == "success" + + # Restore for cleanup + client._session_state.session_task = original_task