Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions tests/test_chat_template_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for chat template kwargs forwarding."""

from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient

import vllm_mlx.server as srv
from vllm_mlx.engine.base import GenerationOutput


def test_chat_completion_request_preserves_chat_template_kwargs():
request = srv.ChatCompletionRequest(
model="test-model",
messages=[srv.Message(role="user", content="Hello")],
chat_template_kwargs={"enable_thinking": False},
)

assert request.chat_template_kwargs == {"enable_thinking": False}


def test_batched_engine_applies_chat_template_kwargs():
with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False):
from vllm_mlx.engine.batched import BatchedEngine

engine = BatchedEngine("test-model")
engine._tokenizer = MagicMock()
engine._tokenizer.apply_chat_template.return_value = "prompt"

prompt = engine._apply_chat_template(
[{"role": "user", "content": "Hello"}],
chat_template_kwargs={"enable_thinking": False},
)

assert prompt == "prompt"
engine._tokenizer.apply_chat_template.assert_called_once()
assert (
engine._tokenizer.apply_chat_template.call_args.kwargs["enable_thinking"]
is False
)


def test_chat_completion_endpoint_forwards_chat_template_kwargs():
captured = {}

class FakeEngine:
model_name = "test-model"
is_mllm = False
preserve_native_tool_format = False

async def chat(self, messages, **kwargs):
captured["messages"] = messages
captured["kwargs"] = kwargs
return GenerationOutput(
text="ORBIT",
prompt_tokens=4,
completion_tokens=1,
finish_reason="stop",
)

client = TestClient(srv.app)
original_engine = srv._engine
original_model_name = srv._model_name
srv._engine = FakeEngine()
srv._model_name = "test-model"
try:
response = client.post(
"/v1/chat/completions",
json={
"model": "test-model",
"messages": [{"role": "user", "content": "Reply with ORBIT."}],
"max_tokens": 8,
"chat_template_kwargs": {"enable_thinking": False},
},
)
finally:
srv._engine = original_engine
srv._model_name = original_model_name

assert response.status_code == 200
assert captured["kwargs"]["chat_template_kwargs"] == {"enable_thinking": False}
assert response.json()["choices"][0]["message"]["content"] == "ORBIT"
3 changes: 3 additions & 0 deletions vllm_mlx/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import time
import uuid
from typing import Any

from pydantic import BaseModel, Field, computed_field

Expand Down Expand Up @@ -169,6 +170,8 @@ class ChatCompletionRequest(BaseModel):
tool_choice: str | dict | None = None # "auto", "none", or specific tool
# Structured output
response_format: ResponseFormat | dict | None = None
# Extra kwargs forwarded to tokenizer.apply_chat_template
chat_template_kwargs: dict[str, Any] | None = None
# MLLM-specific parameters
video_fps: float | None = None
video_max_frames: int | None = None
Expand Down
39 changes: 30 additions & 9 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def _apply_chat_template(
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
num_images: int = 0,
chat_template_kwargs: dict[str, Any] | None = None,
) -> str:
"""Apply chat template to messages.

Expand Down Expand Up @@ -367,19 +368,20 @@ def _apply_chat_template(
"tokenize": False,
"add_generation_prompt": True,
}
if tools:
if chat_template_kwargs:
template_kwargs.update(chat_template_kwargs)
if tools and "tools" not in template_kwargs:
template_kwargs["tools"] = tools

try:
return template_applicator.apply_chat_template(
messages, **template_kwargs
)
except TypeError as e:
# Some templates don't accept 'tools'; retry without them.
# Some templates don't accept extra kwargs; retry without them.
logger.debug(f"Chat template TypeError, retrying without extras: {e}")
for key in ["tools"]:
if key in template_kwargs:
del template_kwargs[key]
for key in ["tools", *(chat_template_kwargs or {}).keys()]:
template_kwargs.pop(key, None)
return template_applicator.apply_chat_template(
messages, **template_kwargs
)
Expand Down Expand Up @@ -620,12 +622,14 @@ async def chat(

# Convert tools for template
template_tools = convert_tools_for_template(tools) if tools else None
chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {})

# Apply chat template
prompt = self._apply_chat_template(
messages,
template_tools,
num_images=len(all_images),
chat_template_kwargs=chat_template_kwargs,
)

return await self.generate(
Expand All @@ -639,7 +643,10 @@ async def chat(
)

def _compute_prefix_boundary(
self, messages: list[dict[str, Any]], tools: list[dict] | None = None
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
) -> int:
"""Compute token count for the shared prefix across message variations.

Expand All @@ -661,15 +668,23 @@ def _compute_prefix_boundary(
template_tools = convert_tools_for_template(tools) if tools else None

# Tokenize the real prompt
real_prompt = self._apply_chat_template(messages, template_tools)
real_prompt = self._apply_chat_template(
messages,
template_tools,
chat_template_kwargs=chat_template_kwargs,
)

# Build a dummy variant with different last user content
dummy_messages = list(messages)
dummy_messages[last_user_idx] = {
**messages[last_user_idx],
"content": "XXXXXXXXXX",
}
dummy_prompt = self._apply_chat_template(dummy_messages, template_tools)
dummy_prompt = self._apply_chat_template(
dummy_messages,
template_tools,
chat_template_kwargs=chat_template_kwargs,
)

tokenizer = self.tokenizer
if hasattr(tokenizer, "tokenizer"):
Expand Down Expand Up @@ -731,16 +746,22 @@ async def stream_chat(

# Convert tools for template
template_tools = convert_tools_for_template(tools) if tools else None
chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {})

# Apply chat template
prompt = self._apply_chat_template(
messages,
template_tools,
num_images=len(all_images),
chat_template_kwargs=chat_template_kwargs,
)

# Compute prefix boundary for cache
prefix_boundary = self._compute_prefix_boundary(messages, tools)
prefix_boundary = self._compute_prefix_boundary(
messages,
tools,
chat_template_kwargs=chat_template_kwargs,
)
if prefix_boundary > 0:
kwargs["prefix_boundary"] = prefix_boundary

Expand Down
2 changes: 2 additions & 0 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
chat_kwargs["specprefill"] = request.specprefill
if request.specprefill_keep_pct is not None:
chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct
if request.chat_template_kwargs:
chat_kwargs["chat_template_kwargs"] = dict(request.chat_template_kwargs)

# Add tools if provided
if request.tools:
Expand Down