diff --git a/docs/python-sdk/fastmcp-server-providers-proxy.mdx b/docs/python-sdk/fastmcp-server-providers-proxy.mdx index 803adb9735..50c105bcf2 100644 --- a/docs/python-sdk/fastmcp-server-providers-proxy.mdx +++ b/docs/python-sdk/fastmcp-server-providers-proxy.mdx @@ -15,7 +15,7 @@ classes that forward execution to remote servers. ## Functions -### `default_proxy_roots_handler` +### `default_proxy_roots_handler` ```python default_proxy_roots_handler(context: RequestContext[ClientSession, LifespanContextT]) -> RootsList @@ -25,7 +25,7 @@ default_proxy_roots_handler(context: RequestContext[ClientSession, LifespanConte Forward list roots request from remote server to proxy's connected clients. -### `default_proxy_sampling_handler` +### `default_proxy_sampling_handler` ```python default_proxy_sampling_handler(messages: list[mcp.types.SamplingMessage], params: mcp.types.CreateMessageRequestParams, context: RequestContext[ClientSession, LifespanContextT]) -> mcp.types.CreateMessageResult @@ -35,7 +35,7 @@ default_proxy_sampling_handler(messages: list[mcp.types.SamplingMessage], params Forward sampling request from remote server to proxy's connected clients. -### `default_proxy_elicitation_handler` +### `default_proxy_elicitation_handler` ```python default_proxy_elicitation_handler(message: str, response_type: type, params: mcp.types.ElicitRequestParams, context: RequestContext[ClientSession, LifespanContextT]) -> ElicitResult @@ -45,7 +45,7 @@ default_proxy_elicitation_handler(message: str, response_type: type, params: mcp Forward elicitation request from remote server to proxy's connected clients. -### `default_proxy_log_handler` +### `default_proxy_log_handler` ```python default_proxy_log_handler(message: LogMessage) -> None @@ -55,7 +55,7 @@ default_proxy_log_handler(message: LogMessage) -> None Forward log notification from remote server to proxy's connected clients. -### `default_proxy_progress_handler` +### `default_proxy_progress_handler` ```python default_proxy_progress_handler(progress: float, total: float | None, message: str | None) -> None @@ -67,7 +67,7 @@ Forward progress notification from remote server to proxy's connected clients. ## Classes -### `ProxyTool` +### `ProxyTool` A Tool that represents and executes a tool on a remote server. @@ -75,7 +75,7 @@ A Tool that represents and executes a tool on a remote server. **Methods:** -#### `model_copy` +#### `model_copy` ```python model_copy(self, **kwargs: Any) -> ProxyTool @@ -84,7 +84,7 @@ model_copy(self, **kwargs: Any) -> ProxyTool Override to preserve _backend_name when name changes. -#### `from_mcp_tool` +#### `from_mcp_tool` ```python from_mcp_tool(cls, client_factory: ClientFactoryT, mcp_tool: mcp.types.Tool) -> ProxyTool @@ -93,7 +93,7 @@ from_mcp_tool(cls, client_factory: ClientFactoryT, mcp_tool: mcp.types.Tool) -> Factory method to create a ProxyTool from a raw MCP tool schema. -#### `run` +#### `run` ```python run(self, arguments: dict[str, Any], context: Context | None = None) -> ToolResult @@ -102,13 +102,13 @@ run(self, arguments: dict[str, Any], context: Context | None = None) -> ToolResu Executes the tool by making a call through the client. -#### `get_span_attributes` +#### `get_span_attributes` ```python get_span_attributes(self) -> dict[str, Any] ``` -### `ProxyResource` +### `ProxyResource` A Resource that represents and reads a resource from a remote server. @@ -116,7 +116,7 @@ A Resource that represents and reads a resource from a remote server. **Methods:** -#### `model_copy` +#### `model_copy` ```python model_copy(self, **kwargs: Any) -> ProxyResource @@ -125,7 +125,7 @@ model_copy(self, **kwargs: Any) -> ProxyResource Override to preserve _backend_uri when uri changes. -#### `from_mcp_resource` +#### `from_mcp_resource` ```python from_mcp_resource(cls, client_factory: ClientFactoryT, mcp_resource: mcp.types.Resource) -> ProxyResource @@ -134,7 +134,7 @@ from_mcp_resource(cls, client_factory: ClientFactoryT, mcp_resource: mcp.types.R Factory method to create a ProxyResource from a raw MCP resource schema. -#### `read` +#### `read` ```python read(self) -> ResourceResult @@ -143,13 +143,13 @@ read(self) -> ResourceResult Read the resource content from the remote server. -#### `get_span_attributes` +#### `get_span_attributes` ```python get_span_attributes(self) -> dict[str, Any] ``` -### `ProxyTemplate` +### `ProxyTemplate` A ResourceTemplate that represents and creates resources from a remote server template. @@ -157,7 +157,7 @@ A ResourceTemplate that represents and creates resources from a remote server te **Methods:** -#### `model_copy` +#### `model_copy` ```python model_copy(self, **kwargs: Any) -> ProxyTemplate @@ -166,7 +166,7 @@ model_copy(self, **kwargs: Any) -> ProxyTemplate Override to preserve _backend_uri_template when uri_template changes. -#### `from_mcp_template` +#### `from_mcp_template` ```python from_mcp_template(cls, client_factory: ClientFactoryT, mcp_template: mcp.types.ResourceTemplate) -> ProxyTemplate @@ -175,7 +175,7 @@ from_mcp_template(cls, client_factory: ClientFactoryT, mcp_template: mcp.types.R Factory method to create a ProxyTemplate from a raw MCP template schema. -#### `create_resource` +#### `create_resource` ```python create_resource(self, uri: str, params: dict[str, Any], context: Context | None = None) -> ProxyResource @@ -184,13 +184,13 @@ create_resource(self, uri: str, params: dict[str, Any], context: Context | None Create a resource from the template by calling the remote server. -#### `get_span_attributes` +#### `get_span_attributes` ```python get_span_attributes(self) -> dict[str, Any] ``` -### `ProxyPrompt` +### `ProxyPrompt` A Prompt that represents and renders a prompt from a remote server. @@ -198,7 +198,7 @@ A Prompt that represents and renders a prompt from a remote server. **Methods:** -#### `model_copy` +#### `model_copy` ```python model_copy(self, **kwargs: Any) -> ProxyPrompt @@ -207,7 +207,7 @@ model_copy(self, **kwargs: Any) -> ProxyPrompt Override to preserve _backend_name when name changes. -#### `from_mcp_prompt` +#### `from_mcp_prompt` ```python from_mcp_prompt(cls, client_factory: ClientFactoryT, mcp_prompt: mcp.types.Prompt) -> ProxyPrompt @@ -216,7 +216,7 @@ from_mcp_prompt(cls, client_factory: ClientFactoryT, mcp_prompt: mcp.types.Promp Factory method to create a ProxyPrompt from a raw MCP prompt schema. -#### `render` +#### `render` ```python render(self, arguments: dict[str, Any]) -> PromptResult @@ -225,13 +225,13 @@ render(self, arguments: dict[str, Any]) -> PromptResult Render the prompt by making a call through the client. -#### `get_span_attributes` +#### `get_span_attributes` ```python get_span_attributes(self) -> dict[str, Any] ``` -### `ProxyProvider` +### `ProxyProvider` Provider that proxies to a remote MCP server via a client factory. @@ -245,7 +245,7 @@ because tasks cannot be executed through a proxy. **Methods:** -#### `get_tasks` +#### `get_tasks` ```python get_tasks(self) -> Sequence[FastMCPComponent] @@ -258,7 +258,7 @@ server lifespan initialization, which would open the client before any context is set. All Proxy* components have task_config.mode="forbidden". -### `FastMCPProxy` +### `FastMCPProxy` A FastMCP server that acts as a proxy to a remote MCP-compliant server. @@ -267,7 +267,7 @@ This is a convenience wrapper that creates a FastMCP server with a ProxyProvider. For more control, use FastMCP with add_provider(ProxyProvider(...)). -### `ProxyClient` +### `ProxyClient` A proxy client that forwards advanced interactions between a remote MCP server and the proxy's connected clients. @@ -275,7 +275,7 @@ A proxy client that forwards advanced interactions between a remote MCP server a Supports forwarding roots, sampling, elicitation, logging, and progress. -### `StatefulProxyClient` +### `StatefulProxyClient` A proxy client that provides a stateful client factory for the proxy server. @@ -286,10 +286,17 @@ And it will be disconnected when the session is exited. This is useful to proxy a stateful mcp server such as the Playwright MCP server. Note that it is essential to ensure that the proxy server itself is also stateful. +Because session reuse means the receive-loop task inherits a stale +``request_ctx`` ContextVar snapshot, the default proxy handlers are +replaced with versions that restore the ContextVar before forwarding. +``ProxyTool.run`` stashes the current ``RequestContext`` in +``_proxy_rc_ref`` before each backend call, and the handlers consult +it to detect (and correct) staleness. + **Methods:** -#### `clear` +#### `clear` ```python clear(self) @@ -298,7 +305,7 @@ clear(self) Clear all cached clients and force disconnect them. -#### `new_stateful` +#### `new_stateful` ```python new_stateful(self) -> Client[ClientTransportT] diff --git a/src/fastmcp/server/providers/proxy.py b/src/fastmcp/server/providers/proxy.py index d063373961..0b7ce0096e 100644 --- a/src/fastmcp/server/providers/proxy.py +++ b/src/fastmcp/server/providers/proxy.py @@ -16,6 +16,7 @@ import mcp.types from mcp import ServerSession from mcp.client.session import ClientSession +from mcp.server.lowlevel.server import request_ctx from mcp.shared.context import LifespanContextT, RequestContext from mcp.shared.exceptions import McpError from mcp.types import ( @@ -121,6 +122,12 @@ async def run( client = await self._get_client() async with client: ctx = context or get_context() + # StatefulProxyClient reuses sessions across requests, so + # its receive-loop task has stale ContextVars from the first + # request. Stash the current RequestContext in the shared + # ref so handlers can restore it before forwarding. + if isinstance(client, StatefulProxyClient): + cast(list[Any], client._proxy_rc_ref)[0] = ctx.request_context # Build meta dict from request context meta: dict[str, Any] | None = None if hasattr(ctx, "request_context"): @@ -781,16 +788,50 @@ async def default_proxy_progress_handler( await ctx.report_progress(progress, total, message) +def _restore_request_context( + rc_ref: list[Any], +) -> None: + """Set the ``request_ctx`` ContextVar from a stashed RequestContext. + + Called at the start of proxy handler invocations in + ``StatefulProxyClient`` to fix stale ContextVars in the receive-loop + task. Only overrides when the ContextVar is genuinely stale (same + session, different request_id) to avoid corrupting the concurrent + case where multiple sessions share the same ref via ``copy.copy``. + """ + rc = rc_ref[0] + if rc is None: + return + try: + current_rc = request_ctx.get() + except LookupError: + request_ctx.set(rc) + return + if current_rc.session is rc.session and current_rc.request_id != rc.request_id: + request_ctx.set(rc) + + +def _make_restoring_handler(handler: Callable, rc_ref: list[Any]) -> Callable: + """Wrap a proxy handler to restore request_ctx before delegating. + + The wrapper is a plain ``async def`` so it passes + ``inspect.isfunction()`` checks in handler registration paths + (e.g., ``create_roots_callback``). + """ + + async def wrapper(*args: Any, **kwargs: Any) -> Any: + _restore_request_context(rc_ref) + return await handler(*args, **kwargs) + + return wrapper + + class ProxyClient(Client[ClientTransportT]): """A proxy client that forwards advanced interactions between a remote MCP server and the proxy's connected clients. Supports forwarding roots, sampling, elicitation, logging, and progress. """ - # Stored context for handlers when contextvar isn't available - # (e.g., when receive loop was started before any request context) - _proxy_context: Context | None = None - def __init__( self, transport: ClientTransportT @@ -826,9 +867,39 @@ class StatefulProxyClient(ProxyClient[ClientTransportT]): This is useful to proxy a stateful mcp server such as the Playwright MCP server. Note that it is essential to ensure that the proxy server itself is also stateful. + + Because session reuse means the receive-loop task inherits a stale + ``request_ctx`` ContextVar snapshot, the default proxy handlers are + replaced with versions that restore the ContextVar before forwarding. + ``ProxyTool.run`` stashes the current ``RequestContext`` in + ``_proxy_rc_ref`` before each backend call, and the handlers consult + it to detect (and correct) staleness. """ + # Mutable list shared across copies (Client.new() uses copy.copy, + # which preserves references to mutable containers). ProxyTool.run + # writes [0] before each backend call; handlers read it to detect + # stale ContextVars and restore the correct request_ctx. + # + # We store the concrete RequestContext (not fastmcp's Context) because + # Context properties are themselves ContextVar-dependent and resolve + # in the caller's async context — which is stale in the receive loop. + _proxy_rc_ref: list[Any] + def __init__(self, *args: Any, **kwargs: Any): + # Install context-restoring handler wrappers BEFORE super().__init__ + # registers them with the Client's session kwargs. + self._proxy_rc_ref = [None] + for key, default_fn in ( + ("roots", default_proxy_roots_handler), + ("sampling_handler", default_proxy_sampling_handler), + ("elicitation_handler", default_proxy_elicitation_handler), + ("log_handler", default_proxy_log_handler), + ("progress_handler", default_proxy_progress_handler), + ): + if key not in kwargs: + kwargs[key] = _make_restoring_handler(default_fn, self._proxy_rc_ref) + super().__init__(*args, **kwargs) self._caches: dict[ServerSession, Client[ClientTransportT]] = {} diff --git a/tests/server/providers/proxy/test_stateful_proxy_client.py b/tests/server/providers/proxy/test_stateful_proxy_client.py index 8ae0ce7133..79a280e724 100644 --- a/tests/server/providers/proxy/test_stateful_proxy_client.py +++ b/tests/server/providers/proxy/test_stateful_proxy_client.py @@ -1,15 +1,18 @@ import asyncio +from dataclasses import dataclass import pytest from anyio import create_task_group from mcp.types import LoggingLevel from fastmcp import Client, Context, FastMCP +from fastmcp.client.elicitation import ElicitResult from fastmcp.client.logging import LogMessage from fastmcp.client.transports import FastMCPTransport from fastmcp.exceptions import ToolError +from fastmcp.server.elicitation import AcceptedElicitation from fastmcp.server.providers.proxy import FastMCPProxy, StatefulProxyClient -from fastmcp.utilities.tests import find_available_port +from fastmcp.utilities.tests import find_available_port, run_server_async @pytest.fixture @@ -145,3 +148,54 @@ def tool_b() -> str: result_b = await client.call_tool("b_tool_b", {}) assert result_a.data == "a" assert result_b.data == "b" + + @pytest.mark.timeout(10) + async def test_stateful_proxy_elicitation_over_http(self): + """Elicitation through a stateful proxy over HTTP must not hang. + + When StatefulProxyClient reuses a session, the receive-loop task + inherits a stale request_ctx ContextVar from the first request. + The streamable-HTTP transport uses related_request_id to route + server-initiated messages (like elicitation) back to the correct + HTTP response stream. A stale request_id routes to a closed + stream, causing the elicitation to hang forever. + + This test runs the proxy over HTTP (not in-process) so the + transport's related_request_id routing is exercised. + """ + + @dataclass + class Person: + name: str + + backend = FastMCP("backend") + + @backend.tool + async def ask_name(ctx: Context) -> str: + result = await ctx.elicit("What is your name?", response_type=Person) + if isinstance(result, AcceptedElicitation): + assert isinstance(result.data, Person) + return f"Hello, {result.data.name}!" + return "declined" + + stateful_client = StatefulProxyClient(backend) + proxy = FastMCPProxy( + client_factory=stateful_client.new_stateful, + name="proxy", + ) + + async def elicitation_handler(message, response_type, params, ctx): + return ElicitResult(action="accept", content=response_type(name="Alice")) + + # Run the proxy over HTTP so the transport uses + # related_request_id routing for server-initiated messages. + async with run_server_async(proxy) as proxy_url: + async with Client( + proxy_url, elicitation_handler=elicitation_handler + ) as client: + result1 = await client.call_tool("ask_name", {}) + assert result1.data == "Hello, Alice!" + # Second call reuses the stateful session — this is the + # one that would hang without the fix. + result2 = await client.call_tool("ask_name", {}) + assert result2.data == "Hello, Alice!"