diff --git a/requirements/test.in b/requirements/test.in index f0941d3c5918..46c75ebc34a4 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -55,4 +55,5 @@ fastsafetensors>=0.1.10 pydantic>=2.12 # 2.11 leads to error on python 3.13 decord==0.6.0 terratorch @ git+https://github.com/IBM/terratorch.git@1.1.rc3 # required for PrithviMAE test -gpt-oss >= 0.0.7; python_version > '3.11' \ No newline at end of file +gpt-oss >= 0.0.7; python_version > '3.11' +fastmcp # required for MCP tests diff --git a/tests/entrypoints/mcp/__init__.py b/tests/entrypoints/mcp/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/tests/entrypoints/mcp/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/tests/entrypoints/mcp/test_mcp_utils.py b/tests/entrypoints/mcp/test_mcp_utils.py new file mode 100644 index 000000000000..697fe04494be --- /dev/null +++ b/tests/entrypoints/mcp/test_mcp_utils.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for MCP utils.""" + +from openai.types.responses.tool import ( + CodeInterpreter, + CodeInterpreterContainerCodeInterpreterToolAuto, + FunctionTool, + Mcp, + WebSearchPreviewTool, +) + +from vllm.entrypoints.mcp.mcp_utils import normalize_tool_to_mcp + + +def test_normalize_mcp_tool_passthrough(): + """MCP tools should pass through unchanged.""" + mcp_tool = Mcp( + type="mcp", server_label="weather", server_url="http://localhost:8765/sse" + ) + result = normalize_tool_to_mcp(mcp_tool) + assert result == mcp_tool + assert result.server_label == "weather" + assert result.server_url == "http://localhost:8765/sse" + + +def test_normalize_code_interpreter(): + """CodeInterpreter should convert to MCP with server_label='code_interpreter'.""" + # For test purposes we provide a minimal container (required by Pydantic) + # Just testing that type is correctly converted + code_tool = CodeInterpreter( + type="code_interpreter", + container=CodeInterpreterContainerCodeInterpreterToolAuto(type="auto"), + ) + result = normalize_tool_to_mcp(code_tool) + + assert isinstance(result, Mcp) + assert result.type == "mcp" + assert result.server_label == "code_interpreter" + # Container field is intentionally discarded + + +def test_normalize_web_search_preview(): + """WebSearchPreviewTool should convert to MCP with server_label='browser'.""" + search_tool = WebSearchPreviewTool( + type="web_search_preview", + search_context_size="medium", + ) + result = normalize_tool_to_mcp(search_tool) + + assert isinstance(result, Mcp) + assert result.type == "mcp" + assert result.server_label == "browser" + # search_context_size is intentionally discarded + + +def test_normalize_other_tools_passthrough(): + """Other tool types should pass through unchanged.""" + # Using a FunctionTool as an example of a non-converted tool type + function_tool = FunctionTool( + type="function", + name="test_func", + function={"name": "test_func", "description": "Test function"}, + ) + + result = normalize_tool_to_mcp(function_tool) + + # Should be unchanged + assert result == function_tool + assert result.type == "function" diff --git a/tests/entrypoints/openai/memory_mcp_server.py b/tests/entrypoints/openai/memory_mcp_server.py new file mode 100755 index 000000000000..5c1b38c04dc4 --- /dev/null +++ b/tests/entrypoints/openai/memory_mcp_server.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Standalone Memory MCP Server + +This is a standalone MCP server that provides memory storage capabilities +with isolation based on x-memory-id headers. + +Tools provided: +- store: Store a key-value pair in memory +- retrieve: Retrieve a value by key +- list_keys: List all keys in the current memory space +- delete: Delete a key from memory + +Run standalone: + python memory_mcp_server.py --port 8765 +""" + +import argparse +import os +import socket +import subprocess +import sys +import time + +from fastmcp import Context, FastMCP + +# In-memory storage: {memory_id: {key: value}} +memories: dict[str, dict[str, str]] = {} + +# Default memory space for requests without x-memory-id header +DEFAULT_MEMORY_ID = "default" + +# Create FastMCP app +mcp = FastMCP("memory") + + +def extract_memory_id(ctx: Context) -> str: + """Extract memory_id from request context headers.""" + try: + # Try to get HTTP request from context + http_request = ctx.get_http_request() + if http_request and hasattr(http_request, "headers"): + headers = http_request.headers + # Headers may be case-insensitive, check variations + memory_id = headers.get("x-memory-id") or headers.get("X-Memory-Id") + if memory_id: + return memory_id + except Exception: + pass + + return DEFAULT_MEMORY_ID + + +@mcp.tool() +def store(key: str, value: str, ctx: Context) -> str: + """Store a key-value pair in memory. + + Args: + key: The key to store + value: The value to store + """ + memory_id = extract_memory_id(ctx) + + # Ensure memory space exists + if memory_id not in memories: + memories[memory_id] = {} + + memories[memory_id][key] = value + return ( + f"Successfully stored key '{key}' with value '{value}' " + f"in memory space '{memory_id}'" + ) + + +@mcp.tool() +def retrieve(key: str, ctx: Context) -> str: + """Retrieve a value by key from memory. + + Args: + key: The key to retrieve + """ + memory_id = extract_memory_id(ctx) + + # Ensure memory space exists + if memory_id not in memories: + memories[memory_id] = {} + + value = memories[memory_id].get(key) + if value is None: + return f"Key '{key}' not found in memory space '{memory_id}'" + return f"Retrieved value for key '{key}': {value}" + + +@mcp.tool() +def list_keys(ctx: Context) -> str: + """List all keys in the current memory space.""" + memory_id = extract_memory_id(ctx) + + # Ensure memory space exists + if memory_id not in memories: + memories[memory_id] = {} + + keys = list(memories[memory_id].keys()) + if not keys: + return f"No keys found in memory space '{memory_id}'" + return f"Keys in memory space '{memory_id}': {', '.join(keys)}" + + +@mcp.tool() +def delete(key: str, ctx: Context) -> str: + """Delete a key from memory. + + Args: + key: The key to delete + """ + memory_id = extract_memory_id(ctx) + + # Ensure memory space exists + if memory_id not in memories: + memories[memory_id] = {} + + if key in memories[memory_id]: + del memories[memory_id][key] + return f"Successfully deleted key '{key}' from memory space '{memory_id}'" + return f"Key '{key}' not found in memory space '{memory_id}'" + + +def start_test_server(port: int) -> subprocess.Popen: + """Start memory MCP server for testing. + + Args: + port: Port to run server on + + Returns: + subprocess.Popen object for the running server + """ + script_path = os.path.abspath(__file__) + process = subprocess.Popen( + [sys.executable, script_path, "--port", str(port)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Wait for server to be ready (TCP check) + for _ in range(30): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(0.5) + result = sock.connect_ex(("localhost", port)) + sock.close() + if result == 0: + return process + except Exception: + pass + time.sleep(0.1) + + # Failed to start + process.kill() + stdout, stderr = process.communicate() + raise RuntimeError( + f"Memory MCP server failed to start.\n" + f"stdout: {stdout.decode()}\nstderr: {stderr.decode()}" + ) + + +def main(): + parser = argparse.ArgumentParser(description="Memory MCP Server") + parser.add_argument( + "--port", + type=int, + default=8765, + help="Port to run the server on (default: 8765)", + ) + args = parser.parse_args() + + print(f"Starting Memory MCP Server on port {args.port}...") + mcp.run(port=args.port, transport="sse") + + +if __name__ == "__main__": + main() diff --git a/tests/entrypoints/openai/test_memory_mcp_server.py b/tests/entrypoints/openai/test_memory_mcp_server.py new file mode 100644 index 000000000000..bbc5b0f9644e --- /dev/null +++ b/tests/entrypoints/openai/test_memory_mcp_server.py @@ -0,0 +1,338 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for the standalone Memory MCP Server. + +These tests verify that the memory MCP server works correctly: +- Basic store/retrieve operations +- Memory isolation with x-memory-id headers +- Default memory pool behavior +- List keys functionality +- Delete functionality +""" + +import asyncio +import os +import socket +import subprocess +import sys +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio +from mcp import ClientSession +from mcp.client.sse import sse_client + +from tests.utils import find_free_port + + +@asynccontextmanager +async def mcp_client_session(server_url: str, headers: dict[str, str] | None = None): + """Create an MCP client session with optional custom headers.""" + async with ( + sse_client(url=server_url, headers=headers or {}) as streams, + ClientSession(*streams) as session, + ): + await session.initialize() + yield session + + +@pytest_asyncio.fixture(scope="function") +async def memory_server(): + """Start memory MCP server as subprocess on random port.""" + port = find_free_port() + server_url = f"http://127.0.0.1:{port}/sse" + + # Get the path to the memory_mcp_server.py script + script_path = os.path.join( + os.path.dirname(__file__), + "memory_mcp_server.py", + ) + + # Start server process + process = subprocess.Popen( + [sys.executable, script_path, "--port", str(port)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Wait for server to be ready (max 10 seconds) + server_ready = False + for attempt in range(50): # 50 attempts * 0.2s = 10s max + try: + # Simple TCP connection check + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(0.5) + result = sock.connect_ex(("127.0.0.1", port)) + sock.close() + if result == 0: + server_ready = True + # Give it a tiny bit more time to fully initialize + await asyncio.sleep(0.1) + break + except Exception: + pass + await asyncio.sleep(0.2) + + if not server_ready: + process.kill() + stdout, stderr = process.communicate() + pytest.fail( + f"Memory MCP server failed to start within 10 seconds.\n" + f"stdout: {stdout}\nstderr: {stderr}" + ) + + # Yield server URL for tests + yield server_url + + # Cleanup: kill process + process.kill() + process.wait() + + +@pytest.mark.asyncio +async def test_basic_store_retrieve(memory_server): + """Test basic store and retrieve operations.""" + async with mcp_client_session(memory_server) as session: + # Store a value + store_result = await session.call_tool( + "store", arguments={"key": "test_key", "value": "test_value"} + ) + assert store_result is not None + assert len(store_result.content) > 0 + assert "Successfully stored" in store_result.content[0].text + + # Retrieve the value + retrieve_result = await session.call_tool( + "retrieve", arguments={"key": "test_key"} + ) + assert retrieve_result is not None + assert len(retrieve_result.content) > 0 + assert "test_value" in retrieve_result.content[0].text + + +@pytest.mark.asyncio +async def test_retrieve_nonexistent_key(memory_server): + """Test retrieving a key that doesn't exist.""" + async with mcp_client_session(memory_server) as session: + retrieve_result = await session.call_tool( + "retrieve", arguments={"key": "nonexistent_key"} + ) + assert retrieve_result is not None + assert len(retrieve_result.content) > 0 + assert "not found" in retrieve_result.content[0].text + + +@pytest.mark.asyncio +async def test_memory_isolation_with_headers(memory_server): + """Test that different x-memory-id headers isolate data.""" + # Session 1 with x-memory-id: "user1" + async with mcp_client_session( + memory_server, headers={"x-memory-id": "user1"} + ) as session1: + # Store in user1's space + await session1.call_tool("store", arguments={"key": "name", "value": "Alice"}) + + # Session 2 with x-memory-id: "user2" + async with mcp_client_session( + memory_server, headers={"x-memory-id": "user2"} + ) as session2: + # Store in user2's space + await session2.call_tool("store", arguments={"key": "name", "value": "Bob"}) + + # Verify session1 gets Alice + async with mcp_client_session( + memory_server, headers={"x-memory-id": "user1"} + ) as session1: + result = await session1.call_tool("retrieve", arguments={"key": "name"}) + assert "Alice" in result.content[0].text + assert "Bob" not in result.content[0].text + + # Verify session2 gets Bob + async with mcp_client_session( + memory_server, headers={"x-memory-id": "user2"} + ) as session2: + result = await session2.call_tool("retrieve", arguments={"key": "name"}) + assert "Bob" in result.content[0].text + assert "Alice" not in result.content[0].text + + +@pytest.mark.asyncio +async def test_default_memory_pool(memory_server): + """Test that no header uses default pool.""" + # Session without header + async with mcp_client_session(memory_server) as session1: + await session1.call_tool("store", arguments={"key": "shared", "value": "data"}) + + # Another session without header should see the same data + async with mcp_client_session(memory_server) as session2: + result = await session2.call_tool("retrieve", arguments={"key": "shared"}) + assert "data" in result.content[0].text + + +@pytest.mark.asyncio +async def test_default_pool_isolated_from_custom_headers(memory_server): + """Test that default pool is isolated from custom memory IDs.""" + # Store in default pool + async with mcp_client_session(memory_server) as session_default: + await session_default.call_tool( + "store", arguments={"key": "isolation_test", "value": "default_value"} + ) + + # Try to retrieve from custom memory ID - should not find it + async with mcp_client_session( + memory_server, headers={"x-memory-id": "custom"} + ) as session_custom: + result = await session_custom.call_tool( + "retrieve", arguments={"key": "isolation_test"} + ) + assert "not found" in result.content[0].text + + +@pytest.mark.asyncio +async def test_list_keys(memory_server): + """Test listing keys in memory space.""" + async with mcp_client_session( + memory_server, headers={"x-memory-id": "list_test"} + ) as session: + # Initially should be empty + result = await session.call_tool("list_keys", arguments={}) + assert "No keys found" in result.content[0].text + + # Store multiple keys + await session.call_tool("store", arguments={"key": "key1", "value": "value1"}) + await session.call_tool("store", arguments={"key": "key2", "value": "value2"}) + await session.call_tool("store", arguments={"key": "key3", "value": "value3"}) + + # List keys - should see all three + result = await session.call_tool("list_keys", arguments={}) + result_text = result.content[0].text + assert "key1" in result_text + assert "key2" in result_text + assert "key3" in result_text + + +@pytest.mark.asyncio +async def test_list_keys_isolation(memory_server): + """Test that list_keys respects memory isolation.""" + # Store keys in different memory spaces + async with mcp_client_session( + memory_server, headers={"x-memory-id": "space1"} + ) as session1: + await session1.call_tool( + "store", arguments={"key": "space1_key", "value": "val1"} + ) + + async with mcp_client_session( + memory_server, headers={"x-memory-id": "space2"} + ) as session2: + await session2.call_tool( + "store", arguments={"key": "space2_key", "value": "val2"} + ) + + # List keys in space1 - should only see space1_key + async with mcp_client_session( + memory_server, headers={"x-memory-id": "space1"} + ) as session1: + result = await session1.call_tool("list_keys", arguments={}) + result_text = result.content[0].text + assert "space1_key" in result_text + assert "space2_key" not in result_text + + # List keys in space2 - should only see space2_key + async with mcp_client_session( + memory_server, headers={"x-memory-id": "space2"} + ) as session2: + result = await session2.call_tool("list_keys", arguments={}) + result_text = result.content[0].text + assert "space2_key" in result_text + assert "space1_key" not in result_text + + +@pytest.mark.asyncio +async def test_delete(memory_server): + """Test deleting keys.""" + async with mcp_client_session( + memory_server, headers={"x-memory-id": "delete_test"} + ) as session: + # Store a key + await session.call_tool( + "store", arguments={"key": "to_delete", "value": "temp"} + ) + + # Verify it exists + result = await session.call_tool("retrieve", arguments={"key": "to_delete"}) + assert "temp" in result.content[0].text + + # Delete it + delete_result = await session.call_tool( + "delete", arguments={"key": "to_delete"} + ) + assert "Successfully deleted" in delete_result.content[0].text + + # Verify it's gone + result = await session.call_tool("retrieve", arguments={"key": "to_delete"}) + assert "not found" in result.content[0].text + + +@pytest.mark.asyncio +async def test_delete_nonexistent_key(memory_server): + """Test deleting a key that doesn't exist.""" + async with mcp_client_session( + memory_server, headers={"x-memory-id": "delete_test2"} + ) as session: + result = await session.call_tool("delete", arguments={"key": "nonexistent"}) + assert "not found" in result.content[0].text + + +@pytest.mark.asyncio +async def test_delete_isolation(memory_server): + """Test that delete respects memory isolation.""" + # Store same key in different memory spaces + async with mcp_client_session( + memory_server, headers={"x-memory-id": "del_space1"} + ) as session1: + await session1.call_tool( + "store", arguments={"key": "shared_key", "value": "value1"} + ) + + async with mcp_client_session( + memory_server, headers={"x-memory-id": "del_space2"} + ) as session2: + await session2.call_tool( + "store", arguments={"key": "shared_key", "value": "value2"} + ) + + # Delete from space1 + async with mcp_client_session( + memory_server, headers={"x-memory-id": "del_space1"} + ) as session1: + await session1.call_tool("delete", arguments={"key": "shared_key"}) + + # Verify it's gone from space1 + result = await session1.call_tool("retrieve", arguments={"key": "shared_key"}) + assert "not found" in result.content[0].text + + # Verify it still exists in space2 + async with mcp_client_session( + memory_server, headers={"x-memory-id": "del_space2"} + ) as session2: + result = await session2.call_tool("retrieve", arguments={"key": "shared_key"}) + assert "value2" in result.content[0].text + + +@pytest.mark.asyncio +async def test_server_info(memory_server): + """Test that server identifies itself correctly.""" + async with mcp_client_session(memory_server) as session: + # The session.initialize() call should have set the server info + # We can verify the server name through list_tools + tools_result = await session.list_tools() + assert tools_result is not None + assert len(tools_result.tools) == 4 + tool_names = [tool.name for tool in tools_result.tools] + assert "store" in tool_names + assert "retrieve" in tool_names + assert "list_keys" in tool_names + assert "delete" in tool_names diff --git a/tests/entrypoints/openai/test_response_api_mcp_tools.py b/tests/entrypoints/openai/test_response_api_mcp_tools.py deleted file mode 100644 index 653d44f20b44..000000000000 --- a/tests/entrypoints/openai/test_response_api_mcp_tools.py +++ /dev/null @@ -1,112 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import pytest_asyncio -from openai import OpenAI - -from ...utils import RemoteOpenAIServer - -MODEL_NAME = "openai/gpt-oss-20b" - - -@pytest.fixture(scope="module") -def monkeypatch_module(): - from _pytest.monkeypatch import MonkeyPatch - - mpatch = MonkeyPatch() - yield mpatch - mpatch.undo() - - -@pytest.fixture(scope="module") -def mcp_disabled_server(monkeypatch_module: pytest.MonkeyPatch): - args = ["--enforce-eager", "--tool-server", "demo"] - - with monkeypatch_module.context() as m: - m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") - m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="function") -def mcp_enabled_server(monkeypatch_module: pytest.MonkeyPatch): - args = ["--enforce-eager", "--tool-server", "demo"] - - with monkeypatch_module.context() as m: - m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") - m.setenv("PYTHON_EXECUTION_BACKEND", "dangerously_use_uv") - m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "code_interpreter,container") - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest_asyncio.fixture -async def mcp_disabled_client(mcp_disabled_server): - async with mcp_disabled_server.get_async_client() as async_client: - yield async_client - - -@pytest_asyncio.fixture -async def mcp_enabled_client(mcp_enabled_server): - async with mcp_enabled_server.get_async_client() as async_client: - yield async_client - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") -async def test_mcp_tool_env_flag_enabled(mcp_enabled_client: OpenAI, model_name: str): - response = await mcp_enabled_client.responses.create( - model=model_name, - # TODO: Ideally should be able to set max tool calls - # to prevent multi-turn, but it is not currently supported - # would speed up the test - input=( - "What's the first 4 digits after the decimal point of " - "cube root of `19910212 * 20250910`? " - "Show only the digits. The python interpreter is not stateful " - "and you must print to see the output." - ), - tools=[ - { - "type": "mcp", - "server_label": "code_interpreter", - # URL unused for DemoToolServer - "server_url": "http://localhost:8888", - } - ], - ) - assert response is not None - assert response.status == "completed" - assert response.usage.output_tokens_details.tool_output_tokens > 0 - - -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.skip(reason="Code interpreter tool is not available in CI yet.") -async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_name: str): - response = await mcp_disabled_client.responses.create( - model=model_name, - # TODO: Ideally should be able to set max tool calls - # to prevent multi-turn, but it is not currently supported - # would speed up the test - input=( - "What's the first 4 digits after the decimal point of " - "cube root of `19910212 * 20250910`? " - "Show only the digits. The python interpreter is not stateful " - "and you must print to see the output." - ), - tools=[ - { - "type": "mcp", - "server_label": "code_interpreter", - # URL unused for DemoToolServer - "server_url": "http://localhost:8888", - } - ], - ) - assert response is not None - assert response.status == "completed" - assert response.usage.output_tokens_details.tool_output_tokens == 0 diff --git a/tests/entrypoints/openai/test_response_api_mcp_tools_custom.py b/tests/entrypoints/openai/test_response_api_mcp_tools_custom.py new file mode 100644 index 000000000000..d618cbe743ea --- /dev/null +++ b/tests/entrypoints/openai/test_response_api_mcp_tools_custom.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import subprocess + +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module") +def memory_mcp_server(): + """Start Memory MCP server as subprocess.""" + from tests.utils import find_free_port + + from .memory_mcp_server import start_test_server + + # Find a free port + port = find_free_port() + + # Start memory MCP server using helper + process = start_test_server(port) + + yield f"http://localhost:{port}/sse", port + + # Cleanup + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + +@pytest.fixture(scope="module") +def memory_custom_server(monkeypatch_module, memory_mcp_server): + """vLLM server with Memory MCP tool as custom (not elevated).""" + server_url, port = memory_mcp_server + args = ["--enforce-eager", "--tool-server", f"localhost:{port}"] + + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + # NO GPT_OSS_SYSTEM_TOOL_MCP_LABELS - memory is custom tool + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def memory_custom_client(memory_custom_server): + async with memory_custom_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_memory_mcp_custom(memory_custom_client, model_name: str): + """Test Memory MCP tool as custom (not elevated). + + When memory is NOT in GPT_OSS_SYSTEM_TOOL_MCP_LABELS: + - Tool should be in developer message (not system message) + - Tool calls should be on commentary channel + - Tool responses should be on commentary channel + """ + response = await memory_custom_client.responses.create( + model=model_name, + instructions=( + "You must use the memory.store and memory.retrieve tools. " + "Never simulate tool execution." + ), + input=("Store the key 'test_key' with value 'test_value' and then retrieve it"), + tools=[ + { + "type": "mcp", + "server_label": "memory", + # URL unused, connection via --tool-server + "server_url": "http://unused", + "headers": {"x-memory-id": "test-session-custom"}, + } + ], + extra_body={"enable_response_messages": True}, + ) + + assert response is not None + assert response.status == "completed" + assert response.usage.output_tokens_details.tool_output_tokens > 0 + + # Verify input messages: Should have developer message with tool + developer_messages = [ + msg for msg in response.input_messages if msg["author"]["role"] == "developer" + ] + assert len(developer_messages) > 0, "Developer message expected for custom tools" + + # Verify output messages: Tool calls and responses on commentary channel + tool_call_found = False + tool_response_found = False + for message in response.output_messages: + recipient = message.get("recipient") + if recipient and recipient.startswith("memory."): + tool_call_found = True + assert message.get("channel") == "commentary", ( + "Tool call should be on commentary channel" + ) + author = message.get("author", {}) + if ( + author.get("role") == "tool" + and author.get("name") + and author.get("name").startswith("memory.") + ): + tool_response_found = True + assert message.get("channel") == "commentary", ( + "Tool response should be on commentary channel" + ) + + assert tool_call_found, "Should have found at least one memory tool call" + assert tool_response_found, "Should have found at least one memory tool response" + + # Verify McpCall items (tool invocations) + from openai.types.responses.response_output_item import McpCall + + mcp_calls = [ + item for item in reversed(response.output) if isinstance(item, McpCall) + ] + + assert len(mcp_calls) > 0, "Should have at least one McpCall" + + for mcp_call in mcp_calls: + # Verify it's a memory tool call + assert mcp_call.server_label == "memory" + assert mcp_call.name in ["store", "retrieve"] + + # Verify arguments make sense + assert mcp_call.arguments is not None + args = json.loads(mcp_call.arguments) + assert "key" in args + + # Verify output was populated + assert mcp_call.output is not None + assert len(mcp_call.output) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_memory_mcp_with_headers(memory_custom_client, model_name: str): + """Test Memory MCP tool with custom headers for memory isolation. + + Different x-memory-id headers should provide isolated memory spaces. + This tests that headers are properly passed through the MCP protocol. + """ + # First request with session-1 + response1 = await memory_custom_client.responses.create( + model=model_name, + instructions=( + "You must use the memory.store tool. Never simulate tool execution." + ), + input="Store the key 'isolated_key' with value 'session_1_value'", + tools=[ + { + "type": "mcp", + "server_label": "memory", + # URL unused, connection via --tool-server + "server_url": "http://unused", + "headers": {"x-memory-id": "session-1"}, + } + ], + ) + + assert response1.status == "completed" + assert response1.usage.output_tokens_details.tool_output_tokens > 0 + + # Second request with session-2 (different memory space) + response2 = await memory_custom_client.responses.create( + model=model_name, + instructions=( + "You must use the memory.retrieve tool. Never simulate tool execution." + ), + input="Retrieve the value for key 'isolated_key'", + tools=[ + { + "type": "mcp", + "server_label": "memory", + # URL unused, connection via --tool-server + "server_url": "http://unused", + "headers": {"x-memory-id": "session-2"}, + } + ], + ) + + assert response2.status == "completed" + # The key should NOT be found in session-2 (memory isolation working) + # Check McpCall output field for exact error message + from openai.types.responses.response_output_item import McpCall + + mcp_call_output = None + for item in response2.output: + if isinstance(item, McpCall) and item.output: + mcp_call_output = item.output + break + + # Memory isolation: key from session-1 should not be in session-2 + assert mcp_call_output is not None, "Should have McpCall with output" + assert "Key 'isolated_key' not found in memory space 'session-2'" in mcp_call_output diff --git a/tests/entrypoints/openai/test_response_api_mcp_tools_elevated.py b/tests/entrypoints/openai/test_response_api_mcp_tools_elevated.py new file mode 100644 index 000000000000..b06db753186d --- /dev/null +++ b/tests/entrypoints/openai/test_response_api_mcp_tools_elevated.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import subprocess + +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer, find_free_port +from .memory_mcp_server import start_test_server + +MODEL_NAME = "openai/gpt-oss-20b" + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module") +def memory_mcp_server(): + """Start Memory MCP server as subprocess.""" + # Find a free port + port = find_free_port() + + # Start memory MCP server using helper + process = start_test_server(port) + + yield f"http://localhost:{port}/sse", port + + # Cleanup + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait() + + +@pytest.fixture(scope="module") +def memory_elevated_server(monkeypatch_module, memory_mcp_server): + """vLLM server with Memory MCP tool elevated.""" + server_url, port = memory_mcp_server + args = ["--enforce-eager", "--tool-server", f"localhost:{port}"] + + with monkeypatch_module.context() as m: + m.setenv("VLLM_ENABLE_RESPONSES_API_STORE", "1") + # Use system instructions to ensure model follows directions + m.setenv("VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", "1") + m.setenv("GPT_OSS_SYSTEM_TOOL_MCP_LABELS", "memory") # Elevate memory tool + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def memory_elevated_client(memory_elevated_server): + async with memory_elevated_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_memory_mcp_elevated(memory_elevated_client, model_name: str): + """Test Memory MCP tool as elevated. + + When memory IS in GPT_OSS_SYSTEM_TOOL_MCP_LABELS: + - Tool should be in system message (not developer message) + - Tool calls should be on analysis channel + - Tool responses should be on analysis channel + """ + response = await memory_elevated_client.responses.create( + model=model_name, + instructions=( + "You must use the memory.store and memory.retrieve tools. " + "Never simulate tool execution. Call the tool using json " + "on the analysis channel like a normal system tool." + ), + input=( + "Store the key 'elevated_key' with value 'elevated_value' and retrieve it" + ), + tools=[ + { + "type": "mcp", + "server_label": "memory", + # URL unused, connection via --tool-server + "server_url": "http://unused", + "headers": {"x-memory-id": "test-session-elevated"}, + } + ], + extra_body={"enable_response_messages": True}, + ) + + assert response is not None + assert response.status == "completed" + assert response.usage.output_tokens_details.tool_output_tokens > 0 + + # Verify input messages: Should have system message with tool, NO developer message + # (since all tools are elevated) + developer_messages = [ + msg for msg in response.input_messages if msg["author"]["role"] == "developer" + ] + assert len(developer_messages) == 0, ( + "No developer message expected for elevated tools" + ) + + # Verify output messages: Tool calls and responses on analysis channel + tool_call_found = False + tool_response_found = False + for message in response.output_messages: + recipient = message.get("recipient") + if recipient and recipient.startswith("memory."): + tool_call_found = True + assert message.get("channel") == "analysis", ( + "Tool call should be on analysis channel" + ) + author = message.get("author", {}) + if ( + author.get("role") == "tool" + and author.get("name") + and author.get("name").startswith("memory.") + ): + tool_response_found = True + assert message.get("channel") == "analysis", ( + "Tool response should be on analysis channel" + ) + + assert tool_call_found, "Should have found at least one memory tool call" + assert tool_response_found, "Should have found at least one memory tool response" + + # Verify functional correctness + output_text = "" + for item in response.output: + if hasattr(item, "content"): + for content_item in item.content: + if hasattr(content_item, "text"): + output_text += content_item.text + assert ( + "elevated_value" in output_text.lower() or "successfully" in output_text.lower() + ) diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 0720c8aa5121..77d075bc4138 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -24,7 +24,10 @@ def server(): args = ["--enforce-eager", "--tool-server", "demo"] env_dict = dict( VLLM_ENABLE_RESPONSES_API_STORE="1", + # Use system instructions to ensure model follows directions + VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS="1", PYTHON_EXECUTION_BACKEND="dangerously_use_uv", + GPT_OSS_SYSTEM_TOOL_MCP_LABELS="code_interpreter", ) with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: @@ -439,26 +442,41 @@ async def test_web_search(client: OpenAI, model_name: str): async def test_code_interpreter(client: OpenAI, model_name: str): response = await client.responses.create( model=model_name, - # TODO: Ideally should be able to set max tool calls - # to prevent multi-turn, but it is not currently supported - # would speed up the test - input=( - "What's the first 4 digits after the decimal point of " - "cube root of `19910212 * 20250910`? " - "Show only the digits. The python interpreter is not stateful " - "and you must print to see the output." + instructions=( + "You must use the Python tool to execute code. Never simulate execution." ), + input="import random; print(random.randint(1, 1000000))", tools=[{"type": "code_interpreter", "container": {"type": "auto"}}], temperature=0.0, # More deterministic output in response + extra_body={"enable_response_messages": True}, ) assert response is not None assert response.status == "completed" assert response.usage.output_tokens_details.tool_output_tokens > 0 - for item in response.output: - if item.type == "message": - output_string = item.content[0].text - print("output_string: ", output_string, flush=True) - assert "5846" in output_string + + # Verify output messages: Tool calls and responses on analysis channel + tool_call_found = False + tool_response_found = False + for message in response.output_messages: + recipient = message.get("recipient") + if recipient and recipient.startswith("python"): + tool_call_found = True + assert message.get("channel") == "analysis", ( + "Tool call should be on analysis channel" + ) + author = message.get("author", {}) + if ( + author.get("role") == "tool" + and author.get("name") + and author.get("name").startswith("python") + ): + tool_response_found = True + assert message.get("channel") == "analysis", ( + "Tool response should be on analysis channel" + ) + + assert tool_call_found, "Should have found at least one Python tool call" + assert tool_response_found, "Should have found at least one Python tool response" def get_weather(latitude, longitude): @@ -672,22 +690,6 @@ async def test_function_calling_required(client: OpenAI, model_name: str): ) -@pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_system_message_with_tools(client: OpenAI, model_name: str): - from vllm.entrypoints.harmony_utils import get_system_message - - # Test with custom tools enabled - commentary channel should be available - sys_msg = get_system_message(with_custom_tools=True) - valid_channels = sys_msg.content[0].channel_config.valid_channels - assert "commentary" in valid_channels - - # Test with custom tools disabled - commentary channel should be removed - sys_msg = get_system_message(with_custom_tools=False) - valid_channels = sys_msg.content[0].channel_config.valid_channels - assert "commentary" not in valid_channels - - @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_function_calling_full_history(client: OpenAI, model_name: str): diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index 263b076db183..9a670d6db347 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -28,7 +28,7 @@ def append_output(self, output) -> None: async def call_tool(self): return [] - def need_builtin_tool_call(self) -> bool: + def need_server_side_tool_call(self) -> bool: return False def render_for_completion(self): diff --git a/tests/entrypoints/test_context.py b/tests/entrypoints/test_context.py index b0faa870a927..2a87634443f2 100644 --- a/tests/entrypoints/test_context.py +++ b/tests/entrypoints/test_context.py @@ -81,7 +81,7 @@ def mock_parser(): def test_single_turn_token_counting(): """Test token counting behavior for a single turn.""" # Create a context - context = HarmonyContext(messages=[], available_tools=[]) + context = HarmonyContext(messages=[], enabled_tool_namespaces=[]) # Create a mock RequestOutput with specific token counts mock_output = create_mock_request_output( @@ -109,7 +109,7 @@ def test_single_turn_token_counting(): async def test_multi_turn_token_counting(): """Test token counting behavior across multiple turns with tool output.""" # Create a context - context = HarmonyContext(messages=[], available_tools=["browser"]) + context = HarmonyContext(messages=[], enabled_tool_namespaces=["browser"]) # Simulate a conversation with 3 turns # Turn 1: prefill 5, decode 3, tool 7 @@ -159,7 +159,7 @@ async def test_multi_turn_token_counting(): def test_empty_output_tokens(): """Test behavior when RequestOutput has empty output tokens.""" - context = HarmonyContext(messages=[], available_tools=[]) + context = HarmonyContext(messages=[], enabled_tool_namespaces=[]) # Create a RequestOutput with empty output tokens mock_output = create_mock_request_output( @@ -179,7 +179,7 @@ def test_empty_output_tokens(): def test_missing_prompt_token_ids(): """Test behavior when RequestOutput has None prompt_token_ids.""" - context = HarmonyContext(messages=[], available_tools=[]) + context = HarmonyContext(messages=[], enabled_tool_namespaces=[]) mock_output = create_mock_request_output( prompt_token_ids=None, # No prompt token IDs @@ -200,7 +200,7 @@ def test_missing_prompt_token_ids(): def test_reasoning_tokens_counting(mock_parser): """Test that reasoning tokens are counted correctly.""" - context = HarmonyContext(messages=[], available_tools=[]) + context = HarmonyContext(messages=[], enabled_tool_namespaces=[]) # Mock parser to simulate reasoning channel mock_parser.current_channel = "analysis" # Reasoning channel @@ -220,7 +220,7 @@ def test_reasoning_tokens_counting(mock_parser): def test_zero_tokens_edge_case(): """Test behavior with all zero token counts.""" - context = HarmonyContext(messages=[], available_tools=[]) + context = HarmonyContext(messages=[], enabled_tool_namespaces=[]) # Create a request with empty lists (not None) for both prompt and # output tokens @@ -245,7 +245,7 @@ async def test_single_turn_no_tool_output(): """Test that first turn never generates tool output tokens.""" context = HarmonyContext( messages=[], - available_tools=["browser"], # Tools available + enabled_tool_namespaces=["browser"], # Tools available ) # Even with large prompt in first turn, no tool tokens should be counted @@ -268,7 +268,7 @@ async def test_negative_tool_tokens_edge_case(): tokens. We should log an error and clamp the value to 0.""" # Use patch to check if logger.error was called with patch("vllm.entrypoints.context.logger.error") as mock_log: - context = HarmonyContext(messages=[], available_tools=["browser"]) + context = HarmonyContext(messages=[], enabled_tool_namespaces=["browser"]) # First turn mock_output1 = create_mock_request_output( @@ -312,7 +312,7 @@ async def test_streaming_multi_turn_token_counting(mock_parser): message boundaries. """ # Create a streaming context - context = StreamingHarmonyContext(messages=[], available_tools=["browser"]) + context = StreamingHarmonyContext(messages=[], enabled_tool_namespaces=["browser"]) # Simulate three turns of conversation: # Turn 1: stream tokens one by one, then finish the message @@ -465,7 +465,9 @@ async def test_streaming_message_synchronization(mock_parser): recipient=Role.ASSISTANT, ) ] - context = StreamingHarmonyContext(messages=initial_messages, available_tools=[]) + context = StreamingHarmonyContext( + messages=initial_messages, enabled_tool_namespaces=[] + ) # Verify initial state assert len(context._messages) == 1 diff --git a/tests/entrypoints/test_harmony_utils.py b/tests/entrypoints/test_harmony_utils.py new file mode 100644 index 000000000000..8990c2f0cbdf --- /dev/null +++ b/tests/entrypoints/test_harmony_utils.py @@ -0,0 +1,1010 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for harmony_utils.py module.""" + +from unittest.mock import MagicMock, patch + +import pytest +from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseReasoningItem, +) +from openai.types.responses.response_output_item import McpCall +from openai_harmony import ( + Author, + Message, + ReasoningEffort, + Role, + StreamableParser, + ToolDescription, + ToolNamespaceConfig, +) + +from vllm.entrypoints.harmony_utils import ( + BUILTIN_TOOLS, + REASONING_EFFORT, + build_system_and_developer_messages, + create_function_tools_namespace, + create_tool_definition, + get_developer_message, + get_encoding, + get_stop_tokens_for_assistant_actions, + get_streamable_parser_for_assistant, + get_system_message, + get_user_message, + parse_chat_input, + parse_chat_output, + parse_output_into_messages, + parse_output_message, + parse_remaining_state, + parse_response_input, + render_for_completion, +) +from vllm.entrypoints.openai.protocol import ChatCompletionToolsParam + + +class TestConstants: + """Test module constants.""" + + def test_reasoning_effort_mapping(self): + """Test that REASONING_EFFORT contains correct mappings.""" + assert REASONING_EFFORT["high"] == ReasoningEffort.HIGH + assert REASONING_EFFORT["medium"] == ReasoningEffort.MEDIUM + assert REASONING_EFFORT["low"] == ReasoningEffort.LOW + assert len(REASONING_EFFORT) == 3 + + def test_builtin_tools(self): + """Test BUILTIN_TOOLS set contains expected tools.""" + assert "web_search_preview" in BUILTIN_TOOLS + assert "code_interpreter" in BUILTIN_TOOLS + assert "container" in BUILTIN_TOOLS + assert len(BUILTIN_TOOLS) == 3 + + +class TestGetEncoding: + """Test get_encoding() function.""" + + def test_get_encoding_returns_encoding(self): + """Test that get_encoding returns a harmony encoding.""" + encoding = get_encoding() + assert encoding is not None + assert hasattr(encoding, "render_conversation_for_completion") + + def test_get_encoding_caches_result(self): + """Test that get_encoding caches the result.""" + encoding1 = get_encoding() + encoding2 = get_encoding() + assert encoding1 is encoding2 + + +class TestCreateToolDefinition: + """Test create_tool_definition() function.""" + + def test_create_tool_definition_from_chat_completion_tool(self): + """Test creating tool definition from ChatCompletionToolsParam.""" + tool = ChatCompletionToolsParam( + type="function", + function={ + "name": "get_weather", + "description": "Get weather information", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + }, + ) + + tool_desc = create_tool_definition(tool) + + assert isinstance(tool_desc, ToolDescription) + assert tool_desc.name == "get_weather" + assert tool_desc.description == "Get weather information" + assert tool_desc.parameters["type"] == "object" + + +class TestCreateFunctionToolsNamespace: + """Test create_function_tools_namespace() function.""" + + def test_create_namespace_single_tool(self): + """Test creating namespace with a single function tool.""" + tool = ChatCompletionToolsParam( + type="function", + function={ + "name": "test_func", + "description": "Test function", + "parameters": {"type": "object", "properties": {}}, + }, + ) + + namespace = create_function_tools_namespace([tool]) + + assert isinstance(namespace, ToolNamespaceConfig) + assert namespace.name == "functions" + assert namespace.description == "" + assert len(namespace.tools) == 1 + assert namespace.tools[0].name == "test_func" + + def test_create_namespace_multiple_tools(self): + """Test creating namespace with multiple function tools.""" + tools = [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "func1", + "description": "Function 1", + "parameters": {"type": "object"}, + }, + ), + ChatCompletionToolsParam( + type="function", + function={ + "name": "func2", + "description": "Function 2", + "parameters": {"type": "object"}, + }, + ), + ] + + namespace = create_function_tools_namespace(tools) + + assert len(namespace.tools) == 2 + assert namespace.tools[0].name == "func1" + assert namespace.tools[1].name == "func2" + + def test_create_namespace_empty_tools(self): + """Test creating namespace with empty tools list.""" + namespace = create_function_tools_namespace([]) + assert len(namespace.tools) == 0 + + +class TestGetUserMessage: + """Test get_user_message() function.""" + + def test_get_user_message_simple(self): + """Test creating a simple user message.""" + msg = get_user_message("Hello!") + + assert isinstance(msg, Message) + assert msg.author.role == Role.USER + assert len(msg.content) == 1 + assert msg.content[0].text == "Hello!" + + def test_get_user_message_multiline(self): + """Test creating user message with multiline content.""" + content = "Line 1\nLine 2\nLine 3" + msg = get_user_message(content) + + assert msg.content[0].text == content + + +class TestGetSystemMessage: + """Test get_system_message() function.""" + + def test_get_system_message_minimal(self): + """Test creating system message with minimal parameters.""" + msg = get_system_message() + + assert isinstance(msg, Message) + assert msg.author.role == Role.SYSTEM + assert len(msg.content) == 1 + + def test_get_system_message_with_model_identity(self): + """Test system message with model identity.""" + msg = get_system_message(model_identity="gpt-oss-test") + + assert msg.content[0].model_identity == "gpt-oss-test" + + def test_get_system_message_with_reasoning_effort(self): + """Test system message with reasoning effort.""" + msg = get_system_message(reasoning_effort="high") + + assert msg.content[0].reasoning_effort == ReasoningEffort.HIGH + + def test_get_system_message_with_start_date(self): + """Test system message with custom start date.""" + date = "2024-01-01" + msg = get_system_message(start_date=date) + + assert msg.content[0].conversation_start_date == date + + @patch( + "vllm.entrypoints.harmony_utils.envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", + True, + ) + def test_get_system_message_with_instructions_in_system(self): + """Test system message with instructions in system content.""" + instructions = "Be concise" + msg = get_system_message(instructions=instructions) + + assert instructions in msg.content[0].model_identity + + def test_get_system_message_with_elevated_tools(self): + """Test system message with elevated tool namespaces.""" + tool_namespace = ToolNamespaceConfig( + name="elevated_tool", + description="Test elevated tool", + tools=[], + ) + + msg = get_system_message(elevated_namespace_descriptions=[tool_namespace]) + + # Verify message was created - internals vary by harmony version + assert isinstance(msg, Message) + assert msg.author.role == Role.SYSTEM + + def test_get_system_message_with_custom_tools(self): + """Test that commentary channel is available with custom tools.""" + tool_namespace = ToolNamespaceConfig( + name="custom_tool", + description="Test custom tool", + tools=[], + ) + + msg = get_system_message(custom_namespace_descriptions=[tool_namespace]) + + valid_channels = msg.content[0].channel_config.valid_channels + assert "commentary" in valid_channels + + def test_get_system_message_without_custom_tools(self): + """Test that commentary channel is removed without custom tools.""" + msg = get_system_message() + + valid_channels = msg.content[0].channel_config.valid_channels + assert "commentary" not in valid_channels + + +class TestGetDeveloperMessage: + """Test get_developer_message() function.""" + + def test_get_developer_message_none_when_empty(self): + """Test that None is returned when no instructions or tools.""" + msg = get_developer_message() + assert msg is None + + @patch( + "vllm.entrypoints.harmony_utils.envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", + False, + ) + def test_get_developer_message_with_instructions(self): + """Test developer message with instructions.""" + instructions = "Test instructions" + msg = get_developer_message(instructions=instructions) + + assert msg is not None + assert msg.author.role == Role.DEVELOPER + assert msg.content[0].instructions == instructions + + @patch( + "vllm.entrypoints.harmony_utils.envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", + True, + ) + def test_get_developer_message_no_instructions_when_system_mode(self): + """Test developer message doesn't include instructions in system mode.""" + instructions = "Test instructions" + msg = get_developer_message(instructions=instructions) + + assert msg is None + + def test_get_developer_message_with_tool_namespaces(self): + """Test developer message with tool namespaces.""" + tool_namespace = ToolNamespaceConfig( + name="test_tool", + description="Test tool", + tools=[], + ) + + msg = get_developer_message(tool_namespaces=[tool_namespace]) + + assert msg is not None + assert msg.author.role == Role.DEVELOPER + # Just verify message was created - internals vary by harmony version + + @patch( + "vllm.entrypoints.harmony_utils.envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS", + False, + ) + def test_get_developer_message_with_both(self): + """Test developer message with instructions and tools.""" + instructions = "Test instructions" + tool_namespace = ToolNamespaceConfig( + name="test_tool", + description="Test tool", + tools=[], + ) + + msg = get_developer_message( + instructions=instructions, tool_namespaces=[tool_namespace] + ) + + assert msg is not None + assert msg.content[0].instructions == instructions + # Just verify message was created - internals vary by harmony version + + +class TestParseResponseInput: + """Test parse_response_input() function.""" + + def test_parse_user_message_string(self): + """Test parsing user message with string content.""" + response_msg = {"role": "user", "content": "Hello"} + + msg = parse_response_input(response_msg, []) + + assert msg.author.role == Role.USER + assert msg.content[0].text == "Hello" + + def test_parse_user_message_with_type(self): + """Test parsing user message with input_text type.""" + response_msg = { + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + } + + msg = parse_response_input(response_msg, []) + + assert msg.author.role == Role.USER + assert msg.content[0].text == "Hello" + + def test_parse_system_message_converts_to_developer(self): + """Test that system messages are converted to developer role.""" + response_msg = {"role": "system", "content": "System prompt"} + + msg = parse_response_input(response_msg, []) + + assert msg.author.role == Role.DEVELOPER + assert "Instructions:" in msg.content[0].text + assert "System prompt" in msg.content[0].text + + def test_parse_assistant_message(self): + """Test parsing assistant message.""" + response_msg = {"role": "assistant", "content": "Hello"} + + msg = parse_response_input(response_msg, []) + + assert msg.author.role == Role.ASSISTANT + assert msg.channel == "final" + assert msg.content[0].text == "Hello" + + def test_parse_function_call_output(self): + """Test parsing function call output message.""" + # First create a function call + function_call = ResponseFunctionToolCall( + call_id="call_123", + type="function_call", + name="get_weather", + arguments='{"location": "SF"}', + id="fc_123", + ) + + output_msg = { + "type": "function_call_output", + "call_id": "call_123", + "output": "70 degrees", + } + + msg = parse_response_input(output_msg, [function_call]) + + assert msg.author.role == Role.TOOL + assert msg.author.name == "functions.get_weather" + assert msg.content[0].text == "70 degrees" + + def test_parse_function_call_output_no_matching_call(self): + """Test parsing function call output with no matching call.""" + output_msg = { + "type": "function_call_output", + "call_id": "call_nonexistent", + "output": "result", + } + + with pytest.raises(ValueError, match="No call message found"): + parse_response_input(output_msg, []) + + def test_parse_reasoning_item(self): + """Test parsing reasoning item.""" + response_msg = { + "type": "reasoning", + "content": [{"text": "Let me think..."}], + } + + msg = parse_response_input(response_msg, []) + + assert msg.author.role == Role.ASSISTANT + assert msg.content[0].text == "Let me think..." + + def test_parse_function_call(self): + """Test parsing function call.""" + response_msg = { + "type": "function_call", + "name": "get_weather", + "arguments": '{"location": "SF"}', + } + + msg = parse_response_input(response_msg, []) + + assert msg.author.role == Role.ASSISTANT + assert msg.channel == "commentary" + assert msg.recipient == "functions.get_weather" + assert msg.content_type == "json" + + def test_parse_unknown_type(self): + """Test parsing unknown message type.""" + response_msg = {"type": "unknown_type"} + + with pytest.raises(ValueError, match="Unknown input type"): + parse_response_input(response_msg, []) + + +class TestParseChatInput: + """Test parse_chat_input() function.""" + + def test_parse_simple_user_message(self): + """Test parsing simple user message.""" + chat_msg = {"role": "user", "content": "Hello"} + + msgs = parse_chat_input(chat_msg) + + assert len(msgs) == 1 + assert msgs[0].author.role == Role.USER + assert msgs[0].content[0].text == "Hello" + + def test_parse_assistant_with_tool_calls(self): + """Test parsing assistant message with tool calls.""" + chat_msg = { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "get_weather", + "arguments": '{"location": "SF"}', + } + }, + { + "function": { + "name": "get_time", + "arguments": '{"timezone": "PST"}', + } + }, + ], + } + + msgs = parse_chat_input(chat_msg) + + assert len(msgs) == 2 + assert all(msg.author.role == Role.ASSISTANT for msg in msgs) + assert all(msg.channel == "commentary" for msg in msgs) + assert msgs[0].recipient == "functions.get_weather" + assert msgs[1].recipient == "functions.get_time" + + def test_parse_tool_message(self): + """Test parsing tool role message.""" + chat_msg = { + "role": "tool", + "name": "get_weather", + "content": "70 degrees", + } + + msgs = parse_chat_input(chat_msg) + + assert len(msgs) == 1 + assert msgs[0].author.role == Role.TOOL + assert msgs[0].author.name == "functions.get_weather" + assert msgs[0].content[0].text == "70 degrees" + assert msgs[0].channel == "commentary" + + def test_parse_tool_message_with_array_content(self): + """Test parsing tool message with array content.""" + chat_msg = { + "role": "tool", + "name": "search", + "content": [ + {"type": "text", "text": "Result 1"}, + {"type": "text", "text": "Result 2"}, + ], + } + + msgs = parse_chat_input(chat_msg) + + assert len(msgs) == 1 + assert msgs[0].content[0].text == "Result 1Result 2" + + def test_parse_message_with_array_content(self): + """Test parsing message with array content.""" + chat_msg = { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": " World"}, + ], + } + + msgs = parse_chat_input(chat_msg) + + assert len(msgs) == 1 + assert len(msgs[0].content) == 2 + assert msgs[0].content[0].text == "Hello" + assert msgs[0].content[1].text == " World" + + +class TestRenderForCompletion: + """Test render_for_completion() function.""" + + def test_render_simple_conversation(self): + """Test rendering a simple conversation.""" + messages = [ + get_user_message("What is 2+2?"), + ] + + token_ids = render_for_completion(messages) + + assert isinstance(token_ids, list) + assert len(token_ids) > 0 + assert all(isinstance(tid, int) for tid in token_ids) + + def test_render_multi_turn(self): + """Test rendering multi-turn conversation.""" + messages = [ + get_user_message("Hello"), + Message.from_role_and_content(Role.ASSISTANT, "Hi there!"), + get_user_message("How are you?"), + ] + + token_ids = render_for_completion(messages) + + assert isinstance(token_ids, list) + assert len(token_ids) > 0 + + +class TestParseOutputMessage: + """Test parse_output_message() function.""" + + def test_parse_final_channel_message(self): + """Test parsing message on final channel.""" + msg = Message.from_role_and_content(Role.ASSISTANT, "Hello!").with_channel( + "final" + ) + + output_items = parse_output_message(msg) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseOutputMessage) + assert output_items[0].type == "message" + assert output_items[0].content[0].text == "Hello!" + + def test_parse_analysis_channel_message(self): + """Test parsing message on analysis channel.""" + msg = Message.from_role_and_content( + Role.ASSISTANT, "Let me think..." + ).with_channel("analysis") + + output_items = parse_output_message(msg) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseReasoningItem) + assert output_items[0].type == "reasoning" + assert output_items[0].content[0].text == "Let me think..." + + def test_parse_function_call_on_commentary(self): + """Test parsing function call on commentary channel.""" + msg = ( + Message.from_role_and_content(Role.ASSISTANT, '{"location": "SF"}') + .with_channel("commentary") + .with_recipient("functions.get_weather") + .with_content_type("json") + ) + + output_items = parse_output_message(msg) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseFunctionToolCall) + assert output_items[0].type == "function_call" + assert output_items[0].name == "get_weather" + assert output_items[0].arguments == '{"location": "SF"}' + + def test_parse_mcp_call_on_commentary(self): + """Test parsing MCP call on commentary channel.""" + msg = ( + Message.from_role_and_content(Role.ASSISTANT, '{"key": "value"}') + .with_channel("commentary") + .with_recipient("memory.store") + .with_content_type("json") + ) + + output_items = parse_output_message(msg) + + assert len(output_items) == 1 + assert isinstance(output_items[0], McpCall) + assert output_items[0].type == "mcp_call" + assert output_items[0].name == "store" + assert output_items[0].server_label == "memory" + assert output_items[0].arguments == '{"key": "value"}' + assert output_items[0].output is None + + def test_parse_tool_response_updates_mcp_call(self): + """Test that tool response updates matching MCP call.""" + # Create an MCP call + mcp_call = McpCall( + id="mcp_123", + type="mcp_call", + name="store", + server_label="memory", + arguments='{"key": "value"}', + output=None, + error=None, + ) + + # Create tool response + tool_msg = Message.from_author_and_content( + Author.new(Role.TOOL, "memory.store"), "Success" + ) + + output_items = parse_output_message(tool_msg, output_items_so_far=[mcp_call]) + + assert len(output_items) == 0 # Tool response doesn't create new items + assert mcp_call.output == "Success" # But updates the existing call + + def test_parse_tool_response_no_matching_call(self): + """Test tool response with no matching call.""" + tool_msg = Message.from_author_and_content( + Author.new(Role.TOOL, "nonexistent.tool"), "Result" + ) + + # Should log error but not crash + output_items = parse_output_message(tool_msg, output_items_so_far=[]) + + assert len(output_items) == 0 + + def test_parse_builtin_tool_on_commentary_becomes_reasoning(self): + """Test that built-in tools on commentary become reasoning items.""" + msg = ( + Message.from_role_and_content(Role.ASSISTANT, "print('hello')") + .with_channel("commentary") + .with_recipient("python") + ) + + output_items = parse_output_message(msg) + + assert len(output_items) == 1 + assert isinstance(output_items[0], ResponseReasoningItem) + + def test_parse_non_assistant_message_returns_empty(self): + """Test that non-assistant messages return empty list.""" + msg = Message.from_role_and_content(Role.USER, "Hello") + + output_items = parse_output_message(msg) + + assert len(output_items) == 0 + + def test_parse_unknown_channel_raises_error(self): + """Test that unknown channel raises error.""" + msg = Message.from_role_and_content(Role.ASSISTANT, "Test").with_channel( + "unknown_channel" + ) + + with pytest.raises(ValueError, match="Unknown channel"): + parse_output_message(msg) + + +class TestParseRemainingState: + """Test parse_remaining_state() function.""" + + def test_parse_remaining_analysis_content(self): + """Test parsing remaining analysis content.""" + parser = MagicMock(spec=StreamableParser) + parser.current_content = "Incomplete thought..." + parser.current_role = Role.ASSISTANT + parser.current_channel = "analysis" + parser.current_recipient = None + + items = parse_remaining_state(parser) + + assert len(items) == 1 + assert isinstance(items[0], ResponseReasoningItem) + assert items[0].content[0].text == "Incomplete thought..." + + def test_parse_remaining_final_content(self): + """Test parsing remaining final content.""" + parser = MagicMock(spec=StreamableParser) + parser.current_content = "Incomplete answer..." + parser.current_role = Role.ASSISTANT + parser.current_channel = "final" + parser.current_recipient = None + + items = parse_remaining_state(parser) + + assert len(items) == 1 + assert isinstance(items[0], ResponseOutputMessage) + assert items[0].status == "incomplete" + assert items[0].content[0].text == "Incomplete answer..." + + def test_parse_remaining_empty_content(self): + """Test parsing with empty content.""" + parser = MagicMock(spec=StreamableParser) + parser.current_content = "" + + items = parse_remaining_state(parser) + + assert len(items) == 0 + + def test_parse_remaining_non_assistant_role(self): + """Test parsing with non-assistant role.""" + parser = MagicMock(spec=StreamableParser) + parser.current_content = "Some content" + parser.current_role = Role.USER + + items = parse_remaining_state(parser) + + assert len(items) == 0 + + def test_parse_remaining_browser_recipient_skipped(self): + """Test that browser recipients are skipped.""" + parser = MagicMock(spec=StreamableParser) + parser.current_content = "Search query" + parser.current_role = Role.ASSISTANT + parser.current_channel = "commentary" + parser.current_recipient = "browser.search" + + items = parse_remaining_state(parser) + + assert len(items) == 0 + + +class TestGetStopTokens: + """Test get_stop_tokens_for_assistant_actions() function.""" + + def test_get_stop_tokens(self): + """Test getting stop tokens.""" + stop_tokens = get_stop_tokens_for_assistant_actions() + + assert isinstance(stop_tokens, list) + assert len(stop_tokens) > 0 + assert all(isinstance(t, int) for t in stop_tokens) + + +class TestGetStreamableParser: + """Test get_streamable_parser_for_assistant() function.""" + + def test_get_streamable_parser(self): + """Test getting streamable parser.""" + parser = get_streamable_parser_for_assistant() + + assert isinstance(parser, StreamableParser) + # Parser doesn't expose role directly, just check it's created + + +class TestParseOutputIntoMessages: + """Test parse_output_into_messages() function.""" + + def test_parse_simple_tokens(self): + """Test parsing simple token sequence.""" + # Get a simple message and render it + msg = Message.from_role_and_content(Role.ASSISTANT, "Hello").with_channel( + "final" + ) + token_ids = render_for_completion([get_user_message("Hi"), msg]) + + # Parse back + parser = parse_output_into_messages(token_ids) + + assert isinstance(parser, StreamableParser) + assert len(parser.messages) >= 0 + + +class TestParseChatOutput: + """Test parse_chat_output() function.""" + + def test_parse_during_reasoning(self): + """Test parsing output stopped during reasoning.""" + # Create a message with just reasoning + msg = Message.from_role_and_content(Role.ASSISTANT, "Test").with_channel( + "analysis" + ) + + token_ids = render_for_completion([get_user_message("Hi"), msg]) + + reasoning, final, is_tool_call = parse_chat_output(token_ids) + + assert isinstance(reasoning, (str, type(None))) + assert isinstance(final, (str, type(None))) + assert isinstance(is_tool_call, bool) + + def test_parse_complete_output(self): + """Test parsing complete output with reasoning and final.""" + # This is a integration-style test + token_ids = render_for_completion([get_user_message("Hello")]) + + reasoning, final, is_tool_call = parse_chat_output(token_ids) + + # Just verify types - actual content depends on model + assert reasoning is None or isinstance(reasoning, str) + assert final is None or isinstance(final, str) + assert isinstance(is_tool_call, bool) + + +class TestBuildSystemAndDeveloperMessages: + """Test build_system_and_developer_messages() function.""" + + def test_build_with_no_tools(self): + """Test building messages with no tools.""" + messages = build_system_and_developer_messages( + request_tools=[], + tool_server=None, + ) + + assert len(messages) == 1 # Just system message + assert messages[0].author.role == Role.SYSTEM + + def test_build_with_function_tools(self): + """Test building messages with function tools.""" + tools = [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object"}, + }, + ) + ] + + messages = build_system_and_developer_messages( + request_tools=tools, + tool_server=None, + ) + + # Should have system + developer messages + assert len(messages) == 2 + assert messages[0].author.role == Role.SYSTEM + assert messages[1].author.role == Role.DEVELOPER + + @patch( + "vllm.entrypoints.harmony_utils.envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS", + ["elevated_tool"], + ) + def test_build_with_elevated_mcp_tools(self): + """Test building messages with elevated MCP tools.""" + from types import SimpleNamespace + + # Mock tool server + tool_server = MagicMock() + tool_server.has_namespace.return_value = True + tool_namespace = ToolNamespaceConfig( + name="elevated_tool", + description="Elevated tool", + tools=[], + ) + tool_server.get_tool_description.return_value = tool_namespace + + # Create tool with attributes (not dict) + tools = [SimpleNamespace(type="mcp", server_label="elevated_tool")] + + messages = build_system_and_developer_messages( + request_tools=tools, + tool_server=tool_server, + ) + + # Should have just system message (elevated tools go in system) + assert len(messages) == 1 + assert messages[0].author.role == Role.SYSTEM + + def test_build_with_custom_mcp_tools(self): + """Test building messages with custom MCP tools.""" + from types import SimpleNamespace + + # Mock tool server + tool_server = MagicMock() + tool_server.has_namespace.return_value = True + tool_namespace = ToolNamespaceConfig( + name="custom_tool", + description="Custom tool", + tools=[], + ) + tool_server.get_tool_description.return_value = tool_namespace + + # Create tool with attributes (not dict) + tools = [SimpleNamespace(type="mcp", server_label="custom_tool")] + + messages = build_system_and_developer_messages( + request_tools=tools, + tool_server=tool_server, + ) + + # Should have system + developer messages + assert len(messages) == 2 + assert messages[0].author.role == Role.SYSTEM + assert messages[1].author.role == Role.DEVELOPER + + def test_build_with_instructions(self): + """Test building messages with custom instructions.""" + messages = build_system_and_developer_messages( + request_tools=[], + tool_server=None, + instructions="Be concise", + ) + + # Instructions should be in system or developer message + assert len(messages) >= 1 + + def test_build_with_reasoning_effort(self): + """Test building messages with reasoning effort.""" + messages = build_system_and_developer_messages( + request_tools=[], + tool_server=None, + reasoning_effort="high", + ) + + assert len(messages) >= 1 + assert messages[0].content[0].reasoning_effort == ReasoningEffort.HIGH + + def test_build_with_missing_mcp_namespace(self): + """Test building messages with missing MCP namespace.""" + from types import SimpleNamespace + + tool_server = MagicMock() + tool_server.has_namespace.return_value = False + tool_server.harmony_tool_descriptions = {} + + # Create tool with attributes (not dict) + tools = [SimpleNamespace(type="mcp", server_label="nonexistent")] + + with pytest.raises(ValueError, match="not available in tool server"): + build_system_and_developer_messages( + request_tools=tools, + tool_server=tool_server, + ) + + def test_build_with_invalid_tool_type(self): + """Test building messages with invalid tool type.""" + from types import SimpleNamespace + + # Create tool with attributes (not dict) + tools = [SimpleNamespace(type="invalid_type")] + + error_msg = "should be of type 'mcp' or 'function'" + with pytest.raises(ValueError, match=error_msg): + build_system_and_developer_messages( + request_tools=tools, + tool_server=None, + ) + + +class TestIntegration: + """Integration tests for harmony_utils.""" + + def test_round_trip_user_message(self): + """Test round-trip encoding and parsing of user message.""" + msg = get_user_message("What is 2+2?") + token_ids = render_for_completion([msg]) + + assert len(token_ids) > 0 + assert all(isinstance(tid, int) for tid in token_ids) + + def test_function_tools_end_to_end(self): + """Test creating function tools and building messages.""" + tools = [ + ChatCompletionToolsParam( + type="function", + function={ + "name": "calculator", + "description": "Perform calculations", + "parameters": { + "type": "object", + "properties": {"expression": {"type": "string"}}, + }, + }, + ) + ] + + messages = build_system_and_developer_messages( + request_tools=tools, + tool_server=None, + instructions="Use the calculator when needed", + ) + + # Should have system and developer messages + assert len(messages) == 2 + + # Verify we can render them + token_ids = render_for_completion(messages) + assert len(token_ids) > 0 diff --git a/tests/utils.py b/tests/utils.py index 8fee50708438..68fb5b05c41f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,6 +10,7 @@ import os import random import signal +import socket import subprocess import sys import tempfile @@ -91,6 +92,24 @@ def _nvml(): """Path to root of the vLLM repository.""" +def find_free_port() -> int: + """ + Find and return a free port number. + + This is useful for starting test servers on dynamic ports to avoid + conflicts when running tests in parallel or when a specific port is + already in use. + + Returns: + int: A free port number + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + class RemoteOpenAIServer: DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key diff --git a/vllm/entrypoints/context.py b/vllm/entrypoints/context.py index c694bcfaaa75..f70a0742b38b 100644 --- a/vllm/entrypoints/context.py +++ b/vllm/entrypoints/context.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from contextlib import AsyncExitStack -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from openai.types.responses.tool import Mcp from openai_harmony import Author, Message, Role, StreamState, TextContent @@ -16,6 +16,7 @@ get_streamable_parser_for_assistant, render_for_completion, ) +from vllm.entrypoints.mcp.mcp_utils import call_mcp_tool from vllm.entrypoints.tool import Tool from vllm.entrypoints.tool_server import ToolServer from vllm.outputs import RequestOutput @@ -25,25 +26,6 @@ logger = logging.getLogger(__name__) -# This is currently needed as the tool type doesn't 1:1 match the -# tool namespace, which is what is used to look up the -# connection to the tool server -_TOOL_NAME_TO_TYPE_MAP = { - "browser": "web_search_preview", - "python": "code_interpreter", - "container": "container", -} - - -def _map_tool_name_to_tool_type(tool_name: str) -> str: - if tool_name not in _TOOL_NAME_TO_TYPE_MAP: - available_tools = ", ".join(_TOOL_NAME_TO_TYPE_MAP.keys()) - raise ValueError( - f"Built-in tool name '{tool_name}' not defined in mapping. " - f"Available tools: {available_tools}" - ) - return _TOOL_NAME_TO_TYPE_MAP[tool_name] - class TurnTokens: """Tracks token counts for a single conversation turn.""" @@ -72,7 +54,7 @@ async def call_tool(self) -> list[Message]: pass @abstractmethod - def need_builtin_tool_call(self) -> bool: + def need_server_side_tool_call(self) -> bool: pass @abstractmethod @@ -111,7 +93,7 @@ def append_output(self, output) -> None: self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) - def need_builtin_tool_call(self) -> bool: + def need_server_side_tool_call(self) -> bool: return False async def call_tool(self) -> list[Message]: @@ -137,13 +119,20 @@ class HarmonyContext(ConversationContext): def __init__( self, messages: list, - available_tools: list[str], + enabled_tool_namespaces: list[str], ): + """Initialize HarmonyContext for managing conversation state. + + Args: + messages: Initial conversation messages + enabled_tool_namespaces: List of all enabled tool namespaces + (includes both elevated and custom MCP tools) + """ self._messages = messages self.finish_reason: str | None = None - self.available_tools = available_tools + self.available_tools = enabled_tool_namespaces self._tool_sessions: dict[str, ClientSession | Tool] = {} - self.called_tools: set[str] = set() + self.called_namespaces: set[str] = set() self.parser = get_streamable_parser_for_assistant() self.num_init_messages = len(messages) @@ -279,14 +268,74 @@ def _update_decode_token_usage(self, output: RequestOutput) -> int: def messages(self) -> list: return self._messages - def need_builtin_tool_call(self) -> bool: + def _resolve_namespace(self, recipient: str) -> str: + """Map recipient to tool namespace. + + Most tools use recipient prefix as namespace + (e.g., "browser.search" → "browser"). + Exception: "python" → "code_interpreter" for gpt-oss specifically. + + Args: + recipient: The recipient string from the message + + Returns: + The namespace string + """ + if recipient.startswith("python"): + return "code_interpreter" + return recipient.split(".")[0] if "." in recipient else recipient + + def _resolve_tool_name(self, recipient: str) -> str: + """Map recipient to tool name. + + Most tools use recipient suffix as tool_name + (e.g., "browser.search" → "search"). + Exception: "python" → "python" for gpt-oss specifically. + + Args: + recipient: The recipient string from the message + + Returns: + The tool_name string + """ + if recipient.startswith("python"): + return "python" + return recipient.split(".")[-1] if "." in recipient else recipient + + def need_server_side_tool_call(self) -> bool: + """Check if the last message requires a server-side tool call. + + Returns: + True if recipient is set, not a client-side function, + and namespace is available + False otherwise + """ + if not self.messages: + return False + last_msg = self.messages[-1] recipient = last_msg.recipient - return recipient is not None and ( - recipient.startswith("browser.") - or recipient.startswith("python") - or recipient.startswith("container.") - ) + + if not recipient: + return False + + # Client-side function tools are handled by client + if recipient.startswith("functions."): + return False + + # Validate that the namespace is actually available + namespace = self._resolve_namespace(recipient) + if namespace not in self.available_tools: + logger.warning( + "Model requested unknown tool namespace: %s (from recipient: %s). " + "Available: %s. Ignoring tool call.", + namespace, + recipient, + self.available_tools, + ) + return False + + return True async def call_tool(self) -> list[Message]: if not self.messages: @@ -294,68 +343,49 @@ async def call_tool(self) -> list[Message]: last_msg = self.messages[-1] recipient = last_msg.recipient if recipient is not None: - if recipient.startswith("browser."): - return await self.call_search_tool( - self._tool_sessions["browser"], last_msg + namespace = self._resolve_namespace(recipient) + if namespace not in self._tool_sessions: + available = list(self._tool_sessions.keys()) + raise ValueError( + f"Tool session for namespace '{namespace}' not found. " + f"Model requested recipient '{recipient}' but the " + f"Available namespaces are: {available}" ) - elif recipient.startswith("python"): - return await self.call_python_tool( - self._tool_sessions["python"], last_msg + tool_session = self._tool_sessions[namespace] + if isinstance(tool_session, Tool): + return await tool_session.get_result(self) + + tool_name = self._resolve_tool_name(recipient) + # Using str here to do str -> json error handling + # in one spot in call_mcp_tool + tool_args_str = "" + # code_interpreter is special as the model outputs code not json + if namespace == "code_interpreter": + tool_args_str = json.dumps( + { + "code": last_msg.content[0].text, + } ) - elif recipient.startswith("container."): - return await self.call_container_tool( - self._tool_sessions["container"], last_msg + else: + tool_args_str = last_msg.content[0].text + + self.called_namespaces.add(namespace) + tool_output_str = await call_mcp_tool( + tool_session, tool_name, tool_args_str + ) + return [ + Message( + author=Author(role=Role.TOOL, name=recipient), + content=[TextContent(text=tool_output_str)], + recipient=Role.ASSISTANT, + channel=last_msg.channel, ) + ] raise ValueError("No tool call found") def render_for_completion(self) -> list[int]: return render_for_completion(self.messages) - async def call_search_tool( - self, tool_session: Union["ClientSession", Tool], last_msg: Message - ) -> list[Message]: - self.called_tools.add("browser") - if isinstance(tool_session, Tool): - return await tool_session.get_result(self) - tool_name = last_msg.recipient.split(".")[1] - args = json.loads(last_msg.content[0].text) - result = await tool_session.call_tool(tool_name, args) - result_str = result.content[0].text - content = TextContent(text=result_str) - author = Author(role=Role.TOOL, name=last_msg.recipient) - return [ - Message( - author=author, - content=[content], - recipient=Role.ASSISTANT, - channel=last_msg.channel, - ) - ] - - async def call_python_tool( - self, tool_session: Union["ClientSession", Tool], last_msg: Message - ) -> list[Message]: - self.called_tools.add("python") - if isinstance(tool_session, Tool): - return await tool_session.get_result(self) - param = { - "code": last_msg.content[0].text, - } - result = await tool_session.call_tool("python", param) - result_str = result.content[0].text - - content = TextContent(text=result_str) - author = Author(role=Role.TOOL, name="python") - - return [ - Message( - author=author, - content=[content], - channel=last_msg.channel, - recipient=Role.ASSISTANT, - ) - ] - async def init_tool_sessions( self, tool_server: ToolServer | None, @@ -364,55 +394,17 @@ async def init_tool_sessions( mcp_tools: dict[str, Mcp], ): if tool_server: - for tool_name in self.available_tools: - if tool_name not in self._tool_sessions: - tool_type = _map_tool_name_to_tool_type(tool_name) + for namespace in self.available_tools: + if namespace not in self._tool_sessions: headers = ( - mcp_tools[tool_type].headers if tool_type in mcp_tools else None + mcp_tools[namespace].headers if namespace in mcp_tools else None ) tool_session = await exit_stack.enter_async_context( - tool_server.new_session(tool_name, request_id, headers) + tool_server.new_session(namespace, request_id, headers) ) - self._tool_sessions[tool_name] = tool_session + self._tool_sessions[namespace] = tool_session exit_stack.push_async_exit(self.cleanup_session) - async def call_container_tool( - self, tool_session: Union["ClientSession", Tool], last_msg: Message - ) -> list[Message]: - """ - Call container tool. Expect this to be run in a stateful docker - with command line terminal. - The official container tool would at least - expect the following format: - - for tool name: exec - - args: - { - "cmd":List[str] "command to execute", - "workdir":optional[str] "current working directory", - "env":optional[object/dict] "environment variables", - "session_name":optional[str] "session name", - "timeout":optional[int] "timeout in seconds", - "user":optional[str] "user name", - } - """ - self.called_tools.add("container") - if isinstance(tool_session, Tool): - return await tool_session.get_result(self) - tool_name = last_msg.recipient.split(".")[1].split(" ")[0] - args = json.loads(last_msg.content[0].text) - result = await tool_session.call_tool(tool_name, args) - result_str = result.content[0].text - content = TextContent(text=result_str) - author = Author(role=Role.TOOL, name=last_msg.recipient) - return [ - Message( - author=author, - content=[content], - recipient=Role.ASSISTANT, - channel=last_msg.channel, - ) - ] - async def cleanup_session(self, *args, **kwargs) -> None: """Can be used as coro to used in __aexit__""" @@ -427,7 +419,7 @@ async def cleanup_tool_session(tool_session): await asyncio.gather( *( cleanup_tool_session(self._tool_sessions[tool]) - for tool in self.called_tools + for tool in self.called_namespaces ) ) diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index fe581e5484e1..4e4298c7ba6a 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -19,6 +19,7 @@ ActionSearch, ResponseFunctionWebSearch, ) +from openai.types.responses.response_output_item import McpCall from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent, ) @@ -36,6 +37,7 @@ SystemContent, TextContent, ToolDescription, + ToolNamespaceConfig, load_harmony_encoding, ) @@ -44,8 +46,12 @@ ChatCompletionToolsParam, ResponseInputOutputItem, ) +from vllm.entrypoints.tool_server import ToolServer +from vllm.logger import init_logger from vllm.utils import random_uuid +logger = init_logger(__name__) + REASONING_EFFORT = { "high": ReasoningEffort.HIGH, "medium": ReasoningEffort.MEDIUM, @@ -65,8 +71,92 @@ } -def has_custom_tools(tool_types: list[str]) -> bool: - return not set(tool_types).issubset(BUILTIN_TOOLS) +def build_system_and_developer_messages( + # Tool for ResponsesAPI, ChatCompletionToolsParam for CompletionsAPI + request_tools: list[Tool] | list[ChatCompletionToolsParam], + tool_server: ToolServer | None, + instructions: str | None = None, + reasoning_effort: Literal["high", "medium", "low"] | None = None, + start_date: str | None = None, + model_identity: str | None = None, +) -> list[Message]: + """Builds system and developer messages for a Harmony request. + + This function standardizes message construction between Responses API + and Chat API. It handles tool elevation, message construction, and + namespace collection. + + Args: + request_tools: List of tools (already normalized to MCP format) + tool_server: Tool server for fetching tool descriptions + instructions: Custom instructions for the assistant + reasoning_effort: Reasoning effort level + start_date: Start date for the conversation + model_identity: Model identity string + + Returns: + List of system message and developer message if nedeed + """ + messages = [] + + # Get elevation list from environment + elevated_namespaces = envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS or [] + + # Classify tools by elevation status + elevated_namespace_descriptions = [] + custom_namespace_descriptions = [] + function_tools = [] + + for tool in request_tools: + if tool.type == "mcp": + if tool_server and tool_server.has_namespace(tool.server_label): + tool_description = tool_server.get_tool_description(tool.server_label) + else: + available = ( + list(tool_server.harmony_tool_descriptions.keys()) + if tool_server + else [] + ) + raise ValueError( + f"MCP namespace '{tool.server_label}' in the request " + f"is not available in tool server. " + f"Available namespaces: {available}" + ) + if tool.server_label in elevated_namespaces: + elevated_namespace_descriptions.append(tool_description) + else: + custom_namespace_descriptions.append(tool_description) + # type is function for responses and completions luckily + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError( + f"Tools should be of type 'mcp' or 'function', got {tool.type}" + f" Tool type conversion should happen before this point. " + ) + if function_tools: + custom_namespace_descriptions.append( + create_function_tools_namespace(function_tools) + ) + + sys_msg = get_system_message( + model_identity=model_identity, + reasoning_effort=reasoning_effort, + start_date=start_date, + elevated_namespace_descriptions=elevated_namespace_descriptions, + custom_namespace_descriptions=custom_namespace_descriptions, + instructions=instructions, + ) + messages.append(sys_msg) + + dev_msg = get_developer_message( + instructions=instructions, + tool_namespaces=custom_namespace_descriptions, + ) + if dev_msg is not None: + messages.append(dev_msg) + + return messages def get_encoding(): @@ -76,16 +166,73 @@ def get_encoding(): return _harmony_encoding +def create_function_tools_namespace( + function_tools: list[Tool | ChatCompletionToolsParam], +) -> ToolNamespaceConfig: + """ + Create a Harmony ToolNamespaceConfig from function tools. + + Function tools are converted to a namespace called "functions" that can be + included in either system or developer messages. + + Args: + function_tools: List of function-type tools + + Returns: + ToolNamespaceConfig with namespace="functions" and all function tool definitions + """ + tool_descriptions = [create_tool_definition(tool) for tool in function_tools] + + # Create namespace config with "functions" as the namespace name + namespace_config = ToolNamespaceConfig( + name="functions", + # Empty to match harmony implementation of functions namespace + description="", + tools=tool_descriptions, + ) + + return namespace_config + + +def create_tool_definition(tool: ChatCompletionToolsParam | Tool): + """Convert a tool to a Harmony ToolDescription.""" + if isinstance(tool, ChatCompletionToolsParam): + return ToolDescription.new( + name=tool.function.name, + description=tool.function.description, + parameters=tool.function.parameters, + ) + return ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) + + def get_system_message( model_identity: str | None = None, reasoning_effort: Literal["high", "medium", "low"] | None = None, start_date: str | None = None, - browser_description: str | None = None, - python_description: str | None = None, - container_description: str | None = None, + elevated_namespace_descriptions: list | None = None, + custom_namespace_descriptions: list | None = None, instructions: str | None = None, - with_custom_tools: bool = False, ) -> Message: + """ + Construct system message for gpt-oss models. + + Args: + model_identity: Model identity string + reasoning_effort: Reasoning effort level (high/medium/low) + start_date: Conversation start date + elevated_namespace_descriptions: List of ToolNamespaceConfig + for elevated namespaces + custom_namespace_descriptions: List of ToolNamespaceConfig + for custom namespaces + instructions: User-provided instructions + + Returns: + System message for Harmony protocol + """ sys_msg_content = SystemContent.new() if model_identity is not None: sys_msg_content = sys_msg_content.with_model_identity(model_identity) @@ -103,13 +250,14 @@ def get_system_message( # NOTE(woosuk): This brings non-determinism in vLLM. Be careful. start_date = datetime.datetime.now().strftime("%Y-%m-%d") sys_msg_content = sys_msg_content.with_conversation_start_date(start_date) - if browser_description is not None: - sys_msg_content = sys_msg_content.with_tools(browser_description) - if python_description is not None: - sys_msg_content = sys_msg_content.with_tools(python_description) - if container_description is not None: - sys_msg_content = sys_msg_content.with_tools(container_description) - if not with_custom_tools: + + # Elevated namespaces are registered in the system message + if elevated_namespace_descriptions is not None: + for tool_namespace in elevated_namespace_descriptions: + sys_msg_content = sys_msg_content.with_tools(tool_namespace) + + # If no custom namespaces are provided, remove the "commentary" channel + if not custom_namespace_descriptions: channel_config = sys_msg_content.channel_config invalid_channel = "commentary" new_config = ChannelConfig.require_channels( @@ -120,52 +268,38 @@ def get_system_message( return sys_msg -def create_tool_definition(tool: ChatCompletionToolsParam | Tool): - if isinstance(tool, ChatCompletionToolsParam): - return ToolDescription.new( - name=tool.function.name, - description=tool.function.description, - parameters=tool.function.parameters, - ) - return ToolDescription.new( - name=tool.name, - description=tool.description, - parameters=tool.parameters, - ) - - def get_developer_message( instructions: str | None = None, - tools: list[Tool | ChatCompletionToolsParam] | None = None, -) -> Message: + tool_namespaces: list | None = None, +) -> Message | None: + """ + Construct developer message for custom (non-elevated) tools. + + Args: + instructions: User-provided instructions + tool_namespaces: List of ToolNamespaceConfig for all custom tools + (MCP and function) + + Returns: + Developer message for Harmony protocol, if needed + """ + developer_instructions = ( + instructions if not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS else None + ) + if not developer_instructions and not tool_namespaces: + return None + dev_msg_content = DeveloperContent.new() - if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: - dev_msg_content = dev_msg_content.with_instructions(instructions) - if tools is not None: - function_tools: list[Tool | ChatCompletionToolsParam] = [] - for tool in tools: - if tool.type in ( - "web_search_preview", - "code_interpreter", - "container", - "mcp", - ): - # These are built-in tools that are added to the system message. - # Adding in MCP for now until we support MCP tools executed - # server side - pass + if developer_instructions: + dev_msg_content = dev_msg_content.with_instructions(developer_instructions) + + # Add all tool namespaces + if tool_namespaces: + for tool_namespace in tool_namespaces: + # Use with_tools instead of with_function_tools to simplify + # adding non-functions namespaces to developer message + dev_msg_content = dev_msg_content.with_tools(tool_namespace) - elif tool.type == "function": - function_tools.append(tool) - else: - raise ValueError(f"tool type {tool.type} not supported") - if function_tools: - function_tool_descriptions = [ - create_tool_definition(tool) for tool in function_tools - ] - dev_msg_content = dev_msg_content.with_function_tools( - function_tool_descriptions - ) dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) return dev_msg @@ -287,14 +421,50 @@ def render_for_completion(messages: list[Message]) -> list[int]: return token_ids -def parse_output_message(message: Message) -> list[ResponseOutputItem]: +def parse_output_message( + message: Message, + output_items_so_far: list[ResponseOutputItem] | None = None, +) -> list[ResponseOutputItem]: """ Parse a Harmony message into a list of output response items. + + Args: + message: The message to parse + output_items_so_far: List of output items parsed so far. When we see + a tool response message, we search backward to find the most recent + matching McpCall (by tool name) that has no output yet. """ + # Handle tool response messages (look-behind pattern) + if message.author.role == "tool": + # This is a tool response. Search backward to find matching tool call. + if not output_items_so_far: + logger.warning( + "Tool response with no prior output items: %s", message.author.name + ) + return [] + + # Find the most recent McpCall that matches this tool and has no output + tool_name = message.author.name # e.g., "memory.store" + matching_call = None + + for item in reversed(output_items_so_far): + if isinstance(item, McpCall): + call_full_name = f"{item.server_label}.{item.name}" + if call_full_name == tool_name and item.output is None: + matching_call = item + break + + if matching_call: + matching_call.output = message.content[0].text if message.content else None + return [] + else: + # We should error here, but it wouldn't make much sense + # before we switch to using McpCall for all tool calls + output + logger.error("Tool call output not output for tool: %s", tool_name) + return [] + if message.author.role != "assistant": - # This is a message from a tool to the assistant (e.g., search result). - # Don't include it in the final output for now. This aligns with - # OpenAI's behavior on models like o4-mini. + # This is some other role (not assistant, not tool) - skip it return [] output_items: list[ResponseOutputItem] = [] @@ -360,6 +530,9 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: or recipient.startswith("browser") or recipient.startswith("container") ): + # Built-in tools on commentary channel → reasoning items + # For legacy compatibility for now + # TODO: Use McpCall here too for content in message.content: reasoning_item = ResponseReasoningItem( id=f"rs_{random_uuid()}", @@ -373,8 +546,22 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: status=None, ) output_items.append(reasoning_item) - else: - raise ValueError(f"Unknown recipient: {recipient}") + elif recipient is not None: + # Any other non-function recipient on commentary channel → MCP call + namespace = recipient.split(".")[0] if "." in recipient else recipient + tool_name = recipient.split(".")[1] if "." in recipient else recipient + + for content in message.content: + mcp_call = McpCall( + id=f"mcp_{random_uuid()}", + type="mcp_call", + name=tool_name, + server_label=namespace, + arguments=content.text, + output=None, + error=None, + ) + output_items.append(mcp_call) elif message.channel == "final": contents = [] for content in message.content: diff --git a/vllm/entrypoints/mcp/__init__.py b/vllm/entrypoints/mcp/__init__.py new file mode 100644 index 000000000000..46782ba6a6b3 --- /dev/null +++ b/vllm/entrypoints/mcp/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""MCP (Model Context Protocol) utilities.""" diff --git a/vllm/entrypoints/mcp/mcp_utils.py b/vllm/entrypoints/mcp/mcp_utils.py new file mode 100644 index 000000000000..0df3557b545e --- /dev/null +++ b/vllm/entrypoints/mcp/mcp_utils.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""MCP tool utilities for backward compatibility and tool normalization.""" + +import json +from typing import TYPE_CHECKING + +from openai.types.responses.tool import ( + CodeInterpreter, + Mcp, + Tool, + WebSearchPreviewTool, +) + +if TYPE_CHECKING: + from mcp import ClientSession + + +def normalize_tool_to_mcp(tool: Tool) -> Tool: + """ + Convert legacy tool types to MCP format for unified handling. + + This provides backward compatibility by converting legacy tool types + (CodeInterpreter, WebSearchPreviewTool) to the unified MCP format. + All downstream code can then handle tools uniformly via the MCP protocol. + + Args: + tool: Any Tool type from OpenAI protocol + + Returns: + - If already MCP: returns as-is + - If CodeInterpreter: converts to MCP with server_label="code_interpreter" + Note: container field is discarded (not needed for MCP protocol) + - If WebSearchPreviewTool: converts to MCP with server_label="browser" + Note: search_context_size and user_location fields are discarded + - Otherwise: returns as-is (function tools, etc. pass through unchanged) + """ + # Already MCP - return as-is + if isinstance(tool, Mcp): + return tool + + # CodeInterpreter → MCP with server_label="code_interpreter" + # Note: Discarding container field as it's not needed for MCP protocol + if isinstance(tool, CodeInterpreter): + return Mcp( + type="mcp", + server_label="code_interpreter", + ) + + # WebSearchPreviewTool → MCP with server_label="browser" + # Note: Discarding search_context_size and user_location fields + # These could be passed as headers in the future if needed + if isinstance(tool, WebSearchPreviewTool): + return Mcp( + type="mcp", + server_label="browser", + ) + + # All other tool types (FunctionTool, FileSearchTool, etc.) pass through unchanged + return tool + + +async def call_mcp_tool( + tool_session: "ClientSession", + tool_name: str, + tool_args_str: str, +) -> str: + """Generic MCP tool call handler + + Args: + tool_session: The MCP client session or Tool instance + tool_name: The tool name to call + tool_args: The args for the tool call + + Returns: + A string representation of the MCP call output + """ + # TODO: Env variable for returning json parsing error to model + # instead of erroring the request + tool_args = json.loads(tool_args_str) + result = await tool_session.call_tool(tool_name, tool_args) + # TODO: Support handling structured MCP call output + result_str = result.content[0].text + return result_str diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 96a0947c4bd3..d50f98f440ef 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1660,7 +1660,7 @@ async def init_app_state( await tool_server.init_and_validate() elif args.tool_server: tool_server = MCPToolServer() - await tool_server.add_tool_server(args.tool_server) + await tool_server.add_mcp_server(args.tool_server) else: tool_server = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 96525f206859..564e934f842d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -23,10 +23,9 @@ make_tool_call_id, ) from vllm.entrypoints.harmony_utils import ( - get_developer_message, + build_system_and_developer_messages, get_stop_tokens_for_assistant_actions, get_streamable_parser_for_assistant, - get_system_message, parse_chat_input, parse_chat_output, render_for_completion, @@ -1726,17 +1725,15 @@ def _make_request_with_harmony( # if the model supports it. TODO: Support browsing. assert not self.supports_browsing assert not self.supports_code_interpreter - sys_msg = get_system_message( + + # Use unified helper to build system and developer messages + sys_dev_messages = build_system_and_developer_messages( + request_tools=request.tools if request.tools else [], + tool_server=None, # Chat API doesn't use tool server + instructions=None, reasoning_effort=request.reasoning_effort, - browser_description=None, - python_description=None, - with_custom_tools=request.tools is not None, ) - messages.append(sys_msg) - - # Add developer message. - dev_msg = get_developer_message(tools=request.tools) - messages.append(dev_msg) + messages.extend(sys_dev_messages) # Add user message. for chat_msg in request.messages: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a041950ffd20..f97fa286242b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1225,7 +1225,7 @@ async def _generate_with_builtin_tools( # NOTE(woosuk): The stop condition is handled by the engine. yield context - if not context.need_builtin_tool_call(): + if not context.need_server_side_tool_call(): # The model did not ask for a tool call, so we're done. break diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 3b9015efd305..40f90c5815da 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -61,17 +61,16 @@ StreamingHarmonyContext, ) from vllm.entrypoints.harmony_utils import ( - get_developer_message, + build_system_and_developer_messages, get_stop_tokens_for_assistant_actions, - get_system_message, get_user_message, - has_custom_tools, parse_output_message, parse_remaining_state, parse_response_input, render_for_completion, ) from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.mcp.mcp_utils import normalize_tool_to_mcp from vllm.entrypoints.openai.protocol import ( DeltaMessage, ErrorResponse, @@ -283,6 +282,12 @@ async def create_responses( else: prev_response = None + # Normalize all tools to MCP format EARLY (before message construction). + # This converts legacy CodeInterpreter and WebSearchPreviewTool to Mcp objects. + # Must happen before _make_request_with_harmony() so that message construction + # sees normalized tools with server_label attributes. + request.tools = [normalize_tool_to_mcp(tool) for tool in request.tools] + try: lora_request = self._maybe_get_adapters(request) model_name = self.models.model_name(lora_request) @@ -314,20 +319,13 @@ async def create_responses( # Schedule the request and get the result generator. generators: list[AsyncGenerator[ConversationContext, None]] = [] - builtin_tool_list: list[str] = [] - if self.use_harmony and self.tool_server is not None: - if self.tool_server.has_tool("browser"): - builtin_tool_list.append("browser") - if self.tool_server.has_tool("python"): - builtin_tool_list.append("python") - if self.tool_server.has_tool("container"): - builtin_tool_list.append("container") - - if self.tool_server is not None: - available_tools = builtin_tool_list - else: - assert len(builtin_tool_list) == 0 - available_tools = [] + enabled_tool_namespaces: set[str] = set() + if self.use_harmony: + # Add all MCP tools from the request (they've been normalized already) + for tool in request.tools: + if tool.type == "mcp": + enabled_tool_namespaces.add(tool.server_label) + try: for i, engine_prompt in enumerate(engine_prompts): maybe_error = self._validate_generator_input(engine_prompt) @@ -351,9 +349,15 @@ async def create_responses( context: ConversationContext if self.use_harmony: if request.stream: - context = StreamingHarmonyContext(messages, available_tools) + context = StreamingHarmonyContext( + messages, + list(enabled_tool_namespaces), + ) else: - context = HarmonyContext(messages, available_tools) + context = HarmonyContext( + messages, + list(enabled_tool_namespaces), + ) else: context = SimpleContext() generator = self._generate_with_builtin_tools( @@ -506,6 +510,7 @@ async def _initialize_tool_sessions( # we should only initialize the tool session if the request needs tools if len(request.tools) == 0: return + mcp_tools = { tool.server_label: tool for tool in request.tools if tool.type == "mcp" } @@ -782,8 +787,15 @@ def _make_response_output_items_with_harmony( ) -> list[ResponseOutputItem]: output_items: list[ResponseOutputItem] = [] num_init_messages = context.num_init_messages + for msg in context.messages[num_init_messages:]: - output_items.extend(parse_output_message(msg)) + output_items.extend( + parse_output_message( + msg, + output_items_so_far=output_items, + ) + ) + # Handle the generation stopped in the middle (if any). last_items = parse_remaining_state(context.parser) if last_items: @@ -837,56 +849,17 @@ def _construct_input_messages_with_harmony( ) -> list[OpenAIHarmonyMessage]: messages: list[OpenAIHarmonyMessage] = [] if prev_response is None: - # New conversation. + # New conversation - build system and developer messages reasoning_effort = request.reasoning.effort if request.reasoning else None - tool_types = [tool.type for tool in request.tools] - - # Allow the MCP Tool type to enable built in tools if the - # server_label is allowlisted in - # envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS - if envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS: - for tool in request.tools: - if ( - tool.type == "mcp" - and tool.server_label in envs.GPT_OSS_SYSTEM_TOOL_MCP_LABELS - ): - tool_types.append(tool.server_label) - enable_browser = ( - "web_search_preview" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("browser") - ) - enable_code_interpreter = ( - "code_interpreter" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("python") - ) - enable_container = ( - "container" in tool_types - and self.tool_server is not None - and self.tool_server.has_tool("container") - ) - with_custom_tools = has_custom_tools(tool_types) - sys_msg = get_system_message( - reasoning_effort=reasoning_effort, - browser_description=self.tool_server.get_tool_description("browser") - if enable_browser and self.tool_server is not None - else None, - python_description=self.tool_server.get_tool_description("python") - if enable_code_interpreter and self.tool_server is not None - else None, - container_description=self.tool_server.get_tool_description("container") - if enable_container and self.tool_server is not None - else None, + + # Use unified helper to build messages + sys_dev_messages = build_system_and_developer_messages( + request_tools=request.tools, + tool_server=self.tool_server, instructions=request.instructions, - with_custom_tools=with_custom_tools, + reasoning_effort=reasoning_effort, ) - messages.append(sys_msg) - if with_custom_tools: - dev_msg = get_developer_message( - instructions=request.instructions, tools=request.tools - ) - messages.append(dev_msg) + messages.extend(sys_dev_messages) else: # Continue the previous conversation. # FIXME(woosuk): Currently, request params like reasoning and @@ -1637,7 +1610,7 @@ async def _process_harmony_streaming_events( previous_item = ctx.parser.messages[-1] if ( self.tool_server is not None - and self.tool_server.has_tool("browser") + and self.tool_server.has_namespace("browser") and previous_item.recipient is not None and previous_item.recipient.startswith("browser.") ): @@ -1722,7 +1695,7 @@ async def _process_harmony_streaming_events( if ( self.tool_server is not None - and self.tool_server.has_tool("python") + and self.tool_server.has_namespace("code_interpreter") and previous_item.recipient is not None and previous_item.recipient.startswith("python") ): diff --git a/vllm/entrypoints/tool_server.py b/vllm/entrypoints/tool_server.py index 0d83031ef69f..436a54dd2236 100644 --- a/vllm/entrypoints/tool_server.py +++ b/vllm/entrypoints/tool_server.py @@ -4,6 +4,8 @@ from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import TYPE_CHECKING, Any +from mcp import ClientSession +from mcp.client.sse import sse_client from openai_harmony import ToolDescription, ToolNamespaceConfig from vllm.entrypoints.tool import HarmonyBrowserTool, HarmonyPythonTool, Tool @@ -16,9 +18,6 @@ async def list_server_and_tools(server_url: str): - from mcp import ClientSession - from mcp.client.sse import sse_client - async with ( sse_client(url=server_url) as streams, ClientSession(*streams) as session, @@ -28,6 +27,7 @@ async def list_server_and_tools(server_url: str): return initialize_response, list_tools_response +# TODO: This is a harmony specific change, migrate to harmony_utils def trim_schema(schema: dict) -> dict: # Turn JSON Schema from MCP generated into Harmony's variant. if "title" in schema: @@ -73,42 +73,41 @@ def post_process_tools_description( class ToolServer(ABC): @abstractmethod - def has_tool(self, tool_name: str) -> bool: + def has_namespace(self, namespace: str) -> bool: """ - Return True if the tool is supported, False otherwise. + Return True if the namespace is supported, False otherwise. """ pass @abstractmethod - def get_tool_description(self, tool_name: str) -> ToolNamespaceConfig | None: + def get_tool_description(self, namespace: str) -> ToolNamespaceConfig | None: """ - Return the tool description for the given tool name. - If the tool is not supported, return None. + Return the tool description for the given namespace. + If the namespace is not supported, return None. """ pass @abstractmethod def new_session( - self, tool_name: str, session_id: str, headers: dict[str, str] | None = None + self, namespace: str, session_id: str, headers: dict[str, str] | None = None ) -> AbstractAsyncContextManager[Any]: """ - Create a session for the tool. + Create a session for the namespace. """ ... class MCPToolServer(ToolServer): def __init__(self): - try: - import mcp # noqa: F401 - except ImportError: - raise ImportError( - "mcp is not installed. Please run `pip install mcp` to use " - "MCPToolServer." - ) from None self.harmony_tool_descriptions = {} - async def add_tool_server(self, server_url: str): + async def add_mcp_server(self, server_url: str): + """ + Add an MCP server. + + Args: + server_url: URL to connect to + """ tool_urls = server_url.split(",") self.harmony_tool_descriptions = {} self.urls: dict[str, str] = {} @@ -116,10 +115,12 @@ async def add_tool_server(self, server_url: str): url = f"http://{url}/sse" initialize_response, list_tools_response = await list_server_and_tools(url) + server_name = initialize_response.serverInfo.name + list_tools_response = post_process_tools_description(list_tools_response) tool_from_mcp = ToolNamespaceConfig( - name=initialize_response.serverInfo.name, + name=server_name, # This is the namespace (== server_label) description=initialize_response.instructions, tools=[ ToolDescription.new( @@ -130,39 +131,43 @@ async def add_tool_server(self, server_url: str): for tool in list_tools_response.tools ], ) - self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp - if tool_from_mcp.name not in self.urls: - self.urls[tool_from_mcp.name] = url - else: + + # Check for namespace collision (keep existing logic) + if tool_from_mcp.name in self.urls: logger.warning( - "Tool %s already exists. Ignoring duplicate tool server %s", - tool_from_mcp.name, + "MCP server at %s provides namespace '%s' which is already " + "registered from %s. Ignoring duplicate registration.", url, + tool_from_mcp.name, + self.urls[tool_from_mcp.name], ) + continue + + # Add to registry + self.harmony_tool_descriptions[tool_from_mcp.name] = tool_from_mcp + self.urls[tool_from_mcp.name] = url + logger.info( "MCPToolServer initialized with tools: %s", list(self.harmony_tool_descriptions.keys()), ) - def has_tool(self, tool_name: str): - return tool_name in self.harmony_tool_descriptions + def has_namespace(self, namespace: str): + return namespace in self.harmony_tool_descriptions - def get_tool_description(self, tool_name: str): - return self.harmony_tool_descriptions.get(tool_name) + def get_tool_description(self, namespace: str): + return self.harmony_tool_descriptions.get(namespace) @asynccontextmanager async def new_session( - self, tool_name: str, session_id: str, headers: dict[str, str] | None = None + self, namespace: str, session_id: str, headers: dict[str, str] | None = None ): - from mcp import ClientSession - from mcp.client.sse import sse_client - - url = self.urls.get(tool_name) + url = self.urls.get(namespace) request_headers = {"x-session-id": session_id} if headers is not None: request_headers.update(headers) if not url: - raise KeyError(f"Tool '{tool_name}' is not supported") + raise KeyError(f"Namespace '{namespace}' is not supported") async with ( sse_client(url=url, headers=request_headers) as streams, ClientSession(*streams) as session, @@ -171,6 +176,7 @@ async def new_session( yield session +# TODO: Move this as it is harmony specific, as the tools return harmony messages class DemoToolServer(ToolServer): def __init__(self): self.tools: dict[str, Tool] = {} @@ -182,28 +188,28 @@ async def init_and_validate(self): if browser_tool.enabled: self.tools["browser"] = browser_tool if python_tool.enabled: - self.tools["python"] = python_tool + self.tools["code_interpreter"] = python_tool # Use namespace, not "python" logger.info( "DemoToolServer initialized with tools: %s", list(self.tools.keys()) ) - def has_tool(self, tool_name: str) -> bool: - return tool_name in self.tools + def has_namespace(self, namespace: str) -> bool: + return namespace in self.tools - def get_tool_description(self, tool_name: str) -> ToolNamespaceConfig | None: - if tool_name not in self.tools: + def get_tool_description(self, namespace: str) -> ToolNamespaceConfig | None: + if namespace not in self.tools: return None - if tool_name == "browser": + if namespace == "browser": return ToolNamespaceConfig.browser() - elif tool_name == "python": + elif namespace == "code_interpreter": return ToolNamespaceConfig.python() else: - raise ValueError(f"Unknown tool {tool_name}") + raise ValueError(f"Unknown namespace {namespace}") @asynccontextmanager async def new_session( - self, tool_name: str, session_id: str, headers: dict[str, str] | None = None + self, namespace: str, session_id: str, headers: dict[str, str] | None = None ): - if tool_name not in self.tools: - raise KeyError(f"Tool '{tool_name}' is not supported") - yield self.tools[tool_name] + if namespace not in self.tools: + raise KeyError(f"Namespace '{namespace}' is not supported") + yield self.tools[namespace] diff --git a/vllm/envs.py b/vllm/envs.py index c3686477d88d..e08513e95091 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1376,12 +1376,13 @@ def get_vllm_port() -> int | None: # The number of SMs to allocate for communication kernels when running DBO # the rest of the SMs on the device will be allocated to compute "VLLM_DBO_COMM_SMS": lambda: int(os.getenv("VLLM_DBO_COMM_SMS", "20")), - # Valid values are container,code_interpreter,web_search_preview - # ex GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter - "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": env_list_with_choices( - "GPT_OSS_SYSTEM_TOOL_MCP_LABELS", - [], - ["container", "code_interpreter", "web_search_preview"], + # Comma-separated list of MCP server labels to elevate to system prompt + # Default to maintain backwards compatibility as the code had special + # treatment for these three tools previously + "GPT_OSS_SYSTEM_TOOL_MCP_LABELS": lambda: ( + ["web_search_preview", "container", "code_interpreter"] + if "GPT_OSS_SYSTEM_TOOL_MCP_LABELS" not in os.environ + else os.environ["GPT_OSS_SYSTEM_TOOL_MCP_LABELS"].split(",") ), # Enable max_autotune & coordinate_descent_tuning in inductor_config # to compile static shapes passed from compile_sizes in compilation_config