diff --git a/tests/test_chat_template_kwargs.py b/tests/test_chat_template_kwargs.py new file mode 100644 index 00000000..534dc6e0 --- /dev/null +++ b/tests/test_chat_template_kwargs.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for chat template kwargs forwarding.""" + +from unittest.mock import MagicMock, patch + +from fastapi.testclient import TestClient + +import vllm_mlx.server as srv +from vllm_mlx.engine.base import GenerationOutput + + +def test_chat_completion_request_preserves_chat_template_kwargs(): + request = srv.ChatCompletionRequest( + model="test-model", + messages=[srv.Message(role="user", content="Hello")], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert request.chat_template_kwargs == {"enable_thinking": False} + + +def test_batched_engine_applies_chat_template_kwargs(): + with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False): + from vllm_mlx.engine.batched import BatchedEngine + + engine = BatchedEngine("test-model") + engine._tokenizer = MagicMock() + engine._tokenizer.apply_chat_template.return_value = "prompt" + + prompt = engine._apply_chat_template( + [{"role": "user", "content": "Hello"}], + chat_template_kwargs={"enable_thinking": False}, + ) + + assert prompt == "prompt" + engine._tokenizer.apply_chat_template.assert_called_once() + assert ( + engine._tokenizer.apply_chat_template.call_args.kwargs["enable_thinking"] + is False + ) + + +def test_chat_completion_endpoint_forwards_chat_template_kwargs(): + captured = {} + + class FakeEngine: + model_name = "test-model" + is_mllm = False + preserve_native_tool_format = False + + async def chat(self, messages, **kwargs): + captured["messages"] = messages + captured["kwargs"] = kwargs + return GenerationOutput( + text="ORBIT", + prompt_tokens=4, + completion_tokens=1, + finish_reason="stop", + ) + + client = TestClient(srv.app) + original_engine = srv._engine + original_model_name = srv._model_name + srv._engine = FakeEngine() + srv._model_name = "test-model" + try: + response = client.post( + "/v1/chat/completions", + json={ + "model": "test-model", + "messages": [{"role": "user", "content": "Reply with ORBIT."}], + "max_tokens": 8, + "chat_template_kwargs": {"enable_thinking": False}, + }, + ) + finally: + srv._engine = original_engine + srv._model_name = original_model_name + + assert response.status_code == 200 + assert captured["kwargs"]["chat_template_kwargs"] == {"enable_thinking": False} + assert response.json()["choices"][0]["message"]["content"] == "ORBIT" diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index 32b26e03..f7bcaaaa 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -11,6 +11,7 @@ import time import uuid +from typing import Any from pydantic import BaseModel, Field, computed_field @@ -169,6 +170,8 @@ class ChatCompletionRequest(BaseModel): tool_choice: str | dict | None = None # "auto", "none", or specific tool # Structured output response_format: ResponseFormat | dict | None = None + # Extra kwargs forwarded to tokenizer.apply_chat_template + chat_template_kwargs: dict[str, Any] | None = None # MLLM-specific parameters video_fps: float | None = None video_max_frames: int | None = None diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628..c30421e1 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -335,6 +335,7 @@ def _apply_chat_template( messages: list[dict[str, Any]], tools: list[dict] | None = None, num_images: int = 0, + chat_template_kwargs: dict[str, Any] | None = None, ) -> str: """Apply chat template to messages. @@ -367,7 +368,9 @@ def _apply_chat_template( "tokenize": False, "add_generation_prompt": True, } - if tools: + if chat_template_kwargs: + template_kwargs.update(chat_template_kwargs) + if tools and "tools" not in template_kwargs: template_kwargs["tools"] = tools try: @@ -375,11 +378,10 @@ def _apply_chat_template( messages, **template_kwargs ) except TypeError as e: - # Some templates don't accept 'tools'; retry without them. + # Some templates don't accept extra kwargs; retry without them. logger.debug(f"Chat template TypeError, retrying without extras: {e}") - for key in ["tools"]: - if key in template_kwargs: - del template_kwargs[key] + for key in ["tools", *(chat_template_kwargs or {}).keys()]: + template_kwargs.pop(key, None) return template_applicator.apply_chat_template( messages, **template_kwargs ) @@ -620,12 +622,14 @@ async def chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) # Apply chat template prompt = self._apply_chat_template( messages, template_tools, num_images=len(all_images), + chat_template_kwargs=chat_template_kwargs, ) return await self.generate( @@ -639,7 +643,10 @@ async def chat( ) def _compute_prefix_boundary( - self, messages: list[dict[str, Any]], tools: list[dict] | None = None + self, + messages: list[dict[str, Any]], + tools: list[dict] | None = None, + chat_template_kwargs: dict[str, Any] | None = None, ) -> int: """Compute token count for the shared prefix across message variations. @@ -661,7 +668,11 @@ def _compute_prefix_boundary( template_tools = convert_tools_for_template(tools) if tools else None # Tokenize the real prompt - real_prompt = self._apply_chat_template(messages, template_tools) + real_prompt = self._apply_chat_template( + messages, + template_tools, + chat_template_kwargs=chat_template_kwargs, + ) # Build a dummy variant with different last user content dummy_messages = list(messages) @@ -669,7 +680,11 @@ def _compute_prefix_boundary( **messages[last_user_idx], "content": "XXXXXXXXXX", } - dummy_prompt = self._apply_chat_template(dummy_messages, template_tools) + dummy_prompt = self._apply_chat_template( + dummy_messages, + template_tools, + chat_template_kwargs=chat_template_kwargs, + ) tokenizer = self.tokenizer if hasattr(tokenizer, "tokenizer"): @@ -731,16 +746,22 @@ async def stream_chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) # Apply chat template prompt = self._apply_chat_template( messages, template_tools, num_images=len(all_images), + chat_template_kwargs=chat_template_kwargs, ) # Compute prefix boundary for cache - prefix_boundary = self._compute_prefix_boundary(messages, tools) + prefix_boundary = self._compute_prefix_boundary( + messages, + tools, + chat_template_kwargs=chat_template_kwargs, + ) if prefix_boundary > 0: kwargs["prefix_boundary"] = prefix_boundary diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index a0038d5f..bec5b2b8 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1420,6 +1420,8 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re chat_kwargs["specprefill"] = request.specprefill if request.specprefill_keep_pct is not None: chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + if request.chat_template_kwargs: + chat_kwargs["chat_template_kwargs"] = dict(request.chat_template_kwargs) # Add tools if provided if request.tools: