diff --git a/tests/entrypoints/openai/responses/test_sampling_params.py b/tests/entrypoints/openai/responses/test_sampling_params.py new file mode 100644 index 000000000000..b8d1aa664047 --- /dev/null +++ b/tests/entrypoints/openai/responses/test_sampling_params.py @@ -0,0 +1,113 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for ResponsesRequest.to_sampling_params() parameter mapping.""" + +import pytest + +from vllm.entrypoints.openai.responses.protocol import ResponsesRequest + + +class TestResponsesRequestSamplingParams: + """Test that ResponsesRequest correctly maps parameters to SamplingParams.""" + + def test_basic_sampling_params(self): + """Test basic sampling parameters are correctly mapped.""" + request = ResponsesRequest( + model="test-model", + input="test input", + temperature=0.8, + top_p=0.95, + top_k=50, + max_output_tokens=100, + ) + + sampling_params = request.to_sampling_params(default_max_tokens=1000) + + assert sampling_params.temperature == 0.8 + assert sampling_params.top_p == 0.95 + assert sampling_params.top_k == 50 + assert sampling_params.max_tokens == 100 + + def test_extra_sampling_params(self): + """Test extra sampling parameters are correctly mapped.""" + request = ResponsesRequest( + model="test-model", + input="test input", + repetition_penalty=1.2, + seed=42, + stop=["END", "STOP"], + ignore_eos=True, + vllm_xargs={"custom": "value"}, + ) + + sampling_params = request.to_sampling_params(default_max_tokens=1000) + + assert sampling_params.repetition_penalty == 1.2 + assert sampling_params.seed == 42 + assert sampling_params.stop == ["END", "STOP"] + assert sampling_params.ignore_eos is True + assert sampling_params.extra_args == {"custom": "value"} + + def test_stop_string_conversion(self): + """Test that single stop string is converted to list.""" + request = ResponsesRequest( + model="test-model", + input="test input", + stop="STOP", + ) + + sampling_params = request.to_sampling_params(default_max_tokens=1000) + + assert sampling_params.stop == ["STOP"] + + def test_default_values(self): + """Test default values for optional parameters.""" + request = ResponsesRequest( + model="test-model", + input="test input", + ) + + sampling_params = request.to_sampling_params(default_max_tokens=1000) + + assert sampling_params.repetition_penalty == 1.0 # None → 1.0 + assert sampling_params.stop == [] # Empty list + assert sampling_params.extra_args == {} # Empty dict + + def test_seed_bounds_validation(self): + """Test that seed values outside torch.long bounds are rejected.""" + import torch + from pydantic import ValidationError + + # Test seed below minimum + with pytest.raises(ValidationError) as exc_info: + ResponsesRequest( + model="test-model", + input="test input", + seed=torch.iinfo(torch.long).min - 1, + ) + assert "greater_than_equal" in str(exc_info.value).lower() + + # Test seed above maximum + with pytest.raises(ValidationError) as exc_info: + ResponsesRequest( + model="test-model", + input="test input", + seed=torch.iinfo(torch.long).max + 1, + ) + assert "less_than_equal" in str(exc_info.value).lower() + + # Test valid seed at boundaries + request_min = ResponsesRequest( + model="test-model", + input="test input", + seed=torch.iinfo(torch.long).min, + ) + assert request_min.seed == torch.iinfo(torch.long).min + + request_max = ResponsesRequest( + model="test-model", + input="test input", + seed=torch.iinfo(torch.long).max, + ) + assert request_max.seed == torch.iinfo(torch.long).max diff --git a/tests/entrypoints/openai/responses/test_simple.py b/tests/entrypoints/openai/responses/test_simple.py index 30423788bf79..8f07b02a308c 100644 --- a/tests/entrypoints/openai/responses/test_simple.py +++ b/tests/entrypoints/openai/responses/test_simple.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import pytest import pytest_asyncio from openai import OpenAI @@ -147,3 +146,27 @@ async def test_max_tokens(client: OpenAI, model_name: str): assert response is not None assert response.status == "incomplete" assert response.incomplete_details.reason == "max_output_tokens" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_extra_sampling_params(client: OpenAI, model_name: str): + """Test that extra sampling parameters are accepted and work.""" + # Test with multiple sampling parameters - just verify they're accepted + response = await client.responses.create( + model=model_name, + input="Write a short sentence", + max_output_tokens=50, + temperature=0.7, + top_p=0.9, + extra_body={ + "top_k": 40, + "repetition_penalty": 1.2, + "seed": 42, + }, + ) + + # Verify request succeeded and parameters were accepted + assert response.status in ["completed", "incomplete"] + assert len(response.output) > 0 + assert response.output[0].content[0].text # Has text output diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index 81abebee291d..9a471852ba24 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -6,6 +6,7 @@ import time from typing import Any, Literal, TypeAlias +import torch from openai.types.responses import ( ResponseCodeInterpreterCallCodeDeltaEvent, ResponseCodeInterpreterCallCodeDoneEvent, @@ -77,6 +78,8 @@ logger = init_logger(__name__) +_LONG_INFO = torch.iinfo(torch.long) + class InputTokensDetails(OpenAIBaseModel): cached_tokens: int @@ -230,6 +233,18 @@ class ResponsesRequest(OpenAIBaseModel): # this cannot be used in conjunction with previous_response_id # TODO: consider supporting non harmony messages as well previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None + + repetition_penalty: float | None = None + seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: str | list[str] | None = [] + ignore_eos: bool = False + vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( + default=None, + description=( + "Additional request parameters with (list of) string or " + "numeric values, used by custom extensions." + ), + ) # --8<-- [end:responses-extra-params] def build_chat_params( @@ -297,6 +312,10 @@ def to_sampling_params( top_k = default_sampling_params.get( "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] ) + + if (repetition_penalty := self.repetition_penalty) is None: + repetition_penalty = default_sampling_params.get("repetition_penalty", 1.0) + stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output @@ -313,7 +332,10 @@ def to_sampling_params( elif response_format.type == "json_object": raise NotImplementedError("json_object is not supported") - # TODO: add more parameters + stop = self.stop if self.stop else [] + if isinstance(stop, str): + stop = [stop] + return SamplingParams.from_optional( temperature=temperature, top_p=top_p, @@ -321,11 +343,16 @@ def to_sampling_params( max_tokens=max_tokens, logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, stop_token_ids=stop_token_ids, + stop=stop, + repetition_penalty=repetition_penalty, + seed=self.seed, + ignore_eos=self.ignore_eos, output_kind=( RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY ), structured_outputs=structured_outputs, logit_bias=self.logit_bias, + extra_args=self.vllm_xargs or {}, skip_clone=True, # Created fresh per request, safe to skip clone skip_special_tokens=self.skip_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output,