diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 0cc064cd8f12..c480adcc11bf 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -3,6 +3,7 @@ # imports for structured outputs tests import json +from collections import defaultdict import jsonschema import openai # use the official client for correctness check @@ -13,6 +14,11 @@ import torch from openai import BadRequestError +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.sampling_params import SamplingParams + from ...utils import RemoteOpenAIServer # any model with a chat template should work here @@ -815,3 +821,203 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA assert chat_output.keys() == invocation_output.keys() assert chat_output["choices"] == invocation_output["choices"] + + +# Test n parameter for chat completions +# Tests that the n parameter works correctly for regular sampling +# (non-beam search) in chat completions, addressing issue #34305. + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_chat_completion_n_parameter_non_streaming( + client: openai.AsyncOpenAI, model_name: str +): + """Test that n parameter returns multiple choices for non-streaming requests.""" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the opposite of big?"}, + ] + + # Test with n=3 + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=20, + temperature=0.7, + n=3, + stream=False, + ) + + assert len(chat_completion.choices) == 3 + + # Verify each choice has content and correct index + for i, choice in enumerate(chat_completion.choices): + assert choice.index == i + assert choice.message.content is not None + assert len(choice.message.content) > 0 + + # Verify all responses are different (highly likely with temperature > 0) + contents = [choice.message.content for choice in chat_completion.choices] + assert len(set(contents)) > 1, "Expected different responses with n=3" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_chat_completion_n_parameter_streaming( + client: openai.AsyncOpenAI, model_name: str +): + """Test that n parameter returns multiple choices for streaming requests.""" + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"}, + ] + + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=15, + temperature=0.7, + n=2, + stream=True, + ) + + # Collect all chunks using defaultdict for dynamic handling + chunks_by_index = defaultdict(list) + async for chunk in stream: + for choice in chunk.choices: + if choice.delta.content: + chunks_by_index[choice.index].append(choice.delta.content) + + # Verify both choices received content + assert len(chunks_by_index[0]) > 0, "Choice 0 received no content chunks" + assert len(chunks_by_index[1]) > 0, "Choice 1 received no content chunks" + + # Reconstruct full responses + response_0 = "".join(chunks_by_index[0]) + response_1 = "".join(chunks_by_index[1]) + + assert len(response_0) > 0, "Choice 0 has empty response" + assert len(response_1) > 0, "Choice 1 has empty response" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_chat_completion_n_with_seed(client: openai.AsyncOpenAI, model_name: str): + """Test that n parameter works correctly with seed parameter.""" + messages = [ + {"role": "user", "content": "Say hello."}, + ] + + # Test that seed parameter is accepted and works with n > 1 + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.8, + n=2, + seed=42, + stream=False, + ) + + # Verify we get n=2 choices + assert len(chat_completion.choices) == 2 + + # Verify both choices have valid content + for i, choice in enumerate(chat_completion.choices): + assert choice.index == i + assert choice.message.content is not None + assert len(choice.message.content) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_chat_completion_n_equals_1(client: openai.AsyncOpenAI, model_name: str): + """Test that n=1 (default) still works correctly.""" + messages = [ + {"role": "user", "content": "Hello!"}, + ] + + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=10, + temperature=0.7, + n=1, + stream=False, + ) + + assert len(chat_completion.choices) == 1 + assert chat_completion.choices[0].index == 0 + assert chat_completion.choices[0].message.content is not None + + +# Unit tests for n parameter in ChatCompletionRequest.to_sampling_params() +def test_chat_completion_request_n_parameter_to_sampling_params(): + """Test that n parameter is correctly passed to SamplingParams.""" + # Test with n=3 + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + n=3, + max_tokens=10, + ) + + sampling_params = request.to_sampling_params( + max_tokens=10, + default_sampling_params={}, + ) + + assert isinstance(sampling_params, SamplingParams) + assert sampling_params.n == 3, f"Expected n=3, got n={sampling_params.n}" + + +def test_chat_completion_request_n_parameter_default(): + """Test that n parameter defaults to 1.""" + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + # n not specified, should default to 1 + max_tokens=10, + ) + + assert request.n == 1, "n should default to 1" + sampling_params = request.to_sampling_params( + max_tokens=10, + default_sampling_params={}, + ) + + # SamplingParams.from_optional converts None to 1 + assert sampling_params.n == 1, f"Expected n=1 (default), got n={sampling_params.n}" + + +def test_chat_completion_request_n_parameter_various_values(): + """Test n parameter with various values.""" + for n_value in [1, 2, 5, 10]: + request = ChatCompletionRequest( + model="test-model", + messages=[{"role": "user", "content": "Test"}], + n=n_value, + max_tokens=10, + ) + + sampling_params = request.to_sampling_params( + max_tokens=10, + default_sampling_params={}, + ) + + assert sampling_params.n == n_value, ( + f"Expected n={n_value}, got n={sampling_params.n}" + )