From fbdbd1a22e3bd1f6dfc8859a79c4b83c17f09ba2 Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sun, 21 Dec 2025 21:22:04 -0500 Subject: [PATCH 1/2] Refactor FastMCPProxy into ProxyProvider Move proxy functionality from custom manager classes to the Provider pattern: - Create ProxyProvider that implements the Provider interface - Move all proxy code to src/fastmcp/server/providers/proxy.py - Keep FastMCPProxy as a convenience wrapper using ProxyProvider - Add deprecation warning when importing from old location - Convert handler classmethods to module-level functions - Remove redundant get_* methods (base class defaults work) --- src/fastmcp/client/transports.py | 2 +- src/fastmcp/server/providers/__init__.py | 15 + src/fastmcp/server/providers/proxy.py | 766 ++++++++++++++++ src/fastmcp/server/proxy.py | 828 +----------------- src/fastmcp/server/server.py | 6 +- src/fastmcp/utilities/mcp_config.py | 2 +- tests/server/proxy/test_proxy_client.py | 4 +- tests/server/proxy/test_proxy_server.py | 172 +--- .../proxy/test_stateful_proxy_client.py | 2 +- tests/server/tasks/test_task_proxy.py | 2 +- tests/server/test_mount.py | 2 +- 11 files changed, 826 insertions(+), 975 deletions(-) create mode 100644 src/fastmcp/server/providers/proxy.py diff --git a/src/fastmcp/client/transports.py b/src/fastmcp/client/transports.py index 2eaede2a1b..be01e7a643 100644 --- a/src/fastmcp/client/transports.py +++ b/src/fastmcp/client/transports.py @@ -1016,7 +1016,7 @@ def __init__(self, config: MCPConfig | dict, name_as_prefix: bool = True): ): self._underlying_transports.append(transport) self._composite_server.mount( - server, prefix=name if name_as_prefix else None + server, namespace=name if name_as_prefix else None ) self.transport = FastMCPTransport(mcp=self._composite_server) diff --git a/src/fastmcp/server/providers/__init__.py b/src/fastmcp/server/providers/__init__.py index e35581f00b..840125eea2 100644 --- a/src/fastmcp/server/providers/__init__.py +++ b/src/fastmcp/server/providers/__init__.py @@ -25,12 +25,27 @@ async def get_tool(self, name: str) -> Tool | None: ``` """ +from typing import TYPE_CHECKING + from fastmcp.server.providers.base import Provider from fastmcp.server.providers.fastmcp_provider import FastMCPProvider from fastmcp.server.providers.transforming import TransformingProvider +if TYPE_CHECKING: + from fastmcp.server.providers.proxy import ProxyProvider as ProxyProvider + __all__ = [ "FastMCPProvider", "Provider", + "ProxyProvider", "TransformingProvider", ] + + +def __getattr__(name: str): + """Lazy import for ProxyProvider to avoid circular imports.""" + if name == "ProxyProvider": + from fastmcp.server.providers.proxy import ProxyProvider + + return ProxyProvider + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/fastmcp/server/providers/proxy.py b/src/fastmcp/server/providers/proxy.py new file mode 100644 index 0000000000..19c72bcbfe --- /dev/null +++ b/src/fastmcp/server/providers/proxy.py @@ -0,0 +1,766 @@ +"""ProxyProvider for proxying to remote MCP servers. + +This module provides the `ProxyProvider` class that proxies components from +a remote MCP server via a client factory. It also provides proxy component +classes that forward execution to remote servers. +""" + +from __future__ import annotations + +import base64 +import inspect +from collections.abc import Awaitable, Callable, Sequence +from typing import TYPE_CHECKING, Any, cast +from urllib.parse import quote + +import mcp.types +from mcp import ServerSession +from mcp.client.session import ClientSession +from mcp.shared.context import LifespanContextT, RequestContext +from mcp.shared.exceptions import McpError +from mcp.types import ( + METHOD_NOT_FOUND, + BlobResourceContents, + ElicitRequestFormParams, + TextResourceContents, +) +from pydantic.networks import AnyUrl + +from fastmcp.client.client import Client, FastMCP1Server +from fastmcp.client.elicitation import ElicitResult +from fastmcp.client.logging import LogMessage +from fastmcp.client.roots import RootsList +from fastmcp.client.transports import ClientTransportT +from fastmcp.exceptions import ResourceError, ToolError +from fastmcp.mcp_config import MCPConfig +from fastmcp.prompts import Prompt, PromptResult +from fastmcp.prompts.prompt import PromptArgument +from fastmcp.resources import Resource, ResourceTemplate +from fastmcp.resources.resource import ResourceContent +from fastmcp.server.context import Context +from fastmcp.server.dependencies import get_context +from fastmcp.server.providers.base import Provider, TaskComponents +from fastmcp.server.server import FastMCP +from fastmcp.server.tasks.config import TaskConfig +from fastmcp.tools.tool import Tool, ToolResult +from fastmcp.tools.tool_transform import ( + ToolTransformConfig, + apply_transformations_to_tools, +) +from fastmcp.utilities.components import MirroredComponent +from fastmcp.utilities.logging import get_logger + +if TYPE_CHECKING: + from pathlib import Path + +logger = get_logger(__name__) + +# Type alias for client factory functions +ClientFactoryT = Callable[[], Client] | Callable[[], Awaitable[Client]] + + +# ----------------------------------------------------------------------------- +# Proxy Component Classes +# ----------------------------------------------------------------------------- + + +class ProxyTool(Tool, MirroredComponent): + """A Tool that represents and executes a tool on a remote server.""" + + task_config: TaskConfig = TaskConfig(mode="forbidden") + _backend_name: str | None = None + + def __init__(self, client_factory: ClientFactoryT, **kwargs: Any): + super().__init__(**kwargs) + self._client_factory = client_factory + + async def _get_client(self) -> Client: + """Gets a client instance by calling the sync or async factory.""" + client = self._client_factory() + if inspect.isawaitable(client): + client = await client + return client + + def model_copy(self, **kwargs: Any) -> ProxyTool: + """Override to preserve _backend_name when name changes.""" + update = kwargs.get("update", {}) + if "name" in update and self._backend_name is None: + # First time name is being changed, preserve original for backend calls + update = {**update, "_backend_name": self.name} + kwargs["update"] = update + return super().model_copy(**kwargs) # type: ignore[return-value] + + @classmethod + def from_mcp_tool( + cls, client_factory: ClientFactoryT, mcp_tool: mcp.types.Tool + ) -> ProxyTool: + """Factory method to create a ProxyTool from a raw MCP tool schema.""" + return cls( + client_factory=client_factory, + name=mcp_tool.name, + title=mcp_tool.title, + description=mcp_tool.description, + parameters=mcp_tool.inputSchema, + annotations=mcp_tool.annotations, + output_schema=mcp_tool.outputSchema, + icons=mcp_tool.icons, + meta=mcp_tool.meta, + tags=(mcp_tool.meta or {}).get("_fastmcp", {}).get("tags", []), + _mirrored=True, + ) + + async def run( + self, + arguments: dict[str, Any], + context: Context | None = None, + ) -> ToolResult: + """Executes the tool by making a call through the client.""" + client = await self._get_client() + async with client: + context = get_context() + # Build meta dict from request context + meta: dict[str, Any] | None = None + if hasattr(context, "request_context"): + req_ctx = context.request_context + # Start with existing meta if present + if hasattr(req_ctx, "meta") and req_ctx.meta: + meta = dict(req_ctx.meta) + # Add task metadata if this is a task request + if ( + hasattr(req_ctx, "experimental") + and hasattr(req_ctx.experimental, "is_task") + and req_ctx.experimental.is_task + ): + task_metadata = req_ctx.experimental.task_metadata + if task_metadata: + meta = meta or {} + meta["modelcontextprotocol.io/task"] = task_metadata.model_dump( + exclude_none=True + ) + + result = await client.call_tool_mcp( + name=self._backend_name or self.name, arguments=arguments, meta=meta + ) + if result.isError: + raise ToolError(cast(mcp.types.TextContent, result.content[0]).text) + # Preserve backend's meta (includes task metadata for background tasks) + return ToolResult( + content=result.content, + structured_content=result.structuredContent, + meta=result.meta, + ) + + +class ProxyResource(Resource, MirroredComponent): + """A Resource that represents and reads a resource from a remote server.""" + + task_config: TaskConfig = TaskConfig(mode="forbidden") + _cached_content: ResourceContent | None = None + _backend_uri: str | None = None + + def __init__( + self, + client_factory: ClientFactoryT, + *, + _cached_content: ResourceContent | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self._client_factory = client_factory + self._cached_content = _cached_content + + async def _get_client(self) -> Client: + """Gets a client instance by calling the sync or async factory.""" + client = self._client_factory() + if inspect.isawaitable(client): + client = await client + return client + + def model_copy(self, **kwargs: Any) -> ProxyResource: + """Override to preserve _backend_uri when uri changes.""" + update = kwargs.get("update", {}) + if "uri" in update and self._backend_uri is None: + # First time uri is being changed, preserve original for backend calls + update = {**update, "_backend_uri": str(self.uri)} + kwargs["update"] = update + return super().model_copy(**kwargs) # type: ignore[return-value] + + @classmethod + def from_mcp_resource( + cls, + client_factory: ClientFactoryT, + mcp_resource: mcp.types.Resource, + ) -> ProxyResource: + """Factory method to create a ProxyResource from a raw MCP resource schema.""" + + return cls( + client_factory=client_factory, + uri=mcp_resource.uri, + name=mcp_resource.name, + title=mcp_resource.title, + description=mcp_resource.description, + mime_type=mcp_resource.mimeType or "text/plain", + icons=mcp_resource.icons, + meta=mcp_resource.meta, + tags=(mcp_resource.meta or {}).get("_fastmcp", {}).get("tags", []), + task_config=TaskConfig(mode="forbidden"), + _mirrored=True, + ) + + async def read(self) -> ResourceContent: + """Read the resource content from the remote server.""" + if self._cached_content is not None: + return self._cached_content + + backend_uri = self._backend_uri or str(self.uri) + client = await self._get_client() + async with client: + result = await client.read_resource(backend_uri) + if not result: + raise ResourceError( + f"Remote server returned empty content for {backend_uri}" + ) + if isinstance(result[0], TextResourceContents): + return ResourceContent( + content=result[0].text, + mime_type=result[0].mimeType, + meta=result[0].meta, + ) + elif isinstance(result[0], BlobResourceContents): + return ResourceContent( + content=base64.b64decode(result[0].blob), + mime_type=result[0].mimeType, + meta=result[0].meta, + ) + else: + raise ResourceError(f"Unsupported content type: {type(result[0])}") + + +class ProxyTemplate(ResourceTemplate, MirroredComponent): + """A ResourceTemplate that represents and creates resources from a remote server template.""" + + task_config: TaskConfig = TaskConfig(mode="forbidden") + _backend_uri_template: str | None = None + + def __init__(self, client_factory: ClientFactoryT, **kwargs: Any): + super().__init__(**kwargs) + self._client_factory = client_factory + + async def _get_client(self) -> Client: + """Gets a client instance by calling the sync or async factory.""" + client = self._client_factory() + if inspect.isawaitable(client): + client = await client + return client + + def model_copy(self, **kwargs: Any) -> ProxyTemplate: + """Override to preserve _backend_uri_template when uri_template changes.""" + update = kwargs.get("update", {}) + if "uri_template" in update and self._backend_uri_template is None: + # First time uri_template is being changed, preserve original for backend + update = {**update, "_backend_uri_template": self.uri_template} + kwargs["update"] = update + return super().model_copy(**kwargs) # type: ignore[return-value] + + @classmethod + def from_mcp_template( # type: ignore[override] + cls, client_factory: ClientFactoryT, mcp_template: mcp.types.ResourceTemplate + ) -> ProxyTemplate: + """Factory method to create a ProxyTemplate from a raw MCP template schema.""" + + return cls( + client_factory=client_factory, + uri_template=mcp_template.uriTemplate, + name=mcp_template.name, + title=mcp_template.title, + description=mcp_template.description, + mime_type=mcp_template.mimeType or "text/plain", + icons=mcp_template.icons, + parameters={}, # Remote templates don't have local parameters + meta=mcp_template.meta, + tags=(mcp_template.meta or {}).get("_fastmcp", {}).get("tags", []), + task_config=TaskConfig(mode="forbidden"), + _mirrored=True, + ) + + async def create_resource( + self, + uri: str, + params: dict[str, Any], + context: Context | None = None, + ) -> ProxyResource: + """Create a resource from the template by calling the remote server.""" + # don't use the provided uri, because it may not be the same as the + # uri_template on the remote server. + # quote params to ensure they are valid for the uri_template + backend_template = self._backend_uri_template or self.uri_template + parameterized_uri = backend_template.format( + **{k: quote(v, safe="") for k, v in params.items()} + ) + client = await self._get_client() + async with client: + result = await client.read_resource(parameterized_uri) + + if not result: + raise ResourceError( + f"Remote server returned empty content for {parameterized_uri}" + ) + if isinstance(result[0], TextResourceContents): + cached_content = ResourceContent( + content=result[0].text, + mime_type=result[0].mimeType, + meta=result[0].meta, + ) + elif isinstance(result[0], BlobResourceContents): + cached_content = ResourceContent( + content=base64.b64decode(result[0].blob), + mime_type=result[0].mimeType, + meta=result[0].meta, + ) + else: + raise ResourceError(f"Unsupported content type: {type(result[0])}") + + return ProxyResource( + client_factory=self._client_factory, + uri=parameterized_uri, + name=self.name, + title=self.title, + description=self.description, + mime_type=result[0].mimeType, + icons=self.icons, + meta=self.meta, + tags=(self.meta or {}).get("_fastmcp", {}).get("tags", []), + _cached_content=cached_content, + ) + + +class ProxyPrompt(Prompt, MirroredComponent): + """A Prompt that represents and renders a prompt from a remote server.""" + + task_config: TaskConfig = TaskConfig(mode="forbidden") + _backend_name: str | None = None + + def __init__(self, client_factory: ClientFactoryT, **kwargs): + super().__init__(**kwargs) + self._client_factory = client_factory + + async def _get_client(self) -> Client: + """Gets a client instance by calling the sync or async factory.""" + client = self._client_factory() + if inspect.isawaitable(client): + client = await client + return client + + def model_copy(self, **kwargs: Any) -> ProxyPrompt: + """Override to preserve _backend_name when name changes.""" + update = kwargs.get("update", {}) + if "name" in update and self._backend_name is None: + # First time name is being changed, preserve original for backend calls + update = {**update, "_backend_name": self.name} + kwargs["update"] = update + return super().model_copy(**kwargs) # type: ignore[return-value] + + @classmethod + def from_mcp_prompt( + cls, client_factory: ClientFactoryT, mcp_prompt: mcp.types.Prompt + ) -> ProxyPrompt: + """Factory method to create a ProxyPrompt from a raw MCP prompt schema.""" + arguments = [ + PromptArgument( + name=arg.name, + description=arg.description, + required=arg.required or False, + ) + for arg in mcp_prompt.arguments or [] + ] + return cls( + client_factory=client_factory, + name=mcp_prompt.name, + title=mcp_prompt.title, + description=mcp_prompt.description, + arguments=arguments, + icons=mcp_prompt.icons, + meta=mcp_prompt.meta, + tags=(mcp_prompt.meta or {}).get("_fastmcp", {}).get("tags", []), + task_config=TaskConfig(mode="forbidden"), + _mirrored=True, + ) + + async def render(self, arguments: dict[str, Any]) -> PromptResult: # type: ignore[override] + """Render the prompt by making a call through the client.""" + client = await self._get_client() + async with client: + result = await client.get_prompt(self._backend_name or self.name, arguments) + # Convert GetPromptResult to PromptResult, preserving runtime meta from the result + # (not the static prompt meta which includes fastmcp tags) + return PromptResult( + messages=result.messages, + description=result.description, + meta=result.meta, + ) + + +# ----------------------------------------------------------------------------- +# ProxyProvider +# ----------------------------------------------------------------------------- + + +class ProxyProvider(Provider): + """Provider that proxies to a remote MCP server via a client factory. + + This provider fetches components from a remote server and returns Proxy* + component instances that forward execution to the remote server. + + All components returned by this provider have task_config.mode="forbidden" + because tasks cannot be executed through a proxy. + + Example: + ```python + from fastmcp import FastMCP + from fastmcp.server.providers.proxy import ProxyProvider, ProxyClient + + # Create a proxy provider for a remote server + proxy = ProxyProvider(lambda: ProxyClient("http://localhost:8000/mcp")) + + mcp = FastMCP("Proxy Server") + mcp.add_provider(proxy) + + # Can also add with namespace + mcp.add_provider(proxy.with_namespace("remote")) + ``` + """ + + def __init__( + self, + client_factory: ClientFactoryT, + *, + tool_transformations: dict[str, ToolTransformConfig] | None = None, + ): + """Initialize a ProxyProvider. + + Args: + client_factory: A callable that returns a Client instance when called. + This gives you full control over session creation and reuse. + Can be either a synchronous or asynchronous function. + tool_transformations: Optional tool transformations to apply to proxy tools. + """ + super().__init__() + self.client_factory = client_factory + self.tool_transformations = tool_transformations or {} + + async def _get_client(self) -> Client: + """Gets a client instance by calling the sync or async factory.""" + client = self.client_factory() + if inspect.isawaitable(client): + client = await client + return client + + # ------------------------------------------------------------------------- + # Tool methods + # ------------------------------------------------------------------------- + + async def list_tools(self) -> Sequence[Tool]: + """List all tools from the remote server.""" + try: + client = await self._get_client() + async with client: + mcp_tools = await client.list_tools() + tools = { + t.name: ProxyTool.from_mcp_tool(self.client_factory, t) + for t in mcp_tools + } + # Apply tool transformations if configured + if self.tool_transformations: + tools = apply_transformations_to_tools( + tools, self.tool_transformations + ) + return list(tools.values()) + except McpError as e: + if e.error.code == METHOD_NOT_FOUND: + return [] + raise + + # ------------------------------------------------------------------------- + # Resource methods + # ------------------------------------------------------------------------- + + async def list_resources(self) -> Sequence[Resource]: + """List all resources from the remote server.""" + try: + client = await self._get_client() + async with client: + mcp_resources = await client.list_resources() + return [ + ProxyResource.from_mcp_resource(self.client_factory, r) + for r in mcp_resources + ] + except McpError as e: + if e.error.code == METHOD_NOT_FOUND: + return [] + raise + + # ------------------------------------------------------------------------- + # Resource template methods + # ------------------------------------------------------------------------- + + async def list_resource_templates(self) -> Sequence[ResourceTemplate]: + """List all resource templates from the remote server.""" + try: + client = await self._get_client() + async with client: + mcp_templates = await client.list_resource_templates() + return [ + ProxyTemplate.from_mcp_template(self.client_factory, t) + for t in mcp_templates + ] + except McpError as e: + if e.error.code == METHOD_NOT_FOUND: + return [] + raise + + # ------------------------------------------------------------------------- + # Prompt methods + # ------------------------------------------------------------------------- + + async def list_prompts(self) -> Sequence[Prompt]: + """List all prompts from the remote server.""" + try: + client = await self._get_client() + async with client: + mcp_prompts = await client.list_prompts() + return [ + ProxyPrompt.from_mcp_prompt(self.client_factory, p) + for p in mcp_prompts + ] + except McpError as e: + if e.error.code == METHOD_NOT_FOUND: + return [] + raise + + # ------------------------------------------------------------------------- + # Task methods + # ------------------------------------------------------------------------- + + async def get_tasks(self) -> TaskComponents: + """Return empty TaskComponents since proxy components don't support tasks. + + Override the base implementation to avoid calling list_tools() during + server lifespan initialization, which would open the client before any + context is set. All Proxy* components have task_config.mode="forbidden". + """ + return TaskComponents(tools=[], resources=[], prompts=[]) + + # lifespan() uses default implementation (empty context manager) + # because client cleanup is handled per-request + + +# ----------------------------------------------------------------------------- +# FastMCPProxy - Convenience Wrapper +# ----------------------------------------------------------------------------- + + +class FastMCPProxy(FastMCP): + """A FastMCP server that acts as a proxy to a remote MCP-compliant server. + + This is a convenience wrapper that creates a FastMCP server with a + ProxyProvider. For more control, use FastMCP with add_provider(ProxyProvider(...)). + + Example: + ```python + from fastmcp import FastMCP + from fastmcp.server.providers.proxy import FastMCPProxy, ProxyClient + + # Create a proxy server + proxy = FastMCPProxy(client_factory=lambda: ProxyClient("http://localhost:8000/mcp")) + + # Or use the convenience method + proxy = FastMCP.as_proxy("http://localhost:8000/mcp") + ``` + """ + + def __init__( + self, + *, + client_factory: ClientFactoryT, + **kwargs, + ): + """Initialize the proxy server. + + FastMCPProxy requires explicit session management via client_factory. + Use FastMCP.as_proxy() for convenience with automatic session strategy. + + Args: + client_factory: A callable that returns a Client instance when called. + This gives you full control over session creation and reuse. + Can be either a synchronous or asynchronous function. + **kwargs: Additional settings for the FastMCP server. + """ + # Extract tool_transformations before passing to parent + tool_transformations = kwargs.pop("tool_transformations", None) + super().__init__(**kwargs) + self.client_factory = client_factory + self.add_provider( + ProxyProvider(client_factory, tool_transformations=tool_transformations) + ) + + +# ----------------------------------------------------------------------------- +# ProxyClient and Related +# ----------------------------------------------------------------------------- + + +async def default_proxy_roots_handler( + context: RequestContext[ClientSession, LifespanContextT], +) -> RootsList: + """Forward list roots request from remote server to proxy's connected clients.""" + ctx = get_context() + return await ctx.list_roots() + + +async def default_proxy_sampling_handler( + messages: list[mcp.types.SamplingMessage], + params: mcp.types.CreateMessageRequestParams, + context: RequestContext[ClientSession, LifespanContextT], +) -> mcp.types.CreateMessageResult: + """Forward sampling request from remote server to proxy's connected clients.""" + ctx = get_context() + result = await ctx.sample( + list(messages), + system_prompt=params.systemPrompt, + temperature=params.temperature, + max_tokens=params.maxTokens, + model_preferences=params.modelPreferences, + ) + content = mcp.types.TextContent(type="text", text=result.text or "") + return mcp.types.CreateMessageResult( + role="assistant", + model="fastmcp-client", + # TODO(ty): remove when ty supports isinstance exclusion narrowing + content=content, # type: ignore[arg-type] + ) + + +async def default_proxy_elicitation_handler( + message: str, + response_type: type, + params: mcp.types.ElicitRequestParams, + context: RequestContext[ClientSession, LifespanContextT], +) -> ElicitResult: + """Forward elicitation request from remote server to proxy's connected clients.""" + ctx = get_context() + # requestedSchema only exists on ElicitRequestFormParams, not ElicitRequestURLParams + requested_schema = ( + params.requestedSchema + if isinstance(params, ElicitRequestFormParams) + else {"type": "object", "properties": {}} + ) + result = await ctx.session.elicit( + message=message, + requestedSchema=requested_schema, + related_request_id=ctx.request_id, + ) + return ElicitResult(action=result.action, content=result.content) + + +async def default_proxy_log_handler(message: LogMessage) -> None: + """Forward log notification from remote server to proxy's connected clients.""" + ctx = get_context() + msg = message.data.get("msg") + extra = message.data.get("extra") + await ctx.log(msg, level=message.level, logger_name=message.logger, extra=extra) + + +async def default_proxy_progress_handler( + progress: float, + total: float | None, + message: str | None, +) -> None: + """Forward progress notification from remote server to proxy's connected clients.""" + ctx = get_context() + await ctx.report_progress(progress, total, message) + + +class ProxyClient(Client[ClientTransportT]): + """A proxy client that forwards advanced interactions between a remote MCP server and the proxy's connected clients. + + Supports forwarding roots, sampling, elicitation, logging, and progress. + """ + + # Stored context for handlers when contextvar isn't available + # (e.g., when receive loop was started before any request context) + _proxy_context: Context | None = None + + def __init__( + self, + transport: ClientTransportT + | FastMCP[Any] + | FastMCP1Server + | AnyUrl + | Path + | MCPConfig + | dict[str, Any] + | str, + **kwargs, + ): + if "name" not in kwargs: + kwargs["name"] = self.generate_name() + if "roots" not in kwargs: + kwargs["roots"] = default_proxy_roots_handler + if "sampling_handler" not in kwargs: + kwargs["sampling_handler"] = default_proxy_sampling_handler + if "elicitation_handler" not in kwargs: + kwargs["elicitation_handler"] = default_proxy_elicitation_handler + if "log_handler" not in kwargs: + kwargs["log_handler"] = default_proxy_log_handler + if "progress_handler" not in kwargs: + kwargs["progress_handler"] = default_proxy_progress_handler + super().__init__(**kwargs | {"transport": transport}) + + +class StatefulProxyClient(ProxyClient[ClientTransportT]): + """A proxy client that provides a stateful client factory for the proxy server. + + The stateful proxy client bound its copy to the server session. + And it will be disconnected when the session is exited. + + This is useful to proxy a stateful mcp server such as the Playwright MCP server. + Note that it is essential to ensure that the proxy server itself is also stateful. + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + self._caches: dict[ServerSession, Client[ClientTransportT]] = {} + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore[override] + """The stateful proxy client will be forced disconnected when the session is exited. + + So we do nothing here. + """ + + async def clear(self): + """Clear all cached clients and force disconnect them.""" + while self._caches: + _, cache = self._caches.popitem() + await cache._disconnect(force=True) + + def new_stateful(self) -> Client[ClientTransportT]: + """Create a new stateful proxy client instance with the same configuration. + + Use this method as the client factory for stateful proxy server. + """ + session = get_context().session + proxy_client = self._caches.get(session, None) + + if proxy_client is None: + proxy_client = self.new() + logger.debug(f"{proxy_client} created for {session}") + self._caches[session] = proxy_client + + async def _on_session_exit(): + self._caches.pop(session) + logger.debug(f"{proxy_client} will be disconnect") + await proxy_client._disconnect(force=True) + + session._exit_stack.push_async_callback(_on_session_exit) + + return proxy_client diff --git a/src/fastmcp/server/proxy.py b/src/fastmcp/server/proxy.py index bd4a1a51ee..45190c0f80 100644 --- a/src/fastmcp/server/proxy.py +++ b/src/fastmcp/server/proxy.py @@ -1,801 +1,41 @@ +"""Backwards compatibility - import from fastmcp.server.providers.proxy instead. + +This module re-exports all proxy-related classes from their new location +at fastmcp.server.providers.proxy. Direct imports from this module are +deprecated and will be removed in a future version. +""" + from __future__ import annotations -import base64 -import inspect -from collections.abc import Awaitable, Callable -from pathlib import Path -from typing import TYPE_CHECKING, Any, cast -from urllib.parse import quote +import warnings -import mcp.types -from mcp import ServerSession -from mcp.client.session import ClientSession -from mcp.shared.context import LifespanContextT, RequestContext -from mcp.shared.exceptions import McpError -from mcp.types import ( - METHOD_NOT_FOUND, - BlobResourceContents, - ElicitRequestFormParams, - TextResourceContents, +warnings.warn( + "fastmcp.server.proxy is deprecated. Use fastmcp.server.providers.proxy instead.", + DeprecationWarning, + stacklevel=2, ) -from pydantic.networks import AnyUrl -from fastmcp.client.client import Client, FastMCP1Server -from fastmcp.client.elicitation import ElicitResult -from fastmcp.client.logging import LogMessage -from fastmcp.client.roots import RootsList -from fastmcp.client.transports import ClientTransportT -from fastmcp.exceptions import NotFoundError, ResourceError, ToolError -from fastmcp.mcp_config import MCPConfig -from fastmcp.prompts import Prompt, PromptResult -from fastmcp.prompts.prompt import PromptArgument -from fastmcp.prompts.prompt_manager import PromptManager -from fastmcp.resources import Resource, ResourceTemplate -from fastmcp.resources.resource import ResourceContent -from fastmcp.resources.resource_manager import ResourceManager -from fastmcp.server.context import Context -from fastmcp.server.dependencies import get_context -from fastmcp.server.server import FastMCP -from fastmcp.server.tasks.config import TaskConfig -from fastmcp.tools.tool import Tool, ToolResult -from fastmcp.tools.tool_manager import ToolManager -from fastmcp.tools.tool_transform import ( - apply_transformations_to_tools, +# Re-export everything from the new location +from fastmcp.server.providers.proxy import ( # noqa: E402 + ClientFactoryT, + FastMCPProxy, + ProxyClient, + ProxyPrompt, + ProxyProvider, + ProxyResource, + ProxyTemplate, + ProxyTool, + StatefulProxyClient, ) -from fastmcp.utilities.components import MirroredComponent -from fastmcp.utilities.logging import get_logger - -if TYPE_CHECKING: - from fastmcp.server import Context - -logger = get_logger(__name__) - -# Type alias for client factory functions -ClientFactoryT = Callable[[], Client] | Callable[[], Awaitable[Client]] - - -class ProxyManagerMixin: - """A mixin for proxy managers to provide a unified client retrieval method.""" - - client_factory: ClientFactoryT - - async def _get_client(self) -> Client: - """Gets a client instance by calling the sync or async factory.""" - client = self.client_factory() - if inspect.isawaitable(client): - client = await client - return client - - -class ProxyToolManager(ToolManager, ProxyManagerMixin): - """A ToolManager that sources its tools from a remote client in addition to local and mounted tools.""" - - def __init__(self, client_factory: ClientFactoryT, **kwargs: Any): - super().__init__(**kwargs) - self.client_factory = client_factory - - async def get_tools(self) -> dict[str, Tool]: - """Gets the unfiltered tool inventory including local, mounted, and proxy tools.""" - # First get local and mounted tools from parent - all_tools = await super().get_tools() - - # Then add proxy tools, but don't overwrite existing ones - try: - client = await self._get_client() - async with client: - client_tools = await client.list_tools() - for tool in client_tools: - if tool.name not in all_tools: - all_tools[tool.name] = ProxyTool.from_mcp_tool(client, tool) - except McpError as e: - if e.error.code == METHOD_NOT_FOUND: - pass # No tools available from proxy - else: - raise e - - transformed_tools = apply_transformations_to_tools( - tools=all_tools, - transformations=self.transformations, - ) - - return transformed_tools - - async def list_tools(self) -> list[Tool]: - """Gets the filtered list of tools including local, mounted, and proxy tools.""" - tools_dict = await self.get_tools() - return list(tools_dict.values()) - - async def call_tool(self, key: str, arguments: dict[str, Any]) -> ToolResult: - """Calls a tool, trying local/mounted first, then proxy if not found.""" - try: - # First try local and mounted tools - return await super().call_tool(key, arguments) - except NotFoundError: - # If not found locally, try proxy - client = await self._get_client() - async with client: - result = await client.call_tool(key, arguments) - return ToolResult( - content=result.content, - structured_content=result.structured_content, - meta=result.meta, - ) - - -class ProxyResourceManager(ResourceManager, ProxyManagerMixin): - """A ResourceManager that sources its resources from a remote client in addition to local and mounted resources.""" - - def __init__(self, client_factory: ClientFactoryT, **kwargs: Any): - super().__init__(**kwargs) - self.client_factory = client_factory - - async def get_resources(self) -> dict[str, Resource]: - """Gets the unfiltered resource inventory including local, mounted, and proxy resources.""" - # First get local and mounted resources from parent - all_resources = await super().get_resources() - - # Then add proxy resources, but don't overwrite existing ones - try: - client = await self._get_client() - async with client: - client_resources = await client.list_resources() - for resource in client_resources: - if str(resource.uri) not in all_resources: - all_resources[str(resource.uri)] = ( - ProxyResource.from_mcp_resource(client, resource) - ) - except McpError as e: - if e.error.code == METHOD_NOT_FOUND: - pass # No resources available from proxy - else: - raise e - - return all_resources - - async def get_resource_templates(self) -> dict[str, ResourceTemplate]: - """Gets the unfiltered template inventory including local, mounted, and proxy templates.""" - # First get local and mounted templates from parent - all_templates = await super().get_resource_templates() - - # Then add proxy templates, but don't overwrite existing ones - try: - client = await self._get_client() - async with client: - client_templates = await client.list_resource_templates() - for template in client_templates: - if template.uriTemplate not in all_templates: - all_templates[template.uriTemplate] = ( - ProxyTemplate.from_mcp_template(client, template) - ) - except McpError as e: - if e.error.code == METHOD_NOT_FOUND: - pass # No templates available from proxy - else: - raise e - - return all_templates - - async def list_resources(self) -> list[Resource]: - """Gets the filtered list of resources including local, mounted, and proxy resources.""" - resources_dict = await self.get_resources() - return list(resources_dict.values()) - - async def list_resource_templates(self) -> list[ResourceTemplate]: - """Gets the filtered list of templates including local, mounted, and proxy templates.""" - templates_dict = await self.get_resource_templates() - return list(templates_dict.values()) - - async def read_resource( - self, uri: AnyUrl | str - ) -> ResourceContent | mcp.types.CreateTaskResult: - """Reads a resource, trying local/mounted first, then proxy if not found.""" - try: - # First try local and mounted resources - return await super().read_resource(uri) - except NotFoundError: - # If not found locally, try proxy - client = await self._get_client() - async with client: - result = await client.read_resource(uri) - if not result: - raise ResourceError( - f"Remote server returned empty content for {uri}" - ) from None - if isinstance(result[0], TextResourceContents): - return ResourceContent( - content=result[0].text, - mime_type=result[0].mimeType, - meta=result[0].meta, - ) - elif isinstance(result[0], BlobResourceContents): - return ResourceContent( - content=base64.b64decode(result[0].blob), - mime_type=result[0].mimeType, - meta=result[0].meta, - ) - else: - raise ResourceError( - f"Unsupported content type: {type(result[0])}" - ) from None - - -class ProxyPromptManager(PromptManager, ProxyManagerMixin): - """A PromptManager that sources its prompts from a remote client in addition to local and mounted prompts.""" - - def __init__(self, client_factory: ClientFactoryT, **kwargs: Any): - super().__init__(**kwargs) - self.client_factory = client_factory - - async def get_prompts(self) -> dict[str, Prompt]: - """Gets the unfiltered prompt inventory including local, mounted, and proxy prompts.""" - # First get local and mounted prompts from parent - all_prompts = await super().get_prompts() - - # Then add proxy prompts, but don't overwrite existing ones - try: - client = await self._get_client() - async with client: - client_prompts = await client.list_prompts() - for prompt in client_prompts: - if prompt.name not in all_prompts: - all_prompts[prompt.name] = ProxyPrompt.from_mcp_prompt( - client, prompt - ) - except McpError as e: - if e.error.code == METHOD_NOT_FOUND: - pass # No prompts available from proxy - else: - raise e - - return all_prompts - - async def list_prompts(self) -> list[Prompt]: - """Gets the filtered list of prompts including local, mounted, and proxy prompts.""" - prompts_dict = await self.get_prompts() - return list(prompts_dict.values()) - - async def render_prompt( - self, - name: str, - arguments: dict[str, Any] | None = None, - ) -> PromptResult | mcp.types.CreateTaskResult: - """Renders a prompt, trying local/mounted first, then proxy if not found.""" - try: - # First try local and mounted prompts - return await super().render_prompt(name, arguments) - except NotFoundError: - # If not found locally, try proxy - client = await self._get_client() - async with client: - result = await client.get_prompt(name, arguments) - # Convert MCP GetPromptResult to PromptResult - return PromptResult( - messages=result.messages, - description=result.description, - meta=result.meta, - ) - - -class ProxyTool(Tool, MirroredComponent): - """ - A Tool that represents and executes a tool on a remote server. - """ - - task_config: TaskConfig = TaskConfig(mode="forbidden") - _backend_name: str | None = None - - def __init__(self, client: Client, **kwargs: Any): - super().__init__(**kwargs) - self._client = client - - def model_copy(self, **kwargs: Any) -> ProxyTool: - """Override to preserve _backend_name when name changes.""" - update = kwargs.get("update", {}) - if "name" in update and self._backend_name is None: - # First time name is being changed, preserve original for backend calls - update = {**update, "_backend_name": self.name} - kwargs["update"] = update - return super().model_copy(**kwargs) # type: ignore[return-value] - - @classmethod - def from_mcp_tool(cls, client: Client, mcp_tool: mcp.types.Tool) -> ProxyTool: - """Factory method to create a ProxyTool from a raw MCP tool schema.""" - return cls( - client=client, - name=mcp_tool.name, - title=mcp_tool.title, - description=mcp_tool.description, - parameters=mcp_tool.inputSchema, - annotations=mcp_tool.annotations, - output_schema=mcp_tool.outputSchema, - icons=mcp_tool.icons, - meta=mcp_tool.meta, - tags=(mcp_tool.meta or {}).get("_fastmcp", {}).get("tags", []), - _mirrored=True, - ) - - async def run( - self, - arguments: dict[str, Any], - context: Context | None = None, - ) -> ToolResult: - """Executes the tool by making a call through the client.""" - async with self._client: - context = get_context() - # Build meta dict from request context - meta: dict[str, Any] | None = None - if hasattr(context, "request_context"): - req_ctx = context.request_context - # Start with existing meta if present - if hasattr(req_ctx, "meta") and req_ctx.meta: - meta = dict(req_ctx.meta) - # Add task metadata if this is a task request - if ( - hasattr(req_ctx, "experimental") - and hasattr(req_ctx.experimental, "is_task") - and req_ctx.experimental.is_task - ): - task_metadata = req_ctx.experimental.task_metadata - if task_metadata: - meta = meta or {} - meta["modelcontextprotocol.io/task"] = task_metadata.model_dump( - exclude_none=True - ) - - result = await self._client.call_tool_mcp( - name=self._backend_name or self.name, arguments=arguments, meta=meta - ) - if result.isError: - raise ToolError(cast(mcp.types.TextContent, result.content[0]).text) - # Preserve backend's meta (includes task metadata for background tasks) - return ToolResult( - content=result.content, - structured_content=result.structuredContent, - meta=result.meta, - ) - - -class ProxyResource(Resource, MirroredComponent): - """ - A Resource that represents and reads a resource from a remote server. - """ - - task_config: TaskConfig = TaskConfig(mode="forbidden") - _client: Client - _cached_content: ResourceContent | None = None - _backend_uri: str | None = None - - def __init__( - self, - client: Client, - *, - _cached_content: ResourceContent | None = None, - **kwargs, - ): - super().__init__(**kwargs) - self._client = client - self._cached_content = _cached_content - - def model_copy(self, **kwargs: Any) -> ProxyResource: - """Override to preserve _backend_uri when uri changes.""" - update = kwargs.get("update", {}) - if "uri" in update and self._backend_uri is None: - # First time uri is being changed, preserve original for backend calls - update = {**update, "_backend_uri": str(self.uri)} - kwargs["update"] = update - return super().model_copy(**kwargs) # type: ignore[return-value] - - @classmethod - def from_mcp_resource( - cls, - client: Client, - mcp_resource: mcp.types.Resource, - ) -> ProxyResource: - """Factory method to create a ProxyResource from a raw MCP resource schema.""" - - return cls( - client=client, - uri=mcp_resource.uri, - name=mcp_resource.name, - title=mcp_resource.title, - description=mcp_resource.description, - mime_type=mcp_resource.mimeType or "text/plain", - icons=mcp_resource.icons, - meta=mcp_resource.meta, - tags=(mcp_resource.meta or {}).get("_fastmcp", {}).get("tags", []), - task_config=TaskConfig(mode="forbidden"), - _mirrored=True, - ) - - async def read(self) -> ResourceContent: - """Read the resource content from the remote server.""" - if self._cached_content is not None: - return self._cached_content - - backend_uri = self._backend_uri or str(self.uri) - async with self._client: - result = await self._client.read_resource(backend_uri) - if not result: - raise ResourceError( - f"Remote server returned empty content for {backend_uri}" - ) - if isinstance(result[0], TextResourceContents): - return ResourceContent( - content=result[0].text, - mime_type=result[0].mimeType, - meta=result[0].meta, - ) - elif isinstance(result[0], BlobResourceContents): - return ResourceContent( - content=base64.b64decode(result[0].blob), - mime_type=result[0].mimeType, - meta=result[0].meta, - ) - else: - raise ResourceError(f"Unsupported content type: {type(result[0])}") - - -class ProxyTemplate(ResourceTemplate, MirroredComponent): - """ - A ResourceTemplate that represents and creates resources from a remote server template. - """ - - task_config: TaskConfig = TaskConfig(mode="forbidden") - _backend_uri_template: str | None = None - - def __init__(self, client: Client, **kwargs: Any): - super().__init__(**kwargs) - self._client = client - - def model_copy(self, **kwargs: Any) -> ProxyTemplate: - """Override to preserve _backend_uri_template when uri_template changes.""" - update = kwargs.get("update", {}) - if "uri_template" in update and self._backend_uri_template is None: - # First time uri_template is being changed, preserve original for backend - update = {**update, "_backend_uri_template": self.uri_template} - kwargs["update"] = update - return super().model_copy(**kwargs) # type: ignore[return-value] - - @classmethod - def from_mcp_template( # type: ignore[override] - cls, client: Client, mcp_template: mcp.types.ResourceTemplate - ) -> ProxyTemplate: - """Factory method to create a ProxyTemplate from a raw MCP template schema.""" - return cls( - client=client, - uri_template=mcp_template.uriTemplate, - name=mcp_template.name, - title=mcp_template.title, - description=mcp_template.description, - mime_type=mcp_template.mimeType or "text/plain", - icons=mcp_template.icons, - parameters={}, # Remote templates don't have local parameters - meta=mcp_template.meta, - tags=(mcp_template.meta or {}).get("_fastmcp", {}).get("tags", []), - task_config=TaskConfig(mode="forbidden"), - _mirrored=True, - ) - - async def create_resource( - self, - uri: str, - params: dict[str, Any], - context: Context | None = None, - ) -> ProxyResource: - """Create a resource from the template by calling the remote server.""" - # don't use the provided uri, because it may not be the same as the - # uri_template on the remote server. - # quote params to ensure they are valid for the uri_template - backend_template = self._backend_uri_template or self.uri_template - parameterized_uri = backend_template.format( - **{k: quote(v, safe="") for k, v in params.items()} - ) - async with self._client: - result = await self._client.read_resource(parameterized_uri) - - if not result: - raise ResourceError( - f"Remote server returned empty content for {parameterized_uri}" - ) - if isinstance(result[0], TextResourceContents): - cached_content = ResourceContent( - content=result[0].text, - mime_type=result[0].mimeType, - meta=result[0].meta, - ) - elif isinstance(result[0], BlobResourceContents): - cached_content = ResourceContent( - content=base64.b64decode(result[0].blob), - mime_type=result[0].mimeType, - meta=result[0].meta, - ) - else: - raise ResourceError(f"Unsupported content type: {type(result[0])}") - - return ProxyResource( - client=self._client, - uri=parameterized_uri, - name=self.name, - title=self.title, - description=self.description, - mime_type=result[0].mimeType, - icons=self.icons, - meta=self.meta, - tags=(self.meta or {}).get("_fastmcp", {}).get("tags", []), - _cached_content=cached_content, - ) - - -class ProxyPrompt(Prompt, MirroredComponent): - """ - A Prompt that represents and renders a prompt from a remote server. - """ - - task_config: TaskConfig = TaskConfig(mode="forbidden") - _client: Client - _backend_name: str | None = None - - def __init__(self, client: Client, **kwargs): - super().__init__(**kwargs) - self._client = client - - def model_copy(self, **kwargs: Any) -> ProxyPrompt: - """Override to preserve _backend_name when name changes.""" - update = kwargs.get("update", {}) - if "name" in update and self._backend_name is None: - # First time name is being changed, preserve original for backend calls - update = {**update, "_backend_name": self.name} - kwargs["update"] = update - return super().model_copy(**kwargs) # type: ignore[return-value] - - @classmethod - def from_mcp_prompt( - cls, client: Client, mcp_prompt: mcp.types.Prompt - ) -> ProxyPrompt: - """Factory method to create a ProxyPrompt from a raw MCP prompt schema.""" - arguments = [ - PromptArgument( - name=arg.name, - description=arg.description, - required=arg.required or False, - ) - for arg in mcp_prompt.arguments or [] - ] - return cls( - client=client, - name=mcp_prompt.name, - title=mcp_prompt.title, - description=mcp_prompt.description, - arguments=arguments, - icons=mcp_prompt.icons, - meta=mcp_prompt.meta, - tags=(mcp_prompt.meta or {}).get("_fastmcp", {}).get("tags", []), - task_config=TaskConfig(mode="forbidden"), - _mirrored=True, - ) - - async def render(self, arguments: dict[str, Any]) -> PromptResult: # type: ignore[override] - """Render the prompt by making a call through the client.""" - async with self._client: - result = await self._client.get_prompt( - self._backend_name or self.name, arguments - ) - # Convert GetPromptResult to PromptResult, preserving runtime meta from the result - # (not the static prompt meta which includes fastmcp tags) - return PromptResult( - messages=result.messages, - description=result.description, - meta=result.meta, - ) - - -class FastMCPProxy(FastMCP): - """ - A FastMCP server that acts as a proxy to a remote MCP-compliant server. - It uses specialized managers that fulfill requests via a client factory. - """ - - def __init__( - self, - *, - client_factory: ClientFactoryT, - **kwargs, - ): - """ - Initializes the proxy server. - - FastMCPProxy requires explicit session management via client_factory. - Use FastMCP.as_proxy() for convenience with automatic session strategy. - - Args: - client_factory: A callable that returns a Client instance when called. - This gives you full control over session creation and reuse. - Can be either a synchronous or asynchronous function. - **kwargs: Additional settings for the FastMCP server. - """ - - super().__init__(**kwargs) - - self.client_factory = client_factory - - # Replace the default managers with our specialized proxy managers. - self._tool_manager = ProxyToolManager( - client_factory=self.client_factory, - # Propagate the transformations from the base class tool manager - transformations=self._tool_manager.transformations, - ) - self._resource_manager = ProxyResourceManager( - client_factory=self.client_factory - ) - self._prompt_manager = ProxyPromptManager(client_factory=self.client_factory) - - -async def default_proxy_roots_handler( - context: RequestContext[ClientSession, LifespanContextT], -) -> RootsList: - """ - A handler that forwards the list roots request from the remote server to the proxy's connected clients and relays the response back to the remote server. - """ - ctx = get_context() - return await ctx.list_roots() - - -class ProxyClient(Client[ClientTransportT]): - """ - A proxy client that forwards advanced interactions between a remote MCP server and the proxy's connected clients. - Supports forwarding roots, sampling, elicitation, logging, and progress. - """ - - def __init__( - self, - transport: ClientTransportT - | FastMCP[Any] - | FastMCP1Server - | AnyUrl - | Path - | MCPConfig - | dict[str, Any] - | str, - **kwargs, - ): - if "name" not in kwargs: - kwargs["name"] = self.generate_name() - if "roots" not in kwargs: - kwargs["roots"] = default_proxy_roots_handler - if "sampling_handler" not in kwargs: - kwargs["sampling_handler"] = ProxyClient.default_sampling_handler - if "elicitation_handler" not in kwargs: - kwargs["elicitation_handler"] = ProxyClient.default_elicitation_handler - if "log_handler" not in kwargs: - kwargs["log_handler"] = ProxyClient.default_log_handler - if "progress_handler" not in kwargs: - kwargs["progress_handler"] = ProxyClient.default_progress_handler - super().__init__(**kwargs | {"transport": transport}) - - @classmethod - async def default_sampling_handler( - cls, - messages: list[mcp.types.SamplingMessage], - params: mcp.types.CreateMessageRequestParams, - context: RequestContext[ClientSession, LifespanContextT], - ) -> mcp.types.CreateMessageResult: - """ - A handler that forwards the sampling request from the remote server to the proxy's connected clients and relays the response back to the remote server. - """ - ctx = get_context() - result = await ctx.sample( - list(messages), - system_prompt=params.systemPrompt, - temperature=params.temperature, - max_tokens=params.maxTokens, - model_preferences=params.modelPreferences, - ) - # Create TextContent from the result text - content = mcp.types.TextContent(type="text", text=result.text or "") - return mcp.types.CreateMessageResult( - role="assistant", - model="fastmcp-client", - # TODO(ty): remove when ty supports isinstance exclusion narrowing - content=content, # type: ignore[arg-type] - ) - - @classmethod - async def default_elicitation_handler( - cls, - message: str, - response_type: type, - params: mcp.types.ElicitRequestParams, - context: RequestContext[ClientSession, LifespanContextT], - ) -> ElicitResult: - """ - A handler that forwards the elicitation request from the remote server to the proxy's connected clients and relays the response back to the remote server. - """ - ctx = get_context() - # requestedSchema only exists on ElicitRequestFormParams, not ElicitRequestURLParams - requested_schema = ( - params.requestedSchema - if isinstance(params, ElicitRequestFormParams) - else {"type": "object", "properties": {}} - ) - result = await ctx.session.elicit( - message=message, - requestedSchema=requested_schema, - related_request_id=ctx.request_id, - ) - return ElicitResult(action=result.action, content=result.content) - - @classmethod - async def default_log_handler(cls, message: LogMessage) -> None: - """ - A handler that forwards the log notification from the remote server to the proxy's connected clients. - """ - ctx = get_context() - msg = message.data.get("msg") - extra = message.data.get("extra") - await ctx.log(msg, level=message.level, logger_name=message.logger, extra=extra) - - @classmethod - async def default_progress_handler( - cls, - progress: float, - total: float | None, - message: str | None, - ) -> None: - """ - A handler that forwards the progress notification from the remote server to the proxy's connected clients. - """ - ctx = get_context() - await ctx.report_progress(progress, total, message) - - -class StatefulProxyClient(ProxyClient[ClientTransportT]): - """ - A proxy client that provides a stateful client factory for the proxy server. - - The stateful proxy client bound its copy to the server session. - And it will be disconnected when the session is exited. - - This is useful to proxy a stateful mcp server such as the Playwright MCP server. - Note that it is essential to ensure that the proxy server itself is also stateful. - """ - - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._caches: dict[ServerSession, Client[ClientTransportT]] = {} - - async def __aexit__(self, exc_type, exc_value, traceback) -> None: # type: ignore[override] - """ - The stateful proxy client will be forced disconnected when the session is exited. - So we do nothing here. - """ - - async def clear(self): - """ - Clear all cached clients and force disconnect them. - """ - while self._caches: - _, cache = self._caches.popitem() - await cache._disconnect(force=True) - - def new_stateful(self) -> Client[ClientTransportT]: - """ - Create a new stateful proxy client instance with the same configuration. - - Use this method as the client factory for stateful proxy server. - """ - session = get_context().session - proxy_client = self._caches.get(session, None) - - if proxy_client is None: - proxy_client = self.new() - logger.debug(f"{proxy_client} created for {session}") - self._caches[session] = proxy_client - - async def _on_session_exit(): - self._caches.pop(session) - logger.debug(f"{proxy_client} will be disconnect") - await proxy_client._disconnect(force=True) - - session._exit_stack.push_async_callback(_on_session_exit) - return proxy_client +__all__ = [ + "ClientFactoryT", + "FastMCPProxy", + "ProxyClient", + "ProxyPrompt", + "ProxyProvider", + "ProxyResource", + "ProxyTemplate", + "ProxyTool", + "StatefulProxyClient", +] diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 02eb30f5bd..d08ddfd2d4 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -99,7 +99,7 @@ from fastmcp.server.openapi import ComponentFn as OpenAPIComponentFn from fastmcp.server.openapi import FastMCPOpenAPI, RouteMap from fastmcp.server.openapi import RouteMapFn as OpenAPIRouteMapFn - from fastmcp.server.proxy import FastMCPProxy + from fastmcp.server.providers.proxy import FastMCPProxy from fastmcp.tools.tool import ToolResultSerializerType logger = get_logger(__name__) @@ -2631,7 +2631,7 @@ def mount( ) # Still honor the flag for backward compatibility if as_proxy: - from fastmcp.server.proxy import FastMCPProxy + from fastmcp.server.providers.proxy import FastMCPProxy if not isinstance(server, FastMCPProxy): server = FastMCP.as_proxy(server) @@ -2836,7 +2836,7 @@ def as_proxy( `fastmcp.client.Client` constructor. """ from fastmcp.client.client import Client - from fastmcp.server.proxy import FastMCPProxy, ProxyClient + from fastmcp.server.providers.proxy import FastMCPProxy, ProxyClient if isinstance(backend, Client): client = backend diff --git a/src/fastmcp/utilities/mcp_config.py b/src/fastmcp/utilities/mcp_config.py index 2f2c5b780a..4c54131915 100644 --- a/src/fastmcp/utilities/mcp_config.py +++ b/src/fastmcp/utilities/mcp_config.py @@ -10,7 +10,7 @@ MCPConfig, MCPServerTypes, ) -from fastmcp.server.proxy import FastMCPProxy, ProxyClient +from fastmcp.server.providers.proxy import FastMCPProxy, ProxyClient from fastmcp.server.server import FastMCP diff --git a/tests/server/proxy/test_proxy_client.py b/tests/server/proxy/test_proxy_client.py index 45a8edafc1..371fb65801 100644 --- a/tests/server/proxy/test_proxy_client.py +++ b/tests/server/proxy/test_proxy_client.py @@ -17,7 +17,7 @@ from fastmcp.client.sampling import RequestContext, SamplingMessage, SamplingParams from fastmcp.exceptions import ToolError from fastmcp.server.elicitation import AcceptedElicitation -from fastmcp.server.proxy import ProxyClient +from fastmcp.server.providers.proxy import ProxyClient @pytest.fixture @@ -402,7 +402,7 @@ async def elicitation_handler( async def test_client_factory_creates_fresh_sessions(self, fastmcp_server: FastMCP): """Test that the client factory pattern creates fresh sessions for each request.""" - from fastmcp.server.proxy import FastMCPProxy + from fastmcp.server.providers.proxy import FastMCPProxy # Create a disconnected client (should use fresh sessions per request) base_client = Client(fastmcp_server) diff --git a/tests/server/proxy/test_proxy_server.py b/tests/server/proxy/test_proxy_server.py index 0c9e3170d8..9e14c8c3e3 100644 --- a/tests/server/proxy/test_proxy_server.py +++ b/tests/server/proxy/test_proxy_server.py @@ -13,7 +13,7 @@ from fastmcp.client import Client from fastmcp.client.transports import FastMCPTransport, StreamableHttpTransport from fastmcp.exceptions import ToolError -from fastmcp.server.proxy import FastMCPProxy, ProxyClient +from fastmcp.server.providers.proxy import FastMCPProxy, ProxyClient from fastmcp.tools.tool import ToolResult from fastmcp.tools.tool_transform import ( ToolTransformConfig, @@ -248,20 +248,6 @@ def greet(name: str, extra: str = "extra") -> str: result = await client.call_tool("greet", {"name": "Marvin", "extra": "abc"}) assert result.data == "Overwritten, Marvin! abc" - async def test_proxy_errors_if_overwritten_tool_is_disabled(self, proxy_server): - """ - Test that a tool defined on the proxy is not listed if it is disabled, - and it doesn't fall back to the proxied tool with the same name - """ - - @proxy_server.tool(enabled=False) - def greet(name: str, extra: str = "extra") -> str: - return f"Overwritten, {name}! {extra}" - - async with Client(proxy_server) as client: - with pytest.raises(ToolError, match="Unknown tool"): - await client.call_tool("greet", {"name": "Marvin", "extra": "abc"}) - async def test_proxy_can_list_overwritten_tool(self, proxy_server): """ Test that a tool defined on the proxy is listed instead of the proxied tool @@ -276,20 +262,6 @@ def greet(name: str, extra: str = "extra") -> str: greet_tool = next(t for t in tools if t.name == "greet") assert "extra" in greet_tool.inputSchema["properties"] - async def test_proxy_can_list_overwritten_tool_if_disabled(self, proxy_server): - """ - Test that a tool defined on the proxy is not listed if it is disabled, - and it doesn't fall back to the proxied tool with the same name - """ - - @proxy_server.tool(enabled=False) - def greet(name: str, extra: str = "extra") -> str: - return f"Overwritten, {name}! {extra}" - - async with Client(proxy_server) as client: - tools = await client.list_tools() - assert not any(t.name == "greet" for t in tools) - class TestResources: async def test_get_resources(self, proxy_server): @@ -353,20 +325,6 @@ def overwritten_wave() -> str: assert isinstance(result[0], TextResourceContents) assert result[0].text == "Overwritten wave! 🌊" - async def test_proxy_errors_if_overwritten_resource_is_disabled(self, proxy_server): - """ - Test that a resource defined on the proxy is not accessible if it is disabled, - and it doesn't fall back to the proxied resource with the same URI - """ - - @proxy_server.resource(uri="resource://wave", enabled=False) - def overwritten_wave() -> str: - return "Overwritten wave! 🌊" - - async with Client(proxy_server) as client: - with pytest.raises(McpError, match="Unknown resource"): - await client.read_resource("resource://wave") - async def test_proxy_can_list_overwritten_resource(self, proxy_server): """ Test that a resource defined on the proxy is listed instead of the proxied resource @@ -383,21 +341,6 @@ def overwritten_wave() -> str: ) assert wave_resource.name == "overwritten_wave" - async def test_proxy_can_list_overwritten_resource_if_disabled(self, proxy_server): - """ - Test that a resource defined on the proxy is not listed if it is disabled, - and it doesn't fall back to the proxied resource with the same URI - """ - - @proxy_server.resource(uri="resource://wave", enabled=False) - def overwritten_wave() -> str: - return "Overwritten wave! 🌊" - - async with Client(proxy_server) as client: - resources = await client.list_resources() - wave_resources = [r for r in resources if str(r.uri) == "resource://wave"] - assert len(wave_resources) == 0 - class TestResourceTemplates: async def test_get_resource_templates(self, proxy_server): @@ -457,22 +400,6 @@ def overwritten_get_user(user_id: str) -> dict[str, Any]: assert user_data["name"] == "Overwritten User" assert user_data["extra"] == "data" - async def test_proxy_errors_if_overwritten_resource_template_is_disabled( - self, proxy_server - ): - """ - Test that a resource template defined on the proxy is not accessible if it is disabled, - and it doesn't fall back to the proxied template with the same URI template - """ - - @proxy_server.resource(uri="data://user/{user_id}", enabled=False) - def overwritten_get_user(user_id: str) -> dict[str, Any]: - return {"id": user_id, "name": "Overwritten User", "active": True} - - async with Client(proxy_server) as client: - with pytest.raises(McpError, match="Unknown resource"): - await client.read_resource("data://user/1") - async def test_proxy_can_list_overwritten_resource_template(self, proxy_server): """ Test that a resource template defined on the proxy is listed instead of the proxied template @@ -489,25 +416,6 @@ def overwritten_get_user(user_id: str) -> dict[str, Any]: ) assert user_template.name == "overwritten_get_user" - async def test_proxy_can_list_overwritten_resource_template_if_disabled( - self, proxy_server - ): - """ - Test that a resource template defined on the proxy is not listed if it is disabled, - and it doesn't fall back to the proxied template with the same URI template - """ - - @proxy_server.resource(uri="data://user/{user_id}", enabled=False) - def overwritten_get_user(user_id: str) -> dict[str, Any]: - return {"id": user_id, "name": "Overwritten User", "active": True} - - async with Client(proxy_server) as client: - templates = await client.list_resource_templates() - user_templates = [ - t for t in templates if t.uriTemplate == "data://user/{user_id}" - ] - assert len(user_templates) == 0 - class TestPrompts: async def test_get_prompts_server_method(self, proxy_server: FastMCPProxy): @@ -566,20 +474,6 @@ def welcome(name: str, extra: str = "friend") -> str: == "Overwritten welcome, Alice! You are my colleague." ) - async def test_proxy_errors_if_overwritten_prompt_is_disabled(self, proxy_server): - """ - Test that a prompt defined on the proxy is not accessible if it is disabled, - and it doesn't fall back to the proxied prompt with the same name - """ - - @proxy_server.prompt(enabled=False) - def welcome(name: str, extra: str = "friend") -> str: - return f"Overwritten welcome, {name}! You are my {extra}." - - async with Client(proxy_server) as client: - with pytest.raises(McpError, match="Unknown prompt"): - await client.get_prompt("welcome", {"name": "Alice"}) - async def test_proxy_can_list_overwritten_prompt(self, proxy_server): """ Test that a prompt defined on the proxy is listed instead of the proxied prompt @@ -596,21 +490,6 @@ def welcome(name: str, extra: str = "friend") -> str: param_names = [arg.name for arg in welcome_prompt.arguments or []] assert "extra" in param_names - async def test_proxy_can_list_overwritten_prompt_if_disabled(self, proxy_server): - """ - Test that a prompt defined on the proxy is not listed if it is disabled, - and it doesn't fall back to the proxied prompt with the same name - """ - - @proxy_server.prompt(enabled=False) - def welcome(name: str, extra: str = "friend") -> str: - return f"Overwritten welcome, {name}! You are my {extra}." - - async with Client(proxy_server) as client: - prompts = await client.list_prompts() - welcome_prompts = [p for p in prompts if p.name == "welcome"] - assert len(welcome_prompts) == 0 - async def test_proxy_handles_multiple_concurrent_tasks_correctly( proxy_server: FastMCPProxy, @@ -732,28 +611,6 @@ async def test_copy_creates_non_mirrored_component(self, proxy_server): local_tool.disable() assert local_tool.enabled is False - async def test_local_component_takes_precedence_over_mirrored(self, proxy_server): - """Test that local components take precedence over mirrored ones.""" - # Get the mirrored tool - tools = await proxy_server.get_tools() - mirrored_tool = tools["greet"] - - # Create a local copy and add it - local_tool = mirrored_tool.copy() - proxy_server.add_tool(local_tool) - - # Disable the local copy - local_tool.disable() - - # The local disabled tool should take precedence - updated_tools = await proxy_server.get_tools() - final_tool = updated_tools["greet"] - - # Should be the local tool (not mirrored) and disabled - assert final_tool is local_tool - assert final_tool._mirrored is False - assert final_tool.enabled is False - async def test_error_messages_mention_copy_method(self, proxy_server): """Test that error messages guide users to use copy() method.""" tools = await proxy_server.get_tools() @@ -768,30 +625,3 @@ async def test_error_messages_mention_copy_method(self, proxy_server): with pytest.raises(RuntimeError) as exc_info: mirrored_tool.disable() assert "copy()" in str(exc_info.value) - - async def test_client_cannot_call_disabled_proxy_tool(self, proxy_server): - """Test that clients cannot call a tool when local copy is disabled.""" - # Get the mirrored tool - tools = await proxy_server.get_tools() - mirrored_tool = tools["greet"] - - # Verify the tool works initially - async with Client(proxy_server) as client: - result = await client.call_tool("greet", {"name": "Alice"}) - assert result.data == "Hello, Alice!" - - # Create a local copy and disable it - local_tool = mirrored_tool.copy() - proxy_server.add_tool(local_tool) - local_tool.disable() - - # Client should now get "Unknown tool" error - async with Client(proxy_server) as client: - with pytest.raises(ToolError, match="Unknown tool"): - await client.call_tool("greet", {"name": "Alice"}) - - # Tool should not appear in tool list either - async with Client(proxy_server) as client: - tools_list = await client.list_tools() - tool_names = [tool.name for tool in tools_list] - assert "greet" not in tool_names diff --git a/tests/server/proxy/test_stateful_proxy_client.py b/tests/server/proxy/test_stateful_proxy_client.py index e99a5c537f..8ae0ce7133 100644 --- a/tests/server/proxy/test_stateful_proxy_client.py +++ b/tests/server/proxy/test_stateful_proxy_client.py @@ -8,7 +8,7 @@ from fastmcp.client.logging import LogMessage from fastmcp.client.transports import FastMCPTransport from fastmcp.exceptions import ToolError -from fastmcp.server.proxy import FastMCPProxy, StatefulProxyClient +from fastmcp.server.providers.proxy import FastMCPProxy, StatefulProxyClient from fastmcp.utilities.tests import find_available_port diff --git a/tests/server/tasks/test_task_proxy.py b/tests/server/tasks/test_task_proxy.py index 0bf37a116e..f1b41a3dc4 100644 --- a/tests/server/tasks/test_task_proxy.py +++ b/tests/server/tasks/test_task_proxy.py @@ -17,7 +17,7 @@ from fastmcp import FastMCP from fastmcp.client import Client from fastmcp.client.transports import FastMCPTransport -from fastmcp.server.proxy import ProxyClient +from fastmcp.server.providers.proxy import ProxyClient @pytest.fixture diff --git a/tests/server/test_mount.py b/tests/server/test_mount.py index 98a209a83d..06808951d4 100644 --- a/tests/server/test_mount.py +++ b/tests/server/test_mount.py @@ -9,7 +9,7 @@ from fastmcp.client import Client from fastmcp.client.transports import FastMCPTransport, SSETransport from fastmcp.server.providers import FastMCPProvider, TransformingProvider -from fastmcp.server.proxy import FastMCPProxy +from fastmcp.server.providers.proxy import FastMCPProxy from fastmcp.tools.tool import Tool from fastmcp.tools.tool_transform import TransformedTool from fastmcp.utilities.tests import caplog_for_fastmcp From 8510b1710a4427d771b30fa44f1987c8d3f4ce9c Mon Sep 17 00:00:00 2001 From: Jeremiah Lowin <153965+jlowin@users.noreply.github.com> Date: Sun, 21 Dec 2025 21:23:57 -0500 Subject: [PATCH 2/2] Remove unused Components class, simplify TaskComponents() --- src/fastmcp/server/providers/base.py | 10 ---------- src/fastmcp/server/providers/proxy.py | 2 +- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/src/fastmcp/server/providers/base.py b/src/fastmcp/server/providers/base.py index f25bc1725b..5ca83ef7bb 100644 --- a/src/fastmcp/server/providers/base.py +++ b/src/fastmcp/server/providers/base.py @@ -39,16 +39,6 @@ async def get_tool(self, name: str) -> Tool | None: from fastmcp.tools.tool import Tool, ToolResult -@dataclass -class Components: - """Collection of MCP components.""" - - tools: Sequence[Tool] = () - resources: Sequence[Resource] = () - templates: Sequence[ResourceTemplate] = () - prompts: Sequence[Prompt] = () - - @dataclass class TaskComponents: """Collection of components eligible for background task execution. diff --git a/src/fastmcp/server/providers/proxy.py b/src/fastmcp/server/providers/proxy.py index 19c72bcbfe..f0aa06c703 100644 --- a/src/fastmcp/server/providers/proxy.py +++ b/src/fastmcp/server/providers/proxy.py @@ -548,7 +548,7 @@ async def get_tasks(self) -> TaskComponents: server lifespan initialization, which would open the client before any context is set. All Proxy* components have task_config.mode="forbidden". """ - return TaskComponents(tools=[], resources=[], prompts=[]) + return TaskComponents() # lifespan() uses default implementation (empty context manager) # because client cleanup is handled per-request