diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index cd66863a1df8..f9c9737919d2 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -240,9 +240,21 @@ response = client.chat.completions.create( ) ``` +The same `chat_template_kwargs` override is also supported on the `/v1/responses` +endpoint: + +```python +response = client.responses.create( + model=model, + input="Compute 23 * 17 and explain briefly.", + reasoning={"effort": "low"}, + extra_body={"chat_template_kwargs": {"enable_thinking": True}}, +) +``` + ## Limitations -- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). +- The reasoning content is only available for online serving's chat completion and responses endpoints (`/v1/chat/completions` and `/v1/responses`). ## How to support a new reasoning model diff --git a/tests/entrypoints/openai/responses/test_chat_template_kwargs.py b/tests/entrypoints/openai/responses/test_chat_template_kwargs.py new file mode 100644 index 000000000000..9496c4ba6b9e --- /dev/null +++ b/tests/entrypoints/openai/responses/test_chat_template_kwargs.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import pytest_asyncio +from openai import OpenAI + +from tests.utils import RemoteOpenAIServer + +from .conftest import BASE_TEST_ENV + +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--reasoning-parser", + "qwen3", + "--dtype", + "bfloat16", + "--enforce-eager", + "--max-model-len", + "4096", + "--default-chat-template-kwargs", + '{"enable_thinking": false}', + ] + env_dict = { + **BASE_TEST_ENV, + "VLLM_ENABLE_RESPONSES_API_STORE": "1", + } + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_responses_honors_default_chat_template_kwargs( + client: OpenAI, model_name: str +): + response = await client.responses.create( + model=model_name, + input="Compute 17 * 19 and explain briefly.", + reasoning={"effort": "low"}, + temperature=0.0, + ) + + reasoning_items = [item for item in response.output if item.type == "reasoning"] + + assert response.status == "completed" + assert response.output_text + assert not reasoning_items + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_responses_request_chat_template_kwargs_override_server_default( + client: OpenAI, model_name: str +): + response = await client.responses.create( + model=model_name, + input="Compute 23 * 17 and explain briefly.", + reasoning={"effort": "low"}, + temperature=0.0, + extra_body={"chat_template_kwargs": {"enable_thinking": True}}, + ) + + reasoning_items = [item for item in response.output if item.type == "reasoning"] + + assert response.status == "completed" + assert response.usage is not None + assert response.usage.output_tokens_details.reasoning_tokens > 0 + assert reasoning_items + assert reasoning_items[0].content diff --git a/tests/entrypoints/openai/responses/test_protocol.py b/tests/entrypoints/openai/responses/test_protocol.py index db5d7d692490..cf19760f263f 100644 --- a/tests/entrypoints/openai/responses/test_protocol.py +++ b/tests/entrypoints/openai/responses/test_protocol.py @@ -5,6 +5,7 @@ ) from vllm.entrypoints.openai.responses.protocol import ( + ResponsesRequest, serialize_message, serialize_messages, ) @@ -37,3 +38,32 @@ def test_serialize_messages() -> None: } msg = Message.from_dict(msg_value) assert serialize_messages([msg, dict_value]) == [msg_value, dict_value] + + +def test_responses_request_accepts_chat_template_kwargs() -> None: + request = ResponsesRequest( + input="Hello", + chat_template_kwargs={"enable_thinking": False}, + ) + + assert request.chat_template_kwargs == {"enable_thinking": False} + + +def test_build_chat_params_merges_responses_chat_template_kwargs() -> None: + request = ResponsesRequest( + input="Hello", + chat_template_kwargs={"enable_thinking": False}, + reasoning={"effort": "low"}, + ) + + chat_params = request.build_chat_params( + default_template=None, + default_template_content_format="auto", + ) + + assert chat_params.chat_template_kwargs == { + "enable_thinking": False, + "add_generation_prompt": True, + "continue_final_message": False, + "reasoning_effort": "low", + } diff --git a/tests/entrypoints/openai/responses/test_serving_responses.py b/tests/entrypoints/openai/responses/test_serving_responses.py index b5d2b24a63a5..acf53266d79e 100644 --- a/tests/entrypoints/openai/responses/test_serving_responses.py +++ b/tests/entrypoints/openai/responses/test_serving_responses.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import AsyncExitStack -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest import pytest_asyncio @@ -619,6 +619,138 @@ def _make_serving_instance_with_reasoning(): return serving +@pytest.mark.asyncio +async def test_make_request_passes_default_chat_template_kwargs(): + engine_client = MagicMock() + model_config = MagicMock() + model_config.max_model_len = 100 + model_config.hf_config.model_type = "test" + model_config.get_diff_sampling_param.return_value = {} + engine_client.model_config = model_config + engine_client.input_processor = MagicMock() + engine_client.io_processor = MagicMock() + engine_client.renderer = MagicMock() + + openai_serving_render = MagicMock() + openai_serving_render.preprocess_chat = AsyncMock(return_value=([], [object()])) + + serving = OpenAIServingResponses( + engine_client=engine_client, + models=MagicMock(), + openai_serving_render=openai_serving_render, + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + default_chat_template_kwargs={"enable_thinking": False}, + ) + + request = ResponsesRequest( + input="hi", + tools=[], + chat_template_kwargs={"enable_thinking": True}, + ) + + await serving._make_request(request, None) + + assert openai_serving_render.preprocess_chat.await_count == 1 + assert openai_serving_render.preprocess_chat.await_args.kwargs[ + "default_template_kwargs" + ] == {"enable_thinking": False} + + +@pytest.mark.asyncio +async def test_reasoning_parser_receives_merged_chat_template_kwargs(): + serving = _make_serving_instance_with_reasoning() + serving.default_chat_template_kwargs = {"enable_thinking": False} + + mock_parser = MagicMock() + mock_parser.count_reasoning_tokens.return_value = 0 + serving.parser = MagicMock() + serving.parser.reasoning_parser_cls = MagicMock(return_value=mock_parser) + serving.parser.tool_parser_cls = None + + tokenizer = MagicMock() + context = SimpleContext() + completion = CompletionOutput( + index=0, + text="final", + token_ids=[20], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop", + stop_reason=None, + ) + req_output = RequestOutput( + request_id="req", + prompt="hi", + prompt_token_ids=[7, 8], + prompt_logprobs=None, + outputs=[completion], + finished=True, + num_cached_tokens=0, + ) + context.append_output(req_output) + + async def dummy_result_generator(): + yield None + + request = ResponsesRequest( + input="hi", + tools=[], + stream=False, + chat_template_kwargs={"enable_thinking": True}, + ) + sampling_params = SamplingParams(max_tokens=16) + metadata = RequestResponseMetadata(request_id="req") + + await serving.responses_full_generator( + request=request, + sampling_params=sampling_params, + result_generator=dummy_result_generator(), + context=context, + model_name="test-model", + tokenizer=tokenizer, + request_metadata=metadata, + ) + + serving.parser.reasoning_parser_cls.assert_called_once_with( + tokenizer, + chat_template_kwargs={"enable_thinking": True}, + ) + + +def test_make_response_output_items_passes_merged_chat_template_kwargs(): + serving = _make_serving_instance_with_reasoning() + serving.default_chat_template_kwargs = {"enable_thinking": False} + + mock_parser = MagicMock() + mock_parser.extract_response_outputs.return_value = [] + serving.parser = MagicMock(return_value=mock_parser) + + request = ResponsesRequest( + input="hi", + tools=[], + chat_template_kwargs={"enable_thinking": True}, + ) + final_output = CompletionOutput( + index=0, + text="final", + token_ids=[20], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop", + stop_reason=None, + ) + tokenizer = MagicMock() + + serving._make_response_output_items(request, final_output, tokenizer) + + serving.parser.assert_called_once_with( + tokenizer, + chat_template_kwargs={"enable_thinking": True}, + ) + + def _identity_increment(event): """Simple identity callable for _increment_sequence_number_and_return.""" seq = getattr(_identity_increment, "_counter", 0) diff --git a/vllm/entrypoints/openai/generate/api_router.py b/vllm/entrypoints/openai/generate/api_router.py index c81c295e4597..7366b6c8c40c 100644 --- a/vllm/entrypoints/openai/generate/api_router.py +++ b/vllm/entrypoints/openai/generate/api_router.py @@ -92,6 +92,7 @@ async def init_generate_state( enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, + default_chat_template_kwargs=args.default_chat_template_kwargs, ) if "generate" in supported_tasks else None diff --git a/vllm/entrypoints/openai/parser/responses_parser.py b/vllm/entrypoints/openai/parser/responses_parser.py index b5518f0f108a..411e7448ef7c 100644 --- a/vllm/entrypoints/openai/parser/responses_parser.py +++ b/vllm/entrypoints/openai/parser/responses_parser.py @@ -36,9 +36,10 @@ def __init__( self, *, tokenizer: TokenizerLike, - reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser], + reasoning_parser_cls: Callable[..., ReasoningParser], response_messages: list[ResponseInputOutputItem], request: ResponsesRequest, + chat_template_kwargs: dict | None, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, ): self.response_messages: list[ResponseInputOutputItem] = ( @@ -49,7 +50,10 @@ def __init__( self.tokenizer = tokenizer self.request = request - self.reasoning_parser_instance = reasoning_parser_cls(tokenizer) + self.reasoning_parser_instance = reasoning_parser_cls( + tokenizer, + chat_template_kwargs=chat_template_kwargs, + ) self.tool_parser_instance = None if tool_parser_cls is not None: self.tool_parser_instance = tool_parser_cls(tokenizer) @@ -159,9 +163,10 @@ def make_response_output_items_from_parsable_context( def get_responses_parser_for_simple_context( *, tokenizer: TokenizerLike, - reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser], + reasoning_parser_cls: Callable[..., ReasoningParser], response_messages: list[ResponseInputOutputItem], request: ResponsesRequest, + chat_template_kwargs: dict | None, tool_parser_cls, ) -> ResponsesParser: """Factory function to create a ResponsesParser with @@ -175,5 +180,6 @@ def get_responses_parser_for_simple_context( reasoning_parser_cls=reasoning_parser_cls, response_messages=response_messages, request=request, + chat_template_kwargs=chat_template_kwargs, tool_parser_cls=tool_parser_cls, ) diff --git a/vllm/entrypoints/openai/responses/context.py b/vllm/entrypoints/openai/responses/context.py index a4c55c23c588..2dd2182d3b9c 100644 --- a/vllm/entrypoints/openai/responses/context.py +++ b/vllm/entrypoints/openai/responses/context.py @@ -273,8 +273,9 @@ def __init__( *, response_messages: list[ResponseInputOutputItem], tokenizer: TokenizerLike, - reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None, + reasoning_parser_cls: Callable[..., ReasoningParser] | None, request: ResponsesRequest, + chat_template_kwargs: dict[str, Any] | None, available_tools: list[str] | None, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, chat_template: str | None, @@ -295,6 +296,7 @@ def __init__( reasoning_parser_cls=reasoning_parser_cls, response_messages=response_messages, request=request, + chat_template_kwargs=chat_template_kwargs, tool_parser_cls=tool_parser_cls, ) self.tool_parser_cls = tool_parser_cls diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index 43fbba1dd43f..d3bed5bbdee8 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -181,6 +181,13 @@ class ResponsesRequest(OpenAIBaseModel): "and vLLM will ignore it." ), ) + chat_template_kwargs: dict[str, Any] | None = Field( + default=None, + description=( + "Additional keyword args to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ) # --8<-- [start:responses-extra-params] request_id: str = Field( @@ -276,7 +283,7 @@ def build_chat_params( chat_template=default_template, chat_template_content_format=default_template_content_format, chat_template_kwargs=merge_kwargs( # To remove unset values - {}, + self.chat_template_kwargs, dict( add_generation_prompt=not continue_final, continue_final_message=continue_final, diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 53c28693ade7..08b722850b76 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -183,6 +183,7 @@ def __init__( enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + default_chat_template_kwargs: dict[str, Any] | None = None, ) -> None: super().__init__( engine_client=engine_client, @@ -194,6 +195,7 @@ def __init__( self.openai_serving_render = openai_serving_render self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.default_chat_template_kwargs = default_chat_template_kwargs or {} self.enable_log_outputs = enable_log_outputs # Set up the unified parser - either a unified parser or fall back to @@ -267,6 +269,24 @@ def __init__( self.tool_server = tool_server + def _get_chat_template_kwargs(self, request: ResponsesRequest) -> dict[str, Any]: + return self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) + + def _make_reasoning_parser( + self, + tokenizer: TokenizerLike, + request: ResponsesRequest, + ): + if self.parser is None or self.parser.reasoning_parser_cls is None: + return None + return self.parser.reasoning_parser_cls( + tokenizer, + chat_template_kwargs=self._get_chat_template_kwargs(request), + ) + def _validate_generator_input( self, engine_prompt: ProcessorInputs, @@ -453,6 +473,7 @@ async def create_responses( if self.parser else None, request=request, + chat_template_kwargs=self._get_chat_template_kwargs(request), tool_parser_cls=self.parser.tool_parser_cls if self.parser else None, @@ -464,7 +485,8 @@ async def create_responses( context = SimpleContext() if self.parser and self.parser.reasoning_parser_cls is not None: - reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) + reasoning_parser = self._make_reasoning_parser(tokenizer, request) + assert reasoning_parser is not None if ( isinstance( struct_out := sampling_params.structured_outputs, @@ -591,7 +613,7 @@ async def _make_request( messages, default_template=self.chat_template, default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, + default_template_kwargs=self.default_chat_template_kwargs, tool_dicts=tool_dicts, tool_parser=self.parser.tool_parser_cls if self.parser else None, ) @@ -615,7 +637,7 @@ async def _render_next_turn( new_messages, default_template=chat_template, default_template_content_format=chat_template_content_format, - default_template_kwargs=None, + default_template_kwargs=self.default_chat_template_kwargs, tool_dicts=tool_dicts, tool_parser=tool_parser, ) @@ -835,7 +857,8 @@ async def responses_full_generator( and self.parser.reasoning_parser_cls is not None and isinstance(context, (SimpleContext, ParsableContext)) ): - reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) + reasoning_parser = self._make_reasoning_parser(tokenizer, request) + assert reasoning_parser is not None accumulated = getattr(context, "_accumulated_token_ids", []) or [] num_reasoning_tokens = reasoning_parser.count_reasoning_tokens(accumulated) @@ -1003,7 +1026,10 @@ def _make_response_output_items( # Use parser to extract and create response output items if self.parser: - parser = self.parser(tokenizer) + parser = self.parser( + tokenizer, + chat_template_kwargs=self._get_chat_template_kwargs(request), + ) return parser.extract_response_outputs( model_output=final_output.text, model_output_token_ids=final_output.token_ids, @@ -1343,7 +1369,7 @@ async def _process_simple_streaming_events( current_item_id = "" reasoning_parser = None if self.parser and self.parser.reasoning_parser_cls: - reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) + reasoning_parser = self._make_reasoning_parser(tokenizer, request) tool_parser = None if self.parser and self.parser.tool_parser_cls: tool_parser = self.parser.tool_parser_cls(tokenizer) diff --git a/vllm/parser/abstract_parser.py b/vllm/parser/abstract_parser.py index dd9dc94237dc..483f84d735d6 100644 --- a/vllm/parser/abstract_parser.py +++ b/vllm/parser/abstract_parser.py @@ -542,10 +542,12 @@ class _WrappedParser(DelegatingParser): reasoning_parser_cls: type[ReasoningParser] | None = None tool_parser_cls: type[ToolParser] | None = None - def __init__(self, tokenizer: TokenizerLike): + def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): super().__init__(tokenizer) # Instantiate the underlying parsers from class attributes if self.__class__.reasoning_parser_cls is not None: - self._reasoning_parser = self.__class__.reasoning_parser_cls(tokenizer) + self._reasoning_parser = self.__class__.reasoning_parser_cls( + tokenizer, *args, **kwargs + ) if self.__class__.tool_parser_cls is not None: self._tool_parser = self.__class__.tool_parser_cls(tokenizer)