Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions tests/entrypoints/openai/responses/test_sampling_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

"""Unit tests for ResponsesRequest.to_sampling_params() parameter mapping."""

import pytest

from vllm.entrypoints.openai.responses.protocol import ResponsesRequest


class TestResponsesRequestSamplingParams:
"""Test that ResponsesRequest correctly maps parameters to SamplingParams."""

def test_basic_sampling_params(self):
"""Test basic sampling parameters are correctly mapped."""
request = ResponsesRequest(
model="test-model",
input="test input",
temperature=0.8,
top_p=0.95,
top_k=50,
max_output_tokens=100,
)

sampling_params = request.to_sampling_params(default_max_tokens=1000)

assert sampling_params.temperature == 0.8
assert sampling_params.top_p == 0.95
assert sampling_params.top_k == 50
assert sampling_params.max_tokens == 100

def test_extra_sampling_params(self):
"""Test extra sampling parameters are correctly mapped."""
request = ResponsesRequest(
model="test-model",
input="test input",
repetition_penalty=1.2,
seed=42,
stop=["END", "STOP"],
ignore_eos=True,
vllm_xargs={"custom": "value"},
)

sampling_params = request.to_sampling_params(default_max_tokens=1000)

assert sampling_params.repetition_penalty == 1.2
assert sampling_params.seed == 42
assert sampling_params.stop == ["END", "STOP"]
assert sampling_params.ignore_eos is True
assert sampling_params.extra_args == {"custom": "value"}

def test_stop_string_conversion(self):
"""Test that single stop string is converted to list."""
request = ResponsesRequest(
model="test-model",
input="test input",
stop="STOP",
)

sampling_params = request.to_sampling_params(default_max_tokens=1000)

assert sampling_params.stop == ["STOP"]

def test_default_values(self):
"""Test default values for optional parameters."""
request = ResponsesRequest(
model="test-model",
input="test input",
)

sampling_params = request.to_sampling_params(default_max_tokens=1000)

assert sampling_params.repetition_penalty == 1.0 # None → 1.0
assert sampling_params.stop == [] # Empty list
assert sampling_params.extra_args == {} # Empty dict

def test_seed_bounds_validation(self):
"""Test that seed values outside torch.long bounds are rejected."""
import torch
from pydantic import ValidationError

# Test seed below minimum
with pytest.raises(ValidationError) as exc_info:
ResponsesRequest(
model="test-model",
input="test input",
seed=torch.iinfo(torch.long).min - 1,
)
assert "greater_than_equal" in str(exc_info.value).lower()

# Test seed above maximum
with pytest.raises(ValidationError) as exc_info:
ResponsesRequest(
model="test-model",
input="test input",
seed=torch.iinfo(torch.long).max + 1,
)
assert "less_than_equal" in str(exc_info.value).lower()

# Test valid seed at boundaries
request_min = ResponsesRequest(
model="test-model",
input="test input",
seed=torch.iinfo(torch.long).min,
)
assert request_min.seed == torch.iinfo(torch.long).min

request_max = ResponsesRequest(
model="test-model",
input="test input",
seed=torch.iinfo(torch.long).max,
)
assert request_max.seed == torch.iinfo(torch.long).max
25 changes: 24 additions & 1 deletion tests/entrypoints/openai/responses/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import pytest
import pytest_asyncio
from openai import OpenAI
Expand Down Expand Up @@ -147,3 +146,27 @@ async def test_max_tokens(client: OpenAI, model_name: str):
assert response is not None
assert response.status == "incomplete"
assert response.incomplete_details.reason == "max_output_tokens"


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_extra_sampling_params(client: OpenAI, model_name: str):
"""Test that extra sampling parameters are accepted and work."""
# Test with multiple sampling parameters - just verify they're accepted
response = await client.responses.create(
model=model_name,
input="Write a short sentence",
max_output_tokens=50,
temperature=0.7,
top_p=0.9,
extra_body={
"top_k": 40,
"repetition_penalty": 1.2,
"seed": 42,
},
)

# Verify request succeeded and parameters were accepted
assert response.status in ["completed", "incomplete"]
assert len(response.output) > 0
assert response.output[0].content[0].text # Has text output
29 changes: 28 additions & 1 deletion vllm/entrypoints/openai/responses/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
from typing import Any, Literal, TypeAlias

import torch
from openai.types.responses import (
ResponseCodeInterpreterCallCodeDeltaEvent,
ResponseCodeInterpreterCallCodeDoneEvent,
Expand Down Expand Up @@ -77,6 +78,8 @@

logger = init_logger(__name__)

_LONG_INFO = torch.iinfo(torch.long)


class InputTokensDetails(OpenAIBaseModel):
cached_tokens: int
Expand Down Expand Up @@ -230,6 +233,18 @@ class ResponsesRequest(OpenAIBaseModel):
# this cannot be used in conjunction with previous_response_id
# TODO: consider supporting non harmony messages as well
previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None

repetition_penalty: float | None = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compared to the previous version, these additions make it much more reasonable.

Let's invite others to review this.

/cc @qandrew @yeqcharlotte @DarkLight1337

seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: str | list[str] | None = []
ignore_eos: bool = False
vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field(
default=None,
description=(
"Additional request parameters with (list of) string or "
"numeric values, used by custom extensions."
),
)
# --8<-- [end:responses-extra-params]

def build_chat_params(
Expand Down Expand Up @@ -297,6 +312,10 @@ def to_sampling_params(
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)

if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get("repetition_penalty", 1.0)

stop_token_ids = default_sampling_params.get("stop_token_ids")

# Structured output
Expand All @@ -313,19 +332,27 @@ def to_sampling_params(
elif response_format.type == "json_object":
raise NotImplementedError("json_object is not supported")

# TODO: add more parameters
stop = self.stop if self.stop else []
if isinstance(stop, str):
stop = [stop]

return SamplingParams.from_optional(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_tokens,
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids,
stop=stop,
repetition_penalty=repetition_penalty,
seed=self.seed,
ignore_eos=self.ignore_eos,
output_kind=(
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
),
structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
extra_args=self.vllm_xargs or {},
skip_clone=True, # Created fresh per request, safe to skip clone
skip_special_tokens=self.skip_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
Expand Down