Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tests/entrypoints/test_chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
parse_chat_messages,
parse_chat_messages_futures,
resolve_chat_template_content_format,
resolve_chat_template_kwargs,
resolve_hf_chat_template)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
Expand All @@ -37,6 +38,7 @@
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
Expand Down Expand Up @@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
assert isinstance(chat_template, str)


@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
QWEN2VL_MODEL_ID,
{
"add_vision_id", "add_generation_prompt",
"continue_final_message", "tools"
},
),
(
QWEN3_MODEL_ID,
{
"enable_thinking", "add_generation_prompt",
"continue_final_message", "tools"
},
),
],
)
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")

tools = ([{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}])

chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}

model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)

# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)

# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs


# NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json`
# yapf: disable
Expand Down
54 changes: 52 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast)

import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block
# yapf: disable
Expand Down Expand Up @@ -50,7 +55,7 @@
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils import random_uuid, supports_kw

logger = init_logger(__name__)

Expand Down Expand Up @@ -1548,6 +1553,46 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()


# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)


def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: str,
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
fn_kw = {
k for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}

env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)

# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template"}
accept_vars = (fn_kw | template_vars) - unexpected_vars
return {
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
}


def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
Expand All @@ -1573,12 +1618,17 @@ def apply_hf_chat_template(
)

try:
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
Comment on lines +1621 to +1625
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it called by each request? do we need to cache it?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's called by each request. This function won't compile the chat template, so I think it won't introduces too much overhead.

Given that kwargs can be various depending on the request's extra body, it's not really applicable to cache it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add cache in #26227

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's called by each request. This function won't compile the chat template, so I think it won't introduces too much overhead.

Given that kwargs can be various depending on the request's extra body, it's not really applicable to cache it.

using npu(ascend 910c) it costs about 10-20ms, it is great to cache it

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this change might be linked to a ~10x latency jump we saw in our benchmarks on 0.11.0 (where CPU hit 100%).
The caching in 0.11.1 appears to solve it, but upgrading is a bit tricky since it requires CUDA ≥ 12.9.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can still use older CUDA versions if you install vLLM from source.

return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize,
**kwargs,
**resolved_kwargs,
)

# External library exceptions can sometimes occur despite the framework's
Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,7 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.
Expand Down
10 changes: 7 additions & 3 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,13 @@ class FrontendArgs:
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to OpenAI
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: Optional[str] = None
Expand Down
14 changes: 13 additions & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "",
enable_auto_tools: bool = False,
Expand All @@ -89,6 +90,7 @@ def __init__(
self.response_role = response_role
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs

# set up tool use
Expand Down Expand Up @@ -220,6 +222,16 @@ async def create_chat_completion(

if not self.use_harmony:
# Common case.
request_chat_template = request.chat_template
chat_template_kwargs = request.chat_template_kwargs
if not self.trust_request_chat_template and (
request_chat_template is not None or
(chat_template_kwargs and
chat_template_kwargs.get("chat_template") is not None)):
return self.create_error_response(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
(
conversation,
request_prompts,
Expand All @@ -228,7 +240,7 @@ async def create_chat_completion(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template=request_chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
Expand Down