diff --git a/docs/clients/sampling.mdx b/docs/clients/sampling.mdx index 1f8d29a83b..dba91b7a91 100644 --- a/docs/clients/sampling.mdx +++ b/docs/clients/sampling.mdx @@ -185,7 +185,7 @@ For full-featured sampling with tool support, use the built-in OpenAI handler. I ```python from fastmcp import Client -from fastmcp.experimental.sampling.handlers.openai import OpenAISamplingHandler +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler client = Client( "my_mcp_server.py", @@ -212,5 +212,5 @@ Tool execution happens on the server side. The client's role is to pass tools to -To implement a custom sampling handler, see the [OpenAISamplingHandler source code](https://github.com/jlowin/fastmcp/blob/main/src/fastmcp/experimental/sampling/handlers/openai.py) as a reference. +To implement a custom sampling handler, see the [OpenAISamplingHandler source code](https://github.com/jlowin/fastmcp/blob/main/src/fastmcp/client/sampling/handlers/openai.py) as a reference. \ No newline at end of file diff --git a/docs/servers/sampling.mdx b/docs/servers/sampling.mdx index ae31b3aa61..4e19b9521b 100644 --- a/docs/servers/sampling.mdx +++ b/docs/servers/sampling.mdx @@ -461,9 +461,7 @@ FastMCP provides an OpenAI-compatible handler that works with OpenAI's API and c import os from openai import OpenAI from fastmcp import FastMCP -from fastmcp.experimental.sampling.handlers.openai import ( - OpenAISamplingHandler, -) +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler server = FastMCP( name="My Server", diff --git a/examples/advanced_sampling/client_sampling_test.py b/examples/advanced_sampling/client_sampling_test.py index 102c5ba95e..31c3b5312a 100644 --- a/examples/advanced_sampling/client_sampling_test.py +++ b/examples/advanced_sampling/client_sampling_test.py @@ -21,7 +21,7 @@ from pydantic import BaseModel from fastmcp import Client, Context, FastMCP -from fastmcp.experimental.sampling.handlers.openai import OpenAISamplingHandler +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler # Create the MCP server mcp = FastMCP("Sampling Test Server") diff --git a/examples/advanced_sampling/structured_output.py b/examples/advanced_sampling/structured_output.py index 6a27568178..9fe31112b7 100644 --- a/examples/advanced_sampling/structured_output.py +++ b/examples/advanced_sampling/structured_output.py @@ -18,7 +18,7 @@ from pydantic import BaseModel from fastmcp import Client, Context, FastMCP -from fastmcp.experimental.sampling.handlers.openai import OpenAISamplingHandler +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler # Define a structured output model diff --git a/examples/advanced_sampling/tool_use.py b/examples/advanced_sampling/tool_use.py index bb38246988..3c6998dd53 100644 --- a/examples/advanced_sampling/tool_use.py +++ b/examples/advanced_sampling/tool_use.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, Field from fastmcp import Client, Context, FastMCP -from fastmcp.experimental.sampling.handlers.openai import OpenAISamplingHandler +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler # Define tools (available to the LLM during sampling) diff --git a/examples/sampling_fallback.py b/examples/sampling_fallback.py index 31a8551317..edabd2d9ee 100644 --- a/examples/sampling_fallback.py +++ b/examples/sampling_fallback.py @@ -9,7 +9,7 @@ from openai import OpenAI from fastmcp import FastMCP -from fastmcp.experimental.sampling.handlers.openai import OpenAISamplingHandler +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler from fastmcp.server.context import Context diff --git a/src/fastmcp/client/client.py b/src/fastmcp/client/client.py index 2a68a5b3af..182fd1c957 100644 --- a/src/fastmcp/client/client.py +++ b/src/fastmcp/client/client.py @@ -47,7 +47,6 @@ create_roots_callback, ) from fastmcp.client.sampling import ( - ClientSamplingHandler, SamplingHandler, create_sampling_callback, ) @@ -82,7 +81,6 @@ __all__ = [ "Client", - "ClientSamplingHandler", "ElicitationHandler", "LogHandler", "MessageHandler", @@ -248,7 +246,7 @@ def __init__( ), name: str | None = None, roots: RootsList | RootsHandler | None = None, - sampling_handler: ClientSamplingHandler | None = None, + sampling_handler: SamplingHandler | None = None, sampling_capabilities: mcp.types.SamplingCapability | None = None, elicitation_handler: ElicitationHandler | None = None, log_handler: LogHandler | None = None, @@ -368,7 +366,7 @@ def set_roots(self, roots: RootsList | RootsHandler) -> None: def set_sampling_callback( self, - sampling_callback: ClientSamplingHandler, + sampling_callback: SamplingHandler, sampling_capabilities: mcp.types.SamplingCapability | None = None, ) -> None: """Set the sampling callback for the client.""" diff --git a/src/fastmcp/client/sampling.py b/src/fastmcp/client/sampling.py deleted file mode 100644 index cf7dad77a2..0000000000 --- a/src/fastmcp/client/sampling.py +++ /dev/null @@ -1,56 +0,0 @@ -import inspect -from collections.abc import Awaitable, Callable -from typing import TypeAlias - -import mcp.types -from mcp import CreateMessageResult -from mcp.client.session import ClientSession, SamplingFnT -from mcp.shared.context import LifespanContextT, RequestContext -from mcp.types import CreateMessageRequestParams as SamplingParams -from mcp.types import SamplingMessage - -from fastmcp.server.sampling.handler import ServerSamplingHandler - -__all__ = ["SamplingHandler", "SamplingMessage", "SamplingParams"] - - -ClientSamplingHandler: TypeAlias = Callable[ - [ - list[SamplingMessage], - SamplingParams, - RequestContext[ClientSession, LifespanContextT], - ], - str | CreateMessageResult | Awaitable[str | CreateMessageResult], -] - -SamplingHandler: TypeAlias = ( - ClientSamplingHandler[LifespanContextT] | ServerSamplingHandler[LifespanContextT] -) - - -def create_sampling_callback( - sampling_handler: ClientSamplingHandler[LifespanContextT], -) -> SamplingFnT: - async def _sampling_handler( - context: RequestContext[ClientSession, LifespanContextT], - params: SamplingParams, - ) -> CreateMessageResult | mcp.types.ErrorData: - try: - result = sampling_handler(params.messages, params, context) - if inspect.isawaitable(result): - result = await result - - if isinstance(result, str): - result = CreateMessageResult( - role="assistant", - model="fastmcp-client", - content=mcp.types.TextContent(type="text", text=result), - ) - return result - except Exception as e: - return mcp.types.ErrorData( - code=mcp.types.INTERNAL_ERROR, - message=str(e), - ) - - return _sampling_handler diff --git a/src/fastmcp/client/sampling/__init__.py b/src/fastmcp/client/sampling/__init__.py new file mode 100644 index 0000000000..1cdb9ba1df --- /dev/null +++ b/src/fastmcp/client/sampling/__init__.py @@ -0,0 +1,69 @@ +import inspect +from collections.abc import Awaitable, Callable +from typing import TypeAlias, TypeVar + +import mcp.types +from mcp import ClientSession, CreateMessageResult +from mcp.client.session import SamplingFnT +from mcp.server.session import ServerSession +from mcp.shared.context import LifespanContextT, RequestContext +from mcp.types import CreateMessageRequestParams as SamplingParams +from mcp.types import CreateMessageResultWithTools, SamplingMessage + +# Result type that handlers can return +SamplingHandlerResult: TypeAlias = ( + str | CreateMessageResult | CreateMessageResultWithTools +) + +# Session type for sampling handlers - works with both client and server sessions +SessionT = TypeVar("SessionT", ClientSession, ServerSession) + +# Unified sampling handler type that works for both clients and servers. +# Handlers receive messages and parameters from the MCP sampling flow +# and return LLM responses. +SamplingHandler: TypeAlias = Callable[ + [ + list[SamplingMessage], + SamplingParams, + RequestContext[SessionT, LifespanContextT], + ], + SamplingHandlerResult | Awaitable[SamplingHandlerResult], +] + + +__all__ = [ + "RequestContext", + "SamplingHandler", + "SamplingHandlerResult", + "SamplingMessage", + "SamplingParams", + "create_sampling_callback", +] + + +def create_sampling_callback( + sampling_handler: SamplingHandler, +) -> SamplingFnT: + async def _sampling_handler( + context, + params: SamplingParams, + ) -> CreateMessageResult | CreateMessageResultWithTools | mcp.types.ErrorData: + try: + result = sampling_handler(params.messages, params, context) + if inspect.isawaitable(result): + result = await result + + if isinstance(result, str): + result = CreateMessageResult( + role="assistant", + model="fastmcp-client", + content=mcp.types.TextContent(type="text", text=result), + ) + return result + except Exception as e: + return mcp.types.ErrorData( + code=mcp.types.INTERNAL_ERROR, + message=str(e), + ) + + return _sampling_handler diff --git a/src/fastmcp/client/sampling/handlers/__init__.py b/src/fastmcp/client/sampling/handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/fastmcp/client/sampling/handlers/openai.py b/src/fastmcp/client/sampling/handlers/openai.py new file mode 100644 index 0000000000..2114fe9736 --- /dev/null +++ b/src/fastmcp/client/sampling/handlers/openai.py @@ -0,0 +1,399 @@ +"""OpenAI sampling handler for FastMCP.""" + +import json +from collections.abc import Iterator, Sequence +from typing import Any, get_args + +from mcp import ClientSession, ServerSession +from mcp.shared.context import LifespanContextT, RequestContext +from mcp.types import CreateMessageRequestParams as SamplingParams +from mcp.types import ( + CreateMessageResult, + CreateMessageResultWithTools, + ModelPreferences, + SamplingMessage, + StopReason, + TextContent, + Tool, + ToolChoice, + ToolResultContent, + ToolUseContent, +) + +try: + from openai import NOT_GIVEN, AsyncOpenAI, NotGiven + from openai.types.chat import ( + ChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolChoiceOptionParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, + ) + from openai.types.shared.chat_model import ChatModel + from openai.types.shared_params import FunctionDefinition +except ImportError as e: + raise ImportError( + "The `openai` package is not installed. " + "Please install `fastmcp[openai]` or add `openai` to your dependencies manually." + ) from e + + +class OpenAISamplingHandler: + """Sampling handler that uses the OpenAI API.""" + + def __init__( + self, + default_model: ChatModel, + client: AsyncOpenAI | None = None, + ) -> None: + self.client: AsyncOpenAI = client or AsyncOpenAI() + self.default_model: ChatModel = default_model + + async def __call__( + self, + messages: list[SamplingMessage], + params: SamplingParams, + context: RequestContext[ServerSession, LifespanContextT] + | RequestContext[ClientSession, LifespanContextT], + ) -> CreateMessageResult | CreateMessageResultWithTools: + openai_messages: list[ChatCompletionMessageParam] = ( + self._convert_to_openai_messages( + system_prompt=params.systemPrompt, + messages=messages, + ) + ) + + model: ChatModel = self._select_model_from_preferences(params.modelPreferences) + + # Convert MCP tools to OpenAI format + openai_tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN + if params.tools: + openai_tools = self._convert_tools_to_openai(params.tools) + + # Convert tool_choice to OpenAI format + openai_tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN + if params.toolChoice: + openai_tool_choice = self._convert_tool_choice_to_openai(params.toolChoice) + + response = await self.client.chat.completions.create( + model=model, + messages=openai_messages, + temperature=( + params.temperature if params.temperature is not None else NOT_GIVEN + ), + max_tokens=params.maxTokens, + stop=params.stopSequences if params.stopSequences else NOT_GIVEN, + tools=openai_tools, + tool_choice=openai_tool_choice, + ) + + # Return appropriate result type based on whether tools were provided + if params.tools: + return self._chat_completion_to_result_with_tools(response) + return self._chat_completion_to_create_message_result(response) + + @staticmethod + def _iter_models_from_preferences( + model_preferences: ModelPreferences | str | list[str] | None, + ) -> Iterator[str]: + if model_preferences is None: + return + + if isinstance(model_preferences, str) and model_preferences in get_args( + ChatModel + ): + yield model_preferences + + elif isinstance(model_preferences, list): + yield from model_preferences + + elif isinstance(model_preferences, ModelPreferences): + if not (hints := model_preferences.hints): + return + + for hint in hints: + if not (name := hint.name): + continue + + yield name + + @staticmethod + def _convert_to_openai_messages( + system_prompt: str | None, messages: Sequence[SamplingMessage] + ) -> list[ChatCompletionMessageParam]: + openai_messages: list[ChatCompletionMessageParam] = [] + + if system_prompt: + openai_messages.append( + ChatCompletionSystemMessageParam( + role="system", + content=system_prompt, + ) + ) + + for message in messages: + content = message.content + + # Handle list content (from CreateMessageResultWithTools) + if isinstance(content, list): + # Collect tool calls and text from the list + tool_calls: list[ChatCompletionMessageToolCallParam] = [] + text_parts: list[str] = [] + # Collect tool results separately to maintain correct ordering + tool_messages: list[ChatCompletionToolMessageParam] = [] + + for item in content: + if isinstance(item, ToolUseContent): + tool_calls.append( + ChatCompletionMessageToolCallParam( + id=item.id, + type="function", + function={ + "name": item.name, + "arguments": json.dumps(item.input), + }, + ) + ) + elif isinstance(item, TextContent): + text_parts.append(item.text) + elif isinstance(item, ToolResultContent): + # Collect tool results (added after assistant message) + content_text = "" + if item.content: + result_texts = [] + for sub_item in item.content: + if isinstance(sub_item, TextContent): + result_texts.append(sub_item.text) + content_text = "\n".join(result_texts) + tool_messages.append( + ChatCompletionToolMessageParam( + role="tool", + tool_call_id=item.toolUseId, + content=content_text, + ) + ) + + # Add assistant message with tool calls if present + # OpenAI requires: assistant (with tool_calls) -> tool messages + if tool_calls or text_parts: + msg_content = "\n".join(text_parts) if text_parts else None + if tool_calls: + openai_messages.append( + ChatCompletionAssistantMessageParam( + role="assistant", + content=msg_content, + tool_calls=tool_calls, + ) + ) + # Add tool messages AFTER assistant message + openai_messages.extend(tool_messages) + elif msg_content: + if message.role == "user": + openai_messages.append( + ChatCompletionUserMessageParam( + role="user", + content=msg_content, + ) + ) + else: + openai_messages.append( + ChatCompletionAssistantMessageParam( + role="assistant", + content=msg_content, + ) + ) + elif tool_messages: + # Tool results only (assistant message was in previous message) + openai_messages.extend(tool_messages) + continue + + # Handle ToolUseContent (assistant's tool calls) + if isinstance(content, ToolUseContent): + openai_messages.append( + ChatCompletionAssistantMessageParam( + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCallParam( + id=content.id, + type="function", + function={ + "name": content.name, + "arguments": json.dumps(content.input), + }, + ) + ], + ) + ) + continue + + # Handle ToolResultContent (user's tool results) + if isinstance(content, ToolResultContent): + # Extract text parts from the content list + result_texts: list[str] = [] + if content.content: + for item in content.content: + if isinstance(item, TextContent): + result_texts.append(item.text) + openai_messages.append( + ChatCompletionToolMessageParam( + role="tool", + tool_call_id=content.toolUseId, + content="\n".join(result_texts), + ) + ) + continue + + # Handle TextContent + if isinstance(content, TextContent): + if message.role == "user": + openai_messages.append( + ChatCompletionUserMessageParam( + role="user", + content=content.text, + ) + ) + else: + openai_messages.append( + ChatCompletionAssistantMessageParam( + role="assistant", + content=content.text, + ) + ) + continue + + raise ValueError(f"Unsupported content type: {type(content)}") + + return openai_messages + + @staticmethod + def _chat_completion_to_create_message_result( + chat_completion: ChatCompletion, + ) -> CreateMessageResult: + if len(chat_completion.choices) == 0: + raise ValueError("No response for completion") + + first_choice = chat_completion.choices[0] + + if content := first_choice.message.content: + return CreateMessageResult( + content=TextContent(type="text", text=content), + role="assistant", + model=chat_completion.model, + ) + + raise ValueError("No content in response from completion") + + def _select_model_from_preferences( + self, model_preferences: ModelPreferences | str | list[str] | None + ) -> ChatModel: + for model_option in self._iter_models_from_preferences(model_preferences): + if model_option in get_args(ChatModel): + chosen_model: ChatModel = model_option # type: ignore[assignment] + return chosen_model + + return self.default_model + + @staticmethod + def _convert_tools_to_openai(tools: list[Tool]) -> list[ChatCompletionToolParam]: + """Convert MCP tools to OpenAI tool format.""" + openai_tools: list[ChatCompletionToolParam] = [] + for tool in tools: + # Build parameters dict, ensuring required fields + parameters: dict[str, Any] = dict(tool.inputSchema) + if "type" not in parameters: + parameters["type"] = "object" + + openai_tools.append( + ChatCompletionToolParam( + type="function", + function=FunctionDefinition( + name=tool.name, + description=tool.description or "", + parameters=parameters, + ), + ) + ) + return openai_tools + + @staticmethod + def _convert_tool_choice_to_openai( + tool_choice: ToolChoice, + ) -> ChatCompletionToolChoiceOptionParam: + """Convert MCP tool_choice to OpenAI format.""" + if tool_choice.mode == "auto": + return "auto" + elif tool_choice.mode == "required": + return "required" + elif tool_choice.mode == "none": + return "none" + else: + raise ValueError(f"Unsupported tool_choice mode: {tool_choice.mode!r}") + + @staticmethod + def _chat_completion_to_result_with_tools( + chat_completion: ChatCompletion, + ) -> CreateMessageResultWithTools: + """Convert OpenAI response to CreateMessageResultWithTools.""" + if len(chat_completion.choices) == 0: + raise ValueError("No response for completion") + + first_choice = chat_completion.choices[0] + message = first_choice.message + + # Determine stop reason + stop_reason: StopReason + if first_choice.finish_reason == "tool_calls": + stop_reason = "toolUse" + elif first_choice.finish_reason == "stop": + stop_reason = "endTurn" + elif first_choice.finish_reason == "length": + stop_reason = "maxTokens" + else: + stop_reason = "endTurn" + + # Build content list + content: list[TextContent | ToolUseContent] = [] + + # Add text content if present + if message.content: + content.append(TextContent(type="text", text=message.content)) + + # Add tool calls if present + if message.tool_calls: + for tool_call in message.tool_calls: + # Skip non-function tool calls + if not hasattr(tool_call, "function"): + continue + func = tool_call.function # type: ignore[union-attr] + # Parse the arguments JSON string + try: + arguments = json.loads(func.arguments) # type: ignore[union-attr] + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid JSON in tool arguments for " + f"'{func.name}': {func.arguments}" # type: ignore[union-attr] + ) from e + + content.append( + ToolUseContent( + type="tool_use", + id=tool_call.id, + name=func.name, # type: ignore[union-attr] + input=arguments, + ) + ) + + # Must have at least some content + if not content: + raise ValueError("No content in response from completion") + + return CreateMessageResultWithTools( + content=content, # type: ignore[arg-type] + role="assistant", + model=chat_completion.model, + stopReason=stop_reason, + ) diff --git a/src/fastmcp/experimental/sampling/handlers/__init__.py b/src/fastmcp/experimental/sampling/handlers/__init__.py index e69de29bb2..627dfd0116 100644 --- a/src/fastmcp/experimental/sampling/handlers/__init__.py +++ b/src/fastmcp/experimental/sampling/handlers/__init__.py @@ -0,0 +1,5 @@ +# Re-export for backwards compatibility +# The canonical location is now fastmcp.client.sampling.handlers +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler + +__all__ = ["OpenAISamplingHandler"] diff --git a/src/fastmcp/experimental/sampling/handlers/base.py b/src/fastmcp/experimental/sampling/handlers/base.py deleted file mode 100644 index 0b12b99106..0000000000 --- a/src/fastmcp/experimental/sampling/handlers/base.py +++ /dev/null @@ -1,21 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Awaitable - -from mcp import ClientSession, CreateMessageResult -from mcp.server.session import ServerSession -from mcp.shared.context import LifespanContextT, RequestContext -from mcp.types import CreateMessageRequestParams as SamplingParams -from mcp.types import ( - SamplingMessage, -) - - -class BaseLLMSamplingHandler(ABC): - @abstractmethod - def __call__( - self, - messages: list[SamplingMessage], - params: SamplingParams, - context: RequestContext[ServerSession, LifespanContextT] - | RequestContext[ClientSession, LifespanContextT], - ) -> str | CreateMessageResult | Awaitable[str | CreateMessageResult]: ... diff --git a/src/fastmcp/experimental/sampling/handlers/openai.py b/src/fastmcp/experimental/sampling/handlers/openai.py index 3d327682a7..b466f7a772 100644 --- a/src/fastmcp/experimental/sampling/handlers/openai.py +++ b/src/fastmcp/experimental/sampling/handlers/openai.py @@ -1,417 +1,5 @@ -import json -from collections.abc import Iterator, Sequence -from typing import Any, get_args +# Re-export for backwards compatibility +# The canonical location is now fastmcp.client.sampling.handlers.openai +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler -from mcp import ClientSession, ServerSession -from mcp.shared.context import LifespanContextT, RequestContext -from mcp.types import CreateMessageRequestParams as SamplingParams -from mcp.types import ( - CreateMessageResult, - CreateMessageResultWithTools, - ModelPreferences, - SamplingMessage, - StopReason, - TextContent, - Tool, - ToolChoice, - ToolResultContent, - ToolUseContent, -) - -try: - from openai import NOT_GIVEN, AsyncOpenAI, NotGiven - from openai.types.chat import ( - ChatCompletion, - ChatCompletionAssistantMessageParam, - ChatCompletionMessageParam, - ChatCompletionMessageToolCallParam, - ChatCompletionSystemMessageParam, - ChatCompletionToolChoiceOptionParam, - ChatCompletionToolMessageParam, - ChatCompletionToolParam, - ChatCompletionUserMessageParam, - ) - from openai.types.shared.chat_model import ChatModel - from openai.types.shared_params import FunctionDefinition -except ImportError as e: - raise ImportError( - "The `openai` package is not installed. Please install `fastmcp[openai]` or add `openai` to your dependencies manually." - ) from e - -from typing_extensions import override - -from fastmcp.experimental.sampling.handlers.base import BaseLLMSamplingHandler - - -class OpenAISamplingHandler(BaseLLMSamplingHandler): - def __init__( - self, - default_model: ChatModel, - client: AsyncOpenAI | None = None, - ) -> None: - self.client: AsyncOpenAI = client or AsyncOpenAI() - self.default_model: ChatModel = default_model - - @override - async def __call__( - self, - messages: list[SamplingMessage], - params: SamplingParams, - context: RequestContext[ServerSession, LifespanContextT] - | RequestContext[ClientSession, LifespanContextT], - ) -> CreateMessageResult | CreateMessageResultWithTools: - openai_messages: list[ChatCompletionMessageParam] = ( - self._convert_to_openai_messages( - system_prompt=params.systemPrompt, - messages=messages, - ) - ) - - model: ChatModel = self._select_model_from_preferences(params.modelPreferences) - - # Convert MCP tools to OpenAI format - openai_tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN - if params.tools: - openai_tools = self._convert_tools_to_openai(params.tools) - - # Convert tool_choice to OpenAI format - openai_tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN - if params.toolChoice: - openai_tool_choice = self._convert_tool_choice_to_openai(params.toolChoice) - - response = await self.client.chat.completions.create( - model=model, - messages=openai_messages, - temperature=( - params.temperature if params.temperature is not None else NOT_GIVEN - ), - max_tokens=params.maxTokens, - stop=params.stopSequences if params.stopSequences else NOT_GIVEN, - tools=openai_tools, - tool_choice=openai_tool_choice, - ) - - # Return appropriate result type based on whether tools were provided - if params.tools: - return self._chat_completion_to_result_with_tools(response) - return self._chat_completion_to_create_message_result(response) - - @staticmethod - def _iter_models_from_preferences( - model_preferences: ModelPreferences | str | list[str] | None, - ) -> Iterator[str]: - if model_preferences is None: - return - - if isinstance(model_preferences, str) and model_preferences in get_args( - ChatModel - ): - yield model_preferences - - if isinstance(model_preferences, list): - yield from model_preferences - - if isinstance(model_preferences, ModelPreferences): - if not (hints := model_preferences.hints): - return - - for hint in hints: - if not (name := hint.name): - continue - - yield name - - @staticmethod - def _convert_to_openai_messages( - system_prompt: str | None, messages: Sequence[SamplingMessage] - ) -> list[ChatCompletionMessageParam]: - openai_messages: list[ChatCompletionMessageParam] = [] - - if system_prompt: - openai_messages.append( - ChatCompletionSystemMessageParam( - role="system", - content=system_prompt, - ) - ) - - if isinstance(messages, str): - openai_messages.append( - ChatCompletionUserMessageParam( - role="user", - content=messages, - ) - ) - - if isinstance(messages, list): - for message in messages: - if isinstance(message, str): - openai_messages.append( - ChatCompletionUserMessageParam( - role="user", - content=message, - ) - ) - continue - - content = message.content - - # Handle list content (from CreateMessageResultWithTools) - if isinstance(content, list): - # Collect tool calls and text from the list - tool_calls: list[ChatCompletionMessageToolCallParam] = [] - text_parts: list[str] = [] - # Collect tool results separately to maintain correct ordering - tool_messages: list[ChatCompletionToolMessageParam] = [] - - for item in content: - if isinstance(item, ToolUseContent): - tool_calls.append( - ChatCompletionMessageToolCallParam( - id=item.id, - type="function", - function={ - "name": item.name, - "arguments": json.dumps(item.input), - }, - ) - ) - elif isinstance(item, TextContent): - text_parts.append(item.text) - elif isinstance(item, ToolResultContent): - # Collect tool results (added after assistant message) - content_text = "" - if item.content: - result_texts = [] - for sub_item in item.content: - if isinstance(sub_item, TextContent): - result_texts.append(sub_item.text) - content_text = "\n".join(result_texts) - tool_messages.append( - ChatCompletionToolMessageParam( - role="tool", - tool_call_id=item.toolUseId, - content=content_text, - ) - ) - - # Add assistant message with tool calls if present - # OpenAI requires: assistant (with tool_calls) -> tool messages - if tool_calls or text_parts: - msg_content = "\n".join(text_parts) if text_parts else None - if tool_calls: - openai_messages.append( - ChatCompletionAssistantMessageParam( - role="assistant", - content=msg_content, - tool_calls=tool_calls, - ) - ) - # Add tool messages AFTER assistant message - openai_messages.extend(tool_messages) - elif msg_content: - if message.role == "user": - openai_messages.append( - ChatCompletionUserMessageParam( - role="user", - content=msg_content, - ) - ) - else: - openai_messages.append( - ChatCompletionAssistantMessageParam( - role="assistant", - content=msg_content, - ) - ) - elif tool_messages: - # Tool results only (assistant message was in previous message) - openai_messages.extend(tool_messages) - continue - - # Handle ToolUseContent (assistant's tool calls) - if isinstance(content, ToolUseContent): - openai_messages.append( - ChatCompletionAssistantMessageParam( - role="assistant", - tool_calls=[ - ChatCompletionMessageToolCallParam( - id=content.id, - type="function", - function={ - "name": content.name, - "arguments": json.dumps(content.input), - }, - ) - ], - ) - ) - continue - - # Handle ToolResultContent (user's tool results) - if isinstance(content, ToolResultContent): - # Extract text parts from the content list - result_texts: list[str] = [] - if content.content: - for item in content.content: - if isinstance(item, TextContent): - result_texts.append(item.text) - openai_messages.append( - ChatCompletionToolMessageParam( - role="tool", - tool_call_id=content.toolUseId, - content="\n".join(result_texts), - ) - ) - continue - - # Handle TextContent - if isinstance(content, TextContent): - if message.role == "user": - openai_messages.append( - ChatCompletionUserMessageParam( - role="user", - content=content.text, - ) - ) - else: - openai_messages.append( - ChatCompletionAssistantMessageParam( - role="assistant", - content=content.text, - ) - ) - continue - - raise ValueError(f"Unsupported content type: {type(content)}") - - return openai_messages - - @staticmethod - def _chat_completion_to_create_message_result( - chat_completion: ChatCompletion, - ) -> CreateMessageResult: - if len(chat_completion.choices) == 0: - raise ValueError("No response for completion") - - first_choice = chat_completion.choices[0] - - if content := first_choice.message.content: - return CreateMessageResult( - content=TextContent(type="text", text=content), - role="assistant", - model=chat_completion.model, - ) - - raise ValueError("No content in response from completion") - - def _select_model_from_preferences( - self, model_preferences: ModelPreferences | str | list[str] | None - ) -> ChatModel: - for model_option in self._iter_models_from_preferences(model_preferences): - if model_option in get_args(ChatModel): - chosen_model: ChatModel = model_option # type: ignore[assignment] - return chosen_model - - return self.default_model - - @staticmethod - def _convert_tools_to_openai(tools: list[Tool]) -> list[ChatCompletionToolParam]: - """Convert MCP tools to OpenAI tool format.""" - openai_tools: list[ChatCompletionToolParam] = [] - for tool in tools: - # Build parameters dict, ensuring required fields - parameters: dict[str, Any] = dict(tool.inputSchema) - if "type" not in parameters: - parameters["type"] = "object" - - openai_tools.append( - ChatCompletionToolParam( - type="function", - function=FunctionDefinition( - name=tool.name, - description=tool.description or "", - parameters=parameters, - ), - ) - ) - return openai_tools - - @staticmethod - def _convert_tool_choice_to_openai( - tool_choice: ToolChoice, - ) -> ChatCompletionToolChoiceOptionParam: - """Convert MCP tool_choice to OpenAI format.""" - if tool_choice.mode == "auto": - return "auto" - elif tool_choice.mode == "required": - return "required" - elif tool_choice.mode == "none": - return "none" - else: - raise ValueError(f"Unsupported tool_choice mode: {tool_choice.mode!r}") - - @staticmethod - def _chat_completion_to_result_with_tools( - chat_completion: ChatCompletion, - ) -> CreateMessageResultWithTools: - """Convert OpenAI response to CreateMessageResultWithTools.""" - if len(chat_completion.choices) == 0: - raise ValueError("No response for completion") - - first_choice = chat_completion.choices[0] - message = first_choice.message - - # Determine stop reason - stop_reason: StopReason - if first_choice.finish_reason == "tool_calls": - stop_reason = "toolUse" - elif first_choice.finish_reason == "stop": - stop_reason = "endTurn" - elif first_choice.finish_reason == "length": - stop_reason = "maxTokens" - else: - stop_reason = "endTurn" - - # Build content list - content: list[TextContent | ToolUseContent] = [] - - # Add text content if present - if message.content: - content.append(TextContent(type="text", text=message.content)) - - # Add tool calls if present - if message.tool_calls: - for tool_call in message.tool_calls: - # Skip non-function tool calls - if not hasattr(tool_call, "function"): - continue - func = tool_call.function # type: ignore[union-attr] - # Parse the arguments JSON string - try: - arguments = json.loads(func.arguments) # type: ignore[union-attr] - except json.JSONDecodeError as e: - raise ValueError( - f"Invalid JSON in tool arguments for " - f"'{func.name}': {func.arguments}" # type: ignore[union-attr] - ) from e - - content.append( - ToolUseContent( - type="tool_use", - id=tool_call.id, - name=func.name, # type: ignore[union-attr] - input=arguments, - ) - ) - - # Must have at least some content - if not content: - raise ValueError("No content in response from completion") - - return CreateMessageResultWithTools( - content=content, # type: ignore[arg-type] - role="assistant", - model=chat_completion.model, - stopReason=stop_reason, - ) +__all__ = ["OpenAISamplingHandler"] diff --git a/src/fastmcp/server/sampling/__init__.py b/src/fastmcp/server/sampling/__init__.py index afb903a9ed..392326d35c 100644 --- a/src/fastmcp/server/sampling/__init__.py +++ b/src/fastmcp/server/sampling/__init__.py @@ -1,6 +1,5 @@ """Sampling module for FastMCP servers.""" -from fastmcp.server.sampling.handler import ServerSamplingHandler from fastmcp.server.sampling.run import SampleStep, SamplingResult from fastmcp.server.sampling.sampling_tool import SamplingTool @@ -8,5 +7,4 @@ "SampleStep", "SamplingResult", "SamplingTool", - "ServerSamplingHandler", ] diff --git a/src/fastmcp/server/sampling/handler.py b/src/fastmcp/server/sampling/handler.py deleted file mode 100644 index fe462c3b1e..0000000000 --- a/src/fastmcp/server/sampling/handler.py +++ /dev/null @@ -1,22 +0,0 @@ -from collections.abc import Awaitable, Callable -from typing import TypeAlias - -from mcp import CreateMessageResult -from mcp.server.session import ServerSession -from mcp.shared.context import LifespanContextT, RequestContext -from mcp.types import CreateMessageRequestParams as SamplingParams -from mcp.types import CreateMessageResultWithTools, SamplingMessage - -# Result type that handlers can return -SamplingHandlerResult: TypeAlias = ( - str | CreateMessageResult | CreateMessageResultWithTools -) - -ServerSamplingHandler: TypeAlias = Callable[ - [ - list[SamplingMessage], - SamplingParams, - RequestContext[ServerSession, LifespanContextT], - ], - SamplingHandlerResult | Awaitable[SamplingHandlerResult], -] diff --git a/src/fastmcp/server/server.py b/src/fastmcp/server/server.py index 5e26d5be6f..eb0263d3fe 100644 --- a/src/fastmcp/server/server.py +++ b/src/fastmcp/server/server.py @@ -94,12 +94,12 @@ if TYPE_CHECKING: from fastmcp.client import Client from fastmcp.client.client import FastMCP1Server + from fastmcp.client.sampling import SamplingHandler from fastmcp.client.transports import ClientTransport, ClientTransportT from fastmcp.server.openapi import ComponentFn as OpenAPIComponentFn from fastmcp.server.openapi import FastMCPOpenAPI, RouteMap from fastmcp.server.openapi import RouteMapFn as OpenAPIRouteMapFn from fastmcp.server.proxy import FastMCPProxy - from fastmcp.server.sampling.handler import ServerSamplingHandler from fastmcp.tools.tool import ToolResultSerializerType logger = get_logger(__name__) @@ -208,7 +208,7 @@ def __init__( streamable_http_path: str | None = None, json_response: bool | None = None, stateless_http: bool | None = None, - sampling_handler: ServerSamplingHandler[LifespanResultT] | None = None, + sampling_handler: SamplingHandler | None = None, sampling_handler_behavior: Literal["always", "fallback"] | None = None, ): # Resolve server default for background task support @@ -288,9 +288,7 @@ def __init__( # Set up MCP protocol handlers self._setup_handlers() - self.sampling_handler: ServerSamplingHandler[LifespanResultT] | None = ( - sampling_handler - ) + self.sampling_handler: SamplingHandler | None = sampling_handler self.sampling_handler_behavior: Literal["always", "fallback"] = ( sampling_handler_behavior or "fallback" ) diff --git a/tests/client/sampling/__init__.py b/tests/client/sampling/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/client/sampling/handlers/__init__.py b/tests/client/sampling/handlers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/experimental/sampling/test_openai_handler.py b/tests/client/sampling/handlers/test_openai_handler.py similarity index 97% rename from tests/experimental/sampling/test_openai_handler.py rename to tests/client/sampling/handlers/test_openai_handler.py index 0c5025cde8..d5b263a81e 100644 --- a/tests/experimental/sampling/test_openai_handler.py +++ b/tests/client/sampling/handlers/test_openai_handler.py @@ -18,7 +18,7 @@ ) from openai.types.chat.chat_completion import Choice -from fastmcp.experimental.sampling.handlers.openai import OpenAISamplingHandler +from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler def test_convert_sampling_messages_to_openai_messages():