Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/entrypoints/openai/responses/test_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,12 @@ async def test_function_calling_with_streaming_expected_arguments(
"tool_choice",
["auto", "required", {"type": "function", "name": "get_current_weather"}],
)
@pytest.mark.parametrize(
"enable_thinking",
[True, False],
)
async def test_function_calling_with_streaming_types(
client: openai.AsyncOpenAI, model_name: str, tool_choice
client: openai.AsyncOpenAI, model_name: str, tool_choice, enable_thinking: bool
):
# this links the "done" type with the "start" type
# so every "done" type should have a corresponding "start" type
Expand Down Expand Up @@ -436,6 +440,7 @@ async def test_function_calling_with_streaming_types(
input=input_list,
tools=tools,
tool_choice=tool_choice,
extra_body={"chat_template_kwargs": {"enable_thinking": enable_thinking}},
stream=True,
)

Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/generate/api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ async def init_generate_state(
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
enable_log_outputs=args.enable_log_outputs,
default_chat_template_kwargs=args.default_chat_template_kwargs,
)
if "generate" in supported_tasks
else None
Expand Down
9 changes: 8 additions & 1 deletion vllm/entrypoints/openai/responses/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,13 @@ class ResponsesRequest(OpenAIBaseModel):
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
chat_template_kwargs: dict[str, Any] | None = Field(
default=None,
description=(
"Additional keyword args to pass to the chat template renderer. "
"Will be accessible by the template."
),
)
# --8<-- [end:responses-extra-params]

def build_chat_params(
Expand All @@ -296,7 +303,7 @@ def build_chat_params(
chat_template=default_template,
chat_template_content_format=default_template_content_format,
chat_template_kwargs=merge_kwargs( # To remove unset values
{},
self.chat_template_kwargs,
dict(
add_generation_prompt=not continue_final,
continue_final_message=continue_final,
Expand Down
22 changes: 14 additions & 8 deletions vllm/entrypoints/openai/responses/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
enable_log_outputs: bool = False,
default_chat_template_kwargs: dict[str, Any] | None = None,
) -> None:
super().__init__(
engine_client=engine_client,
Expand All @@ -178,6 +179,7 @@ def __init__(
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.chat_template_kwargs = default_chat_template_kwargs or {}
self.enable_log_outputs = enable_log_outputs

# Set up the unified parser - either a unified parser or fall back to
Expand Down Expand Up @@ -254,10 +256,14 @@ def __init__(
def _effective_chat_template_kwargs(
self, request: ResponsesRequest
) -> dict[str, Any]:
return request.build_chat_params(
self.chat_template,
self.chat_template_content_format,
).chat_template_kwargs
return (
request.build_chat_params(
self.chat_template,
self.chat_template_content_format,
)
.with_defaults(self.chat_template_kwargs)
.chat_template_kwargs
)

def _validate_generator_input(
self,
Expand Down Expand Up @@ -601,13 +607,13 @@ async def _make_request(
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None,
)

chat_template_kwargs = self._effective_chat_template_kwargs(request)
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
request,
messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
default_template_kwargs=chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=self.parser.tool_parser_cls if self.parser else None,
reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None,
Expand All @@ -626,13 +632,13 @@ async def _render_next_turn(
new_messages = construct_input_messages(
request_input=messages,
)

chat_template_kwargs = self._effective_chat_template_kwargs(request)
_, engine_inputs = await self.openai_serving_render.preprocess_chat(
request,
new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
default_template_kwargs=chat_template_kwargs,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
reasoning_parser=self.parser.reasoning_parser_cls if self.parser else None,
Expand Down
Loading