diff --git a/src/fastmcp/client/sampling/handlers/anthropic.py b/src/fastmcp/client/sampling/handlers/anthropic.py index 4bef921b39..176fe64a4d 100644 --- a/src/fastmcp/client/sampling/handlers/anthropic.py +++ b/src/fastmcp/client/sampling/handlers/anthropic.py @@ -3,10 +3,11 @@ from collections.abc import Iterator, Sequence from typing import Any -from mcp.types import CreateMessageRequestParams as SamplingParams from mcp.types import ( + AudioContent, CreateMessageResult, CreateMessageResultWithTools, + ImageContent, ModelPreferences, SamplingMessage, SamplingMessageContentBlock, @@ -17,10 +18,13 @@ ToolResultContent, ToolUseContent, ) +from mcp.types import CreateMessageRequestParams as SamplingParams try: from anthropic import AsyncAnthropic from anthropic.types import ( + Base64ImageSourceParam, + ImageBlockParam, Message, MessageParam, TextBlock, @@ -42,6 +46,28 @@ __all__ = ["AnthropicSamplingHandler"] +# Anthropic supports these image MIME types +_ANTHROPIC_IMAGE_MEDIA_TYPES = frozenset( + {"image/jpeg", "image/png", "image/gif", "image/webp"} +) + + +def _image_content_to_anthropic_block(content: ImageContent) -> ImageBlockParam: + """Convert MCP ImageContent to Anthropic ImageBlockParam.""" + if content.mimeType not in _ANTHROPIC_IMAGE_MEDIA_TYPES: + raise ValueError( + f"Unsupported image MIME type for Anthropic: {content.mimeType!r}. " + f"Supported types: {', '.join(sorted(_ANTHROPIC_IMAGE_MEDIA_TYPES))}" + ) + return ImageBlockParam( + type="image", + source=Base64ImageSourceParam( + type="base64", + media_type=content.mimeType, # type: ignore[arg-type] + data=content.data, + ), + ) + class AnthropicSamplingHandler: """Sampling handler that uses the Anthropic API. @@ -155,7 +181,10 @@ def _convert_to_anthropic_messages( # Handle list content (from CreateMessageResultWithTools) if isinstance(content, list): content_blocks: list[ - TextBlockParam | ToolUseBlockParam | ToolResultBlockParam + TextBlockParam + | ImageBlockParam + | ToolUseBlockParam + | ToolResultBlockParam ] = [] for item in content: @@ -172,6 +201,17 @@ def _convert_to_anthropic_messages( content_blocks.append( TextBlockParam(type="text", text=item.text) ) + elif isinstance(item, ImageContent): + if message.role != "user": + raise ValueError( + "ImageContent is only supported in user messages " + "for Anthropic" + ) + content_blocks.append(_image_content_to_anthropic_block(item)) + elif isinstance(item, AudioContent): + raise ValueError( + "AudioContent is not supported by the Anthropic API" + ) elif isinstance(item, ToolResultContent): # Extract text content from the result result_content: str | list[TextBlockParam] = "" @@ -262,6 +302,24 @@ def _convert_to_anthropic_messages( ) continue + # Handle ImageContent + if isinstance(content, ImageContent): + if message.role != "user": + raise ValueError( + "ImageContent is only supported in user messages for Anthropic" + ) + anthropic_messages.append( + MessageParam( + role="user", + content=[_image_content_to_anthropic_block(content)], + ) + ) + continue + + # Handle AudioContent - not supported by Anthropic + if isinstance(content, AudioContent): + raise ValueError("AudioContent is not supported by the Anthropic API") + raise ValueError(f"Unsupported content type: {type(content)}") return anthropic_messages diff --git a/src/fastmcp/client/sampling/handlers/google_genai.py b/src/fastmcp/client/sampling/handlers/google_genai.py index c072d51316..ad1a3d1e80 100644 --- a/src/fastmcp/client/sampling/handlers/google_genai.py +++ b/src/fastmcp/client/sampling/handlers/google_genai.py @@ -1,11 +1,13 @@ """Google GenAI sampling handler with tool support for FastMCP 3.0.""" +import base64 from collections.abc import Sequence from uuid import uuid4 try: from google.genai import Client as GoogleGenaiClient from google.genai.types import ( + Blob, Candidate, Content, FunctionCall, @@ -197,6 +199,22 @@ def _sampling_content_to_google_genai_part( if isinstance(content, TextContent): return Part(text=content.text) + if isinstance(content, ImageContent): + return Part( + inline_data=Blob( + data=base64.b64decode(content.data), + mime_type=content.mimeType, + ) + ) + + if isinstance(content, AudioContent): + return Part( + inline_data=Blob( + data=base64.b64decode(content.data), + mime_type=content.mimeType, + ) + ) + if isinstance(content, ToolUseContent): # Note: thought_signature bypass is required for manually constructed tool calls. # Google's Gemini 3+ models enforce thought signature validation for function calls. diff --git a/src/fastmcp/client/sampling/handlers/openai.py b/src/fastmcp/client/sampling/handlers/openai.py index 38f5ff356b..bac082786d 100644 --- a/src/fastmcp/client/sampling/handlers/openai.py +++ b/src/fastmcp/client/sampling/handlers/openai.py @@ -6,10 +6,11 @@ from mcp import ClientSession, ServerSession from mcp.shared.context import LifespanContextT, RequestContext -from mcp.types import CreateMessageRequestParams as SamplingParams from mcp.types import ( + AudioContent, CreateMessageResult, CreateMessageResultWithTools, + ImageContent, ModelPreferences, SamplingMessage, StopReason, @@ -19,12 +20,17 @@ ToolResultContent, ToolUseContent, ) +from mcp.types import CreateMessageRequestParams as SamplingParams try: from openai import AsyncOpenAI from openai.types.chat import ( ChatCompletion, ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam, + ChatCompletionContentPartParam, + ChatCompletionContentPartTextParam, ChatCompletionMessageParam, ChatCompletionMessageToolCallParam, ChatCompletionSystemMessageParam, @@ -41,6 +47,50 @@ "Please install `fastmcp[openai]` or add `openai` to your dependencies manually." ) from e +# OpenAI only supports wav and mp3 for input audio +_OPENAI_AUDIO_FORMATS: dict[str, str] = { + "audio/wav": "wav", + "audio/x-wav": "wav", + "audio/mp3": "mp3", + "audio/mpeg": "mp3", +} + +_OPENAI_IMAGE_MEDIA_TYPES: frozenset[str] = frozenset( + {"image/jpeg", "image/png", "image/gif", "image/webp"} +) + + +def _image_content_to_openai_part( + content: ImageContent, +) -> ChatCompletionContentPartImageParam: + """Convert MCP ImageContent to OpenAI image_url content part.""" + if content.mimeType not in _OPENAI_IMAGE_MEDIA_TYPES: + raise ValueError( + f"Unsupported image MIME type for OpenAI: {content.mimeType!r}. " + f"Supported types: {', '.join(sorted(_OPENAI_IMAGE_MEDIA_TYPES))}" + ) + data_url = f"data:{content.mimeType};base64,{content.data}" + return ChatCompletionContentPartImageParam( + type="image_url", + image_url={"url": data_url}, + ) + + +def _audio_content_to_openai_part( + content: AudioContent, +) -> ChatCompletionContentPartInputAudioParam: + """Convert MCP AudioContent to OpenAI input_audio content part.""" + audio_format = _OPENAI_AUDIO_FORMATS.get(content.mimeType) + if audio_format is None: + raise ValueError( + f"Unsupported audio MIME type for OpenAI: {content.mimeType!r}. " + f"Supported types: {', '.join(sorted(_OPENAI_AUDIO_FORMATS))}" + ) + return ChatCompletionContentPartInputAudioParam( + type="input_audio", + input_audio={"data": content.data, "format": audio_format}, + ) + class OpenAISamplingHandler: """Sampling handler that uses the OpenAI API.""" @@ -147,8 +197,9 @@ def _convert_to_openai_messages( # Handle list content (from CreateMessageResultWithTools) if isinstance(content, list): - # Collect tool calls and text from the list + # Collect tool calls, content parts, and text from the list tool_calls: list[ChatCompletionMessageToolCallParam] = [] + content_parts: list[ChatCompletionContentPartParam] = [] text_parts: list[str] = [] # Collect tool results separately to maintain correct ordering tool_messages: list[ChatCompletionToolMessageParam] = [] @@ -167,6 +218,15 @@ def _convert_to_openai_messages( ) elif isinstance(item, TextContent): text_parts.append(item.text) + content_parts.append( + ChatCompletionContentPartTextParam( + type="text", text=item.text + ) + ) + elif isinstance(item, ImageContent): + content_parts.append(_image_content_to_openai_part(item)) + elif isinstance(item, AudioContent): + content_parts.append(_audio_content_to_openai_part(item)) elif isinstance(item, ToolResultContent): # Collect tool results (added after assistant message) content_text = "" @@ -186,33 +246,47 @@ def _convert_to_openai_messages( # Add assistant message with tool calls if present # OpenAI requires: assistant (with tool_calls) -> tool messages - if tool_calls or text_parts: - msg_content = "\n".join(text_parts) if text_parts else None + if tool_calls or content_parts: if tool_calls: + has_multimodal = len(content_parts) > len(text_parts) + if has_multimodal: + raise ValueError( + "ImageContent/AudioContent is only supported " + "in user messages for OpenAI" + ) + text_str = "\n".join(text_parts) or None openai_messages.append( ChatCompletionAssistantMessageParam( role="assistant", - content=msg_content, + content=text_str, tool_calls=tool_calls, ) ) # Add tool messages AFTER assistant message openai_messages.extend(tool_messages) - elif msg_content: + elif content_parts: if message.role == "user": openai_messages.append( ChatCompletionUserMessageParam( role="user", - content=msg_content, + content=content_parts, ) ) else: - openai_messages.append( - ChatCompletionAssistantMessageParam( - role="assistant", - content=msg_content, + has_multimodal = len(content_parts) > len(text_parts) + if has_multimodal: + raise ValueError( + "ImageContent/AudioContent is only supported " + "in user messages for OpenAI" + ) + assistant_text = "\n".join(text_parts) + if assistant_text: + openai_messages.append( + ChatCompletionAssistantMessageParam( + role="assistant", + content=assistant_text, + ) ) - ) elif tool_messages: # Tool results only (assistant message was in previous message) openai_messages.extend(tool_messages) @@ -272,6 +346,34 @@ def _convert_to_openai_messages( ) continue + # Handle ImageContent + if isinstance(content, ImageContent): + if message.role != "user": + raise ValueError( + "ImageContent is only supported in user messages for OpenAI" + ) + openai_messages.append( + ChatCompletionUserMessageParam( + role="user", + content=[_image_content_to_openai_part(content)], + ) + ) + continue + + # Handle AudioContent + if isinstance(content, AudioContent): + if message.role != "user": + raise ValueError( + "AudioContent is only supported in user messages for OpenAI" + ) + openai_messages.append( + ChatCompletionUserMessageParam( + role="user", + content=[_audio_content_to_openai_part(content)], + ) + ) + continue + raise ValueError(f"Unsupported content type: {type(content)}") return openai_messages diff --git a/tests/client/sampling/handlers/test_anthropic_handler.py b/tests/client/sampling/handlers/test_anthropic_handler.py index 57a464adab..5910eb92b4 100644 --- a/tests/client/sampling/handlers/test_anthropic_handler.py +++ b/tests/client/sampling/handlers/test_anthropic_handler.py @@ -1,19 +1,26 @@ +from typing import Any from unittest.mock import MagicMock import pytest from anthropic import AsyncAnthropic from anthropic.types import Message, TextBlock, ToolUseBlock, Usage from mcp.types import ( + AudioContent, CreateMessageResult, CreateMessageResultWithTools, + ImageContent, ModelHint, ModelPreferences, SamplingMessage, TextContent, + ToolResultContent, ToolUseContent, ) -from fastmcp.client.sampling.handlers.anthropic import AnthropicSamplingHandler +from fastmcp.client.sampling.handlers.anthropic import ( + AnthropicSamplingHandler, + _image_content_to_anthropic_block, +) def test_convert_sampling_messages_to_anthropic_messages(): @@ -34,15 +41,137 @@ def test_convert_sampling_messages_to_anthropic_messages(): ] -def test_convert_to_anthropic_messages_raises_on_non_text(): - from fastmcp.utilities.types import Image +def test_image_content_to_anthropic_block(): + block = _image_content_to_anthropic_block( + ImageContent(type="image", data="YWJj", mimeType="image/png") + ) + + assert block == { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "YWJj", + }, + } + + +def test_image_content_unsupported_mime_type_raises(): + with pytest.raises(ValueError, match="Unsupported image MIME type"): + _image_content_to_anthropic_block( + ImageContent(type="image", data="YWJj", mimeType="image/bmp") + ) + + +def test_convert_single_image_content_to_anthropic_message(): + msgs = AnthropicSamplingHandler._convert_to_anthropic_messages( + messages=[ + SamplingMessage( + role="user", + content=ImageContent(type="image", data="YWJj", mimeType="image/png"), + ) + ], + ) + + assert len(msgs) == 1 + assert msgs[0] == { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "YWJj", + }, + } + ], + } + - with pytest.raises(ValueError): +def test_convert_single_audio_content_raises(): + with pytest.raises(ValueError, match="AudioContent is not supported"): AnthropicSamplingHandler._convert_to_anthropic_messages( messages=[ SamplingMessage( role="user", - content=Image(data=b"abc").to_image_content(), + content=AudioContent( + type="audio", data="YWJj", mimeType="audio/wav" + ), + ) + ], + ) + + +def test_convert_list_content_with_image_and_text(): + msgs = AnthropicSamplingHandler._convert_to_anthropic_messages( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(type="text", text="Describe this image"), + ImageContent(type="image", data="YWJj", mimeType="image/jpeg"), + ], + ) + ], + ) + + assert len(msgs) == 1 + assert msgs[0] == { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": "YWJj", + }, + }, + ], + } + + +def test_convert_list_content_with_audio_raises(): + with pytest.raises(ValueError, match="AudioContent is not supported"): + AnthropicSamplingHandler._convert_to_anthropic_messages( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(type="text", text="Listen to this"), + AudioContent(type="audio", data="YWJj", mimeType="audio/wav"), + ], + ) + ], + ) + + +def test_convert_image_in_assistant_message_raises(): + with pytest.raises(ValueError, match="ImageContent is only supported in user"): + AnthropicSamplingHandler._convert_to_anthropic_messages( + messages=[ + SamplingMessage( + role="assistant", + content=ImageContent( + type="image", data="YWJj", mimeType="image/png" + ), + ) + ], + ) + + +def test_convert_list_image_in_assistant_message_raises(): + with pytest.raises(ValueError, match="ImageContent is only supported in user"): + AnthropicSamplingHandler._convert_to_anthropic_messages( + messages=[ + SamplingMessage( + role="assistant", + content=[ + TextContent(type="text", text="Here's the image"), + ImageContent(type="image", data="YWJj", mimeType="image/png"), + ], ) ], ) @@ -61,7 +190,7 @@ def test_convert_to_anthropic_messages_raises_on_non_text(): (["unknown-model"], "fallback-model"), ], ) -def test_select_model_from_preferences(prefs, expected): +def test_select_model_from_preferences(prefs: Any, expected: str) -> None: mock_client = MagicMock(spec=AsyncAnthropic) handler = AnthropicSamplingHandler( default_model="fallback-model", client=mock_client @@ -220,8 +349,6 @@ def test_convert_messages_with_tool_use_content(): def test_convert_messages_with_tool_result_content(): """Test converting messages that include tool result content from user.""" - from mcp.types import ToolResultContent - msgs = AnthropicSamplingHandler._convert_to_anthropic_messages( messages=[ SamplingMessage( diff --git a/tests/client/sampling/handlers/test_google_genai_handler.py b/tests/client/sampling/handlers/test_google_genai_handler.py index 92403461e5..7eb0da188a 100644 --- a/tests/client/sampling/handlers/test_google_genai_handler.py +++ b/tests/client/sampling/handlers/test_google_genai_handler.py @@ -1,3 +1,4 @@ +import base64 from unittest.mock import MagicMock import pytest @@ -14,9 +15,12 @@ UserContent, ) from mcp.types import ( + AudioContent, CreateMessageResult, + ImageContent, ModelHint, ModelPreferences, + SamplingMessage, TextContent, ToolChoice, ToolResultContent, @@ -42,8 +46,6 @@ def test_convert_sampling_messages_to_google_genai_content(): - from mcp.types import SamplingMessage, TextContent - msgs = _convert_messages_to_google_genai_content( messages=[ SamplingMessage( @@ -62,20 +64,98 @@ def test_convert_sampling_messages_to_google_genai_content(): assert msgs[1].parts[0].text == "ok" -def test_convert_to_google_genai_messages_raises_on_non_text(): - from mcp.types import SamplingMessage +def test_convert_single_image_content_to_google_genai(): + part = _sampling_content_to_google_genai_part( + ImageContent(type="image", data="YWJj", mimeType="image/png") + ) + + assert part.inline_data is not None + assert part.inline_data.data == base64.b64decode("YWJj") + assert part.inline_data.mime_type == "image/png" + + +def test_convert_single_audio_content_to_google_genai(): + part = _sampling_content_to_google_genai_part( + AudioContent(type="audio", data="YWJj", mimeType="audio/wav") + ) + + assert part.inline_data is not None + assert part.inline_data.data == base64.b64decode("YWJj") + assert part.inline_data.mime_type == "audio/wav" + + +def test_convert_image_message_to_google_genai_content(): + msgs = _convert_messages_to_google_genai_content( + messages=[ + SamplingMessage( + role="user", + content=ImageContent(type="image", data="YWJj", mimeType="image/jpeg"), + ) + ], + ) + + assert len(msgs) == 1 + assert isinstance(msgs[0], UserContent) + assert msgs[0].parts[0].inline_data is not None + assert msgs[0].parts[0].inline_data.mime_type == "image/jpeg" + + +def test_convert_audio_message_to_google_genai_content(): + msgs = _convert_messages_to_google_genai_content( + messages=[ + SamplingMessage( + role="user", + content=AudioContent(type="audio", data="YWJj", mimeType="audio/mp3"), + ) + ], + ) + + assert len(msgs) == 1 + assert isinstance(msgs[0], UserContent) + assert msgs[0].parts[0].inline_data is not None + assert msgs[0].parts[0].inline_data.mime_type == "audio/mp3" + + +def test_convert_list_content_with_image_and_text(): + msgs = _convert_messages_to_google_genai_content( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(type="text", text="What is in this image?"), + ImageContent(type="image", data="YWJj", mimeType="image/png"), + ], + ) + ], + ) - from fastmcp.utilities.types import Image + assert len(msgs) == 1 + assert isinstance(msgs[0], UserContent) + assert len(msgs[0].parts) == 2 + assert msgs[0].parts[0].text == "What is in this image?" + assert msgs[0].parts[1].inline_data is not None + assert msgs[0].parts[1].inline_data.mime_type == "image/png" - with pytest.raises(ValueError): - _convert_messages_to_google_genai_content( - messages=[ - SamplingMessage( - role="user", - content=Image(data=b"abc").to_image_content(), - ) - ], - ) + +def test_convert_list_content_with_audio_and_text(): + msgs = _convert_messages_to_google_genai_content( + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(type="text", text="Transcribe this audio"), + AudioContent(type="audio", data="YWJj", mimeType="audio/wav"), + ], + ) + ], + ) + + assert len(msgs) == 1 + assert isinstance(msgs[0], UserContent) + assert len(msgs[0].parts) == 2 + assert msgs[0].parts[0].text == "Transcribe this audio" + assert msgs[0].parts[1].inline_data is not None + assert msgs[0].parts[1].inline_data.mime_type == "audio/wav" def test_get_model(): @@ -207,8 +287,6 @@ def test_sampling_content_to_google_genai_part_tool_result_no_underscore(): def test_convert_messages_with_tool_use(): """Test converting messages containing ToolUseContent.""" - from mcp.types import SamplingMessage - msgs = _convert_messages_to_google_genai_content( messages=[ SamplingMessage( @@ -236,8 +314,6 @@ def test_convert_messages_with_tool_use(): def test_convert_messages_with_tool_result(): """Test converting messages containing ToolResultContent.""" - from mcp.types import SamplingMessage - msgs = _convert_messages_to_google_genai_content( messages=[ SamplingMessage( @@ -245,7 +321,7 @@ def test_convert_messages_with_tool_result(): content=ToolResultContent( type="tool_result", toolUseId="get_weather_123", - content=[TextContent(type="text", text="Sunny, 72°F")], + content=[TextContent(type="text", text="Sunny, 72F")], ), ), ], @@ -259,8 +335,6 @@ def test_convert_messages_with_tool_result(): def test_convert_messages_with_multiple_content_blocks(): """Test converting messages with multiple content blocks (list content).""" - from mcp.types import SamplingMessage - msgs = _convert_messages_to_google_genai_content( messages=[ SamplingMessage( diff --git a/tests/client/sampling/handlers/test_openai_handler.py b/tests/client/sampling/handlers/test_openai_handler.py index 29f12d7499..3a263bcc76 100644 --- a/tests/client/sampling/handlers/test_openai_handler.py +++ b/tests/client/sampling/handlers/test_openai_handler.py @@ -1,25 +1,36 @@ +from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest from mcp.types import ( + AudioContent, CreateMessageRequestParams, CreateMessageResult, + ImageContent, ModelHint, ModelPreferences, SamplingMessage, TextContent, + ToolUseContent, ) from openai import AsyncOpenAI from openai.types.chat import ( ChatCompletion, ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam, + ChatCompletionContentPartTextParam, ChatCompletionMessage, ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam, ) from openai.types.chat.chat_completion import Choice -from fastmcp.client.sampling.handlers.openai import OpenAISamplingHandler +from fastmcp.client.sampling.handlers.openai import ( + OpenAISamplingHandler, + _audio_content_to_openai_part, + _image_content_to_openai_part, +) def test_convert_sampling_messages_to_openai_messages(): @@ -42,16 +53,189 @@ def test_convert_sampling_messages_to_openai_messages(): ] -def test_convert_to_openai_messages_raises_on_non_text(): - from fastmcp.utilities.types import Image +def test_image_content_to_openai_part(): + part = _image_content_to_openai_part( + ImageContent(type="image", data="YWJj", mimeType="image/png") + ) + + assert part == ChatCompletionContentPartImageParam( + type="image_url", + image_url={"url": "data:image/png;base64,YWJj"}, + ) + + +def test_audio_content_to_openai_part_wav(): + part = _audio_content_to_openai_part( + AudioContent(type="audio", data="YWJj", mimeType="audio/wav") + ) + + assert part == ChatCompletionContentPartInputAudioParam( + type="input_audio", + input_audio={"data": "YWJj", "format": "wav"}, + ) + + +def test_audio_content_to_openai_part_mp3(): + part = _audio_content_to_openai_part( + AudioContent(type="audio", data="YWJj", mimeType="audio/mpeg") + ) + + assert part["input_audio"]["format"] == "mp3" + + +def test_audio_content_to_openai_part_unsupported_raises(): + with pytest.raises(ValueError, match="Unsupported audio MIME type"): + _audio_content_to_openai_part( + AudioContent(type="audio", data="YWJj", mimeType="audio/ogg") + ) + + +def test_image_content_to_openai_part_unsupported_raises(): + with pytest.raises(ValueError, match="Unsupported image MIME type"): + _image_content_to_openai_part( + ImageContent(type="image", data="YWJj", mimeType="image/bmp") + ) + + +def test_convert_single_image_content_to_openai_message(): + msgs = OpenAISamplingHandler._convert_to_openai_messages( + system_prompt=None, + messages=[ + SamplingMessage( + role="user", + content=ImageContent(type="image", data="YWJj", mimeType="image/png"), + ) + ], + ) + + assert len(msgs) == 1 + assert msgs[0] == ChatCompletionUserMessageParam( + role="user", + content=[ + ChatCompletionContentPartImageParam( + type="image_url", + image_url={"url": "data:image/png;base64,YWJj"}, + ) + ], + ) + + +def test_convert_single_audio_content_to_openai_message(): + msgs = OpenAISamplingHandler._convert_to_openai_messages( + system_prompt=None, + messages=[ + SamplingMessage( + role="user", + content=AudioContent(type="audio", data="YWJj", mimeType="audio/wav"), + ) + ], + ) + + assert len(msgs) == 1 + assert msgs[0] == ChatCompletionUserMessageParam( + role="user", + content=[ + ChatCompletionContentPartInputAudioParam( + type="input_audio", + input_audio={"data": "YWJj", "format": "wav"}, + ) + ], + ) + + +def test_convert_list_content_with_image_and_text(): + msgs = OpenAISamplingHandler._convert_to_openai_messages( + system_prompt=None, + messages=[ + SamplingMessage( + role="user", + content=[ + TextContent(type="text", text="What is in this image?"), + ImageContent(type="image", data="YWJj", mimeType="image/jpeg"), + ], + ) + ], + ) + + assert len(msgs) == 1 + assert msgs[0] == ChatCompletionUserMessageParam( + role="user", + content=[ + ChatCompletionContentPartTextParam( + type="text", text="What is in this image?" + ), + ChatCompletionContentPartImageParam( + type="image_url", + image_url={"url": "data:image/jpeg;base64,YWJj"}, + ), + ], + ) + + +def test_convert_image_in_assistant_message_raises(): + with pytest.raises(ValueError, match="ImageContent is only supported in user"): + OpenAISamplingHandler._convert_to_openai_messages( + system_prompt=None, + messages=[ + SamplingMessage( + role="assistant", + content=ImageContent( + type="image", data="YWJj", mimeType="image/png" + ), + ) + ], + ) + + +def test_convert_audio_in_assistant_message_raises(): + with pytest.raises(ValueError, match="AudioContent is only supported in user"): + OpenAISamplingHandler._convert_to_openai_messages( + system_prompt=None, + messages=[ + SamplingMessage( + role="assistant", + content=AudioContent( + type="audio", data="YWJj", mimeType="audio/wav" + ), + ) + ], + ) + + +def test_convert_list_image_in_assistant_message_raises(): + """Image/audio in an assistant list-content message should raise, not silently drop.""" + with pytest.raises(ValueError, match="only supported in user messages"): + OpenAISamplingHandler._convert_to_openai_messages( + system_prompt=None, + messages=[ + SamplingMessage( + role="assistant", + content=[ + TextContent(type="text", text="Here's the image"), + ImageContent(type="image", data="YWJj", mimeType="image/png"), + ], + ) + ], + ) + - with pytest.raises(ValueError): +def test_convert_list_tool_calls_with_image_raises(): + """Image/audio alongside tool_calls in assistant list should raise.""" + with pytest.raises(ValueError, match="only supported in user messages"): OpenAISamplingHandler._convert_to_openai_messages( system_prompt=None, messages=[ SamplingMessage( - role="user", - content=Image(data=b"abc").to_image_content(), + role="assistant", + content=[ + ToolUseContent( + type="tool_use", + id="call_1", + name="my_tool", + input={"arg": "val"}, + ), + ImageContent(type="image", data="YWJj", mimeType="image/png"), + ], ) ], ) @@ -67,7 +251,7 @@ def test_convert_to_openai_messages_raises_on_non_text(): (["unknown-model"], "fallback-model"), ], ) -def test_select_model_from_preferences(prefs, expected): +def test_select_model_from_preferences(prefs: Any, expected: str) -> None: mock_client = MagicMock(spec=AsyncOpenAI) handler = OpenAISamplingHandler(default_model="fallback-model", client=mock_client) # type: ignore[arg-type] assert handler._select_model_from_preferences(prefs) == expected