diff --git a/tests/entrypoints/openai/responses/test_function_call.py b/tests/entrypoints/openai/responses/test_function_call.py index 8ca43feaca4f..9dcbd74c890b 100644 --- a/tests/entrypoints/openai/responses/test_function_call.py +++ b/tests/entrypoints/openai/responses/test_function_call.py @@ -325,8 +325,12 @@ async def test_function_calling_with_streaming_expected_arguments( "tool_choice", ["auto", "required", {"type": "function", "name": "get_current_weather"}], ) +@pytest.mark.parametrize( + "enable_thinking", + [True, False], +) async def test_function_calling_with_streaming_types( - client: openai.AsyncOpenAI, model_name: str, tool_choice + client: openai.AsyncOpenAI, model_name: str, tool_choice, enable_thinking: bool ): # this links the "done" type with the "start" type # so every "done" type should have a corresponding "start" type @@ -436,6 +440,7 @@ async def test_function_calling_with_streaming_types( input=input_list, tools=tools, tool_choice=tool_choice, + extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}}, stream=True, ) diff --git a/vllm/entrypoints/openai/generate/api_router.py b/vllm/entrypoints/openai/generate/api_router.py index 4386baa14e10..84a7fddeabe3 100644 --- a/vllm/entrypoints/openai/generate/api_router.py +++ b/vllm/entrypoints/openai/generate/api_router.py @@ -103,6 +103,7 @@ async def init_generate_state( enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, + default_chat_template_kwargs=args.default_chat_template_kwargs, ) if "generate" in supported_tasks else None diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index b5d69ea1cccc..10aa5bde392b 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -276,6 +276,13 @@ class ResponsesRequest(OpenAIBaseModel): default=None, description="KVTransfer parameters used for disaggregated serving.", ) + chat_template_kwargs: dict[str, Any] | None = Field( + default=None, + description=( + "Additional keyword args to pass to the chat template renderer. " + "Will be accessible by the template." + ), + ) # --8<-- [end:responses-extra-params] def build_chat_params( @@ -296,7 +303,7 @@ def build_chat_params( chat_template=default_template, chat_template_content_format=default_template_content_format, chat_template_kwargs=merge_kwargs( # To remove unset values - {}, + self.chat_template_kwargs, dict( add_generation_prompt=not continue_final, continue_final_message=continue_final, diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index fb28a2256ad0..92b19f175ced 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -167,6 +167,7 @@ def __init__( enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + default_chat_template_kwargs: dict[str, Any] | None = None, ) -> None: super().__init__( engine_client=engine_client, @@ -178,6 +179,7 @@ def __init__( self.openai_serving_render = openai_serving_render self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.chat_template_kwargs = default_chat_template_kwargs or {} self.enable_log_outputs = enable_log_outputs # Set up the unified parser - either a unified parser or fall back to @@ -254,10 +256,14 @@ def __init__( def _effective_chat_template_kwargs( self, request: ResponsesRequest ) -> dict[str, Any]: - return request.build_chat_params( - self.chat_template, - self.chat_template_content_format, - ).chat_template_kwargs + return ( + request.build_chat_params( + self.chat_template, + self.chat_template_content_format, + ) + .with_defaults(self.chat_template_kwargs) + .chat_template_kwargs + ) def _validate_generator_input( self, @@ -601,13 +607,13 @@ async def _make_request( prev_msg=self.msg_store.get(prev_response.id) if prev_response else None, prev_response_output=prev_response.output if prev_response else None, ) - + chat_template_kwargs = self._effective_chat_template_kwargs(request) _, engine_inputs = await self.openai_serving_render.preprocess_chat( request, messages, default_template=self.chat_template, default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, + default_template_kwargs=chat_template_kwargs, tool_dicts=tool_dicts, tool_parser=self.parser.tool_parser_cls if self.parser else None, reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None, @@ -626,13 +632,13 @@ async def _render_next_turn( new_messages = construct_input_messages( request_input=messages, ) - + chat_template_kwargs = self._effective_chat_template_kwargs(request) _, engine_inputs = await self.openai_serving_render.preprocess_chat( request, new_messages, default_template=chat_template, default_template_content_format=chat_template_content_format, - default_template_kwargs=None, + default_template_kwargs=chat_template_kwargs, tool_dicts=tool_dicts, tool_parser=tool_parser, reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None,