diff --git a/src/fastmcp/client/client.py b/src/fastmcp/client/client.py index 39cd459fb0..99d6231eb9 100644 --- a/src/fastmcp/client/client.py +++ b/src/fastmcp/client/client.py @@ -4,7 +4,6 @@ import copy import datetime import secrets -import uuid import weakref from collections.abc import Coroutine from contextlib import AsyncExitStack, asynccontextmanager, suppress @@ -15,23 +14,10 @@ import anyio import httpx import mcp.types -import pydantic_core from exceptiongroup import catch from mcp import ClientSession, McpError -from mcp.types import ( - CancelTaskRequest, - CancelTaskRequestParams, - GetTaskPayloadRequest, - GetTaskPayloadRequestParams, - GetTaskPayloadResult, - GetTaskRequest, - GetTaskRequestParams, - GetTaskResult, - ListTasksRequest, - PaginatedRequestParams, - TaskStatusNotification, -) -from pydantic import AnyUrl, RootModel +from mcp.types import GetTaskResult, TaskStatusNotification +from pydantic import AnyUrl import fastmcp from fastmcp.client.elicitation import ElicitationHandler, create_elicitation_callback @@ -42,6 +28,8 @@ ) from fastmcp.client.messages import MessageHandler, MessageHandlerT from fastmcp.client.progress import ProgressHandler, default_progress_handler +from fastmcp.client.prompts import ClientPromptsMixin +from fastmcp.client.resources import ClientResourcesMixin from fastmcp.client.roots import ( RootsHandler, RootsList, @@ -51,21 +39,22 @@ SamplingHandler, create_sampling_callback, ) +from fastmcp.client.task_management import ClientTaskManagementMixin from fastmcp.client.tasks import ( PromptTask, ResourceTask, TaskNotificationHandler, ToolTask, ) -from fastmcp.client.telemetry import client_span -from fastmcp.exceptions import ToolError +from fastmcp.client.tools_client import ClientToolsMixin from fastmcp.mcp_config import MCPConfig from fastmcp.server import FastMCP -from fastmcp.telemetry import inject_trace_context from fastmcp.utilities.exceptions import get_catch_handlers -from fastmcp.utilities.json_schema_type import json_schema_to_type from fastmcp.utilities.logging import get_logger -from fastmcp.utilities.types import get_cached_typeadapter +from fastmcp.utilities.timeout import ( + normalize_timeout_to_seconds, + normalize_timeout_to_timedelta, +) from .transports import ( ClientTransport, @@ -128,7 +117,13 @@ class CallToolResult: is_error: bool = False -class Client(Generic[ClientTransportT]): +class Client( + Generic[ClientTransportT], + ClientResourcesMixin, + ClientPromptsMixin, + ClientToolsMixin, + ClientTaskManagementMixin, +): """ MCP client that delegates connection management to a Transport instance. @@ -277,19 +272,12 @@ def __init__( self._progress_handler = progress_handler # Convert timeout to timedelta if needed - if isinstance(timeout, int | float): - timeout = datetime.timedelta(seconds=float(timeout)) + timeout = normalize_timeout_to_timedelta(timeout) - # handle init handshake timeout + # handle init handshake timeout (0 means disabled) if init_timeout is None: init_timeout = fastmcp.settings.client_init_timeout - if isinstance(init_timeout, datetime.timedelta): - init_timeout = init_timeout.total_seconds() - elif not init_timeout: - init_timeout = None - else: - init_timeout = float(init_timeout) - self._init_timeout = init_timeout + self._init_timeout = normalize_timeout_to_seconds(init_timeout) self.auto_initialize = auto_initialize @@ -482,12 +470,8 @@ async def initialize( if timeout is None: timeout = self._init_timeout - - # Convert timeout if needed - if isinstance(timeout, datetime.timedelta): - timeout = timeout.total_seconds() - elif timeout is not None: - timeout = float(timeout) + else: + timeout = normalize_timeout_to_seconds(timeout) try: with anyio.fail_after(timeout): @@ -794,579 +778,6 @@ async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.session.send_roots_list_changed() - # --- Resources --- - - async def list_resources_mcp( - self, *, cursor: str | None = None - ) -> mcp.types.ListResourcesResult: - """Send a resources/list request and return the complete MCP protocol result. - - Args: - cursor: Optional pagination cursor from a previous request's nextCursor. - - Returns: - mcp.types.ListResourcesResult: The complete response object from the protocol, - containing the list of resources and any additional metadata. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - logger.debug(f"[{self.name}] called list_resources") - - result = await self._await_with_session_monitoring( - self.session.list_resources(cursor=cursor) - ) - return result - - async def list_resources(self) -> list[mcp.types.Resource]: - """Retrieve all resources available on the server. - - This method automatically fetches all pages if the server paginates results, - returning the complete list. For manual pagination control (e.g., to handle - large result sets incrementally), use list_resources_mcp() with the cursor parameter. - - Returns: - list[mcp.types.Resource]: A list of all Resource objects. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - all_resources: list[mcp.types.Resource] = [] - cursor: str | None = None - - while True: - result = await self.list_resources_mcp(cursor=cursor) - all_resources.extend(result.resources) - if result.nextCursor is None: - break - cursor = result.nextCursor - - return all_resources - - async def list_resource_templates_mcp( - self, *, cursor: str | None = None - ) -> mcp.types.ListResourceTemplatesResult: - """Send a resources/listResourceTemplates request and return the complete MCP protocol result. - - Args: - cursor: Optional pagination cursor from a previous request's nextCursor. - - Returns: - mcp.types.ListResourceTemplatesResult: The complete response object from the protocol, - containing the list of resource templates and any additional metadata. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - logger.debug(f"[{self.name}] called list_resource_templates") - - result = await self._await_with_session_monitoring( - self.session.list_resource_templates(cursor=cursor) - ) - return result - - async def list_resource_templates(self) -> list[mcp.types.ResourceTemplate]: - """Retrieve all resource templates available on the server. - - This method automatically fetches all pages if the server paginates results, - returning the complete list. For manual pagination control (e.g., to handle - large result sets incrementally), use list_resource_templates_mcp() with the - cursor parameter. - - Returns: - list[mcp.types.ResourceTemplate]: A list of all ResourceTemplate objects. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - all_templates: list[mcp.types.ResourceTemplate] = [] - cursor: str | None = None - - while True: - result = await self.list_resource_templates_mcp(cursor=cursor) - all_templates.extend(result.resourceTemplates) - if result.nextCursor is None: - break - cursor = result.nextCursor - - return all_templates - - async def read_resource_mcp( - self, uri: AnyUrl | str, meta: dict[str, Any] | None = None - ) -> mcp.types.ReadResourceResult: - """Send a resources/read request and return the complete MCP protocol result. - - Args: - uri (AnyUrl | str): The URI of the resource to read. Can be a string or an AnyUrl object. - meta (dict[str, Any] | None, optional): Request metadata (e.g., for SEP-1686 tasks). Defaults to None. - - Returns: - mcp.types.ReadResourceResult: The complete response object from the protocol, - containing the resource contents and any additional metadata. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - uri_str = str(uri) - with client_span( - f"resources/read {uri_str}", - "resources/read", - uri_str, - session_id=self.transport.get_session_id(), - resource_uri=uri_str, - ): - logger.debug(f"[{self.name}] called read_resource: {uri}") - - if isinstance(uri, str): - uri = AnyUrl(uri) # Ensure AnyUrl - - # Inject trace context into meta for propagation to server - propagated_meta = inject_trace_context(meta) - - # If meta provided, use send_request for SEP-1686 task support - if propagated_meta: - task_dict = propagated_meta.get("modelcontextprotocol.io/task") - request = mcp.types.ReadResourceRequest( - params=mcp.types.ReadResourceRequestParams( - uri=uri, - task=mcp.types.TaskMetadata(**task_dict) if task_dict else None, - _meta=propagated_meta, # type: ignore[unknown-argument] # pydantic alias - ) - ) - result = await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=mcp.types.ReadResourceResult, - ) - ) - else: - result = await self._await_with_session_monitoring( - self.session.read_resource(uri) - ) - return result - - @overload - async def read_resource( - self, - uri: AnyUrl | str, - *, - version: str | None = None, - meta: dict[str, Any] | None = None, - task: Literal[False] = False, - ) -> list[mcp.types.TextResourceContents | mcp.types.BlobResourceContents]: ... - - @overload - async def read_resource( - self, - uri: AnyUrl | str, - *, - version: str | None = None, - meta: dict[str, Any] | None = None, - task: Literal[True], - task_id: str | None = None, - ttl: int = 60000, - ) -> ResourceTask: ... - - async def read_resource( - self, - uri: AnyUrl | str, - *, - version: str | None = None, - meta: dict[str, Any] | None = None, - task: bool = False, - task_id: str | None = None, - ttl: int = 60000, - ) -> ( - list[mcp.types.TextResourceContents | mcp.types.BlobResourceContents] - | ResourceTask - ): - """Read the contents of a resource or resolved template. - - Args: - uri (AnyUrl | str): The URI of the resource to read. Can be a string or an AnyUrl object. - version (str | None): Specific version to read. If None, reads highest version. - meta (dict[str, Any] | None): Optional request-level metadata. - task (bool): If True, execute as background task (SEP-1686). Defaults to False. - task_id (str | None): Optional client-provided task ID (auto-generated if not provided). - ttl (int): Time to keep results available in milliseconds (default 60s). - - Returns: - list[mcp.types.TextResourceContents | mcp.types.BlobResourceContents] | ResourceTask: - A list of content objects if task=False, or a ResourceTask object if task=True. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - # Merge version into request-level meta (not arguments) - request_meta = dict(meta) if meta else {} - if version is not None: - request_meta["fastmcp"] = { - **request_meta.get("fastmcp", {}), - "version": version, - } - - if task: - return await self._read_resource_as_task( - uri, task_id, ttl, meta=request_meta or None - ) - - if isinstance(uri, str): - try: - uri = AnyUrl(uri) # Ensure AnyUrl - except Exception as e: - raise ValueError( - f"Provided resource URI is invalid: {str(uri)!r}" - ) from e - result = await self.read_resource_mcp(uri, meta=request_meta or None) - return result.contents - - async def _read_resource_as_task( - self, - uri: AnyUrl | str, - task_id: str | None = None, - ttl: int = 60000, - meta: dict[str, Any] | None = None, - ) -> ResourceTask: - """Read a resource for background execution (SEP-1686). - - Returns a ResourceTask object that handles both background and immediate execution. - - Args: - uri: Resource URI to read - task_id: Optional client-provided task ID (ignored, for backward compatibility) - ttl: Time to keep results available in milliseconds (default 60s) - meta: Optional metadata to pass with the request (e.g., version info) - - Returns: - ResourceTask: Future-like object for accessing task status and results - """ - # Per SEP-1686 final spec: client sends only ttl, server generates taskId - if isinstance(uri, str): - uri = AnyUrl(uri) - - request = mcp.types.ReadResourceRequest( - params=mcp.types.ReadResourceRequestParams( - uri=uri, - task=mcp.types.TaskMetadata(ttl=ttl), - _meta=meta, # type: ignore[unknown-argument] # pydantic alias - ) - ) - - # Server returns CreateTaskResult (task accepted) or ReadResourceResult (graceful degradation) - TaskResponseUnion = RootModel[ - mcp.types.CreateTaskResult | mcp.types.ReadResourceResult - ] - wrapped_result = await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=TaskResponseUnion, - ) - ) - raw_result = wrapped_result.root - - if isinstance(raw_result, mcp.types.CreateTaskResult): - # Task was accepted - extract task info from CreateTaskResult - server_task_id = raw_result.task.taskId - self._submitted_task_ids.add(server_task_id) - - task_obj = ResourceTask( - self, server_task_id, uri=str(uri), immediate_result=None - ) - self._task_registry[server_task_id] = weakref.ref(task_obj) - return task_obj - else: - # Graceful degradation - server returned ReadResourceResult - synthetic_task_id = task_id or str(uuid.uuid4()) - return ResourceTask( - self, - synthetic_task_id, - uri=str(uri), - immediate_result=raw_result.contents, - ) - - # async def subscribe_resource(self, uri: AnyUrl | str) -> None: - # """Send a resources/subscribe request.""" - # if isinstance(uri, str): - # uri = AnyUrl(uri) - # await self.session.subscribe_resource(uri) - - # async def unsubscribe_resource(self, uri: AnyUrl | str) -> None: - # """Send a resources/unsubscribe request.""" - # if isinstance(uri, str): - # uri = AnyUrl(uri) - # await self.session.unsubscribe_resource(uri) - - # --- Prompts --- - - async def list_prompts_mcp( - self, *, cursor: str | None = None - ) -> mcp.types.ListPromptsResult: - """Send a prompts/list request and return the complete MCP protocol result. - - Args: - cursor: Optional pagination cursor from a previous request's nextCursor. - - Returns: - mcp.types.ListPromptsResult: The complete response object from the protocol, - containing the list of prompts and any additional metadata. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - logger.debug(f"[{self.name}] called list_prompts") - - result = await self._await_with_session_monitoring( - self.session.list_prompts(cursor=cursor) - ) - return result - - async def list_prompts(self) -> list[mcp.types.Prompt]: - """Retrieve all prompts available on the server. - - This method automatically fetches all pages if the server paginates results, - returning the complete list. For manual pagination control (e.g., to handle - large result sets incrementally), use list_prompts_mcp() with the cursor parameter. - - Returns: - list[mcp.types.Prompt]: A list of all Prompt objects. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - all_prompts: list[mcp.types.Prompt] = [] - cursor: str | None = None - - while True: - result = await self.list_prompts_mcp(cursor=cursor) - all_prompts.extend(result.prompts) - if result.nextCursor is None: - break - cursor = result.nextCursor - - return all_prompts - - # --- Prompt --- - async def get_prompt_mcp( - self, - name: str, - arguments: dict[str, Any] | None = None, - meta: dict[str, Any] | None = None, - ) -> mcp.types.GetPromptResult: - """Send a prompts/get request and return the complete MCP protocol result. - - Args: - name (str): The name of the prompt to retrieve. - arguments (dict[str, Any] | None, optional): Arguments to pass to the prompt. Defaults to None. - meta (dict[str, Any] | None, optional): Request metadata (e.g., for SEP-1686 tasks). Defaults to None. - - Returns: - mcp.types.GetPromptResult: The complete response object from the protocol, - containing the prompt messages and any additional metadata. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - with client_span( - f"prompts/get {name}", - "prompts/get", - name, - session_id=self.transport.get_session_id(), - ): - logger.debug(f"[{self.name}] called get_prompt: {name}") - - # Serialize arguments for MCP protocol - convert non-string values to JSON - serialized_arguments: dict[str, str] | None = None - if arguments: - serialized_arguments = {} - for key, value in arguments.items(): - if isinstance(value, str): - serialized_arguments[key] = value - else: - # Use pydantic_core.to_json for consistent serialization - serialized_arguments[key] = pydantic_core.to_json(value).decode( - "utf-8" - ) - - # Inject trace context into meta for propagation to server - propagated_meta = inject_trace_context(meta) - - # If meta provided, use send_request for SEP-1686 task support - if propagated_meta: - task_dict = propagated_meta.get("modelcontextprotocol.io/task") - request = mcp.types.GetPromptRequest( - params=mcp.types.GetPromptRequestParams( - name=name, - arguments=serialized_arguments, - task=mcp.types.TaskMetadata(**task_dict) if task_dict else None, - _meta=propagated_meta, # type: ignore[unknown-argument] # pydantic alias - ) - ) - result = await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=mcp.types.GetPromptResult, - ) - ) - else: - result = await self._await_with_session_monitoring( - self.session.get_prompt(name=name, arguments=serialized_arguments) - ) - return result - - @overload - async def get_prompt( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - version: str | None = None, - meta: dict[str, Any] | None = None, - task: Literal[False] = False, - ) -> mcp.types.GetPromptResult: ... - - @overload - async def get_prompt( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - version: str | None = None, - meta: dict[str, Any] | None = None, - task: Literal[True], - task_id: str | None = None, - ttl: int = 60000, - ) -> PromptTask: ... - - async def get_prompt( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - version: str | None = None, - meta: dict[str, Any] | None = None, - task: bool = False, - task_id: str | None = None, - ttl: int = 60000, - ) -> mcp.types.GetPromptResult | PromptTask: - """Retrieve a rendered prompt message list from the server. - - Args: - name (str): The name of the prompt to retrieve. - arguments (dict[str, Any] | None, optional): Arguments to pass to the prompt. Defaults to None. - version (str | None, optional): Specific prompt version to get. If None, gets highest version. - meta (dict[str, Any] | None): Optional request-level metadata. - task (bool): If True, execute as background task (SEP-1686). Defaults to False. - task_id (str | None): Optional client-provided task ID (auto-generated if not provided). - ttl (int): Time to keep results available in milliseconds (default 60s). - - Returns: - mcp.types.GetPromptResult | PromptTask: The complete response object if task=False, - or a PromptTask object if task=True. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - # Merge version into request-level meta (not arguments) - request_meta = dict(meta) if meta else {} - if version is not None: - request_meta["fastmcp"] = { - **request_meta.get("fastmcp", {}), - "version": version, - } - - if task: - return await self._get_prompt_as_task( - name, arguments, task_id, ttl, meta=request_meta or None - ) - - result = await self.get_prompt_mcp( - name=name, arguments=arguments, meta=request_meta or None - ) - return result - - async def _get_prompt_as_task( - self, - name: str, - arguments: dict[str, Any] | None = None, - task_id: str | None = None, - ttl: int = 60000, - meta: dict[str, Any] | None = None, - ) -> PromptTask: - """Get a prompt for background execution (SEP-1686). - - Returns a PromptTask object that handles both background and immediate execution. - - Args: - name: Prompt name to get - arguments: Prompt arguments - task_id: Optional client-provided task ID (ignored, for backward compatibility) - ttl: Time to keep results available in milliseconds (default 60s) - meta: Optional request metadata (e.g., version info) - - Returns: - PromptTask: Future-like object for accessing task status and results - """ - # Per SEP-1686 final spec: client sends only ttl, server generates taskId - # Serialize arguments for MCP protocol - serialized_arguments: dict[str, str] | None = None - if arguments: - serialized_arguments = {} - for key, value in arguments.items(): - if isinstance(value, str): - serialized_arguments[key] = value - else: - serialized_arguments[key] = pydantic_core.to_json(value).decode( - "utf-8" - ) - - request = mcp.types.GetPromptRequest( - params=mcp.types.GetPromptRequestParams( - name=name, - arguments=serialized_arguments, - task=mcp.types.TaskMetadata(ttl=ttl), - _meta=meta, # type: ignore[unknown-argument] # pydantic alias - ) - ) - - # Server returns CreateTaskResult (task accepted) or GetPromptResult (graceful degradation) - TaskResponseUnion = RootModel[ - mcp.types.CreateTaskResult | mcp.types.GetPromptResult - ] - wrapped_result = await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=TaskResponseUnion, - ) - ) - raw_result = wrapped_result.root - - if isinstance(raw_result, mcp.types.CreateTaskResult): - # Task was accepted - extract task info from CreateTaskResult - server_task_id = raw_result.task.taskId - self._submitted_task_ids.add(server_task_id) - - task_obj = PromptTask( - self, server_task_id, prompt_name=name, immediate_result=None - ) - self._task_registry[server_task_id] = weakref.ref(task_obj) - return task_obj - else: - # Graceful degradation - server returned GetPromptResult - synthetic_task_id = task_id or str(uuid.uuid4()) - return PromptTask( - self, synthetic_task_id, prompt_name=name, immediate_result=raw_result - ) - # --- Completion --- async def complete_mcp( @@ -1426,456 +837,6 @@ async def complete( ) return result.completion - # --- Tools --- - - async def list_tools_mcp( - self, *, cursor: str | None = None - ) -> mcp.types.ListToolsResult: - """Send a tools/list request and return the complete MCP protocol result. - - Args: - cursor: Optional pagination cursor from a previous request's nextCursor. - - Returns: - mcp.types.ListToolsResult: The complete response object from the protocol, - containing the list of tools and any additional metadata. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - logger.debug(f"[{self.name}] called list_tools") - - result = await self._await_with_session_monitoring( - self.session.list_tools(cursor=cursor) - ) - return result - - async def list_tools(self) -> list[mcp.types.Tool]: - """Retrieve all tools available on the server. - - This method automatically fetches all pages if the server paginates results, - returning the complete list. For manual pagination control (e.g., to handle - large result sets incrementally), use list_tools_mcp() with the cursor parameter. - - Returns: - list[mcp.types.Tool]: A list of all Tool objects. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the request results in a TimeoutError | JSONRPCError - """ - all_tools: list[mcp.types.Tool] = [] - cursor: str | None = None - - while True: - result = await self.list_tools_mcp(cursor=cursor) - all_tools.extend(result.tools) - if result.nextCursor is None: - break - cursor = result.nextCursor - - return all_tools - - # --- Call Tool --- - - async def call_tool_mcp( - self, - name: str, - arguments: dict[str, Any], - progress_handler: ProgressHandler | None = None, - timeout: datetime.timedelta | float | int | None = None, - meta: dict[str, Any] | None = None, - ) -> mcp.types.CallToolResult: - """Send a tools/call request and return the complete MCP protocol result. - - This method returns the raw CallToolResult object, which includes an isError flag - and other metadata. It does not raise an exception if the tool call results in an error. - - Args: - name (str): The name of the tool to call. - arguments (dict[str, Any]): Arguments to pass to the tool. - timeout (datetime.timedelta | float | int | None, optional): The timeout for the tool call. Defaults to None. - progress_handler (ProgressHandler | None, optional): The progress handler to use for the tool call. Defaults to None. - meta (dict[str, Any] | None, optional): Additional metadata to include with the request. - This is useful for passing contextual information (like user IDs, trace IDs, or preferences) - that shouldn't be tool arguments but may influence server-side processing. The server - can access this via `context.request_context.meta`. Defaults to None. - - Returns: - mcp.types.CallToolResult: The complete response object from the protocol, - containing the tool result and any additional metadata. - - Raises: - RuntimeError: If called while the client is not connected. - McpError: If the tool call requests results in a TimeoutError | JSONRPCError - """ - with client_span( - f"tools/call {name}", - "tools/call", - name, - session_id=self.transport.get_session_id(), - ): - logger.debug(f"[{self.name}] called call_tool: {name}") - - # Convert timeout to timedelta if needed - if isinstance(timeout, int | float): - timeout = datetime.timedelta(seconds=float(timeout)) - - # Inject trace context into meta for propagation to server - propagated_meta = inject_trace_context(meta) - - result = await self._await_with_session_monitoring( - self.session.call_tool( - name=name, - arguments=arguments, - read_timeout_seconds=timeout, - progress_callback=progress_handler or self._progress_handler, - meta=propagated_meta if propagated_meta else None, - ) - ) - return result - - async def _parse_call_tool_result( - self, name: str, result: mcp.types.CallToolResult, raise_on_error: bool = False - ) -> CallToolResult: - """Parse an mcp.types.CallToolResult into our CallToolResult dataclass. - - Args: - name: Tool name (for schema lookup) - result: Raw MCP protocol result - raise_on_error: Whether to raise ToolError on errors - - Returns: - CallToolResult: Parsed result with structured data - """ - data = None - if result.isError and raise_on_error: - msg = cast(mcp.types.TextContent, result.content[0]).text - raise ToolError(msg) - elif result.structuredContent: - try: - if name not in self.session._tool_output_schemas: - await self.session.list_tools() - if name in self.session._tool_output_schemas: - output_schema = self.session._tool_output_schemas.get(name) - if output_schema: - if output_schema.get("x-fastmcp-wrap-result"): - output_schema = output_schema.get("properties", {}).get( - "result" - ) - structured_content = result.structuredContent.get("result") - else: - structured_content = result.structuredContent - output_type = json_schema_to_type(output_schema) - type_adapter = get_cached_typeadapter(output_type) - data = type_adapter.validate_python(structured_content) - else: - data = result.structuredContent - except Exception as e: - logger.error(f"[{self.name}] Error parsing structured content: {e}") - - return CallToolResult( - content=result.content, - structured_content=result.structuredContent, - meta=result.meta, - data=data, - is_error=result.isError, - ) - - @overload - async def call_tool( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - version: str | None = None, - timeout: datetime.timedelta | float | int | None = None, - progress_handler: ProgressHandler | None = None, - raise_on_error: bool = True, - meta: dict[str, Any] | None = None, - task: Literal[False] = False, - ) -> CallToolResult: ... - - @overload - async def call_tool( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - version: str | None = None, - timeout: datetime.timedelta | float | int | None = None, - progress_handler: ProgressHandler | None = None, - raise_on_error: bool = True, - meta: dict[str, Any] | None = None, - task: Literal[True], - task_id: str | None = None, - ttl: int = 60000, - ) -> ToolTask: ... - - async def call_tool( - self, - name: str, - arguments: dict[str, Any] | None = None, - *, - version: str | None = None, - timeout: datetime.timedelta | float | int | None = None, - progress_handler: ProgressHandler | None = None, - raise_on_error: bool = True, - meta: dict[str, Any] | None = None, - task: bool = False, - task_id: str | None = None, - ttl: int = 60000, - ) -> CallToolResult | ToolTask: - """Call a tool on the server. - - Unlike call_tool_mcp, this method raises a ToolError if the tool call results in an error. - - Args: - name (str): The name of the tool to call. - arguments (dict[str, Any] | None, optional): Arguments to pass to the tool. Defaults to None. - version (str | None, optional): Specific tool version to call. If None, calls highest version. - timeout (datetime.timedelta | float | int | None, optional): The timeout for the tool call. Defaults to None. - progress_handler (ProgressHandler | None, optional): The progress handler to use for the tool call. Defaults to None. - raise_on_error (bool, optional): Whether to raise an exception if the tool call results in an error. Defaults to True. - meta (dict[str, Any] | None, optional): Additional metadata to include with the request. - This is useful for passing contextual information (like user IDs, trace IDs, or preferences) - that shouldn't be tool arguments but may influence server-side processing. The server - can access this via `context.request_context.meta`. Defaults to None. - task (bool): If True, execute as background task (SEP-1686). Defaults to False. - task_id (str | None): Optional client-provided task ID (auto-generated if not provided). - ttl (int): Time to keep results available in milliseconds (default 60s). - - Returns: - CallToolResult | ToolTask: The content returned by the tool if task=False, - or a ToolTask object if task=True. If the tool returns structured - outputs, they are returned as a dataclass (if an output schema - is available) or a dictionary; otherwise, a list of content - blocks is returned. Note: to receive both structured and - unstructured outputs, use call_tool_mcp instead and access the - raw result object. - - Raises: - ToolError: If the tool call results in an error. - McpError: If the tool call request results in a TimeoutError | JSONRPCError - RuntimeError: If called while the client is not connected. - """ - # Merge version into request-level meta (not arguments) - request_meta = dict(meta) if meta else {} - if version is not None: - request_meta["fastmcp"] = { - **request_meta.get("fastmcp", {}), - "version": version, - } - - if task: - return await self._call_tool_as_task( - name, arguments, task_id, ttl, meta=request_meta or None - ) - - result = await self.call_tool_mcp( - name=name, - arguments=arguments or {}, - timeout=timeout, - progress_handler=progress_handler, - meta=request_meta or None, - ) - return await self._parse_call_tool_result( - name, result, raise_on_error=raise_on_error - ) - - async def _call_tool_as_task( - self, - name: str, - arguments: dict[str, Any] | None = None, - task_id: str | None = None, - ttl: int = 60000, - meta: dict[str, Any] | None = None, - ) -> ToolTask: - """Call a tool for background execution (SEP-1686). - - Returns a ToolTask object that handles both background and immediate execution. - If the server accepts background execution, ToolTask will poll for results. - If the server declines (graceful degradation), ToolTask wraps the immediate result. - - Args: - name: Tool name to call - arguments: Tool arguments - task_id: Optional client-provided task ID (ignored, for backward compatibility) - ttl: Time to keep results available in milliseconds (default 60s) - meta: Optional request metadata (e.g., version info) - - Returns: - ToolTask: Future-like object for accessing task status and results - """ - # Per SEP-1686 final spec: client sends only ttl, server generates taskId - # Build request with task metadata - request = mcp.types.CallToolRequest( - params=mcp.types.CallToolRequestParams( - name=name, - arguments=arguments or {}, - task=mcp.types.TaskMetadata(ttl=ttl), - _meta=meta, # type: ignore[unknown-argument] # pydantic alias - ) - ) - - # Server returns CreateTaskResult (task accepted) or CallToolResult (graceful degradation) - # Use RootModel with Union to handle both response types (SDK calls model_validate) - TaskResponseUnion = RootModel[ - mcp.types.CreateTaskResult | mcp.types.CallToolResult - ] - wrapped_result = await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=TaskResponseUnion, - ) - ) - raw_result = wrapped_result.root - - if isinstance(raw_result, mcp.types.CreateTaskResult): - # Task was accepted - extract task info from CreateTaskResult - server_task_id = raw_result.task.taskId - self._submitted_task_ids.add(server_task_id) - - task_obj = ToolTask( - self, server_task_id, tool_name=name, immediate_result=None - ) - self._task_registry[server_task_id] = weakref.ref(task_obj) - return task_obj - else: - # Graceful degradation - server returned CallToolResult - parsed_result = await self._parse_call_tool_result(name, raw_result) - synthetic_task_id = task_id or str(uuid.uuid4()) - return ToolTask( - self, synthetic_task_id, tool_name=name, immediate_result=parsed_result - ) - - async def get_task_status(self, task_id: str) -> GetTaskResult: - """Query the status of a background task. - - Sends a 'tasks/get' MCP protocol request over the existing transport. - - Args: - task_id: The task ID returned from call_tool_as_task - - Returns: - GetTaskResult: Status information including taskId, status, pollInterval, etc. - - Raises: - RuntimeError: If client not connected - McpError: If the request results in a TimeoutError | JSONRPCError - """ - request = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id)) - return await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=GetTaskResult, - ) - ) - - async def get_task_result(self, task_id: str) -> Any: - """Retrieve the raw result of a completed background task. - - Sends a 'tasks/result' MCP protocol request over the existing transport. - Returns the raw result - callers should parse it appropriately. - - Args: - task_id: The task ID returned from call_tool_as_task - - Returns: - Any: The raw result (could be tool, prompt, or resource result) - - Raises: - RuntimeError: If client not connected, task not found, or task failed - McpError: If the request results in a TimeoutError | JSONRPCError - """ - request = GetTaskPayloadRequest( - params=GetTaskPayloadRequestParams(taskId=task_id) - ) - # Return raw result - Task classes handle type-specific parsing - result = await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[arg-type] - result_type=GetTaskPayloadResult, - ) - ) - # Return as dict for compatibility with Task class parsing - return result.model_dump(exclude_none=True, by_alias=True) - - async def list_tasks( - self, - cursor: str | None = None, - limit: int = 50, - ) -> dict[str, Any]: - """List background tasks. - - Sends a 'tasks/list' MCP protocol request to the server. If the server - returns an empty list (indicating client-side tracking), falls back to - querying status for locally tracked task IDs. - - Args: - cursor: Optional pagination cursor - limit: Maximum number of tasks to return (default 50) - - Returns: - dict: Response with structure: - - tasks: List of task status dicts with taskId, status, etc. - - nextCursor: Optional cursor for next page - - Raises: - RuntimeError: If client not connected - McpError: If the request results in a TimeoutError | JSONRPCError - """ - # Send protocol request - params = PaginatedRequestParams(cursor=cursor, limit=limit) # type: ignore[call-arg] # Optional field in MCP SDK - request = ListTasksRequest(params=params) - server_response = await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[invalid-argument-type] - result_type=mcp.types.ListTasksResult, - ) - ) - - # If server returned tasks, use those - if server_response.tasks: - return server_response.model_dump(by_alias=True) - - # Server returned empty - fall back to client-side tracking - tasks = [] - for task_id in list(self._submitted_task_ids)[:limit]: - try: - status = await self.get_task_status(task_id) - tasks.append(status.model_dump(by_alias=True)) - except Exception: - # Task may have expired or been deleted, skip it - continue - - return {"tasks": tasks, "nextCursor": None} - - async def cancel_task(self, task_id: str) -> mcp.types.CancelTaskResult: - """Cancel a task, transitioning it to cancelled state. - - Sends a 'tasks/cancel' MCP protocol request. Task will halt execution - and transition to cancelled state. - - Args: - task_id: The task ID to cancel - - Returns: - CancelTaskResult: The task status showing cancelled state - - Raises: - RuntimeError: If task doesn't exist - McpError: If the request results in a TimeoutError | JSONRPCError - """ - request = CancelTaskRequest(params=CancelTaskRequestParams(taskId=task_id)) - return await self._await_with_session_monitoring( - self.session.send_request( - request=request, # type: ignore[invalid-argument-type] - result_type=mcp.types.CancelTaskResult, - ) - ) - @classmethod def generate_name(cls, name: str | None = None) -> str: class_name = cls.__name__ diff --git a/src/fastmcp/client/prompts.py b/src/fastmcp/client/prompts.py new file mode 100644 index 0000000000..e7c8df80de --- /dev/null +++ b/src/fastmcp/client/prompts.py @@ -0,0 +1,295 @@ +"""Prompt-related methods for FastMCP Client.""" + +from __future__ import annotations + +import uuid +import weakref +from typing import TYPE_CHECKING, Any, Literal, overload + +import mcp.types +import pydantic_core +from pydantic import RootModel + +if TYPE_CHECKING: + from fastmcp.client.client import Client + +from fastmcp.client.tasks import PromptTask +from fastmcp.client.telemetry import client_span +from fastmcp.telemetry import inject_trace_context +from fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + +# Type alias for task response union (SEP-1686 graceful degradation) +PromptTaskResponseUnion = RootModel[ + mcp.types.CreateTaskResult | mcp.types.GetPromptResult +] + + +class ClientPromptsMixin: + """Mixin providing prompt-related methods for Client.""" + + # --- Prompts --- + + async def list_prompts_mcp( + self: Client, *, cursor: str | None = None + ) -> mcp.types.ListPromptsResult: + """Send a prompts/list request and return the complete MCP protocol result. + + Args: + cursor: Optional pagination cursor from a previous request's nextCursor. + + Returns: + mcp.types.ListPromptsResult: The complete response object from the protocol, + containing the list of prompts and any additional metadata. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + logger.debug(f"[{self.name}] called list_prompts") + + result = await self._await_with_session_monitoring( + self.session.list_prompts(cursor=cursor) + ) + return result + + async def list_prompts(self: Client) -> list[mcp.types.Prompt]: + """Retrieve all prompts available on the server. + + This method automatically fetches all pages if the server paginates results, + returning the complete list. For manual pagination control (e.g., to handle + large result sets incrementally), use list_prompts_mcp() with the cursor parameter. + + Returns: + list[mcp.types.Prompt]: A list of all Prompt objects. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + all_prompts: list[mcp.types.Prompt] = [] + cursor: str | None = None + + while True: + result = await self.list_prompts_mcp(cursor=cursor) + all_prompts.extend(result.prompts) + if result.nextCursor is None: + break + cursor = result.nextCursor + + return all_prompts + + # --- Prompt --- + async def get_prompt_mcp( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + meta: dict[str, Any] | None = None, + ) -> mcp.types.GetPromptResult: + """Send a prompts/get request and return the complete MCP protocol result. + + Args: + name (str): The name of the prompt to retrieve. + arguments (dict[str, Any] | None, optional): Arguments to pass to the prompt. Defaults to None. + meta (dict[str, Any] | None, optional): Request metadata (e.g., for SEP-1686 tasks). Defaults to None. + + Returns: + mcp.types.GetPromptResult: The complete response object from the protocol, + containing the prompt messages and any additional metadata. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + with client_span( + f"prompts/get {name}", + "prompts/get", + name, + session_id=self.transport.get_session_id(), + ): + logger.debug(f"[{self.name}] called get_prompt: {name}") + + # Serialize arguments for MCP protocol - convert non-string values to JSON + serialized_arguments: dict[str, str] | None = None + if arguments: + serialized_arguments = {} + for key, value in arguments.items(): + if isinstance(value, str): + serialized_arguments[key] = value + else: + # Use pydantic_core.to_json for consistent serialization + serialized_arguments[key] = pydantic_core.to_json(value).decode( + "utf-8" + ) + + # Inject trace context into meta for propagation to server + propagated_meta = inject_trace_context(meta) + + # If meta provided, use send_request for SEP-1686 task support + if propagated_meta: + task_dict = propagated_meta.get("modelcontextprotocol.io/task") + request = mcp.types.GetPromptRequest( + params=mcp.types.GetPromptRequestParams( + name=name, + arguments=serialized_arguments, + task=mcp.types.TaskMetadata(**task_dict) if task_dict else None, + _meta=propagated_meta, # type: ignore[unknown-argument] # pydantic alias + ) + ) + result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=mcp.types.GetPromptResult, + ) + ) + else: + result = await self._await_with_session_monitoring( + self.session.get_prompt(name=name, arguments=serialized_arguments) + ) + return result + + @overload + async def get_prompt( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + *, + version: str | None = None, + meta: dict[str, Any] | None = None, + task: Literal[False] = False, + ) -> mcp.types.GetPromptResult: ... + + @overload + async def get_prompt( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + *, + version: str | None = None, + meta: dict[str, Any] | None = None, + task: Literal[True], + task_id: str | None = None, + ttl: int = 60000, + ) -> PromptTask: ... + + async def get_prompt( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + *, + version: str | None = None, + meta: dict[str, Any] | None = None, + task: bool = False, + task_id: str | None = None, + ttl: int = 60000, + ) -> mcp.types.GetPromptResult | PromptTask: + """Retrieve a rendered prompt message list from the server. + + Args: + name (str): The name of the prompt to retrieve. + arguments (dict[str, Any] | None, optional): Arguments to pass to the prompt. Defaults to None. + version (str | None, optional): Specific prompt version to get. If None, gets highest version. + meta (dict[str, Any] | None): Optional request-level metadata. + task (bool): If True, execute as background task (SEP-1686). Defaults to False. + task_id (str | None): Optional client-provided task ID (auto-generated if not provided). + ttl (int): Time to keep results available in milliseconds (default 60s). + + Returns: + mcp.types.GetPromptResult | PromptTask: The complete response object if task=False, + or a PromptTask object if task=True. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + # Merge version into request-level meta (not arguments) + request_meta = dict(meta) if meta else {} + if version is not None: + request_meta["fastmcp"] = { + **request_meta.get("fastmcp", {}), + "version": version, + } + + if task: + return await self._get_prompt_as_task( + name, arguments, task_id, ttl, meta=request_meta or None + ) + + result = await self.get_prompt_mcp( + name=name, arguments=arguments, meta=request_meta or None + ) + return result + + async def _get_prompt_as_task( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + task_id: str | None = None, + ttl: int = 60000, + meta: dict[str, Any] | None = None, + ) -> PromptTask: + """Get a prompt for background execution (SEP-1686). + + Returns a PromptTask object that handles both background and immediate execution. + + Args: + name: Prompt name to get + arguments: Prompt arguments + task_id: Optional client-provided task ID (ignored, for backward compatibility) + ttl: Time to keep results available in milliseconds (default 60s) + meta: Optional request metadata (e.g., version info) + + Returns: + PromptTask: Future-like object for accessing task status and results + """ + # Per SEP-1686 final spec: client sends only ttl, server generates taskId + # Inject trace context into meta for propagation to server + propagated_meta = inject_trace_context(meta) + + # Serialize arguments for MCP protocol + serialized_arguments: dict[str, str] | None = None + if arguments: + serialized_arguments = {} + for key, value in arguments.items(): + if isinstance(value, str): + serialized_arguments[key] = value + else: + serialized_arguments[key] = pydantic_core.to_json(value).decode( + "utf-8" + ) + + request = mcp.types.GetPromptRequest( + params=mcp.types.GetPromptRequestParams( + name=name, + arguments=serialized_arguments, + task=mcp.types.TaskMetadata(ttl=ttl), + _meta=propagated_meta, # type: ignore[unknown-argument] # pydantic alias + ) + ) + + # Server returns CreateTaskResult (task accepted) or GetPromptResult (graceful degradation) + wrapped_result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=PromptTaskResponseUnion, + ) + ) + raw_result = wrapped_result.root + + if isinstance(raw_result, mcp.types.CreateTaskResult): + # Task was accepted - extract task info from CreateTaskResult + server_task_id = raw_result.task.taskId + self._submitted_task_ids.add(server_task_id) + + task_obj = PromptTask( + self, server_task_id, prompt_name=name, immediate_result=None + ) + self._task_registry[server_task_id] = weakref.ref(task_obj) + return task_obj + else: + # Graceful degradation - server returned GetPromptResult + synthetic_task_id = task_id or str(uuid.uuid4()) + return PromptTask( + self, synthetic_task_id, prompt_name=name, immediate_result=raw_result + ) diff --git a/src/fastmcp/client/resources.py b/src/fastmcp/client/resources.py new file mode 100644 index 0000000000..86de910d9b --- /dev/null +++ b/src/fastmcp/client/resources.py @@ -0,0 +1,325 @@ +"""Resource-related methods for FastMCP Client.""" + +from __future__ import annotations + +import uuid +import weakref +from typing import TYPE_CHECKING, Any, Literal, overload + +import mcp.types +from pydantic import AnyUrl, RootModel + +if TYPE_CHECKING: + from fastmcp.client.client import Client + +from fastmcp.client.tasks import ResourceTask +from fastmcp.client.telemetry import client_span +from fastmcp.telemetry import inject_trace_context +from fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + +# Type alias for task response union (SEP-1686 graceful degradation) +ResourceTaskResponseUnion = RootModel[ + mcp.types.CreateTaskResult | mcp.types.ReadResourceResult +] + + +class ClientResourcesMixin: + """Mixin providing resource-related methods for Client.""" + + # --- Resources --- + + async def list_resources_mcp( + self: Client, *, cursor: str | None = None + ) -> mcp.types.ListResourcesResult: + """Send a resources/list request and return the complete MCP protocol result. + + Args: + cursor: Optional pagination cursor from a previous request's nextCursor. + + Returns: + mcp.types.ListResourcesResult: The complete response object from the protocol, + containing the list of resources and any additional metadata. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + logger.debug(f"[{self.name}] called list_resources") + + result = await self._await_with_session_monitoring( + self.session.list_resources(cursor=cursor) + ) + return result + + async def list_resources(self: Client) -> list[mcp.types.Resource]: + """Retrieve all resources available on the server. + + This method automatically fetches all pages if the server paginates results, + returning the complete list. For manual pagination control (e.g., to handle + large result sets incrementally), use list_resources_mcp() with the cursor parameter. + + Returns: + list[mcp.types.Resource]: A list of all Resource objects. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + all_resources: list[mcp.types.Resource] = [] + cursor: str | None = None + + while True: + result = await self.list_resources_mcp(cursor=cursor) + all_resources.extend(result.resources) + if result.nextCursor is None: + break + cursor = result.nextCursor + + return all_resources + + async def list_resource_templates_mcp( + self: Client, *, cursor: str | None = None + ) -> mcp.types.ListResourceTemplatesResult: + """Send a resources/listResourceTemplates request and return the complete MCP protocol result. + + Args: + cursor: Optional pagination cursor from a previous request's nextCursor. + + Returns: + mcp.types.ListResourceTemplatesResult: The complete response object from the protocol, + containing the list of resource templates and any additional metadata. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + logger.debug(f"[{self.name}] called list_resource_templates") + + result = await self._await_with_session_monitoring( + self.session.list_resource_templates(cursor=cursor) + ) + return result + + async def list_resource_templates(self: Client) -> list[mcp.types.ResourceTemplate]: + """Retrieve all resource templates available on the server. + + This method automatically fetches all pages if the server paginates results, + returning the complete list. For manual pagination control (e.g., to handle + large result sets incrementally), use list_resource_templates_mcp() with the + cursor parameter. + + Returns: + list[mcp.types.ResourceTemplate]: A list of all ResourceTemplate objects. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + all_templates: list[mcp.types.ResourceTemplate] = [] + cursor: str | None = None + + while True: + result = await self.list_resource_templates_mcp(cursor=cursor) + all_templates.extend(result.resourceTemplates) + if result.nextCursor is None: + break + cursor = result.nextCursor + + return all_templates + + async def read_resource_mcp( + self: Client, uri: AnyUrl | str, meta: dict[str, Any] | None = None + ) -> mcp.types.ReadResourceResult: + """Send a resources/read request and return the complete MCP protocol result. + + Args: + uri (AnyUrl | str): The URI of the resource to read. Can be a string or an AnyUrl object. + meta (dict[str, Any] | None, optional): Request metadata (e.g., for SEP-1686 tasks). Defaults to None. + + Returns: + mcp.types.ReadResourceResult: The complete response object from the protocol, + containing the resource contents and any additional metadata. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + uri_str = str(uri) + with client_span( + f"resources/read {uri_str}", + "resources/read", + uri_str, + session_id=self.transport.get_session_id(), + resource_uri=uri_str, + ): + logger.debug(f"[{self.name}] called read_resource: {uri}") + + if isinstance(uri, str): + uri = AnyUrl(uri) # Ensure AnyUrl + + # Inject trace context into meta for propagation to server + propagated_meta = inject_trace_context(meta) + + # If meta provided, use send_request for SEP-1686 task support + if propagated_meta: + task_dict = propagated_meta.get("modelcontextprotocol.io/task") + request = mcp.types.ReadResourceRequest( + params=mcp.types.ReadResourceRequestParams( + uri=uri, + task=mcp.types.TaskMetadata(**task_dict) if task_dict else None, + _meta=propagated_meta, # type: ignore[unknown-argument] # pydantic alias + ) + ) + result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=mcp.types.ReadResourceResult, + ) + ) + else: + result = await self._await_with_session_monitoring( + self.session.read_resource(uri) + ) + return result + + @overload + async def read_resource( + self: Client, + uri: AnyUrl | str, + *, + version: str | None = None, + meta: dict[str, Any] | None = None, + task: Literal[False] = False, + ) -> list[mcp.types.TextResourceContents | mcp.types.BlobResourceContents]: ... + + @overload + async def read_resource( + self: Client, + uri: AnyUrl | str, + *, + version: str | None = None, + meta: dict[str, Any] | None = None, + task: Literal[True], + task_id: str | None = None, + ttl: int = 60000, + ) -> ResourceTask: ... + + async def read_resource( + self: Client, + uri: AnyUrl | str, + *, + version: str | None = None, + meta: dict[str, Any] | None = None, + task: bool = False, + task_id: str | None = None, + ttl: int = 60000, + ) -> ( + list[mcp.types.TextResourceContents | mcp.types.BlobResourceContents] + | ResourceTask + ): + """Read the contents of a resource or resolved template. + + Args: + uri (AnyUrl | str): The URI of the resource to read. Can be a string or an AnyUrl object. + version (str | None): Specific version to read. If None, reads highest version. + meta (dict[str, Any] | None): Optional request-level metadata. + task (bool): If True, execute as background task (SEP-1686). Defaults to False. + task_id (str | None): Optional client-provided task ID (auto-generated if not provided). + ttl (int): Time to keep results available in milliseconds (default 60s). + + Returns: + list[mcp.types.TextResourceContents | mcp.types.BlobResourceContents] | ResourceTask: + A list of content objects if task=False, or a ResourceTask object if task=True. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + # Merge version into request-level meta (not arguments) + request_meta = dict(meta) if meta else {} + if version is not None: + request_meta["fastmcp"] = { + **request_meta.get("fastmcp", {}), + "version": version, + } + + if task: + return await self._read_resource_as_task( + uri, task_id, ttl, meta=request_meta or None + ) + + if isinstance(uri, str): + try: + uri = AnyUrl(uri) # Ensure AnyUrl + except Exception as e: + raise ValueError( + f"Provided resource URI is invalid: {str(uri)!r}" + ) from e + result = await self.read_resource_mcp(uri, meta=request_meta or None) + return result.contents + + async def _read_resource_as_task( + self: Client, + uri: AnyUrl | str, + task_id: str | None = None, + ttl: int = 60000, + meta: dict[str, Any] | None = None, + ) -> ResourceTask: + """Read a resource for background execution (SEP-1686). + + Returns a ResourceTask object that handles both background and immediate execution. + + Args: + uri: Resource URI to read + task_id: Optional client-provided task ID (ignored, for backward compatibility) + ttl: Time to keep results available in milliseconds (default 60s) + meta: Optional metadata to pass with the request (e.g., version info) + + Returns: + ResourceTask: Future-like object for accessing task status and results + """ + # Per SEP-1686 final spec: client sends only ttl, server generates taskId + # Inject trace context into meta for propagation to server + propagated_meta = inject_trace_context(meta) + + if isinstance(uri, str): + uri = AnyUrl(uri) + + request = mcp.types.ReadResourceRequest( + params=mcp.types.ReadResourceRequestParams( + uri=uri, + task=mcp.types.TaskMetadata(ttl=ttl), + _meta=propagated_meta, # type: ignore[unknown-argument] # pydantic alias + ) + ) + + # Server returns CreateTaskResult (task accepted) or ReadResourceResult (graceful degradation) + wrapped_result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=ResourceTaskResponseUnion, + ) + ) + raw_result = wrapped_result.root + + if isinstance(raw_result, mcp.types.CreateTaskResult): + # Task was accepted - extract task info from CreateTaskResult + server_task_id = raw_result.task.taskId + self._submitted_task_ids.add(server_task_id) + + task_obj = ResourceTask( + self, server_task_id, uri=str(uri), immediate_result=None + ) + self._task_registry[server_task_id] = weakref.ref(task_obj) + return task_obj + else: + # Graceful degradation - server returned ReadResourceResult + synthetic_task_id = task_id or str(uuid.uuid4()) + return ResourceTask( + self, + synthetic_task_id, + uri=str(uri), + immediate_result=raw_result.contents, + ) diff --git a/src/fastmcp/client/task_management.py b/src/fastmcp/client/task_management.py new file mode 100644 index 0000000000..4c192e55e9 --- /dev/null +++ b/src/fastmcp/client/task_management.py @@ -0,0 +1,157 @@ +"""Task management methods for FastMCP Client.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import mcp.types +from mcp import McpError + +if TYPE_CHECKING: + from fastmcp.client.client import Client +from mcp.types import ( + CancelTaskRequest, + CancelTaskRequestParams, + GetTaskPayloadRequest, + GetTaskPayloadRequestParams, + GetTaskPayloadResult, + GetTaskRequest, + GetTaskRequestParams, + GetTaskResult, + ListTasksRequest, + PaginatedRequestParams, +) + +from fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class ClientTaskManagementMixin: + """Mixin providing task management methods for Client.""" + + async def get_task_status(self: Client, task_id: str) -> GetTaskResult: + """Query the status of a background task. + + Sends a 'tasks/get' MCP protocol request over the existing transport. + + Args: + task_id: The task ID returned from call_tool_as_task + + Returns: + GetTaskResult: Status information including taskId, status, pollInterval, etc. + + Raises: + RuntimeError: If client not connected + McpError: If the request results in a TimeoutError | JSONRPCError + """ + request = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id)) + return await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=GetTaskResult, + ) + ) + + async def get_task_result(self: Client, task_id: str) -> Any: + """Retrieve the raw result of a completed background task. + + Sends a 'tasks/result' MCP protocol request over the existing transport. + Returns the raw result - callers should parse it appropriately. + + Args: + task_id: The task ID returned from call_tool_as_task + + Returns: + Any: The raw result (could be tool, prompt, or resource result) + + Raises: + RuntimeError: If client not connected, task not found, or task failed + McpError: If the request results in a TimeoutError | JSONRPCError + """ + request = GetTaskPayloadRequest( + params=GetTaskPayloadRequestParams(taskId=task_id) + ) + # Return raw result - Task classes handle type-specific parsing + result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=GetTaskPayloadResult, + ) + ) + # Return as dict for compatibility with Task class parsing + return result.model_dump(exclude_none=True, by_alias=True) + + async def list_tasks( + self: Client, + cursor: str | None = None, + limit: int = 50, + ) -> dict[str, Any]: + """List background tasks. + + Sends a 'tasks/list' MCP protocol request to the server. If the server + returns an empty list (indicating client-side tracking), falls back to + querying status for locally tracked task IDs. + + Args: + cursor: Optional pagination cursor + limit: Maximum number of tasks to return (default 50) + + Returns: + dict: Response with structure: + - tasks: List of task status dicts with taskId, status, etc. + - nextCursor: Optional cursor for next page + + Raises: + RuntimeError: If client not connected + McpError: If the request results in a TimeoutError | JSONRPCError + """ + # Send protocol request + params = PaginatedRequestParams(cursor=cursor, limit=limit) # type: ignore[call-arg] # Optional field in MCP SDK + request = ListTasksRequest(params=params) + server_response = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[invalid-argument-type] + result_type=mcp.types.ListTasksResult, + ) + ) + + # If server returned tasks, use those + if server_response.tasks: + return server_response.model_dump(by_alias=True) + + # Server returned empty - fall back to client-side tracking + tasks = [] + for task_id in list(self._submitted_task_ids)[:limit]: + try: + status = await self.get_task_status(task_id) + tasks.append(status.model_dump(by_alias=True)) + except McpError: + # Task may have expired or been deleted, skip it + continue + + return {"tasks": tasks, "nextCursor": None} + + async def cancel_task(self: Client, task_id: str) -> mcp.types.CancelTaskResult: + """Cancel a task, transitioning it to cancelled state. + + Sends a 'tasks/cancel' MCP protocol request. Task will halt execution + and transition to cancelled state. + + Args: + task_id: The task ID to cancel + + Returns: + CancelTaskResult: The task status showing cancelled state + + Raises: + RuntimeError: If task doesn't exist + McpError: If the request results in a TimeoutError | JSONRPCError + """ + request = CancelTaskRequest(params=CancelTaskRequestParams(taskId=task_id)) + return await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[invalid-argument-type] + result_type=mcp.types.CancelTaskResult, + ) + ) diff --git a/src/fastmcp/client/tools_client.py b/src/fastmcp/client/tools_client.py new file mode 100644 index 0000000000..557af20be8 --- /dev/null +++ b/src/fastmcp/client/tools_client.py @@ -0,0 +1,397 @@ +"""Tool-related methods for FastMCP Client.""" + +from __future__ import annotations + +import uuid +import weakref +from typing import TYPE_CHECKING, Any, Literal, overload + +import mcp.types +from pydantic import RootModel + +if TYPE_CHECKING: + import datetime + + from fastmcp.client.client import CallToolResult, Client +from fastmcp.client.progress import ProgressHandler +from fastmcp.client.tasks import ToolTask +from fastmcp.client.telemetry import client_span +from fastmcp.exceptions import ToolError +from fastmcp.telemetry import inject_trace_context +from fastmcp.utilities.json_schema_type import json_schema_to_type +from fastmcp.utilities.logging import get_logger +from fastmcp.utilities.timeout import normalize_timeout_to_timedelta +from fastmcp.utilities.types import get_cached_typeadapter + +logger = get_logger(__name__) + +# Type alias for task response union (SEP-1686 graceful degradation) +ToolTaskResponseUnion = RootModel[mcp.types.CreateTaskResult | mcp.types.CallToolResult] + + +class ClientToolsMixin: + """Mixin providing tool-related methods for Client.""" + + # --- Tools --- + + async def list_tools_mcp( + self: Client, *, cursor: str | None = None + ) -> mcp.types.ListToolsResult: + """Send a tools/list request and return the complete MCP protocol result. + + Args: + cursor: Optional pagination cursor from a previous request's nextCursor. + + Returns: + mcp.types.ListToolsResult: The complete response object from the protocol, + containing the list of tools and any additional metadata. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + logger.debug(f"[{self.name}] called list_tools") + + result = await self._await_with_session_monitoring( + self.session.list_tools(cursor=cursor) + ) + return result + + async def list_tools(self: Client) -> list[mcp.types.Tool]: + """Retrieve all tools available on the server. + + This method automatically fetches all pages if the server paginates results, + returning the complete list. For manual pagination control (e.g., to handle + large result sets incrementally), use list_tools_mcp() with the cursor parameter. + + Returns: + list[mcp.types.Tool]: A list of all Tool objects. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the request results in a TimeoutError | JSONRPCError + """ + all_tools: list[mcp.types.Tool] = [] + cursor: str | None = None + + while True: + result = await self.list_tools_mcp(cursor=cursor) + all_tools.extend(result.tools) + if result.nextCursor is None: + break + cursor = result.nextCursor + + return all_tools + + # --- Call Tool --- + + async def call_tool_mcp( + self: Client, + name: str, + arguments: dict[str, Any], + progress_handler: ProgressHandler | None = None, + timeout: datetime.timedelta | float | int | None = None, + meta: dict[str, Any] | None = None, + ) -> mcp.types.CallToolResult: + """Send a tools/call request and return the complete MCP protocol result. + + This method returns the raw CallToolResult object, which includes an isError flag + and other metadata. It does not raise an exception if the tool call results in an error. + + Args: + name (str): The name of the tool to call. + arguments (dict[str, Any]): Arguments to pass to the tool. + timeout (datetime.timedelta | float | int | None, optional): The timeout for the tool call. Defaults to None. + progress_handler (ProgressHandler | None, optional): The progress handler to use for the tool call. Defaults to None. + meta (dict[str, Any] | None, optional): Additional metadata to include with the request. + This is useful for passing contextual information (like user IDs, trace IDs, or preferences) + that shouldn't be tool arguments but may influence server-side processing. The server + can access this via `context.request_context.meta`. Defaults to None. + + Returns: + mcp.types.CallToolResult: The complete response object from the protocol, + containing the tool result and any additional metadata. + + Raises: + RuntimeError: If called while the client is not connected. + McpError: If the tool call requests results in a TimeoutError | JSONRPCError + """ + with client_span( + f"tools/call {name}", + "tools/call", + name, + session_id=self.transport.get_session_id(), + ): + logger.debug(f"[{self.name}] called call_tool: {name}") + + # Inject trace context into meta for propagation to server + propagated_meta = inject_trace_context(meta) + + result = await self._await_with_session_monitoring( + self.session.call_tool( + name=name, + arguments=arguments, + read_timeout_seconds=normalize_timeout_to_timedelta(timeout), + progress_callback=progress_handler or self._progress_handler, + meta=propagated_meta if propagated_meta else None, + ) + ) + return result + + async def _parse_call_tool_result( + self: Client, + name: str, + result: mcp.types.CallToolResult, + raise_on_error: bool = False, + ) -> CallToolResult: + """Parse an mcp.types.CallToolResult into our CallToolResult dataclass. + + Args: + name: Tool name (for schema lookup) + result: Raw MCP protocol result + raise_on_error: Whether to raise ToolError on errors + + Returns: + CallToolResult: Parsed result with structured data + """ + + return await _parse_call_tool_result( + name=name, + result=result, + tool_output_schemas=self.session._tool_output_schemas, + list_tools_fn=self.session.list_tools, + client_name=self.name, + raise_on_error=raise_on_error, + ) + + @overload + async def call_tool( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + *, + version: str | None = None, + timeout: datetime.timedelta | float | int | None = None, + progress_handler: ProgressHandler | None = None, + raise_on_error: bool = True, + meta: dict[str, Any] | None = None, + task: Literal[False] = False, + ) -> CallToolResult: ... + + @overload + async def call_tool( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + *, + version: str | None = None, + timeout: datetime.timedelta | float | int | None = None, + progress_handler: ProgressHandler | None = None, + raise_on_error: bool = True, + meta: dict[str, Any] | None = None, + task: Literal[True], + task_id: str | None = None, + ttl: int = 60000, + ) -> ToolTask: ... + + async def call_tool( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + *, + version: str | None = None, + timeout: datetime.timedelta | float | int | None = None, + progress_handler: ProgressHandler | None = None, + raise_on_error: bool = True, + meta: dict[str, Any] | None = None, + task: bool = False, + task_id: str | None = None, + ttl: int = 60000, + ) -> CallToolResult | ToolTask: + """Call a tool on the server. + + Unlike call_tool_mcp, this method raises a ToolError if the tool call results in an error. + + Args: + name (str): The name of the tool to call. + arguments (dict[str, Any] | None, optional): Arguments to pass to the tool. Defaults to None. + version (str | None, optional): Specific tool version to call. If None, calls highest version. + timeout (datetime.timedelta | float | int | None, optional): The timeout for the tool call. Defaults to None. + progress_handler (ProgressHandler | None, optional): The progress handler to use for the tool call. Defaults to None. + raise_on_error (bool, optional): Whether to raise an exception if the tool call results in an error. Defaults to True. + meta (dict[str, Any] | None, optional): Additional metadata to include with the request. + This is useful for passing contextual information (like user IDs, trace IDs, or preferences) + that shouldn't be tool arguments but may influence server-side processing. The server + can access this via `context.request_context.meta`. Defaults to None. + task (bool): If True, execute as background task (SEP-1686). Defaults to False. + task_id (str | None): Optional client-provided task ID (auto-generated if not provided). + ttl (int): Time to keep results available in milliseconds (default 60s). + + Returns: + CallToolResult | ToolTask: The content returned by the tool if task=False, + or a ToolTask object if task=True. If the tool returns structured + outputs, they are returned as a dataclass (if an output schema + is available) or a dictionary; otherwise, a list of content + blocks is returned. Note: to receive both structured and + unstructured outputs, use call_tool_mcp instead and access the + raw result object. + + Raises: + ToolError: If the tool call results in an error. + McpError: If the tool call request results in a TimeoutError | JSONRPCError + RuntimeError: If called while the client is not connected. + """ + # Merge version into request-level meta (not arguments) + request_meta = dict(meta) if meta else {} + if version is not None: + request_meta["fastmcp"] = { + **request_meta.get("fastmcp", {}), + "version": version, + } + + if task: + return await self._call_tool_as_task( + name, arguments, task_id, ttl, meta=request_meta or None + ) + + result = await self.call_tool_mcp( + name=name, + arguments=arguments or {}, + timeout=timeout, + progress_handler=progress_handler, + meta=request_meta or None, + ) + return await self._parse_call_tool_result( + name, result, raise_on_error=raise_on_error + ) + + async def _call_tool_as_task( + self: Client, + name: str, + arguments: dict[str, Any] | None = None, + task_id: str | None = None, + ttl: int = 60000, + meta: dict[str, Any] | None = None, + ) -> ToolTask: + """Call a tool for background execution (SEP-1686). + + Returns a ToolTask object that handles both background and immediate execution. + If the server accepts background execution, ToolTask will poll for results. + If the server declines (graceful degradation), ToolTask wraps the immediate result. + + Args: + name: Tool name to call + arguments: Tool arguments + task_id: Optional client-provided task ID (ignored, for backward compatibility) + ttl: Time to keep results available in milliseconds (default 60s) + meta: Optional request metadata (e.g., version info) + + Returns: + ToolTask: Future-like object for accessing task status and results + """ + # Per SEP-1686 final spec: client sends only ttl, server generates taskId + # Inject trace context into meta for propagation to server + propagated_meta = inject_trace_context(meta) + + # Build request with task metadata + request = mcp.types.CallToolRequest( + params=mcp.types.CallToolRequestParams( + name=name, + arguments=arguments or {}, + task=mcp.types.TaskMetadata(ttl=ttl), + _meta=propagated_meta, # type: ignore[unknown-argument] # pydantic alias + ) + ) + + # Server returns CreateTaskResult (task accepted) or CallToolResult (graceful degradation) + # Use RootModel with Union to handle both response types (SDK calls model_validate) + wrapped_result = await self._await_with_session_monitoring( + self.session.send_request( + request=request, # type: ignore[arg-type] + result_type=ToolTaskResponseUnion, + ) + ) + raw_result = wrapped_result.root + + if isinstance(raw_result, mcp.types.CreateTaskResult): + # Task was accepted - extract task info from CreateTaskResult + server_task_id = raw_result.task.taskId + self._submitted_task_ids.add(server_task_id) + + task_obj = ToolTask( + self, server_task_id, tool_name=name, immediate_result=None + ) + self._task_registry[server_task_id] = weakref.ref(task_obj) + return task_obj + else: + # Graceful degradation - server returned CallToolResult + parsed_result = await self._parse_call_tool_result(name, raw_result) + synthetic_task_id = task_id or str(uuid.uuid4()) + return ToolTask( + self, + synthetic_task_id, + tool_name=name, + immediate_result=parsed_result, + ) + + +async def _parse_call_tool_result( + name: str, + result: mcp.types.CallToolResult, + tool_output_schemas: dict[str, dict[str, Any] | None], + list_tools_fn: Any, # Callable[[], Awaitable[None]] + client_name: str | None = None, + raise_on_error: bool = False, +) -> CallToolResult: + """Parse an mcp.types.CallToolResult into our CallToolResult dataclass. + + Args: + name: Tool name (for schema lookup) + result: Raw MCP protocol result + tool_output_schemas: Dictionary mapping tool names to their output schemas + list_tools_fn: Async function to refresh tool schemas if needed + client_name: Optional client name for logging + raise_on_error: Whether to raise ToolError on errors + + Returns: + CallToolResult: Parsed result with structured data + """ + from typing import cast + + from fastmcp.client.client import CallToolResult + + data = None + if result.isError and raise_on_error: + msg = cast(mcp.types.TextContent, result.content[0]).text + raise ToolError(msg) + elif result.structuredContent: + try: + if name not in tool_output_schemas: + await list_tools_fn() + if name in tool_output_schemas: + output_schema = tool_output_schemas.get(name) + if output_schema: + if output_schema.get("x-fastmcp-wrap-result"): + output_schema = output_schema.get("properties", {}).get( + "result" + ) + structured_content = result.structuredContent.get("result") + else: + structured_content = result.structuredContent + output_type = json_schema_to_type(output_schema) + type_adapter = get_cached_typeadapter(output_type) + data = type_adapter.validate_python(structured_content) + else: + data = result.structuredContent + except Exception as e: + logger.error( + f"[{client_name or 'client'}] Error parsing structured content: {e}" + ) + + return CallToolResult( + content=result.content, + structured_content=result.structuredContent, + meta=result.meta, + data=data, + is_error=result.isError, + ) diff --git a/src/fastmcp/client/transports/__init__.py b/src/fastmcp/client/transports/__init__.py index b3e9b7c383..010a7cb7ca 100644 --- a/src/fastmcp/client/transports/__init__.py +++ b/src/fastmcp/client/transports/__init__.py @@ -7,8 +7,9 @@ SessionKwargs, ) from fastmcp.client.transports.config import MCPConfigTransport -from fastmcp.client.transports.http import SSETransport, StreamableHttpTransport +from fastmcp.client.transports.http import StreamableHttpTransport from fastmcp.client.transports.inference import infer_transport +from fastmcp.client.transports.sse import SSETransport from fastmcp.client.transports.memory import FastMCPTransport from fastmcp.client.transports.stdio import ( FastMCPStdioTransport, diff --git a/src/fastmcp/client/transports/http.py b/src/fastmcp/client/transports/http.py index 69baaeb465..89ad8fc621 100644 --- a/src/fastmcp/client/transports/http.py +++ b/src/fastmcp/client/transports/http.py @@ -1,11 +1,14 @@ +"""Streamable HTTP transport for FastMCP Client.""" + +from __future__ import annotations + import contextlib import datetime from collections.abc import AsyncIterator, Callable -from typing import Any, Literal, cast +from typing import Literal, cast import httpx from mcp import ClientSession -from mcp.client.sse import sse_client from mcp.client.streamable_http import streamable_http_client from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from pydantic import AnyUrl @@ -16,76 +19,7 @@ from fastmcp.client.auth.oauth import OAuth from fastmcp.client.transports.base import ClientTransport, SessionKwargs from fastmcp.server.dependencies import get_http_headers - - -class SSETransport(ClientTransport): - """Transport implementation that connects to an MCP server via Server-Sent Events.""" - - def __init__( - self, - url: str | AnyUrl, - headers: dict[str, str] | None = None, - auth: httpx.Auth | Literal["oauth"] | str | None = None, - sse_read_timeout: datetime.timedelta | float | int | None = None, - httpx_client_factory: McpHttpClientFactory | None = None, - ): - if isinstance(url, AnyUrl): - url = str(url) - if not isinstance(url, str) or not url.startswith("http"): - raise ValueError("Invalid HTTP/S URL provided for SSE.") - - # Don't modify the URL path - respect the exact URL provided by the user - # Some servers are strict about trailing slashes (e.g., PayPal MCP) - - self.url: str = url - self.headers = headers or {} - self.httpx_client_factory = httpx_client_factory - self._set_auth(auth) - - if isinstance(sse_read_timeout, int | float): - sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout)) - self.sse_read_timeout = sse_read_timeout - - def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): - if auth == "oauth": - auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) - elif isinstance(auth, str): - auth = BearerAuth(auth) - self.auth = auth - - @contextlib.asynccontextmanager - async def connect_session( - self, **session_kwargs: Unpack[SessionKwargs] - ) -> AsyncIterator[ClientSession]: - client_kwargs: dict[str, Any] = {} - - # load headers from an active HTTP request, if available. This will only be true - # if the client is used in a FastMCP Proxy, in which case the MCP client headers - # need to be forwarded to the remote server. - client_kwargs["headers"] = get_http_headers() | self.headers - - # sse_read_timeout has a default value set, so we can't pass None without overriding it - # instead we simply leave the kwarg out if it's not provided - if self.sse_read_timeout is not None: - client_kwargs["sse_read_timeout"] = self.sse_read_timeout.total_seconds() - if session_kwargs.get("read_timeout_seconds") is not None: - read_timeout_seconds = cast( - datetime.timedelta, session_kwargs.get("read_timeout_seconds") - ) - client_kwargs["timeout"] = read_timeout_seconds.total_seconds() - - if self.httpx_client_factory is not None: - client_kwargs["httpx_client_factory"] = self.httpx_client_factory - - async with sse_client(self.url, auth=self.auth, **client_kwargs) as transport: - read_stream, write_stream = transport - async with ClientSession( - read_stream, write_stream, **session_kwargs - ) as session: - yield session - - def __repr__(self) -> str: - return f"" +from fastmcp.utilities.timeout import normalize_timeout_to_timedelta class StreamableHttpTransport(ClientTransport): @@ -137,9 +71,7 @@ def __init__( DeprecationWarning, stacklevel=2, ) - if isinstance(sse_read_timeout, int | float): - sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout)) - self.sse_read_timeout = sse_read_timeout + self.sse_read_timeout = normalize_timeout_to_timedelta(sse_read_timeout) self._get_session_id_cb: Callable[[], str | None] | None = None diff --git a/src/fastmcp/client/transports/inference.py b/src/fastmcp/client/transports/inference.py index de26e37d0f..438995c239 100644 --- a/src/fastmcp/client/transports/inference.py +++ b/src/fastmcp/client/transports/inference.py @@ -6,8 +6,9 @@ from fastmcp.client.transports.base import ClientTransport, ClientTransportT from fastmcp.client.transports.config import MCPConfigTransport -from fastmcp.client.transports.http import SSETransport, StreamableHttpTransport +from fastmcp.client.transports.http import StreamableHttpTransport from fastmcp.client.transports.memory import FastMCPTransport +from fastmcp.client.transports.sse import SSETransport from fastmcp.client.transports.stdio import NodeStdioTransport, PythonStdioTransport from fastmcp.mcp_config import MCPConfig, infer_transport_type_from_url from fastmcp.server.server import FastMCP diff --git a/src/fastmcp/client/transports/sse.py b/src/fastmcp/client/transports/sse.py new file mode 100644 index 0000000000..ec932e6d2d --- /dev/null +++ b/src/fastmcp/client/transports/sse.py @@ -0,0 +1,89 @@ +"""Server-Sent Events (SSE) transport for FastMCP Client.""" + +from __future__ import annotations + +import contextlib +import datetime +from collections.abc import AsyncIterator +from typing import Any, Literal, cast + +import httpx +from mcp import ClientSession +from mcp.client.sse import sse_client +from mcp.shared._httpx_utils import McpHttpClientFactory +from pydantic import AnyUrl +from typing_extensions import Unpack + +from fastmcp.client.auth.bearer import BearerAuth +from fastmcp.client.auth.oauth import OAuth +from fastmcp.client.transports.base import ClientTransport, SessionKwargs +from fastmcp.server.dependencies import get_http_headers +from fastmcp.utilities.timeout import normalize_timeout_to_timedelta + + +class SSETransport(ClientTransport): + """Transport implementation that connects to an MCP server via Server-Sent Events.""" + + def __init__( + self, + url: str | AnyUrl, + headers: dict[str, str] | None = None, + auth: httpx.Auth | Literal["oauth"] | str | None = None, + sse_read_timeout: datetime.timedelta | float | int | None = None, + httpx_client_factory: McpHttpClientFactory | None = None, + ): + if isinstance(url, AnyUrl): + url = str(url) + if not isinstance(url, str) or not url.startswith("http"): + raise ValueError("Invalid HTTP/S URL provided for SSE.") + + # Don't modify the URL path - respect the exact URL provided by the user + # Some servers are strict about trailing slashes (e.g., PayPal MCP) + + self.url: str = url + self.headers = headers or {} + self.httpx_client_factory = httpx_client_factory + self._set_auth(auth) + + self.sse_read_timeout = normalize_timeout_to_timedelta(sse_read_timeout) + + def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): + if auth == "oauth": + auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) + elif isinstance(auth, str): + auth = BearerAuth(auth) + self.auth = auth + + @contextlib.asynccontextmanager + async def connect_session( + self, **session_kwargs: Unpack[SessionKwargs] + ) -> AsyncIterator[ClientSession]: + client_kwargs: dict[str, Any] = {} + + # load headers from an active HTTP request, if available. This will only be true + # if the client is used in a FastMCP Proxy, in which case the MCP client headers + # need to be forwarded to the remote server. + client_kwargs["headers"] = get_http_headers() | self.headers + + # sse_read_timeout has a default value set, so we can't pass None without overriding it + # instead we simply leave the kwarg out if it's not provided + if self.sse_read_timeout is not None: + client_kwargs["sse_read_timeout"] = self.sse_read_timeout.total_seconds() + if session_kwargs.get("read_timeout_seconds") is not None: + read_timeout_seconds = cast( + datetime.timedelta, session_kwargs.get("read_timeout_seconds") + ) + client_kwargs["timeout"] = read_timeout_seconds.total_seconds() + + if self.httpx_client_factory is not None: + client_kwargs["httpx_client_factory"] = self.httpx_client_factory + + async with sse_client(self.url, auth=self.auth, **client_kwargs) as transport: + read_stream, write_stream = transport + async with ClientSession( + read_stream, write_stream, **session_kwargs + ) as session: + yield session + + def __repr__(self) -> str: + return f"" diff --git a/src/fastmcp/utilities/timeout.py b/src/fastmcp/utilities/timeout.py new file mode 100644 index 0000000000..b129978070 --- /dev/null +++ b/src/fastmcp/utilities/timeout.py @@ -0,0 +1,47 @@ +"""Timeout normalization utilities.""" + +from __future__ import annotations + +import datetime + + +def normalize_timeout_to_timedelta( + value: int | float | datetime.timedelta | None, +) -> datetime.timedelta | None: + """Normalize a timeout value to a timedelta. + + Args: + value: Timeout value as int/float (seconds), timedelta, or None + + Returns: + timedelta if value provided, None otherwise + """ + if value is None: + return None + if isinstance(value, datetime.timedelta): + return value + if isinstance(value, int | float): + return datetime.timedelta(seconds=float(value)) + raise TypeError(f"Invalid timeout type: {type(value)}") + + +def normalize_timeout_to_seconds( + value: int | float | datetime.timedelta | None, +) -> float | None: + """Normalize a timeout value to seconds (float). + + Args: + value: Timeout value as int/float (seconds), timedelta, or None. + Zero values are treated as "disabled" and return None. + + Returns: + float seconds if value provided and non-zero, None otherwise + """ + if value is None: + return None + if isinstance(value, datetime.timedelta): + seconds = value.total_seconds() + return None if seconds == 0 else seconds + if isinstance(value, int | float): + return None if value == 0 else float(value) + raise TypeError(f"Invalid timeout type: {type(value)}") diff --git a/tests/client/tasks/test_client_prompt_tasks.py b/tests/client/tasks/test_client_prompt_tasks.py index e1154e95f6..fd8445c408 100644 --- a/tests/client/tasks/test_client_prompt_tasks.py +++ b/tests/client/tasks/test_client_prompt_tasks.py @@ -8,6 +8,7 @@ from fastmcp import FastMCP from fastmcp.client import Client +from fastmcp.client.tasks import PromptTask @pytest.fixture @@ -33,8 +34,6 @@ async def test_get_prompt_as_task_returns_prompt_task(prompt_server): async with Client(prompt_server) as client: task = await client.get_prompt("analysis_prompt", {"topic": "AI"}, task=True) - from fastmcp.client.client import PromptTask - assert isinstance(task, PromptTask) assert isinstance(task.task_id, str) diff --git a/tests/client/tasks/test_client_resource_tasks.py b/tests/client/tasks/test_client_resource_tasks.py index 20d59cd3d1..be7c4ddd3b 100644 --- a/tests/client/tasks/test_client_resource_tasks.py +++ b/tests/client/tasks/test_client_resource_tasks.py @@ -8,6 +8,7 @@ from fastmcp import FastMCP from fastmcp.client import Client +from fastmcp.client.tasks import ResourceTask @pytest.fixture @@ -33,8 +34,6 @@ async def test_read_resource_as_task_returns_resource_task(resource_server): async with Client(resource_server) as client: task = await client.read_resource("file://document.txt", task=True) - from fastmcp.client.client import ResourceTask - assert isinstance(task, ResourceTask) assert isinstance(task.task_id, str) diff --git a/tests/client/tasks/test_client_tool_tasks.py b/tests/client/tasks/test_client_tool_tasks.py index 33dc1ce9a9..e5f3c81745 100644 --- a/tests/client/tasks/test_client_tool_tasks.py +++ b/tests/client/tasks/test_client_tool_tasks.py @@ -9,6 +9,7 @@ from fastmcp import FastMCP from fastmcp.client import Client +from fastmcp.client.tasks import ToolTask @pytest.fixture @@ -34,8 +35,6 @@ async def test_call_tool_as_task_returns_tool_task(tool_task_server): async with Client(tool_task_server) as client: task = await client.call_tool("echo", {"message": "hello"}, task=True) - from fastmcp.client.client import ToolTask - assert isinstance(task, ToolTask) assert isinstance(task.task_id, str) assert len(task.task_id) > 0 diff --git a/tests/server/tasks/test_task_prompts.py b/tests/server/tasks/test_task_prompts.py index 105de14bdf..1a02bec83b 100644 --- a/tests/server/tasks/test_task_prompts.py +++ b/tests/server/tasks/test_task_prompts.py @@ -8,6 +8,7 @@ from fastmcp import FastMCP from fastmcp.client import Client +from fastmcp.client.tasks import PromptTask @pytest.fixture @@ -45,8 +46,6 @@ async def test_prompt_with_task_metadata_returns_immediately(prompt_server): task = await client.get_prompt("background_prompt", {"topic": "AI"}, task=True) # Should return a PromptTask object immediately - from fastmcp.client.client import PromptTask - assert isinstance(task, PromptTask) assert isinstance(task.task_id, str) assert len(task.task_id) > 0 diff --git a/tests/server/tasks/test_task_resources.py b/tests/server/tasks/test_task_resources.py index df02caece7..ed240eb061 100644 --- a/tests/server/tasks/test_task_resources.py +++ b/tests/server/tasks/test_task_resources.py @@ -8,6 +8,7 @@ from fastmcp import FastMCP from fastmcp.client import Client +from fastmcp.client.tasks import ResourceTask @pytest.fixture @@ -50,8 +51,6 @@ async def test_resource_with_task_metadata_returns_immediately(resource_server): task = await client.read_resource("file://large.txt", task=True) # Should return a ResourceTask object immediately - from fastmcp.client.client import ResourceTask - assert isinstance(task, ResourceTask) assert isinstance(task.task_id, str) assert len(task.task_id) > 0 diff --git a/tests/server/tasks/test_task_tools.py b/tests/server/tasks/test_task_tools.py index ef4d96a59b..530f6b0e6f 100644 --- a/tests/server/tasks/test_task_tools.py +++ b/tests/server/tasks/test_task_tools.py @@ -11,6 +11,7 @@ from fastmcp import FastMCP from fastmcp.client import Client +from fastmcp.client.tasks import ToolTask @pytest.fixture @@ -49,8 +50,6 @@ async def test_tool_with_task_metadata_returns_immediately(tool_server): assert task assert not task.returned_immediately - from fastmcp.client.client import ToolTask - assert isinstance(task, ToolTask) assert isinstance(task.task_id, str) assert len(task.task_id) > 0