Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions tests/entrypoints/openai/test_serving_responses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from contextlib import AsyncExitStack
from unittest.mock import AsyncMock, MagicMock

import pytest
import pytest_asyncio

from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.openai.protocol import ResponsesRequest
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.tool_server import ToolServer


class MockConversationContext(ConversationContext):
"""Mock conversation context for testing"""

def __init__(self):
self.init_tool_sessions_called = False
self.init_tool_sessions_args = None
self.init_tool_sessions_kwargs = None

def append_output(self, output) -> None:
pass

async def call_tool(self):
return []

def need_builtin_tool_call(self) -> bool:
return False

def render_for_completion(self):
return []

async def init_tool_sessions(self, tool_server, exit_stack, request_id,
mcp_tools):
self.init_tool_sessions_called = True
self.init_tool_sessions_args = (tool_server, exit_stack, request_id,
mcp_tools)

async def cleanup_session(self) -> None:
pass


@pytest.fixture
def mock_serving_responses():
"""Create a mock OpenAIServingResponses instance"""
serving_responses = MagicMock(spec=OpenAIServingResponses)
serving_responses.tool_server = MagicMock(spec=ToolServer)
return serving_responses


@pytest.fixture
def mock_context():
"""Create a mock conversation context"""
return MockConversationContext()


@pytest.fixture
def mock_exit_stack():
"""Create a mock async exit stack"""
return MagicMock(spec=AsyncExitStack)


class TestInitializeToolSessions:
"""Test class for _initialize_tool_sessions method"""

@pytest_asyncio.fixture
async def serving_responses_instance(self):
"""Create a real OpenAIServingResponses instance for testing"""
# Create minimal mocks for required dependencies
engine_client = MagicMock()
engine_client.get_model_config = AsyncMock()

model_config = MagicMock()
model_config.hf_config.model_type = "test"
model_config.get_diff_sampling_param.return_value = {}

models = MagicMock()

tool_server = MagicMock(spec=ToolServer)

# Create the actual instance
instance = OpenAIServingResponses(
engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
tool_server=tool_server,
)

return instance

@pytest.mark.asyncio
async def test_initialize_tool_sessions(self, serving_responses_instance,
mock_context, mock_exit_stack):
"""Test that method works correctly with only MCP tools"""

request = ResponsesRequest(input="test input", tools=[])

# Call the method
await serving_responses_instance._initialize_tool_sessions(
request, mock_context, mock_exit_stack)
assert mock_context.init_tool_sessions_called is False

# Create only MCP tools
tools = [
{
"type": "web_search_preview"
},
{
"type": "code_interpreter",
"container": {
"type": "auto"
}
},
]

request = ResponsesRequest(input="test input", tools=tools)

# Call the method
await serving_responses_instance._initialize_tool_sessions(
request, mock_context, mock_exit_stack)

# Verify that init_tool_sessions was called
assert mock_context.init_tool_sessions_called
31 changes: 19 additions & 12 deletions vllm/entrypoints/openai/serving_responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,19 @@ def _make_request_with_harmony(

return messages, [prompt_token_ids], [engine_prompt]

async def _initialize_tool_sessions(self, request: ResponsesRequest,
context: ConversationContext,
exit_stack: AsyncExitStack):
# we should only initialize the tool session if the request needs tools
if len(request.tools) == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i believe @Hanchenli also mentioned this issue. can we also cover this in unit test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me know if you have any thoughts @Hanchenli ?

I added a UT, should be ready for review

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks good to me. The issue I mentioned was that the model might generate function_call requests even if we do not provide tools to them.

return
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id, mcp_tools)

async def responses_full_generator(
self,
request: ResponsesRequest,
Expand All @@ -461,12 +474,8 @@ async def responses_full_generator(

async with AsyncExitStack() as exit_stack:
try:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id, mcp_tools)
await self._initialize_tool_sessions(request, context,
exit_stack)
async for _ in result_generator:
pass
except asyncio.CancelledError:
Expand Down Expand Up @@ -1650,12 +1659,10 @@ def _increment_sequence_number_and_return(
async with AsyncExitStack() as exit_stack:
processer = None
if self.use_harmony:
mcp_tools = {
tool.server_label: tool
for tool in request.tools if tool.type == "mcp"
}
await context.init_tool_sessions(self.tool_server, exit_stack,
request.request_id, mcp_tools)
# TODO: in streaming, we noticed this bug:
# https://github.com/vllm-project/vllm/issues/25697
await self._initialize_tool_sessions(request, context,
exit_stack)
processer = self._process_harmony_streaming_events
else:
processer = self._process_simple_streaming_events
Expand Down