Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# imports for structured outputs tests
import json
from collections import defaultdict

import jsonschema
import openai # use the official client for correctness check
Expand All @@ -13,6 +14,11 @@
import torch
from openai import BadRequestError

from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.sampling_params import SamplingParams

from ...utils import RemoteOpenAIServer

# any model with a chat template should work here
Expand Down Expand Up @@ -815,3 +821,203 @@ async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenA

assert chat_output.keys() == invocation_output.keys()
assert chat_output["choices"] == invocation_output["choices"]


# Test n parameter for chat completions
# Tests that the n parameter works correctly for regular sampling
# (non-beam search) in chat completions, addressing issue #34305.


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_chat_completion_n_parameter_non_streaming(
client: openai.AsyncOpenAI, model_name: str
):
"""Test that n parameter returns multiple choices for non-streaming requests."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the opposite of big?"},
]

# Test with n=3
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=20,
temperature=0.7,
n=3,
stream=False,
)

assert len(chat_completion.choices) == 3

# Verify each choice has content and correct index
for i, choice in enumerate(chat_completion.choices):
assert choice.index == i
assert choice.message.content is not None
assert len(choice.message.content) > 0

# Verify all responses are different (highly likely with temperature > 0)
contents = [choice.message.content for choice in chat_completion.choices]
assert len(set(contents)) > 1, "Expected different responses with n=3"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_chat_completion_n_parameter_streaming(
client: openai.AsyncOpenAI, model_name: str
):
"""Test that n parameter returns multiple choices for streaming requests."""
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the capital of France?"},
]

stream = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=15,
temperature=0.7,
n=2,
stream=True,
)

# Collect all chunks using defaultdict for dynamic handling
chunks_by_index = defaultdict(list)
async for chunk in stream:
for choice in chunk.choices:
if choice.delta.content:
chunks_by_index[choice.index].append(choice.delta.content)

# Verify both choices received content
assert len(chunks_by_index[0]) > 0, "Choice 0 received no content chunks"
assert len(chunks_by_index[1]) > 0, "Choice 1 received no content chunks"

# Reconstruct full responses
response_0 = "".join(chunks_by_index[0])
response_1 = "".join(chunks_by_index[1])

assert len(response_0) > 0, "Choice 0 has empty response"
assert len(response_1) > 0, "Choice 1 has empty response"


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_chat_completion_n_with_seed(client: openai.AsyncOpenAI, model_name: str):
"""Test that n parameter works correctly with seed parameter."""
messages = [
{"role": "user", "content": "Say hello."},
]

# Test that seed parameter is accepted and works with n > 1
chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.8,
n=2,
seed=42,
stream=False,
)

# Verify we get n=2 choices
assert len(chat_completion.choices) == 2

# Verify both choices have valid content
for i, choice in enumerate(chat_completion.choices):
assert choice.index == i
assert choice.message.content is not None
assert len(choice.message.content) > 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_chat_completion_n_equals_1(client: openai.AsyncOpenAI, model_name: str):
"""Test that n=1 (default) still works correctly."""
messages = [
{"role": "user", "content": "Hello!"},
]

chat_completion = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=10,
temperature=0.7,
n=1,
stream=False,
)

assert len(chat_completion.choices) == 1
assert chat_completion.choices[0].index == 0
assert chat_completion.choices[0].message.content is not None


# Unit tests for n parameter in ChatCompletionRequest.to_sampling_params()
def test_chat_completion_request_n_parameter_to_sampling_params():
"""Test that n parameter is correctly passed to SamplingParams."""
# Test with n=3
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
n=3,
max_tokens=10,
)

sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)

assert isinstance(sampling_params, SamplingParams)
assert sampling_params.n == 3, f"Expected n=3, got n={sampling_params.n}"


def test_chat_completion_request_n_parameter_default():
"""Test that n parameter defaults to 1."""
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
# n not specified, should default to 1
max_tokens=10,
)

assert request.n == 1, "n should default to 1"
sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)

# SamplingParams.from_optional converts None to 1
assert sampling_params.n == 1, f"Expected n=1 (default), got n={sampling_params.n}"


def test_chat_completion_request_n_parameter_various_values():
"""Test n parameter with various values."""
for n_value in [1, 2, 5, 10]:
request = ChatCompletionRequest(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
n=n_value,
max_tokens=10,
)

sampling_params = request.to_sampling_params(
max_tokens=10,
default_sampling_params={},
)

assert sampling_params.n == n_value, (
f"Expected n={n_value}, got n={sampling_params.n}"
)