From a510953e50e849e9a70d61ad5d2ab3139bf9e3aa Mon Sep 17 00:00:00 2001 From: Jan Hilgard Date: Thu, 12 Feb 2026 19:46:20 +0100 Subject: [PATCH 01/45] Add resumable model download with retry, timeout, and offline mode Large model downloads via huggingface_hub often hang or fail around 10GB. This adds a pre-download step with configurable retry/timeout before load_model() is called, so interrupted downloads can be resumed. New CLI flags for `serve`: --download-timeout, --download-retries, --offline New subcommand: `vllm-mlx download ` for pre-warming caches Closes #75 Co-Authored-By: Claude Opus 4.6 --- tests/test_download.py | 196 ++++++++++++++++++++++++++++++++++++ vllm_mlx/cli.py | 74 ++++++++++++++ vllm_mlx/utils/__init__.py | 3 +- vllm_mlx/utils/download.py | 144 ++++++++++++++++++++++++++ vllm_mlx/utils/tokenizer.py | 13 +-- 5 files changed, 420 insertions(+), 10 deletions(-) create mode 100644 tests/test_download.py create mode 100644 vllm_mlx/utils/download.py diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 000000000..9eba711bb --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for resumable model download with retry/timeout support.""" + +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from vllm_mlx.utils.download import ( + LLM_ALLOW_PATTERNS, + MLLM_ALLOW_PATTERNS, + DownloadConfig, + ensure_model_downloaded, +) + + +class TestLocalPath: + """Tests for local path handling.""" + + def test_local_path_skips_download(self, tmp_path): + """Existing local directory is returned without downloading.""" + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + result = ensure_model_downloaded(str(tmp_path)) + mock_download.assert_not_called() + assert result == tmp_path + + +class TestRetryLogic: + """Tests for download retry behavior.""" + + def test_retry_on_failure(self): + """Failed downloads are retried up to max_retries times.""" + config = DownloadConfig(max_retries=3, retry_backoff_base=0.01) + fake_path = "/fake/cache/path" + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = [ + ConnectionError("timeout"), + ConnectionError("timeout"), + fake_path, + ] + result = ensure_model_downloaded("org/model", config=config) + assert result == Path(fake_path) + assert mock_download.call_count == 3 + + def test_retry_exhaustion(self): + """RuntimeError is raised after all retries are exhausted.""" + config = DownloadConfig(max_retries=2, retry_backoff_base=0.01) + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = ConnectionError("timeout") + with pytest.raises(RuntimeError, match="Failed to download"): + ensure_model_downloaded("org/model", config=config) + assert mock_download.call_count == 2 + + def test_keyboard_interrupt_not_retried(self): + """KeyboardInterrupt propagates immediately without retry.""" + config = DownloadConfig(max_retries=3, retry_backoff_base=0.01) + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = KeyboardInterrupt() + with pytest.raises(KeyboardInterrupt): + ensure_model_downloaded("org/model", config=config) + assert mock_download.call_count == 1 + + +class TestOfflineMode: + """Tests for offline mode behavior.""" + + def test_offline_mode_cached(self): + """Offline mode finds cached model successfully.""" + config = DownloadConfig(offline=True) + fake_path = "/fake/cache/path" + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.return_value = fake_path + result = ensure_model_downloaded("org/model", config=config) + assert result == Path(fake_path) + mock_download.assert_called_once_with("org/model", local_files_only=True) + + def test_offline_mode_missing(self): + """Offline mode raises clear error when model is not cached.""" + config = DownloadConfig(offline=True) + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = Exception("not found locally") + with pytest.raises(RuntimeError, match="not found in local cache"): + ensure_model_downloaded("org/model", config=config) + + +class TestTimeout: + """Tests for download timeout configuration.""" + + def test_hf_timeout_env_set(self): + """HF_HUB_DOWNLOAD_TIMEOUT env var is set during download.""" + config = DownloadConfig(download_timeout=600, max_retries=1) + fake_path = "/fake/cache/path" + captured_timeout = {} + + original_env = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") + + def capture_env(*args, **kwargs): + captured_timeout["value"] = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") + return fake_path + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = capture_env + ensure_model_downloaded("org/model", config=config) + + assert captured_timeout["value"] == "600" + # Env var should be restored after download + assert os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") == original_env + + def test_hf_timeout_env_restored_on_failure(self): + """HF_HUB_DOWNLOAD_TIMEOUT is restored even after failure.""" + config = DownloadConfig( + download_timeout=999, max_retries=1, retry_backoff_base=0.01 + ) + original_env = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = ConnectionError("fail") + with pytest.raises(RuntimeError): + ensure_model_downloaded("org/model", config=config) + + assert os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") == original_env + + +class TestAllowPatterns: + """Tests for LLM vs MLLM download patterns.""" + + def test_llm_patterns_used_by_default(self): + """LLM allow patterns are used when is_mllm=False.""" + config = DownloadConfig(max_retries=1) + fake_path = "/fake/cache/path" + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.return_value = fake_path + ensure_model_downloaded("org/model", config=config, is_mllm=False) + mock_download.assert_called_once_with( + "org/model", allow_patterns=LLM_ALLOW_PATTERNS + ) + + def test_mllm_patterns_used(self): + """MLLM allow patterns are used when is_mllm=True.""" + config = DownloadConfig(max_retries=1) + fake_path = "/fake/cache/path" + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.return_value = fake_path + ensure_model_downloaded("org/model", config=config, is_mllm=True) + mock_download.assert_called_once_with( + "org/model", allow_patterns=MLLM_ALLOW_PATTERNS + ) + + +class TestCLIDownloadCommand: + """Tests for CLI download subcommand argument parsing.""" + + def test_cli_download_command(self): + """Download subcommand parses arguments correctly.""" + import argparse + + # We test argparse by calling parse_args directly + # (main() would try to actually run the command) + with patch("sys.argv", ["vllm-mlx", "download", "org/model"]): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument("model") + download_parser.add_argument("--timeout", type=int, default=300) + download_parser.add_argument("--retries", type=int, default=3) + download_parser.add_argument("--mllm", action="store_true") + + args = parser.parse_args(["download", "org/model", "--timeout", "600"]) + assert args.command == "download" + assert args.model == "org/model" + assert args.timeout == 600 + assert args.retries == 3 + assert args.mllm is False + + def test_cli_download_mllm_flag(self): + """Download subcommand parses --mllm flag.""" + import argparse + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument("model") + download_parser.add_argument("--timeout", type=int, default=300) + download_parser.add_argument("--retries", type=int, default=3) + download_parser.add_argument("--mllm", action="store_true") + + args = parser.parse_args(["download", "org/vl-model", "--mllm"]) + assert args.mllm is True diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index dcbee8acd..4f50c6169 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -105,6 +105,21 @@ def serve_command(args): print(" Reasoning: Use --reasoning-parser to enable") print("=" * 60) + # Pre-download model with retry/timeout + from .api.utils import is_mllm_model + from .utils.download import DownloadConfig, ensure_model_downloaded + + download_config = DownloadConfig( + download_timeout=args.download_timeout, + max_retries=args.download_retries, + offline=getattr(args, "offline", False), + ) + ensure_model_downloaded( + args.model, + config=download_config, + is_mllm=is_mllm_model(args.model), + ) + print(f"Loading model: {args.model}") print(f"Default max tokens: {args.max_tokens}") @@ -194,6 +209,23 @@ def serve_command(args): uvicorn.run(app, host=args.host, port=args.port, log_level="info") +def download_command(args): + """Download a model to local cache without starting a server.""" + from .utils.download import DownloadConfig, ensure_model_downloaded + + config = DownloadConfig( + download_timeout=args.timeout, + max_retries=args.retries, + ) + print(f"Downloading model: {args.model}") + path = ensure_model_downloaded( + args.model, + config=config, + is_mllm=args.mllm, + ) + print(f"Model ready at: {path}") + + def bench_command(args): """Run benchmark.""" import asyncio @@ -827,6 +859,24 @@ def main(): default=None, help="Pre-load an embedding model at startup (e.g. mlx-community/embeddinggemma-300m-6bit)", ) + # Download options + serve_parser.add_argument( + "--download-timeout", + type=int, + default=300, + help="Per-file download timeout in seconds (default: 300)", + ) + serve_parser.add_argument( + "--download-retries", + type=int, + default=3, + help="Number of download retry attempts (default: 3)", + ) + serve_parser.add_argument( + "--offline", + action="store_true", + help="Offline mode — only use locally cached models", + ) # Bench command bench_parser = subparsers.add_parser("bench", help="Run benchmark") bench_parser.add_argument("model", type=str, help="Model to benchmark") @@ -962,6 +1012,28 @@ def main(): help="Quantization group size (default: 64)", ) + # Download command + download_parser = subparsers.add_parser( + "download", help="Download a model to local cache without starting a server" + ) + download_parser.add_argument("model", type=str, help="Model to download") + download_parser.add_argument( + "--timeout", + type=int, + default=300, + help="Per-file download timeout in seconds (default: 300)", + ) + download_parser.add_argument( + "--retries", + type=int, + default=3, + help="Number of retry attempts (default: 3)", + ) + download_parser.add_argument( + "--mllm", + action="store_true", + help="Download as multimodal model (broader file patterns)", + ) args = parser.parse_args() if args.command == "serve": @@ -972,6 +1044,8 @@ def main(): bench_detok_command(args) elif args.command == "bench-kv-cache": bench_kv_cache_command(args) + elif args.command == "download": + download_command(args) else: parser.print_help() sys.exit(1) diff --git a/vllm_mlx/utils/__init__.py b/vllm_mlx/utils/__init__.py index e808515ad..14d5de5c8 100644 --- a/vllm_mlx/utils/__init__.py +++ b/vllm_mlx/utils/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Utility modules for vllm-mlx.""" +from .download import DownloadConfig, ensure_model_downloaded from .tokenizer import load_model_with_fallback -__all__ = ["load_model_with_fallback"] +__all__ = ["DownloadConfig", "ensure_model_downloaded", "load_model_with_fallback"] diff --git a/vllm_mlx/utils/download.py b/vllm_mlx/utils/download.py new file mode 100644 index 000000000..39941c7af --- /dev/null +++ b/vllm_mlx/utils/download.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Resumable model download with retry/timeout support. + +Pre-downloads models via huggingface_hub.snapshot_download() with +configurable timeout and retry logic before passing to mlx-lm/mlx-vlm. +""" + +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path + +from huggingface_hub import snapshot_download + +logger = logging.getLogger(__name__) + +# Mirrors mlx_lm.utils._download() default allow_patterns +LLM_ALLOW_PATTERNS = [ + "*.json", + "model*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "tiktoken.model", + "*.txt", + "*.jsonl", + "*.jinja", +] + +# Mirrors mlx_vlm.utils.get_model_path() allow_patterns +MLLM_ALLOW_PATTERNS = [ + "*.json", + "*.safetensors", + "*.py", + "*.model", + "*.tiktoken", + "*.txt", + "*.jinja", +] + + +@dataclass +class DownloadConfig: + """Configuration for model download behavior.""" + + download_timeout: int = 300 + max_retries: int = 3 + retry_backoff_base: float = 2.0 + offline: bool = False + + +def ensure_model_downloaded( + model_name: str, + config: DownloadConfig | None = None, + is_mllm: bool = False, +) -> Path: + """ + Ensure a model is available locally, downloading with retry if needed. + + Args: + model_name: HuggingFace model name or local path. + config: Download configuration. Uses defaults if None. + is_mllm: If True, use MLLM download patterns (broader file set). + + Returns: + Path to the local model directory. + + Raises: + RuntimeError: If download fails after all retries. + KeyboardInterrupt: Propagated immediately without retry. + """ + if config is None: + config = DownloadConfig() + + model_path = Path(model_name) + if model_path.exists(): + logger.info(f"Model found at local path: {model_path}") + return model_path + + if config.offline: + logger.info(f"Offline mode: looking for cached {model_name}") + try: + result = Path(snapshot_download(model_name, local_files_only=True)) + logger.info(f"Found cached model at {result}") + return result + except Exception as e: + raise RuntimeError( + f"Model '{model_name}' not found in local cache. " + f"Download it first without --offline flag." + ) from e + + allow_patterns = MLLM_ALLOW_PATTERNS if is_mllm else LLM_ALLOW_PATTERNS + + # Set HF download timeout via environment variable + old_timeout = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") + os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = str(config.download_timeout) + + last_error = None + try: + for attempt in range(1, config.max_retries + 1): + try: + logger.info( + f"Downloading model {model_name} " + f"(attempt {attempt}/{config.max_retries}, " + f"timeout={config.download_timeout}s)" + ) + result = Path( + snapshot_download( + model_name, + allow_patterns=allow_patterns, + ) + ) + logger.info(f"Model downloaded successfully to {result}") + return result + except KeyboardInterrupt: + logger.warning("Download interrupted by user.") + raise + except Exception as e: + last_error = e + if attempt < config.max_retries: + wait = config.retry_backoff_base**attempt + logger.warning( + f"Download attempt {attempt} failed: {e}. " + f"Retrying in {wait:.0f}s..." + ) + time.sleep(wait) + else: + logger.error( + f"Download failed after {config.max_retries} attempts." + ) + + raise RuntimeError( + f"Failed to download '{model_name}' after {config.max_retries} " + f"attempts. Last error: {last_error}\n" + f"Run the same command again to resume the download." + ) + finally: + # Restore original env var + if old_timeout is None: + os.environ.pop("HF_HUB_DOWNLOAD_TIMEOUT", None) + else: + os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = old_timeout diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index 3c12e7ae2..e5c29f884 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -9,7 +9,6 @@ import json import logging -from pathlib import Path from .chat_templates import DEFAULT_CHATML_TEMPLATE, NEMOTRON_CHAT_TEMPLATE @@ -65,16 +64,12 @@ def _load_with_tokenizer_fallback(model_name: str): """Load model with fallback tokenizer for non-standard models like Nemotron.""" from mlx_lm.utils import load_model - logger.info("Loading with tokenizer fallback...") + from .download import ensure_model_downloaded - # Get model path - use local path if it exists, otherwise download from Hub - local_path = Path(model_name) - if local_path.is_dir(): - model_path = local_path - else: - from huggingface_hub import snapshot_download + logger.info("Loading with tokenizer fallback...") - model_path = Path(snapshot_download(model_name)) + # Get model path (with retry/timeout support) + model_path = ensure_model_downloaded(model_name, is_mllm=False) # Load model model, _ = load_model(model_path) From c6cf2186005ef56996b50f8bfec5f1f68c80ccaf Mon Sep 17 00:00:00 2001 From: Manus McAuliffe Date: Sat, 28 Mar 2026 03:06:51 +0000 Subject: [PATCH 02/45] fix: populate tokens field in BatchedEngine.generate() The output_token_ids from AsyncEngineCore were tracked internally but never forwarded to GenerationOutput, leaving tokens always []. Also adds tests for the generate() output fields. Co-Authored-By: Claude Sonnet 4.6 --- tests/test_batched_engine.py | 90 ++++++++++++++++++++++++++++++++++++ vllm_mlx/engine/batched.py | 1 + 2 files changed, 91 insertions(+) create mode 100644 tests/test_batched_engine.py diff --git a/tests/test_batched_engine.py b/tests/test_batched_engine.py new file mode 100644 index 000000000..65abbf09e --- /dev/null +++ b/tests/test_batched_engine.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for BatchedEngine generate() output.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestBatchedEngineGenerate: + """Test BatchedEngine.generate() output fields.""" + + def _make_engine(self): + """Create a BatchedEngine instance with loading bypassed.""" + from vllm_mlx.engine.batched import BatchedEngine + + with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False): + engine = BatchedEngine("test-model") + + engine._loaded = True + engine._is_mllm = False + return engine + + def _make_mock_request_output( + self, + output_text="Paris", + output_token_ids=None, + prompt_tokens=10, + completion_tokens=3, + finish_reason="stop", + ): + """Build a mock RequestOutput (as returned by AsyncEngineCore).""" + mock = MagicMock() + mock.output_text = output_text + mock.output_token_ids = output_token_ids if output_token_ids is not None else [3681, 374, 279] + mock.prompt_tokens = prompt_tokens + mock.completion_tokens = completion_tokens + mock.finish_reason = finish_reason + return mock + + @pytest.mark.asyncio + async def test_tokens_field_is_populated(self): + """tokens should contain the output token IDs from AsyncEngineCore.""" + engine = self._make_engine() + token_ids = [3681, 374, 279] + mock_output = self._make_mock_request_output(output_token_ids=token_ids) + + mock_engine = MagicMock() + mock_engine.generate = AsyncMock(return_value=mock_output) + engine._engine = mock_engine + + result = await engine.generate(prompt="What is the capital of France?", max_tokens=10) + + assert result.tokens == token_ids + + @pytest.mark.asyncio + async def test_tokens_field_empty_when_no_tokens_generated(self): + """tokens should be an empty list when output_token_ids is empty.""" + engine = self._make_engine() + mock_output = self._make_mock_request_output(output_token_ids=[]) + + mock_engine = MagicMock() + mock_engine.generate = AsyncMock(return_value=mock_output) + engine._engine = mock_engine + + result = await engine.generate(prompt="test", max_tokens=10) + + assert result.tokens == [] + + @pytest.mark.asyncio + async def test_other_output_fields_still_populated(self): + """Existing fields (text, prompt_tokens, etc.) must remain correct.""" + engine = self._make_engine() + mock_output = self._make_mock_request_output( + output_text="Paris", + output_token_ids=[3681], + prompt_tokens=7, + completion_tokens=1, + finish_reason="stop", + ) + + mock_engine = MagicMock() + mock_engine.generate = AsyncMock(return_value=mock_output) + engine._engine = mock_engine + + result = await engine.generate(prompt="Capital of France?", max_tokens=5) + + assert result.text == "Paris" + assert result.prompt_tokens == 7 + assert result.completion_tokens == 1 + assert result.finish_reason == "stop" diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628e..49fae5439 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -492,6 +492,7 @@ async def generate( return GenerationOutput( text=text, + tokens=output.output_token_ids, prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, finish_reason=output.finish_reason, From 59fb02057c79937bc1c508d6a43853f05104e934 Mon Sep 17 00:00:00 2001 From: Stuart Swerdloff Date: Sun, 29 Mar 2026 22:40:07 +1300 Subject: [PATCH 03/45] feat: add MiniMax tool call parsing support Parse MiniMax-M2.5's XML tool call format: value Handles single/multiple tool calls, JSON parameter values, no-parameter calls, and preserves blocks. 9 unit tests included. --- tests/test_minimax_tool_calling.py | 130 +++++++++++++++++++++++++++++ vllm_mlx/api/tool_calling.py | 42 ++++++++++ 2 files changed, 172 insertions(+) create mode 100644 tests/test_minimax_tool_calling.py diff --git a/tests/test_minimax_tool_calling.py b/tests/test_minimax_tool_calling.py new file mode 100644 index 000000000..ad329aa03 --- /dev/null +++ b/tests/test_minimax_tool_calling.py @@ -0,0 +1,130 @@ +"""Tests for MiniMax tool call parsing.""" + +import json +import unittest + +from vllm_mlx.api.tool_calling import parse_tool_calls + + +class TestMiniMaxToolCallParsing(unittest.TestCase): + """Test parsing of MiniMax-style tool calls.""" + + def test_single_tool_call(self): + text = ''' + +Wanaka +celsius + +''' + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0].function.name, "get_weather") + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["city"], "Wanaka") + self.assertEqual(args["units"], "celsius") + self.assertEqual(cleaned, "") + + def test_tool_call_with_surrounding_text(self): + text = '''Let me check the weather for you. + + +Wanaka + +''' + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 1) + self.assertIn("Let me check", cleaned) + + def test_multiple_tool_calls(self): + text = ''' + +MiniMax M2.5 + + + + +/tmp/test.txt + +''' + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0].function.name, "search") + self.assertEqual(tool_calls[1].function.name, "read_file") + + def test_json_parameter_value(self): + text = ''' + +Meeting +["stuart", "frida"] + +''' + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["title"], "Meeting") + self.assertEqual(args["attendees"], ["stuart", "frida"]) + + def test_numeric_parameter(self): + text = ''' + +42 + +''' + + cleaned, tool_calls = parse_tool_calls(text) + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["value"], 42) + + def test_no_parameters(self): + text = ''' + + +''' + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(tool_calls[0].function.name, "get_time") + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args, {}) + + def test_with_think_tags_preserved(self): + text = ''' +I should check the weather first. + + + +Wanaka + +''' + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertIn("", cleaned) + + def test_no_minimax_tool_calls(self): + text = "Just a regular message with no tool calls." + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNone(tool_calls) + self.assertEqual(cleaned, text) + + def test_tool_call_id_format(self): + text = ''' + +1 + +''' + + _, tool_calls = parse_tool_calls(text) + self.assertTrue(tool_calls[0].id.startswith("call_")) + self.assertEqual(tool_calls[0].type, "function") + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py index 1443c1674..364b65993 100644 --- a/vllm_mlx/api/tool_calling.py +++ b/vllm_mlx/api/tool_calling.py @@ -89,6 +89,7 @@ def parse_tool_calls( Parse tool calls from model output. Supports multiple formats: + - MiniMax: v - Qwen3 bracket: [Calling tool: function_name({"arg": "value"})] - Qwen: {"name": "...", "arguments": {...}} - Llama: {"arg": "value"} @@ -106,6 +107,47 @@ def parse_tool_calls( tool_calls = [] cleaned_text = text + # Pattern for MiniMax-style: v + minimax_pattern = r"\s*(.*?)\s*" + minimax_matches = re.findall(minimax_pattern, text, re.DOTALL) + + for invoke_block in minimax_matches: + # Parse blocks within the tool_call + invoke_pattern = r'(.*?)' + invoke_matches = re.findall(invoke_pattern, invoke_block, re.DOTALL) + + for name, params_block in invoke_matches: + # Parse value pairs + param_pattern = r'\s*(.*?)\s*' + params = re.findall(param_pattern, params_block, re.DOTALL) + arguments = {} + for p_name, p_value in params: + # Try to parse value as JSON (for nested objects/arrays/numbers) + try: + arguments[p_name] = json.loads(p_value) + except (json.JSONDecodeError, ValueError): + arguments[p_name] = p_value + + tool_calls.append( + ToolCall( + id=f"call_{uuid.uuid4().hex[:8]}", + type="function", + function=FunctionCall( + name=name.strip(), + arguments=json.dumps(arguments), + ), + ) + ) + + # Remove MiniMax tool call tags from cleaned text + if minimax_matches: + cleaned_text = re.sub( + r"\s*.*?\s*", + "", + cleaned_text, + flags=re.DOTALL, + ).strip() + # Pattern for Qwen3 bracket-style: [Calling tool: function_name({...})] bracket_pattern = r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]" bracket_matches = re.findall(bracket_pattern, text, re.DOTALL) From d34b758c591751bd18bcc5c0588a2abf22f5e523 Mon Sep 17 00:00:00 2001 From: Penumbra Forge Date: Sun, 29 Mar 2026 15:01:55 -0700 Subject: [PATCH 04/45] perf(reasoning): replace O(N) text scanning with O(1) state machine in streaming parser MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The streaming reasoning parser (BaseThinkingReasoningParser) scans the full accumulated output text for / on every token via `in` checks on previous_text and current_text. This is O(N) per token and O(N²) over a full generation, becoming measurable at longer outputs (5ms+ at 2k tokens, 141ms at 10k tokens). Replace with a three-phase state machine (pre_think → thinking → content) that tracks transitions using only the delta text. Each token is now O(1) regardless of output length. Benchmark results (streaming parser overhead, simulated server loop): Tokens Old (scan) New (state) Speedup ------ ---------- ----------- ------- 500 0.37ms 0.04ms 8.6x 1000 1.38ms 0.10ms 13.5x 2000 5.28ms 0.28ms 19.1x 5000 34.03ms 2.05ms 16.6x 10000 141.26ms 10.16ms 13.9x At 50 tok/s decode on Apple Silicon, each token has a 20ms budget. The old parser consumed 0.3ms/tok at 2k tokens and 1.4ms/tok at 10k — up to 7% of the budget on overhead alone. The new parser is <0.01ms/tok at any length. Changes: - think_parser.py: Rewrote extract_reasoning_streaming() as a state machine with _phase tracking. reset_state() initializes the phase. All three scenarios preserved (explicit tags, implicit mode, no tags). Method signature unchanged for backward compatibility. - benchmarks/bench_reasoning_parser.py: Added streaming parser benchmark. No changes to extract_reasoning() (non-streaming path) — it only runs once per request and is not on the hot path. --- benchmarks/bench_reasoning_parser.py | 55 ++++++++ vllm_mlx/reasoning/think_parser.py | 203 ++++++++++++--------------- 2 files changed, 146 insertions(+), 112 deletions(-) create mode 100644 benchmarks/bench_reasoning_parser.py diff --git a/benchmarks/bench_reasoning_parser.py b/benchmarks/bench_reasoning_parser.py new file mode 100644 index 000000000..c7a2ba13f --- /dev/null +++ b/benchmarks/bench_reasoning_parser.py @@ -0,0 +1,55 @@ +"""Benchmark: reasoning parser streaming performance. + +Measures per-token overhead of extract_reasoning_streaming() at various +output lengths. Demonstrates the difference between O(N²) accumulated +text scanning and O(1) state-machine tracking. + +Usage: + python benchmarks/bench_reasoning_parser.py +""" + +import time + +from vllm_mlx.reasoning.qwen3_parser import Qwen3ReasoningParser + + +def bench_streaming(parser, n_tokens: int, label: str) -> float: + """Simulate n_tokens of streaming through the parser. Returns total ms.""" + parser.reset_state() + + # Simulate: + N reasoning tokens + + 10 content tokens + tokens = [""] + tokens += [f"word{i} " for i in range(n_tokens)] + tokens += [""] + tokens += [f"answer{i} " for i in range(10)] + + accumulated = "" + start = time.perf_counter() + for tok in tokens: + prev = accumulated + accumulated += tok + parser.extract_reasoning_streaming(prev, accumulated, tok) + elapsed = (time.perf_counter() - start) * 1000 + + print(f" {label}: {n_tokens:>6} tokens -> {elapsed:>8.2f}ms " + f"({elapsed / (n_tokens + 11):.3f}ms/tok)") + return elapsed + + +def main(): + parser = Qwen3ReasoningParser() + + print("Reasoning parser streaming benchmark") + print("=" * 60) + print() + + for n in [50, 100, 200, 500, 1000, 2000, 5000]: + bench_streaming(parser, n, f"{n} tokens") + + print() + print("At 50 tok/s, per-token budget is 20ms.") + print("Parser overhead should be <0.1ms/tok to be negligible.") + + +if __name__ == "__main__": + main() diff --git a/vllm_mlx/reasoning/think_parser.py b/vllm_mlx/reasoning/think_parser.py index 136348206..dad085f15 100644 --- a/vllm_mlx/reasoning/think_parser.py +++ b/vllm_mlx/reasoning/think_parser.py @@ -9,6 +9,11 @@ 1. Both tags in output: reasoningcontent 2. Only closing tag (think injected in prompt): reasoningcontent 3. No tags: pure content + +Performance: The streaming parser uses a simple state machine to track the +current phase (pre-think / thinking / content). Each token is classified in +O(1) by checking only the delta text — the accumulated output is never +rescanned. This keeps per-token overhead constant regardless of output length. """ from abc import abstractmethod @@ -27,8 +32,12 @@ class BaseThinkingReasoningParser(ReasoningParser): and only appears in the model output. This is common with AI agents like OpenCode that force models to reason by injecting thinking tags. - The parser tracks state during streaming to correctly separate reasoning - from content as tokens arrive incrementally. + The streaming parser uses a state machine with three phases: + + pre_think -> thinking -> content + + Transitions happen when start/end tokens are detected in the delta text. + No accumulated text scanning is performed — each token is O(1). """ @property @@ -43,6 +52,12 @@ def end_token(self) -> str: def __init__(self, tokenizer=None): super().__init__(tokenizer) + # Streaming state — reset per request via reset_state() + self._phase: str = "pre_think" # "pre_think" | "thinking" | "content" + + def reset_state(self): + """Reset state machine for a new streaming request.""" + self._phase = "pre_think" def extract_reasoning( self, @@ -66,14 +81,11 @@ def extract_reasoning( # Case 1: Both tags present (normal case) if self.start_token in text and self.end_token in text: - # Get everything after start token _, _, after_start = text.partition(self.start_token) - # Split on end token reasoning, _, content = after_start.partition(self.end_token) return reasoning.strip() or None, content.strip() or None # Case 2: Only closing tag (think was injected in prompt) - # Everything before is reasoning if self.end_token in text: reasoning, _, content = text.partition(self.end_token) return reasoning.strip() or None, content.strip() or None @@ -83,7 +95,7 @@ def extract_reasoning( _, _, reasoning = text.partition(self.start_token) return reasoning.strip() or None, None - # Case 4: No tags at all - pure content + # Case 4: No tags at all — pure content return None, model_output def extract_reasoning_streaming( @@ -93,123 +105,90 @@ def extract_reasoning_streaming( delta_text: str, ) -> DeltaMessage | None: """ - Extract reasoning from streaming delta using text-based detection. + Extract reasoning from a streaming delta using state-machine tracking. + + Instead of rescanning the full accumulated text on every token, this + method tracks the current phase (pre_think / thinking / content) and + only inspects the delta for tag transitions. This makes each call O(1) + regardless of how much text has been generated. + + The method signature is kept compatible with the base class — previous_text + and current_text are accepted but not used for phase detection (they remain + available for subclasses that need them). - Handles implicit reasoning mode where was in the prompt - and only appears in the output. + Handles three scenarios: + 1. Explicit ... in model output + 2. Implicit mode ( in prompt, only in output) + 3. No tags at all (pure content after first token with no reasoning) Args: - previous_text: Text accumulated before this delta. - current_text: Text including this delta. - delta_text: Just the new text. + previous_text: Text accumulated before this delta (unused by state machine). + current_text: Text including this delta (unused by state machine). + delta_text: Just the new text in this chunk. Returns: - DeltaMessage with reasoning/content, or None to skip. + DeltaMessage with reasoning and/or content, or None to skip. """ - # Skip if delta is just the special tokens themselves - stripped_delta = delta_text.strip() - if stripped_delta == self.start_token: - return None - if stripped_delta == self.end_token: + if not delta_text: return None - # Check token positions in text (stateless text-based detection) - start_in_prev = self.start_token in previous_text - start_in_current = self.start_token in current_text - end_in_prev = self.end_token in previous_text - end_in_delta = self.end_token in delta_text - - # Case 1: Explicit found in text - standard behavior - if start_in_current: - return self._handle_explicit_think( - previous_text, delta_text, start_in_prev, end_in_prev, end_in_delta - ) - - # Case 2: No but found - implicit reasoning mode - # This handles when was injected in the prompt - if self.end_token in current_text: - return self._handle_implicit_think(delta_text, end_in_prev, end_in_delta) - - # Case 3: No think tags seen yet - # We can't know if was in the prompt, so we must make a choice: - # - Treat as content (safe, but loses reasoning if think was in prompt) - # - Treat as reasoning (risky, wrong if no thinking at all) - # We choose to treat as reasoning IF we haven't seen yet, - # because if think was in prompt, we want to capture the reasoning. - # This will be corrected once is seen. - return DeltaMessage(reasoning=delta_text) - - def _handle_explicit_think( - self, - previous_text: str, - delta_text: str, - start_in_prev: bool, - end_in_prev: bool, - end_in_delta: bool, - ) -> DeltaMessage | None: - """Handle case where tag is explicitly in the output.""" - start_in_delta = self.start_token in delta_text - - if start_in_prev: - # We're after the start token - if end_in_delta: - # Transition: end token in this delta - idx = delta_text.find(self.end_token) - reasoning_part = delta_text[:idx] - content_part = delta_text[idx + len(self.end_token) :] + start_tok = self.start_token + end_tok = self.end_token + + # ── Phase: pre_think ────────────────────────────────────── + # Haven't seen any tags yet. Could be: + # - About to see (explicit reasoning) + # - Already inside implicit reasoning (think was in prompt) + # - No reasoning at all (pure content model) + if self._phase == "pre_think": + # Check for start tag in this delta + if start_tok in delta_text: + self._phase = "thinking" + idx = delta_text.find(start_tok) + len(start_tok) + after = delta_text[idx:] + # Edge case: both tags in same delta + if end_tok in after: + self._phase = "content" + eidx = after.find(end_tok) + reasoning = after[:eidx] + content = after[eidx + len(end_tok):] + return DeltaMessage( + reasoning=reasoning or None, + content=content or None, + ) + return DeltaMessage(reasoning=after) if after else None + + # Check for end tag (implicit mode — think was in prompt) + if end_tok in delta_text: + self._phase = "content" + idx = delta_text.find(end_tok) + reasoning = delta_text[:idx] + content = delta_text[idx + len(end_tok):] return DeltaMessage( - reasoning=reasoning_part if reasoning_part else None, - content=content_part if content_part else None, + reasoning=reasoning or None, + content=content or None, ) - elif end_in_prev: - # Already past reasoning phase - pure content - return DeltaMessage(content=delta_text) - else: - # Still in reasoning phase - return DeltaMessage(reasoning=delta_text) - - elif start_in_delta: - # Start token is in this delta - start_idx = delta_text.find(self.start_token) - - if end_in_delta: - # Both tokens in this delta - end_idx = delta_text.find(self.end_token) - reasoning_part = delta_text[start_idx + len(self.start_token) : end_idx] - content_part = delta_text[end_idx + len(self.end_token) :] - return DeltaMessage( - reasoning=reasoning_part if reasoning_part else None, - content=content_part if content_part else None, - ) - else: - # Only start token - beginning of reasoning - reasoning_part = delta_text[start_idx + len(self.start_token) :] + + # No tags — default to reasoning (implicit mode assumption). + # If the model doesn't use thinking at all, the server's + # non-parser path handles it. This path only activates when + # a reasoning parser is explicitly configured. + return DeltaMessage(reasoning=delta_text) + + # ── Phase: thinking ─────────────────────────────────────── + # Inside a reasoning block, waiting for end tag. + if self._phase == "thinking": + if end_tok in delta_text: + self._phase = "content" + idx = delta_text.find(end_tok) + reasoning = delta_text[:idx] + content = delta_text[idx + len(end_tok):] return DeltaMessage( - reasoning=reasoning_part if reasoning_part else None + reasoning=reasoning or None, + content=content or None, ) + return DeltaMessage(reasoning=delta_text) - # Fallback - treat as content + # ── Phase: content ──────────────────────────────────────── + # Past the reasoning block — everything is content. return DeltaMessage(content=delta_text) - - def _handle_implicit_think( - self, - delta_text: str, - end_in_prev: bool, - end_in_delta: bool, - ) -> DeltaMessage | None: - """Handle case where was in prompt (only in output).""" - if end_in_delta: - # Transition: end token in this delta - idx = delta_text.find(self.end_token) - reasoning_part = delta_text[:idx] - content_part = delta_text[idx + len(self.end_token) :] - return DeltaMessage( - reasoning=reasoning_part if reasoning_part else None, - content=content_part if content_part else None, - ) - elif end_in_prev: - # Already past reasoning phase - pure content - return DeltaMessage(content=delta_text) - else: - # Still in implicit reasoning phase - return DeltaMessage(reasoning=delta_text) From e8cb2326806b9074362884c03478f5f4310b162a Mon Sep 17 00:00:00 2001 From: Thump604 Date: Tue, 31 Mar 2026 18:43:16 -0500 Subject: [PATCH 05/45] fix: normalize messages before chat template application Add _normalize_messages() to server.py and call it in all request paths before apply_chat_template. Maps non-standard roles (developer -> system, per OpenAI Responses API) and merges consecutive same-role messages. Fixes agent crashes from: - OpenAI Responses API sending role="developer" (unrecognized by Qwen3.5 template) - OpenCode sending [system, system, user, user] (rejected by alternating-role templates) Applied in create_chat_completion (both MLLM and LLM paths), create_anthropic_message, and _stream_anthropic_messages. --- tests/test_normalize_messages.py | 176 +++++++++++++++++++++++++++++++ vllm_mlx/server.py | 62 +++++++++++ 2 files changed, 238 insertions(+) create mode 100644 tests/test_normalize_messages.py diff --git a/tests/test_normalize_messages.py b/tests/test_normalize_messages.py new file mode 100644 index 000000000..ae9061e7d --- /dev/null +++ b/tests/test_normalize_messages.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for _normalize_messages() in vllm_mlx.server. + +_normalize_messages() maps non-standard roles (developer -> system) and merges +consecutive same-role messages before chat template application. This prevents +crashes from Qwen 3.5 and Llama templates that require alternating roles. +""" + +import pytest + + +class TestNormalizeMessages: + """Test _normalize_messages() for handling real-world client formats.""" + + def test_merge_consecutive_system_messages(self): + """Consecutive system messages are merged into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": "Always respond in JSON."}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert "helpful assistant" in result[0]["content"] + assert "JSON" in result[0]["content"] + assert result[1]["role"] == "user" + assert result[1]["content"] == "Hello" + + def test_merge_consecutive_user_messages(self): + """Consecutive user messages are merged into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "First part"}, + {"role": "user", "content": "Second part"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[1]["role"] == "user" + assert "First part" in result[1]["content"] + assert "Second part" in result[1]["content"] + + def test_opencode_format(self): + """OpenCode's system+system+user+user format is normalized.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "System prompt part 1"}, + {"role": "system", "content": "System prompt part 2"}, + {"role": "user", "content": "User instruction"}, + {"role": "user", "content": "User question"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + + def test_developer_role_mapped_to_system(self): + """OpenAI Responses API 'developer' role is mapped to 'system'.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "developer", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + + def test_developer_and_system_merged(self): + """developer + system consecutive messages are merged after role mapping.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "developer", "content": "Part 1"}, + {"role": "system", "content": "Part 2"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert "Part 1" in result[0]["content"] + assert "Part 2" in result[0]["content"] + + def test_already_alternating_unchanged(self): + """Well-formed alternating messages pass through unchanged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "Bye"}, + ] + result = _normalize_messages(messages) + assert result == messages + + def test_single_message_unchanged(self): + """Single message passes through unchanged.""" + from vllm_mlx.server import _normalize_messages + + messages = [{"role": "user", "content": "Hello"}] + result = _normalize_messages(messages) + assert result == messages + + def test_empty_messages(self): + """Empty message list passes through.""" + from vllm_mlx.server import _normalize_messages + + assert _normalize_messages([]) == [] + + def test_multimodal_content_preserved(self): + """Messages with list content (multimodal) are not merged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "user", "content": "Describe this:"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": {"url": "http://example.com/img.png"}, + }, + ], + }, + ] + result = _normalize_messages(messages) + # List content can't be trivially merged with string - kept separate + assert len(result) >= 1 + + def test_preserves_non_content_fields(self): + """Fields other than role/content are preserved on the first merged message.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "Part 1", "name": "sys1"}, + {"role": "system", "content": "Part 2"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + + def test_null_content_not_merged(self): + """Messages with None content (tool_calls pattern) are not merged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]}, + {"role": "assistant", "content": "Follow-up"}, + ] + result = _normalize_messages(messages) + # None content can't be merged with string - kept separate + assert len(result) == 2 + + def test_three_consecutive_system_messages(self): + """Three consecutive system messages merge into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "Part 1"}, + {"role": "system", "content": "Part 2"}, + {"role": "system", "content": "Part 3"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert "Part 1" in result[0]["content"] + assert "Part 3" in result[0]["content"] diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f0328d4e6..8de98efef 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1326,12 +1326,14 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re messages.append(msg_dict) images, videos = [], [] # MLLM extracts these from messages logger.debug(f"MLLM: Processing {len(messages)} messages") + messages = _normalize_messages(messages) else: # For LLM, extract text, images, and videos separately messages, images, videos = extract_multimodal_content( request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) has_media = bool(images or videos) @@ -1434,6 +1436,64 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) +def _normalize_messages(messages: list[dict]) -> list[dict]: + """Normalize message roles and merge consecutive same-role messages. + + 1. Maps non-standard roles to standard ones (e.g. ``developer`` -> ``system``). + 2. Merges consecutive same-role messages to satisfy chat template constraints + (Qwen 3.5, Llama, etc. require alternating roles). + + Only merges when both messages have string content. Messages with list + content (multimodal) are left as-is to preserve image/video attachments. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Returns: + New list with normalized roles and consecutive same-role messages merged. + """ + # OpenAI Responses API uses "developer" instead of "system". + # Map it so chat templates don't fail and fall back to raw prefill. + _ROLE_MAP = {"developer": "system"} + + if not messages: + return messages + + merged = [messages[0].copy()] + if merged[0]["role"] in _ROLE_MAP: + merged[0]["role"] = _ROLE_MAP[merged[0]["role"]] + for msg in messages[1:]: + prev = merged[-1] + role = _ROLE_MAP.get(msg["role"], msg["role"]) + if ( + role == prev["role"] + and isinstance(prev.get("content"), str) + and isinstance(msg.get("content"), str) + ): + # Merge string content with double newline separator + prev["content"] = prev["content"] + "\n\n" + msg["content"] + logger.debug( + f"Merged consecutive {role} messages " + f"({len(prev['content'])} chars total)" + ) + else: + copy = msg.copy() + copy["role"] = role + merged.append(copy) + + mapped_roles = sum(1 for m in messages if m["role"] in _ROLE_MAP) + merged_count = len(messages) - len(merged) + if mapped_roles or merged_count: + parts = [] + if mapped_roles: + parts.append(f"mapped {mapped_roles} role(s)") + if merged_count: + parts.append(f"merged {len(messages)} -> {len(merged)}") + logger.info(f"Normalized messages: {', '.join(parts)}") + + return merged + + def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages. @@ -1529,6 +1589,7 @@ async def create_anthropic_message( openai_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) chat_kwargs = { "max_tokens": openai_request.max_tokens or _default_max_tokens, @@ -1686,6 +1747,7 @@ async def _stream_anthropic_messages( openai_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) chat_kwargs = { "max_tokens": openai_request.max_tokens or _default_max_tokens, From af33ec9c35f2a36866a71893924d24de51464e69 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Tue, 31 Mar 2026 18:55:49 -0500 Subject: [PATCH 06/45] fix: remove unused pytest import --- tests/test_normalize_messages.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_normalize_messages.py b/tests/test_normalize_messages.py index ae9061e7d..3061692f5 100644 --- a/tests/test_normalize_messages.py +++ b/tests/test_normalize_messages.py @@ -7,7 +7,6 @@ crashes from Qwen 3.5 and Llama templates that require alternating roles. """ -import pytest class TestNormalizeMessages: From d19a8d3d84a33804f4dc91667e9527875a1498dd Mon Sep 17 00:00:00 2001 From: Thump604 Date: Tue, 31 Mar 2026 19:41:01 -0500 Subject: [PATCH 07/45] style: format test file with black --- tests/test_normalize_messages.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_normalize_messages.py b/tests/test_normalize_messages.py index 3061692f5..6d88437ad 100644 --- a/tests/test_normalize_messages.py +++ b/tests/test_normalize_messages.py @@ -8,7 +8,6 @@ """ - class TestNormalizeMessages: """Test _normalize_messages() for handling real-world client formats.""" From 705586b05df07a18656552c1293bcf6983d09ceb Mon Sep 17 00:00:00 2001 From: yiheng chen Date: Fri, 3 Apr 2026 13:45:44 -0400 Subject: [PATCH 08/45] feat: add Gemma 4 multimodal model support Add detection and inference support for Google's Gemma 4 models (e.g. mlx-community/gemma-4-e2b-it-mxfp4) which include vision and audio capabilities via mlx-vlm >= 0.4.3. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/reference/models.md | 5 +++-- pyproject.toml | 2 +- vllm_mlx/api/utils.py | 2 ++ vllm_mlx/engine/batched.py | 13 +++++++------ vllm_mlx/models/mllm.py | 2 ++ vllm_mlx/multimodal_processor.py | 2 +- 6 files changed, 16 insertions(+), 10 deletions(-) diff --git a/docs/reference/models.md b/docs/reference/models.md index a45550e4d..d378de003 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -12,7 +12,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun | Mistral / Devstral | 7B, Mixtral 8x7B | 4-bit, 8-bit | | Qwen2/Qwen3 | 0.5B to 72B | Various | | DeepSeek V3, R1 | 7B, 33B, 67B | 4-bit | -| Gemma 2, 3 | 2B, 9B, 27B | 4-bit | +| Gemma 2, 3, 4 | 2B, 9B, 27B | 4-bit | | GLM-4.7 | Flash, Base | 4-bit, 8-bit | | Kimi K2 | Various | 4-bit | | Phi-3 | 3.8B, 14B | 4-bit | @@ -35,6 +35,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun | **Qwen-VL** | `Qwen3-VL-4B-Instruct-3bit`, `Qwen3-VL-8B-Instruct-4bit`, `Qwen2-VL-2B/7B-Instruct-4bit` | | **LLaVA** | `llava-1.5-7b-4bit`, `llava-v1.6-mistral-7b-4bit`, `llava-llama-3-8b-v1_1-4bit` | | **Idefics** | `Idefics3-8B-Llama3-4bit`, `idefics2-8b-4bit` | +| **Gemma 4** | `gemma-4-e2b-it-mxfp4` (vision + audio) | | **PaliGemma** | `paligemma2-3b-mix-224-4bit`, `paligemma-3b-mix-224-8bit` | | **Pixtral** | `pixtral-12b-4bit`, `pixtral-12b-8bit` | | **Molmo** | `Molmo-7B-D-0924-4bit`, `Molmo-7B-D-0924-8bit` | @@ -72,7 +73,7 @@ vllm-mlx auto-detects multimodal models by name patterns: - Contains "VL", "Vision", "vision" - Contains "llava", "idefics", "paligemma" - Contains "pixtral", "molmo", "deepseek-vl" -- Contains "MedGemma", "Gemma-3" (vision variants) +- Contains "MedGemma", "Gemma-3", "Gemma-4" (multimodal variants) ## Using Models diff --git a/pyproject.toml b/pyproject.toml index 6ccc45282..87b1974df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ classifiers = [ dependencies = [ "mlx>=0.29.0", "mlx-lm>=0.31.0", # 0.31+ required for ArraysCache native batching (hybrid models) - "mlx-vlm>=0.1.0", # VLM support + "mlx-vlm>=0.4.3", # 0.4.3+ required for Gemma 4 support "transformers>=5.0.0", # mlx-lm 0.30.5+ requires transformers 5.0 (rc3 bug fixed in stable) "tokenizers>=0.19.0", "huggingface-hub>=0.23.0", diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 9fdbfef13..6dea67150 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -339,6 +339,8 @@ def flush(self) -> list[tuple[str, str]]: "PaliGemma", # PaliGemma "gemma-3", "gemma3", # Gemma 3 (multimodal) + "gemma-4", + "gemma4", # Gemma 4 (multimodal: vision + audio) "medgemma", "MedGemma", # MedGemma (medical multimodal with SigLIP vision encoder) "pixtral", diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 3ac52b4b0..51a76aad4 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -89,23 +89,24 @@ class MLLMModelWrapper: but MLLM models return LanguageModelOutput objects. This wrapper extracts the logits from the output. - Also handles Gemma 3's required pixel_values argument by injecting None + Also handles Gemma 3/4's required pixel_values argument by injecting None for text-only requests. """ def __init__(self, model): self._model = model - # Detect if this is a Gemma 3 model (requires pixel_values as positional arg) - self._is_gemma3 = ( + # Detect if this is a Gemma 3/4 model (requires pixel_values as positional arg) + model_type = str(getattr(model, "model_type", "")).lower() + self._is_gemma_multimodal = ( hasattr(model, "model_type") - and "gemma3" in str(getattr(model, "model_type", "")).lower() + and ("gemma3" in model_type or "gemma4" in model_type) ) def __call__(self, *args, **kwargs): """Call the model and extract logits from LanguageModelOutput.""" - # Gemma 3 requires pixel_values as a positional argument, unlike Qwen + # Gemma 3/4 requires pixel_values as a positional argument, unlike Qwen # which makes it optional. Inject pixel_values=None for text-only requests. - if self._is_gemma3 and "pixel_values" not in kwargs: + if self._is_gemma_multimodal and "pixel_values" not in kwargs: kwargs["pixel_values"] = None output = self._model(*args, **kwargs) diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index fcf3537f4..a6c67226e 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -2091,6 +2091,8 @@ def is_mllm_model(model_name: str) -> bool: "PaliGemma", "gemma-3", "gemma3", # Gemma 3 (multimodal) + "gemma-4", + "gemma4", # Gemma 4 (multimodal: vision + audio) "medgemma", "MedGemma", # MedGemma (medical multimodal) "pixtral", diff --git a/vllm_mlx/multimodal_processor.py b/vllm_mlx/multimodal_processor.py index a5c861216..2905e9abb 100644 --- a/vllm_mlx/multimodal_processor.py +++ b/vllm_mlx/multimodal_processor.py @@ -147,7 +147,7 @@ def process( logger.warning(f"Failed to process video: {e}") # Determine add_special_tokens based on model type - if self.config and self.config.model_type in ["gemma3", "gemma3n"]: + if self.config and self.config.model_type in ["gemma3", "gemma3n", "gemma4"]: add_special_tokens = not hasattr(self.processor, "chat_template") # Prepare inputs using mlx_vlm From 98c682d87c25cb1b5b4660e7cd4641e78c5da634 Mon Sep 17 00:00:00 2001 From: Jan Hilgard Date: Mon, 6 Apr 2026 11:50:19 +0200 Subject: [PATCH 09/45] fix: Gemma 4 BatchKVCache, reasoning parser, and MLLM stop tokens - Patch gemma4 Attention to snapshot cache.offset before mutation (mx.array.__iadd__ is in-place, causes wrong RoPE positions) - Add Gemma 4 reasoning parser with channel name stripping (strips "thought"/"response" prefixes, supports both and <|channel>response transition formats) - Configure Gemma 4 EOS/stop tokens to prevent uncontrolled generation - Add 16 Gemma 4 parser tests (non-streaming + streaming) Co-Authored-By: Claude Opus 4.6 --- tests/test_reasoning_parser.py | 266 ++++++++++++++++++++++++++++ vllm_mlx/mllm_batch_generator.py | 67 ++++++- vllm_mlx/mllm_scheduler.py | 21 ++- vllm_mlx/patches/gemma4_mllm.py | 121 +++++++++++++ vllm_mlx/reasoning/__init__.py | 2 + vllm_mlx/reasoning/gemma4_parser.py | 170 ++++++++++++++++++ 6 files changed, 639 insertions(+), 8 deletions(-) create mode 100644 vllm_mlx/patches/gemma4_mllm.py create mode 100644 vllm_mlx/reasoning/gemma4_parser.py diff --git a/tests/test_reasoning_parser.py b/tests/test_reasoning_parser.py index e2d0184e7..4bcb5ab3f 100644 --- a/tests/test_reasoning_parser.py +++ b/tests/test_reasoning_parser.py @@ -6,6 +6,7 @@ - Parser registry (registration, lookup, listing) - Qwen3 parser (non-streaming and streaming) - DeepSeek-R1 parser (non-streaming and streaming) +- Gemma 4 parser (channel protocol, streaming, channel name stripping) - Edge cases (no tags, partial tags, etc.) """ @@ -28,6 +29,7 @@ def test_list_parsers_includes_builtin(self): parsers = list_parsers() assert "qwen3" in parsers assert "deepseek_r1" in parsers + assert "gemma4" in parsers def test_get_parser_qwen3(self): """Should be able to get Qwen3 parser.""" @@ -920,3 +922,267 @@ def test_constrain_tokens_stripped(self, parser): reasoning, content = parser.extract_reasoning(output) assert "<|constrain|>" not in (content or "") assert "<|channel|>" not in (content or "") + + +class TestGemma4Parser: + """Tests for the Gemma 4 reasoning parser (channel-based protocol).""" + + @pytest.fixture + def parser(self): + """Create a fresh Gemma 4 parser for each test.""" + return get_parser("gemma4")() + + # --- Non-streaming tests --- + + def test_extract_standard_format(self, parser): + """Standard format: <|channel>thought...response.""" + output = ( + "<|channel>thought\nLet me think step by step.\nThe answer is 42." + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Let me think step by step." + assert content == "The answer is 42." + + def test_extract_alternative_format(self, parser): + """Alternative format: <|channel>thought...<|channel>response...""" + output = "<|channel>thought\nAnalyzing the problem.\n<|channel>response\nThe result is 7." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Analyzing the problem." + assert content == "The result is 7." + + def test_extract_strips_thought_prefix(self, parser): + """Channel name 'thought' should be stripped from reasoning.""" + output = "<|channel>thought\nActual reasoning hereContent" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Actual reasoning here" + assert "thought" not in reasoning + + def test_extract_no_tags_pure_content(self, parser): + """No channel tags at all should return pure content.""" + output = "Just a regular response without thinking." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == output + + def test_extract_only_start_tag(self, parser): + """Only start tag means incomplete reasoning (no content yet).""" + output = "<|channel>thought\nStill thinking..." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Still thinking..." + assert content is None + + def test_extract_only_end_tag(self, parser): + """Only end tag (think injected in prompt).""" + output = "thought\nImplicit reasoningThe answer" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Implicit reasoning" + assert content == "The answer" + + def test_extract_empty_reasoning(self, parser): + """Empty reasoning should return None.""" + output = "<|channel>thought\nOnly content here." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == "Only content here." + + def test_extract_multiline_reasoning(self, parser): + """Should preserve multiline reasoning content.""" + output = ( + "<|channel>thought\n" + "Step 1: Understand the question.\n" + "Step 2: Analyze the data.\n" + "Step 3: Form conclusion.\n" + "The conclusion is clear." + ) + reasoning, content = parser.extract_reasoning(output) + assert "Step 1" in reasoning + assert "Step 2" in reasoning + assert "Step 3" in reasoning + assert content == "The conclusion is clear." + + def test_extract_unicode_reasoning(self, parser): + """Should handle Unicode in reasoning.""" + output = "<|channel>thought\n日本語テスト 🤔\n答えは42" + reasoning, content = parser.extract_reasoning(output) + assert "日本語テスト" in reasoning + assert "🤔" in reasoning + assert "42" in content + + def test_registry_includes_gemma4(self): + """gemma4 should be in the parser registry.""" + assert "gemma4" in list_parsers() + + # --- Streaming tests --- + + def test_streaming_no_tags_plain_content(self, parser): + """Streaming without any channel tags should return content.""" + parser.reset_state() + result = parser.extract_reasoning_streaming("", "Hello", "Hello") + assert result is not None + assert result.content == "Hello" + assert result.reasoning is None + + def test_streaming_standard_format(self, parser): + """Test streaming through <|channel>thought...content flow.""" + parser.reset_state() + + tokens = [ + "<|channel>", + "thought", + "\n", + "Let me ", + "think.", + "", + "The ", + "answer.", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + # "thought\n" prefix should be stripped + assert "thought" not in full_reasoning or "thought" in "Let me think." + assert "Let me think." in full_reasoning + assert "The answer." in full_content + + def test_streaming_alternative_format(self, parser): + """Test streaming with <|channel>response transition.""" + parser.reset_state() + + tokens = [ + "<|channel>", + "thought", + "\n", + "Analyzing.", + "<|channel>response", + "\n", + "Result: ", + "42", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_content = "".join(content_parts) + assert "Result: 42" in full_content + + def test_streaming_suppresses_channel_names(self, parser): + """Channel names 'thought' and 'response' should not appear in output.""" + parser.reset_state() + + # Simulate realistic Gemma 4 output + tokens = [ + "<|channel>", + "thought", + "\n", + "Real ", + "reasoning.", + "", + "Real ", + "content.", + ] + + accumulated = "" + all_output = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + all_output.append(("r", result.reasoning)) + if result.content: + all_output.append(("c", result.content)) + + # Verify no raw "thought" token leaked as reasoning + reasoning_text = "".join(t for tag, t in all_output if tag == "r") + content_text = "".join(t for tag, t in all_output if tag == "c") + + assert "Real reasoning." in reasoning_text + assert "Real content." in content_text + + def test_streaming_token_by_token(self, parser): + """Test character-by-character streaming (worst case).""" + parser.reset_state() + + output = "<|channel>thought\nStep 1: Think\nStep 2: Analyze\nFinal answer: 42." + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for char in output: + prev = accumulated + accumulated += char + result = parser.extract_reasoning_streaming(prev, accumulated, char) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + assert "Step 1: Think" in full_reasoning + assert "Step 2: Analyze" in full_reasoning + assert "Final answer: 42." in full_content + + def test_streaming_long_thinking_no_end_tag(self, parser): + """When model generates long thinking without end tag, all goes to reasoning.""" + parser.reset_state() + + # Simulate model that hits max_tokens before + tokens = [ + "<|channel>", + "thought", + "\n", + "This is a very long ", + "reasoning process ", + "that continues ", + "without ending.", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + assert "very long reasoning process" in full_reasoning + assert len(content_parts) == 0 diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index ee8d8da7b..fa462ad88 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -324,6 +324,11 @@ def __init__( "MLLMBatchGenerator: Model does not have language_model, using model directly" ) + # Patch attention for BatchKVCache compatibility + from .patches.gemma4_mllm import patch_gemma4_attention_for_batching + + patch_gemma4_attention_for_batching() + self.max_tokens = max_tokens self.stop_tokens = stop_tokens or set() self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) @@ -340,6 +345,9 @@ def __init__( # Statistics self._stats = MLLMBatchStats() + # Error responses for requests that failed during preprocessing + self._pending_error_responses: List[MLLMBatchResponse] = [] + # Vision embedding cache for repeated images self.vision_cache = VisionEmbeddingCache( max_pixel_entries=vision_cache_size, @@ -666,7 +674,7 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: # KVCache.merge() creates a BatchKVCache with proper left-padding # alignment, so all requests share a single batched cache for # subsequent generation steps. - from mlx_lm.models.cache import KVCache + from mlx_lm.models.cache import KVCache, RotatingKVCache sample_cache = per_request_caches[0][0] if not isinstance(sample_cache, KVCache): @@ -676,6 +684,26 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: f"when using multimodal models with --continuous-batching." ) + # Fix: RotatingKVCache._update_concat does NOT trim on first call — + # if prompt length > max_size, the buffer grows beyond max_size. + # BatchRotatingKVCache.merge() then hits a shape mismatch when + # copying via _temporal_order (full buffer) into a max_size slice. + # Trim buffer to max_size before merging. + for rc in per_request_caches: + for layer_cache in rc: + if isinstance(layer_cache, RotatingKVCache): + if layer_cache.keys is not None: + buf_len = layer_cache.keys.shape[2] + if buf_len > layer_cache.max_size: + trim_size = buf_len - layer_cache.max_size + layer_cache.keys = layer_cache._trim( + trim_size, layer_cache.keys + ) + layer_cache.values = layer_cache._trim( + trim_size, layer_cache.values + ) + layer_cache._idx = layer_cache.max_size + try: batch_cache = [ per_request_caches[0][layer_idx].merge( @@ -764,15 +792,40 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None return [] - new_batch = self._process_prompts(requests) - self.unprocessed_requests = self.unprocessed_requests[len(requests) :] - self.active_batch = new_batch - prompt_processing = True + try: + new_batch = self._process_prompts(requests) + self.unprocessed_requests = self.unprocessed_requests[len(requests) :] + self.active_batch = new_batch + prompt_processing = True + except Exception as e: + logger.error( + f"Failed to process batch of {len(requests)} prompts: " + f"{type(e).__name__}: {e}", + exc_info=True, + ) + # Remove failed requests to avoid infinite retry loop + self.unprocessed_requests = self.unprocessed_requests[len(requests) :] + for req in requests: + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + + # Collect any pending error responses (from failed preprocessing) + error_responses = [] + if self._pending_error_responses: + error_responses = list(self._pending_error_responses) + self._pending_error_responses.clear() # Generate next token for active batch batch = self.active_batch if batch is None: - return [] + return error_responses y, logprobs = batch.y, batch.logprobs batch.y, batch.logprobs = self._step(y[:, None], batch.cache) @@ -841,7 +894,7 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None self._stats.generation_tokens += len(responses) - return responses + return error_responses + responses def next(self) -> List[MLLMBatchResponse]: """ diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 555b230f2..9623ca27f 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -219,7 +219,7 @@ def __init__( self.total_completion_tokens = 0 def _get_stop_tokens(self) -> Set[int]: - """Get stop token IDs from tokenizer.""" + """Get stop token IDs from tokenizer and generation_config.json.""" stop_tokens = set() tokenizer = ( self.processor.tokenizer @@ -239,6 +239,25 @@ def _get_stop_tokens(self) -> Set[int]: else: stop_tokens.add(tokenizer.eos_token_ids) + # Also read generation_config.json which may have additional EOS tokens + # (e.g., Gemma 4 has =106, <|tool_response>=50 as EOS) + model_path = getattr(tokenizer, "name_or_path", None) + if model_path: + import json + from pathlib import Path + + gc_path = Path(model_path) / "generation_config.json" + if gc_path.exists(): + try: + gc = json.loads(gc_path.read_text()) + gc_eos = gc.get("eos_token_id") + if isinstance(gc_eos, list): + stop_tokens.update(gc_eos) + elif gc_eos is not None: + stop_tokens.add(gc_eos) + except Exception: + pass + return stop_tokens def _ensure_batch_generator(self) -> None: diff --git a/vllm_mlx/patches/gemma4_mllm.py b/vllm_mlx/patches/gemma4_mllm.py new file mode 100644 index 000000000..dc041cf31 --- /dev/null +++ b/vllm_mlx/patches/gemma4_mllm.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime patch for mlx-vlm's Gemma 4 attention to support BatchKVCache. + +Gemma 4 Attention reads cache.offset into a local variable before calling +update_and_fetch, then uses the same variable later for RoPE on queries: + + offset = cache.offset # reference to mx.array([22]) + keys = self.rope(keys, offset=offset) + keys, values = cache.update_and_fetch(keys, values) + # ^^^ self.offset += 1 mutates the SAME mx.array in-place! + queries = self.rope(queries, offset=offset) # offset is now 23! + +For KVCache, cache.offset is a Python int (immutable), so the local copy +is unaffected. For BatchKVCache, cache.offset is an mx.array and +mx.array.__iadd__ is *in-place*, so the local reference is silently +mutated by update_and_fetch, giving queries the wrong RoPE position. + +This patch replaces Gemma4 Attention.__call__ with a version that +snapshots cache.offset as a defensive copy before any mutation can occur. +The mx.array copy preserves per-sequence offsets needed for correct RoPE +in continuous batching (unlike int conversion which would lose this info). +""" + +import logging +from typing import Any, Optional + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +def _snapshot_cache_offset(cache): + """Snapshot cache offset, making a defensive copy if it's an mx.array. + + BatchKVCache stores offset as mx.array (per-batch-item). + mx.array.__iadd__ is in-place, so update_and_fetch mutates the original. + We return a copy to preserve the pre-update value for RoPE on queries. + """ + if cache is None: + return 0 + off = cache.offset + if isinstance(off, int): + return off + if isinstance(off, mx.array): + return off + 0 # defensive copy — new array, same values + return off + + +def patch_gemma4_attention_for_batching() -> bool: + """Monkey-patch Gemma4 Attention.__call__ to snapshot offset before update. + + Returns True if patch was applied, False if mlx-vlm is not installed + or Gemma 4 module not available. + """ + try: + from mlx_vlm.models.gemma4.language import Attention as Gemma4Attention + from mlx_vlm.models.base import scaled_dot_product_attention + except ImportError: + logger.debug("[Gemma4 patch] mlx-vlm Gemma4 module not available") + return False + + if getattr(Gemma4Attention, "_batch_patched", False): + logger.debug("[Gemma4 patch] Already patched") + return True + + _orig_call = Gemma4Attention.__call__ + + def _patched_call( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, _ = x.shape + + queries = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim) + queries = self.q_norm(queries) + + # Snapshot offset BEFORE update_and_fetch can mutate it in-place. + # Preserves per-sequence mx.array offsets for correct batched RoPE. + offset = _snapshot_cache_offset(cache) + + if self.is_kv_shared_layer and cache is not None: + state = cache.state + keys, values = state[0], state[1] + else: + keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + + if self.use_k_eq_v: + values = keys + else: + values = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + + keys = self.k_norm(keys) + values = self.v_norm(values) + values = values.transpose(0, 2, 1, 3) + + keys = keys.transpose(0, 2, 1, 3) + keys = self.rope(keys, offset=offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + queries = queries.transpose(0, 2, 1, 3) + queries = self.rope(queries, offset=offset) + + if mask is not None and isinstance(mask, mx.array): + if mask.shape[-1] != keys.shape[-2]: + mask = mask[..., -keys.shape[-2] :] + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + Gemma4Attention.__call__ = _patched_call + Gemma4Attention._batch_patched = True + logger.info("[Gemma4 patch] Attention patched for BatchKVCache support") + return True diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py index f138796ff..49d13a26b 100644 --- a/vllm_mlx/reasoning/__init__.py +++ b/vllm_mlx/reasoning/__init__.py @@ -76,6 +76,7 @@ def list_parsers() -> list[str]: def _register_builtin_parsers(): """Register built-in parsers.""" from .deepseek_r1_parser import DeepSeekR1ReasoningParser + from .gemma4_parser import Gemma4ReasoningParser from .gpt_oss_parser import GptOssReasoningParser from .harmony_parser import HarmonyReasoningParser from .qwen3_parser import Qwen3ReasoningParser @@ -84,6 +85,7 @@ def _register_builtin_parsers(): register_parser("deepseek_r1", DeepSeekR1ReasoningParser) register_parser("gpt_oss", GptOssReasoningParser) register_parser("harmony", HarmonyReasoningParser) + register_parser("gemma4", Gemma4ReasoningParser) # Register built-in parsers on module load diff --git a/vllm_mlx/reasoning/gemma4_parser.py b/vllm_mlx/reasoning/gemma4_parser.py new file mode 100644 index 000000000..8b6dd8149 --- /dev/null +++ b/vllm_mlx/reasoning/gemma4_parser.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for Gemma 4 models. + +Gemma 4 uses a channel-based protocol for reasoning: + + <|channel>thought + ...thinking content... + + ...response content... + +Where: + <|channel> = token 100 (channel switch marker) + = token 101 (end-of-channel marker) + +The channel names "thought" and "response" appear as text after the +special tokens and should be stripped from the output. + +Some model variants may use <|channel>response instead of +to transition from thinking to response mode. This parser handles both. + +When thinking is disabled or not triggered, output contains no tags. +""" + +from .base import DeltaMessage +from .think_parser import BaseThinkingReasoningParser + +# Channel names that follow <|channel> — stripped from output +_THOUGHT_PREFIX = "thought" +_RESPONSE_MARKER = "<|channel>response" + + +def _strip_channel_name(text: str, prefix: str) -> str: + """Strip channel name and leading whitespace/newline from text start.""" + if text.startswith(prefix): + text = text[len(prefix) :] + return text.lstrip("\n") + + +class Gemma4ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Gemma 4 models. + + Handles two transition formats: + 1. <|channel>thought...response (standard: token 100 + 101) + 2. <|channel>thought...<|channel>response (alternative: token 100 + 100) + + Channel names ("thought", "response") are stripped from output. + + Example: + Input: "<|channel>thought\\nLet me think...The answer is 42." + Output: reasoning="Let me think...", content="The answer is 42." + + When no tags are present, the entire output is treated as content. + """ + + @property + def start_token(self) -> str: + return "<|channel>" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from complete output. + + Handles both and <|channel>response as transition markers. + Strips channel names ("thought", "response") from output. + """ + text = model_output + + # Try standard format first: <|channel>thought...response + if self.start_token in text and self.end_token in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(self.end_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.strip() + return reasoning or None, content or None + + # Try alternative format: <|channel>thought...<|channel>response... + if text.count(self.start_token) >= 2 and _RESPONSE_MARKER in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(_RESPONSE_MARKER) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.lstrip("\n").strip() + return reasoning or None, content or None + + # Only closing tag (think injected in prompt) + if self.end_token in text: + reasoning, _, content = text.partition(self.end_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.strip() + return reasoning or None, content or None + + # Only start tag (incomplete reasoning, no end yet) + if self.start_token in text: + _, _, reasoning = text.partition(self.start_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + return reasoning or None, None + + # No tags at all — pure content + return None, model_output + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming delta. + + Handles: + - No tags: treat as content (Gemma 4 doesn't inject tags in prompt) + - <|channel>thought: enter reasoning mode, strip channel name + - or <|channel>response: transition to content mode + """ + # No channel tokens at all — plain content + if self.start_token not in current_text and self.end_token not in current_text: + return DeltaMessage(content=delta_text) + + # Check for alternative transition: <|channel>response + if _RESPONSE_MARKER in current_text: + if _RESPONSE_MARKER not in previous_text: + # Transition happening in this delta + # Find what (if any) content comes after the marker + marker_pos = current_text.find(_RESPONSE_MARKER) + after_marker = current_text[marker_pos + len(_RESPONSE_MARKER) :] + after_marker = after_marker.lstrip("\n") + if after_marker: + return DeltaMessage(content=after_marker) + return None # Suppress the marker itself + else: + # Already past transition — pure content + # But we need to only emit the NEW text (delta) + return DeltaMessage(content=delta_text) + + # Delegate to base class for standard <|channel>/ handling + result = super().extract_reasoning_streaming( + previous_text, current_text, delta_text + ) + + # Strip "thought" channel name from initial reasoning + if result is not None and result.reasoning is not None: + r = result.reasoning + # First reasoning delta after <|channel> will be "thought" or "thought\n" + if self.start_token in current_text: + # Check if this is the very first reasoning content + after_channel = current_text.split(self.start_token, 1)[1] + if after_channel.startswith(_THOUGHT_PREFIX): + # Remove "thought" prefix from the accumulated reasoning so far + clean = after_channel[len(_THOUGHT_PREFIX) :].lstrip("\n") + # Compute what portion of clean text is in this delta + prev_after = "" + if self.start_token in previous_text: + prev_after = previous_text.split(self.start_token, 1)[1] + if prev_after.startswith(_THOUGHT_PREFIX): + prev_after = prev_after[len(_THOUGHT_PREFIX) :].lstrip("\n") + # The new reasoning text is clean minus what was already emitted + new_reasoning = clean[len(prev_after) :] + if new_reasoning: + return DeltaMessage(reasoning=new_reasoning) + return None # Suppress channel name token + + return result From dc2279d067a543aa19f5a47d1b75dd011e6ed181 Mon Sep 17 00:00:00 2001 From: Jack Neil Date: Wed, 8 Apr 2026 20:27:11 -0400 Subject: [PATCH 10/45] fix: RotatingKVCache support in MLLM batching and missing return in tokenizer - Accept RotatingKVCache (used by Gemma 4) in batch cache validation - Add missing return statement in load_model_with_fallback Co-Authored-By: Claude Opus 4.6 (1M context) --- vllm_mlx/mllm_batch_generator.py | 9 +++++---- vllm_mlx/utils/tokenizer.py | 2 ++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index fa462ad88..a8845c5e8 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -677,11 +677,12 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: from mlx_lm.models.cache import KVCache, RotatingKVCache sample_cache = per_request_caches[0][0] - if not isinstance(sample_cache, KVCache): + if not isinstance(sample_cache, (KVCache, RotatingKVCache)): raise ValueError( - f"MLLM continuous batching requires standard KVCache but got " - f"{type(sample_cache).__name__}. Disable --kv-cache-quantization " - f"when using multimodal models with --continuous-batching." + f"MLLM continuous batching requires KVCache or RotatingKVCache " + f"but got {type(sample_cache).__name__}. Disable " + f"--kv-cache-quantization when using multimodal models with " + f"--continuous-batching." ) # Fix: RotatingKVCache._update_concat does NOT trim on first call — diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index a50883951..0cb4b5d82 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -67,6 +67,8 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): return _load_strict_false(model_name, tokenizer_config) raise + return model, tokenizer + def _load_strict_false(model_name: str, tokenizer_config: dict = None): """Load model with strict=False to discard extra weights (e.g., vision tower, MTP).""" From 0c47a679a656363efe8022a9ddaf1ea135482b05 Mon Sep 17 00:00:00 2001 From: Tim Perry Date: Fri, 10 Apr 2026 13:44:49 -0700 Subject: [PATCH 11/45] Upgrade mlx-vlm and torchvision so Qwen3.5 multimodal will run This depends on PR 215 or PR 243 being applied first. --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6ccc45282..2df59d437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ classifiers = [ dependencies = [ "mlx>=0.29.0", "mlx-lm>=0.31.0", # 0.31+ required for ArraysCache native batching (hybrid models) - "mlx-vlm>=0.1.0", # VLM support + "mlx-vlm>=0.4.4", # VLM support "transformers>=5.0.0", # mlx-lm 0.30.5+ requires transformers 5.0 (rc3 bug fixed in stable) "tokenizers>=0.19.0", "huggingface-hub>=0.23.0", @@ -44,7 +44,7 @@ dependencies = [ # Video processing for VLM "opencv-python>=4.8.0", # Vision processor (required for transformers AutoProcessor) - "torchvision>=0.18.0", + "torchvision>=0.21.0", # Resource monitoring "psutil>=5.9.0", # Server @@ -75,7 +75,7 @@ vllm = [ ] vision = [ "torch>=2.3.0", - "torchvision>=0.18.0", + "torchvision>=0.21.0", ] # Audio dependencies for TTS/STT (mlx-audio) audio = [ From 588d206abf6bcc24055635d8b19d9bf0d2299a1c Mon Sep 17 00:00:00 2001 From: Wayner Barrios Date: Fri, 10 Apr 2026 17:36:49 -0400 Subject: [PATCH 12/45] handle error finish_reason in mllm_scheduler, format batched.py error responses with token=0 were falling through to the detokenizer and decoding garbage text. now they skip decoding and set the request status to FINISHED_ABORTED. added a test for this case. also ran black on batched.py to fix CI. --- tests/test_mllm_continuous_batching.py | 47 ++++++++++++++++++++++++++ vllm_mlx/engine/batched.py | 5 ++- vllm_mlx/mllm_scheduler.py | 6 ++-- 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/tests/test_mllm_continuous_batching.py b/tests/test_mllm_continuous_batching.py index 28b26b219..3ff525dc6 100644 --- a/tests/test_mllm_continuous_batching.py +++ b/tests/test_mllm_continuous_batching.py @@ -129,6 +129,53 @@ def test_finished_response(self): assert resp.finish_reason == "stop" + def test_error_response_skips_decoding(self): + """Error responses must not decode token=0 as content.""" + from unittest.mock import MagicMock, PropertyMock + + from vllm_mlx.mllm_batch_generator import MLLMBatchResponse + from vllm_mlx.mllm_scheduler import MLLMScheduler + from vllm_mlx.request import RequestStatus + + # Build a minimal scheduler with mocked internals + scheduler = MLLMScheduler.__new__(MLLMScheduler) + scheduler._detokenizer_pool = {} + scheduler.uid_to_request_id = {0: "req-err"} + scheduler.total_completion_tokens = 0 + scheduler.num_requests_processed = 0 + + mock_tokenizer = MagicMock() + mock_tokenizer.decode.return_value = "" + mock_processor = MagicMock() + mock_processor.tokenizer = mock_tokenizer + scheduler.processor = mock_processor + + # Create a running request + mock_request = MagicMock() + mock_request.request_id = "req-err" + mock_request.output_tokens = [] + mock_request.num_output_tokens = 0 + mock_request.num_prompt_tokens = 10 + mock_request.status = RequestStatus.RUNNING + scheduler.running = {"req-err": mock_request} + + error_resp = MLLMBatchResponse( + uid=0, + request_id="req-err", + token=0, + logprobs=mx.array([0.0]), + finish_reason="error", + ) + + outputs, finished = scheduler._process_batch_responses([error_resp]) + + assert "req-err" in finished + assert mock_request.status == RequestStatus.FINISHED_ABORTED + # token=0 should not have been decoded through a detokenizer + assert "req-err" not in scheduler._detokenizer_pool + assert len(outputs) == 1 + assert outputs[0].new_text == "" + class TestMLLMBatch: """Tests for MLLMBatch class.""" diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 51a76aad4..e47cd4fc6 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -97,9 +97,8 @@ def __init__(self, model): self._model = model # Detect if this is a Gemma 3/4 model (requires pixel_values as positional arg) model_type = str(getattr(model, "model_type", "")).lower() - self._is_gemma_multimodal = ( - hasattr(model, "model_type") - and ("gemma3" in model_type or "gemma4" in model_type) + self._is_gemma_multimodal = hasattr(model, "model_type") and ( + "gemma3" in model_type or "gemma4" in model_type ) def __call__(self, *args, **kwargs): diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 9623ca27f..945992045 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -477,8 +477,8 @@ def _process_batch_responses( request.num_output_tokens = len(request.output_tokens) # Decode the new token using streaming detokenizer (UTF-8 safe). - # Skip stop tokens — they are not content. - if response.finish_reason == "stop": + # Skip stop tokens and error placeholders — they are not content. + if response.finish_reason in ("stop", "error"): new_text = "" else: if request_id not in self._detokenizer_pool: @@ -508,6 +508,8 @@ def _process_batch_responses( request.status = RequestStatus.FINISHED_STOPPED elif response.finish_reason == "length": request.status = RequestStatus.FINISHED_LENGTH_CAPPED + elif response.finish_reason == "error": + request.status = RequestStatus.FINISHED_ABORTED output.finished = True output.finish_reason = response.finish_reason From 6f0efc2ee61fd3c974116f5cb329f4bd6dab47d2 Mon Sep 17 00:00:00 2001 From: Jan Hilgard Date: Fri, 10 Apr 2026 21:39:50 +0200 Subject: [PATCH 13/45] fix: patch Gemma 4 attention and RotatingKVCache for BatchKVCache - Fix BatchKVCache offset bug: mx.array.__iadd__ mutates in-place, causing incorrect RoPE positions and token repetition - Fix RotatingKVCache.max_size returning mx.array instead of int - Add Gemma 4 reasoning parser (--reasoning-parser gemma4) - Read additional EOS tokens from generation_config.json - Fix RotatingKVCache prefix cache extraction (negative left_padding) - Relax isinstance guard to accept RotatingKVCache for sliding window models like Gemma 4 (fixes ValueError on continuous batching) - Remove unused _make_batch_cache() dead code - Fix Anthropic endpoint JSON parsing for clients sending invalid escape sequences (e.g. \s, \d in regex patterns within tool defs) Co-Authored-By: Claude Opus 4.6 --- vllm_mlx/api/anthropic_models.py | 4 +- vllm_mlx/memory_cache.py | 176 ++++++++++++++++++++++++---- vllm_mlx/mllm_batch_generator.py | 77 ++++++------ vllm_mlx/server.py | 195 +++++++++++++++++++++++-------- 4 files changed, 347 insertions(+), 105 deletions(-) diff --git a/vllm_mlx/api/anthropic_models.py b/vllm_mlx/api/anthropic_models.py index a5bc6f776..e8854a5e6 100644 --- a/vllm_mlx/api/anthropic_models.py +++ b/vllm_mlx/api/anthropic_models.py @@ -84,8 +84,10 @@ class AnthropicUsage(BaseModel): class AnthropicResponseContentBlock(BaseModel): """A content block in the Anthropic response.""" - type: str # "text" or "tool_use" + type: str # "text", "thinking", or "tool_use" text: str | None = None + # thinking block + thinking: str | None = None # tool_use fields id: str | None = None name: str | None = None diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index f43763541..111439330 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -255,44 +255,121 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry: def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: - """Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced. + """Create copies of cache layers with the last ``trim_by`` positions removed. This is used when returning a cached KV state to the scheduler so that the last N positions are "freed" and the model will recompute them on the next forward pass (preventing duplicate KV entries). - Supports both KVCache (keys/values are arrays) and QuantizedKVCache - (keys/values are 3-tuples of arrays). - """ - from mlx_lm.models.cache import KVCache + For plain KVCache: reduces offset (surplus data beyond offset is harmless + since merge slices to ``keys[:, :, :offset, :]``). - try: - from mlx_lm.models.cache import QuantizedKVCache - except ImportError: - QuantizedKVCache = None # noqa: N806 + For RotatingKVCache: actually trims the circular buffer — reducing offset + alone breaks ``size()`` / ``_temporal_order`` invariants. + + Supports KVCache, RotatingKVCache, and _QuantizedCacheWrapper. + """ + import mlx.core as mx + from mlx_lm.models.cache import RotatingKVCache trimmed: list[Any] = [] + eval_targets: list[Any] = [] for layer_cache in cache: - if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache): - tc = QuantizedKVCache.__new__(QuantizedKVCache) + if isinstance(layer_cache, _QuantizedCacheWrapper): + # Shallow copy with reduced offset + tc = _QuantizedCacheWrapper.__new__(_QuantizedCacheWrapper) tc.keys = layer_cache.keys tc.values = layer_cache.values tc.offset = max(layer_cache.offset - trim_by, 0) - tc.group_size = layer_cache.group_size tc.bits = layer_cache.bits + tc.group_size = layer_cache.group_size + tc.orig_type = layer_cache.orig_type + tc.orig_attrs = layer_cache.orig_attrs + trimmed.append(tc) + elif isinstance(layer_cache, RotatingKVCache): + if layer_cache.keys is None or trim_by <= 0: + trimmed.append(layer_cache) + continue + # RotatingKVCache: must trim buffer, not just offset. + # The buffer stores the last min(offset, max_size) tokens in a + # circular arrangement. Trimming excess positions from the END + # means removing the newest entries (chronologically last). + old_offset = layer_cache.offset + new_offset = max(old_offset - trim_by, 0) + old_size = min(old_offset, layer_cache.max_size) + entries_to_keep = max(0, old_size - trim_by) + + orig_cls = type(layer_cache) + tc = orig_cls.__new__(orig_cls) + tc.offset = new_offset + tc.max_size = layer_cache.max_size + tc.keep = getattr(layer_cache, "keep", 0) + tc.step = getattr(layer_cache, "step", layer_cache.max_size) + + if entries_to_keep <= 0: + # All buffer content is beyond the trim point — clear + tc.keys = None + tc.values = None + tc._idx = 0 + elif entries_to_keep < old_size: + # Reorder to temporal order, keep the oldest entries + ordered_k = layer_cache._temporal_order(layer_cache.keys) + ordered_v = layer_cache._temporal_order(layer_cache.values) + kept_k = ordered_k[:, :, :entries_to_keep, :] + kept_v = ordered_v[:, :, :entries_to_keep, :] + + if new_offset >= tc.max_size: + # Invariant: when offset >= max_size, buffer must be + # full (keys.shape[2] == max_size). Left-pad with + # zeros to restore the full buffer. Zeros represent + # positions evicted long ago; _idx = max_size so + # _temporal_order returns as-is and _update_in_place + # rotates to overwrite zeros first. + pad_n = tc.max_size - entries_to_keep + pad_k = mx.zeros( + (kept_k.shape[0], kept_k.shape[1], pad_n, kept_k.shape[3]), + dtype=kept_k.dtype, + ) + pad_v = mx.zeros( + (kept_v.shape[0], kept_v.shape[1], pad_n, kept_v.shape[3]), + dtype=kept_v.dtype, + ) + tc.keys = mx.concatenate([pad_k, kept_k], axis=2) + tc.values = mx.concatenate([pad_v, kept_v], axis=2) + tc._idx = tc.max_size + else: + tc.keys = kept_k + tc.values = kept_v + tc._idx = entries_to_keep + eval_targets.extend([tc.keys, tc.values]) + else: + # No entries removed (trim_by == 0 already handled above, + # this covers entries_to_keep == old_size edge case) + tc.keys = layer_cache.keys + tc.values = layer_cache.values + tc._idx = layer_cache._idx trimmed.append(tc) elif ( hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys") and not isinstance(layer_cache.keys, (list, tuple)) ): - tc = KVCache.__new__(KVCache) + orig_cls = type(layer_cache) + tc = orig_cls.__new__(orig_cls) tc.keys = layer_cache.keys tc.values = layer_cache.values tc.offset = max(layer_cache.offset - trim_by, 0) + # Preserve type-specific attrs (max_size, keep, step, _idx) + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer_cache, attr): + setattr(tc, attr, getattr(layer_cache, attr)) trimmed.append(tc) else: trimmed.append(layer_cache) + + if eval_targets: + mx.eval(*eval_targets) + return trimmed @@ -353,28 +430,72 @@ def _trim_to_offset(cache: list[Any]) -> list[Any]: return trimmed +class _QuantizedCacheWrapper: + """Lightweight wrapper storing quantized KV arrays + original cache metadata. + + Unlike ``QuantizedKVCache``, this preserves enough info to reconstruct + the *original* cache type (KVCache, RotatingKVCache, etc.) on dequantize. + """ + + __slots__ = ( + "keys", + "values", + "offset", + "bits", + "group_size", + "orig_type", + "orig_attrs", + ) + + def __init__(self, layer: Any, bits: int, group_size: int): + import mlx.core as mx + + self.keys = mx.quantize(layer.keys, group_size=group_size, bits=bits) + self.values = mx.quantize(layer.values, group_size=group_size, bits=bits) + self.offset = layer.offset + self.bits = bits + self.group_size = group_size + self.orig_type = type(layer) + # Preserve RotatingKVCache-specific attrs + self.orig_attrs = {} + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer, attr): + self.orig_attrs[attr] = getattr(layer, attr) + + def _quantize_cache(cache: list[Any], bits: int = 8, group_size: int = 64) -> list[Any]: - """Quantize KVCache layers to reduce memory. Non-KVCache layers are kept as-is.""" + """Quantize KV cache layers to reduce memory. + + Only plain KVCache layers are quantized. RotatingKVCache (sliding window) + is left as-is because its internal _idx/rotation state is tightly coupled + with update_and_fetch logic and cannot survive quantize/dequantize roundtrip. + RotatingKVCache is typically small (max_size=1024) so skipping it is fine. + """ from mlx_lm.models.cache import KVCache quantized = [] for layer in cache: - if isinstance(layer, KVCache) and layer.keys is not None: - quantized.append(layer.to_quantized(group_size=group_size, bits=bits)) + if type(layer) is KVCache and getattr(layer, "keys", None) is not None: + quantized.append(_QuantizedCacheWrapper(layer, bits, group_size)) else: quantized.append(layer) return quantized def _dequantize_cache(cache: list[Any]) -> list[Any]: - """Dequantize QuantizedKVCache layers back to regular KVCache.""" + """Dequantize _QuantizedCacheWrapper layers and copy non-quantized layers. + + All layers are copied (never returned by reference) so that the model's + ``update_and_fetch`` mutations don't corrupt the stored cache entry. + """ import mlx.core as mx - from mlx_lm.models.cache import KVCache, QuantizedKVCache result = [] for layer in cache: - if isinstance(layer, QuantizedKVCache) and layer.keys is not None: - kv = KVCache() + if isinstance(layer, _QuantizedCacheWrapper): + # Reconstruct original cache type from quantized data + orig_cls = layer.orig_type + kv = orig_cls.__new__(orig_cls) kv.keys = mx.dequantize( *layer.keys, group_size=layer.group_size, bits=layer.bits ) @@ -382,6 +503,21 @@ def _dequantize_cache(cache: list[Any]) -> list[Any]: *layer.values, group_size=layer.group_size, bits=layer.bits ) kv.offset = layer.offset + # Restore type-specific attrs (max_size, keep, step, _idx) + for attr, val in layer.orig_attrs.items(): + setattr(kv, attr, val) + result.append(kv) + elif hasattr(layer, "keys") and hasattr(layer, "offset"): + # Deep-copy non-quantized cache layers (e.g. RotatingKVCache) + # so model's in-place mutations don't corrupt stored entries + orig_cls = type(layer) + kv = orig_cls.__new__(orig_cls) + kv.keys = mx.array(layer.keys) if layer.keys is not None else None + kv.values = mx.array(layer.values) if layer.values is not None else None + kv.offset = layer.offset + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer, attr): + setattr(kv, attr, getattr(layer, attr)) result.append(kv) else: result.append(layer) diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index a8845c5e8..1de137587 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -156,15 +156,44 @@ def extend(self, other: "MLLMBatch") -> None: def extract_cache(self, idx: int) -> List[Any]: """ - Extract cache for a single request (for caching). + Extract cache for a single request (for prefix caching). - Args: - idx: Index of request in batch - - Returns: - Cache state for that request + Handles BatchRotatingKVCache negative left_padding bug: + during generation with rotation, left_padding becomes negative, + causing extract() to use Python negative indexing and truncate + the buffer to only generation tokens instead of the full window. """ - return [c.extract(idx) if hasattr(c, "extract") else None for c in self.cache] + from mlx_lm.models.cache import ( + BatchRotatingKVCache, + RotatingKVCache, + ) + + result = [] + for c in self.cache: + if not hasattr(c, "extract"): + result.append(None) + elif isinstance(c, BatchRotatingKVCache): + # Custom extraction: clamp left_padding to >= 0 + cache = RotatingKVCache(c.max_size) + padding = max(0, c.left_padding[idx].item()) + offset = c.offset[idx].item() + cache.keys = c.keys[idx : idx + 1] + cache.values = c.values[idx : idx + 1] + cache._idx = c._idx + if c.rotated: + cache.keys = mx.roll(cache.keys, -c._idx, axis=2) + cache.values = mx.roll(cache.values, -c._idx, axis=2) + cache._idx = c.max_size + cache.keys = mx.contiguous(cache.keys[:, :, padding : cache._idx]) + cache.values = mx.contiguous(cache.values[:, :, padding : cache._idx]) + cache.offset = offset + cache._idx = cache.keys.shape[2] + cache.step = getattr(c, "step", c.max_size) + cache.keep = getattr(c, "keep", 0) + result.append(cache) + else: + result.append(c.extract(idx)) + return result class MLLMBatchStats: @@ -205,32 +234,6 @@ def to_dict(self) -> Dict[str, Any]: } -def _make_batch_cache(model: nn.Module, left_padding: List[int]) -> List[Any]: - """ - Create batch-aware KV cache for the language model. - - Args: - model: The language model (model.language_model from VLM) - left_padding: Padding amounts for left-padded prompts - - Returns: - List of BatchKVCache objects for each layer - """ - from mlx_lm.models.cache import BatchKVCache, KVCache - - def to_batch_cache(c): - if isinstance(c, KVCache): - return BatchKVCache(left_padding) - else: - raise ValueError(f"{type(c)} does not yet support batching") - - if hasattr(model, "make_cache"): - cache = model.make_cache() - return [to_batch_cache(c) for c in cache] - else: - return [BatchKVCache(left_padding) for _ in model.layers] - - def _left_pad_prompts( prompts: List[List[int]], max_length: Optional[int] = None ) -> mx.array: @@ -679,10 +682,10 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: sample_cache = per_request_caches[0][0] if not isinstance(sample_cache, (KVCache, RotatingKVCache)): raise ValueError( - f"MLLM continuous batching requires KVCache or RotatingKVCache " - f"but got {type(sample_cache).__name__}. Disable " - f"--kv-cache-quantization when using multimodal models with " - f"--continuous-batching." + f"MLLM continuous batching requires KVCache or " + f"RotatingKVCache but got {type(sample_cache).__name__}. " + f"Disable --kv-cache-quantization when using multimodal " + f"models with --continuous-batching." ) # Fix: RotatingKVCache._update_concat does NOT trim on first call — diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index af10e7341..9b66bbe2f 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -42,6 +42,7 @@ import json import logging import os +import re import secrets import tempfile import threading @@ -57,8 +58,13 @@ # Import from new modular API # Re-export for backwards compatibility with tests -from .api.anthropic_adapter import anthropic_to_openai, openai_to_anthropic -from .api.anthropic_models import AnthropicRequest +from .api.anthropic_adapter import anthropic_to_openai +from .api.anthropic_models import ( + AnthropicRequest, + AnthropicResponse, + AnthropicResponseContentBlock, + AnthropicUsage, +) from .api.models import ( AssistantMessage, # noqa: F401 ChatCompletionChoice, # noqa: F401 @@ -1535,6 +1541,17 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: # ============================================================================= +def _convert_anthropic_stop_reason(openai_reason: str | None) -> str: + """Convert OpenAI finish_reason to Anthropic stop_reason.""" + mapping = { + "stop": "end_turn", + "tool_calls": "tool_use", + "length": "max_tokens", + "content_filter": "end_turn", + } + return mapping.get(openai_reason or "", "end_turn") + + @app.post("/v1/messages") async def create_anthropic_message( request: Request, @@ -1549,8 +1566,19 @@ async def create_anthropic_message( """ engine = get_engine() - # Parse the raw body to handle Anthropic request format - body = await request.json() + # Parse the raw body to handle Anthropic request format. + # Some clients (e.g. Claude Code) may send JSON with invalid escape + # sequences like \s, \d in regex patterns within tool definitions. + # Python's json.loads is strict per RFC 8259 and rejects these. + try: + body = await request.json() + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + raw = await request.body() + # Replace lone backslashes (not valid JSON escapes) with \\ + body = json.loads(re.sub(rb'\\(?!["\\/bfnrtu])', rb"\\\\", raw)) + else: + raise anthropic_request = AnthropicRequest(**body) _validate_model_name(anthropic_request.model) @@ -1627,35 +1655,63 @@ async def create_anthropic_message( output.text, openai_request ) + # Extract reasoning if parser is configured + reasoning_text = None + if _reasoning_parser and not tool_calls: + text_to_parse = cleaned_text or output.text + reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( + text_to_parse + ) + # Clean output text final_content = None if cleaned_text: final_content = clean_output_text(cleaned_text) - # Determine finish reason - finish_reason = "tool_calls" if tool_calls else output.finish_reason + # Build Anthropic content blocks directly (with thinking support) + content_blocks = [] - # Build OpenAI response to convert - openai_response = ChatCompletionResponse( - model=_model_name, - choices=[ - ChatCompletionChoice( - message=AssistantMessage( - content=final_content, - tool_calls=tool_calls, - ), - finish_reason=finish_reason, + if reasoning_text: + content_blocks.append( + AnthropicResponseContentBlock(type="thinking", thinking=reasoning_text) + ) + + if final_content: + content_blocks.append( + AnthropicResponseContentBlock(type="text", text=final_content) + ) + + if tool_calls: + for tc in tool_calls: + try: + tool_input = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError): + tool_input = {} + content_blocks.append( + AnthropicResponseContentBlock( + type="tool_use", + id=tc.id, + name=tc.function.name, + input=tool_input, + ) ) - ], - usage=Usage( - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - total_tokens=output.prompt_tokens + output.completion_tokens, - ), + + if not content_blocks: + content_blocks.append(AnthropicResponseContentBlock(type="text", text="")) + + stop_reason = _convert_anthropic_stop_reason( + "tool_calls" if tool_calls else output.finish_reason ) - # Convert to Anthropic response - anthropic_response = openai_to_anthropic(openai_response, _model_name) + anthropic_response = AnthropicResponse( + model=_model_name, + content=content_blocks, + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=output.prompt_tokens, + output_tokens=output.completion_tokens, + ), + ) return Response( content=anthropic_response.model_dump_json(exclude_none=True), media_type="application/json", @@ -1836,26 +1892,39 @@ async def _stream_anthropic_messages( # Stream pipeline: raw text → tool call filter → think router → emit # - Tool call filter strips tool call markup (emitted as structured blocks later) - # - Think router separates content into Anthropic thinking blocks + # - Think router separates reasoning from content into Anthropic blocks + # + # When a reasoning parser is configured (e.g. --reasoning-parser gemma4), + # it replaces the generic StreamingThinkRouter to handle model-specific + # reasoning formats (e.g. Gemma 4 <|channel>thought...). accumulated_text = "" + use_reasoning_parser = _reasoning_parser is not None tool_filter = StreamingToolCallFilter() - # Detect if the model's chat template injects into the - # generation prompt. If so, the model starts in thinking mode and - # the opening tag never appears in the output stream. - _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None - _chat_template = "" - if _tokenizer and hasattr(_tokenizer, "chat_template"): - _chat_template = _tokenizer.chat_template or "" - _starts_thinking = ( - "" in _chat_template and "add_generation_prompt" in _chat_template - ) - think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking) + + if use_reasoning_parser: + _reasoning_parser.reset_state() + think_router = None + else: + # Detect if the model's chat template injects into the + # generation prompt. If so, the model starts in thinking mode and + # the opening tag never appears in the output stream. + _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None + _chat_template = "" + if _tokenizer and hasattr(_tokenizer, "chat_template"): + _chat_template = _tokenizer.chat_template or "" + _starts_thinking = ( + "" in _chat_template and "add_generation_prompt" in _chat_template + ) + think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking) + prompt_tokens = 0 completion_tokens = 0 # Track which content blocks we've started current_block_type = None # "thinking" or "text" block_index = 0 + # For reasoning parser: track accumulated text for parser context + reasoning_accumulated = "" async for output in engine.stream_chat(messages=messages, **chat_kwargs): delta_text = output.new_text @@ -1878,30 +1947,62 @@ async def _stream_anthropic_messages( filtered = tool_filter.process(content) if not filtered: continue - # Stage 2: route thinking vs text - pieces = think_router.process(filtered) + + if use_reasoning_parser: + # Stage 2a: use reasoning parser for model-specific formats + prev = reasoning_accumulated + reasoning_accumulated += filtered + delta_msg = _reasoning_parser.extract_reasoning_streaming( + prev, reasoning_accumulated, filtered + ) + if delta_msg is None: + continue + pieces = [] + if delta_msg.reasoning: + pieces.append(("thinking", delta_msg.reasoning)) + if delta_msg.content: + pieces.append(("text", delta_msg.content)) + else: + # Stage 2b: generic tag router + pieces = think_router.process(filtered) + events, current_block_type, block_index = _emit_content_pieces( pieces, current_block_type, block_index ) for event in events: yield event - # Flush remaining from both filters + # Flush remaining from tool filter remaining = tool_filter.flush() if remaining: + if use_reasoning_parser: + prev = reasoning_accumulated + reasoning_accumulated += remaining + delta_msg = _reasoning_parser.extract_reasoning_streaming( + prev, reasoning_accumulated, remaining + ) + pieces = [] + if delta_msg: + if delta_msg.reasoning: + pieces.append(("thinking", delta_msg.reasoning)) + if delta_msg.content: + pieces.append(("text", delta_msg.content)) + else: + pieces = think_router.process(remaining) events, current_block_type, block_index = _emit_content_pieces( - think_router.process(remaining), current_block_type, block_index + pieces, current_block_type, block_index ) for event in events: yield event - flush_pieces = think_router.flush() - if flush_pieces: - events, current_block_type, block_index = _emit_content_pieces( - flush_pieces, current_block_type, block_index - ) - for event in events: - yield event + if not use_reasoning_parser: + flush_pieces = think_router.flush() + if flush_pieces: + events, current_block_type, block_index = _emit_content_pieces( + flush_pieces, current_block_type, block_index + ) + for event in events: + yield event # Close final content block if current_block_type is not None: From c50d7db0771daa6161ce93322269c1ac68331b5b Mon Sep 17 00:00:00 2001 From: Jack Neil <25892470+jackneil@users.noreply.github.com> Date: Fri, 10 Apr 2026 18:22:05 -0400 Subject: [PATCH 14/45] feat: add Gemma 4 tool call parser (#269) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: add Gemma 4 tool parser tests (red) Co-Authored-By: Claude Opus 4.6 (1M context) * feat: add Gemma 4 tool call parser Co-Authored-By: Claude Opus 4.6 (1M context) * feat: register Gemma 4 parser, add streaming tests and wiring Co-Authored-By: Claude Opus 4.6 (1M context) * test: add edge case tests for DC review findings - Unclosed tool call block (server fallback path) - String containing colon (step-ordering guard) - String with real newline and double quote (JSON escaping) Co-Authored-By: Claude Opus 4.6 (1M context) * test: verify Gemma 4 tool calls produce exact OpenAI format for Claude Code Integration tests that verify the full pipeline (parser → server models → JSON serialization) matches what Claude Code expects: tool_calls structure, null content, function.arguments as JSON string, correct finish_reason. Co-Authored-By: Claude Opus 4.6 (1M context) * add Gemma 4 auto-detection to AutoToolParser integrates Gemma 4 format as the first format tried in auto-detection, adds streaming markers for tool call start/end. based on keegoid's approach in #254. * remove unused pytest imports * run black on tool parser, tests, and server --------- Co-authored-by: Jack Neil Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Wayner Barrios --- tests/test_gemma4_openai_format.py | 160 +++++++++++++ tests/test_gemma4_tool_parser.py | 240 ++++++++++++++++++++ tests/test_native_tool_format.py | 2 + tests/test_tool_parsers.py | 3 + vllm_mlx/api/utils.py | 1 + vllm_mlx/server.py | 5 +- vllm_mlx/tool_parsers/__init__.py | 3 + vllm_mlx/tool_parsers/auto_tool_parser.py | 26 ++- vllm_mlx/tool_parsers/gemma4_tool_parser.py | 237 +++++++++++++++++++ 9 files changed, 668 insertions(+), 9 deletions(-) create mode 100644 tests/test_gemma4_openai_format.py create mode 100644 tests/test_gemma4_tool_parser.py create mode 100644 vllm_mlx/tool_parsers/gemma4_tool_parser.py diff --git a/tests/test_gemma4_openai_format.py b/tests/test_gemma4_openai_format.py new file mode 100644 index 000000000..f680c911b --- /dev/null +++ b/tests/test_gemma4_openai_format.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Integration test: verify Gemma 4 tool calls produce valid OpenAI API responses. + +Claude Code (and other OpenAI-compatible clients) expect: +- response.choices[0].message.tool_calls[0].type == "function" +- response.choices[0].message.tool_calls[0].function.name == "read_file" +- response.choices[0].message.tool_calls[0].function.arguments == '{"path":"/tmp/test.py"}' +- response.choices[0].message.content is None (not empty string) +- response.choices[0].finish_reason == "tool_calls" + +This test verifies the FULL pipeline from parser output → server wrapping → JSON response, +not just the parser in isolation. +""" + +import json + +from vllm_mlx.api.models import ( + AssistantMessage, + ChatCompletionChoice, + ChatCompletionResponse, + FunctionCall, + ToolCall, + Usage, +) +from vllm_mlx.tool_parsers.gemma4_tool_parser import Gemma4ToolParser + + +def _build_response_from_parser(parser_output, model_name="gemma-4-27b-it"): + """Simulate what server.py does at lines 1494-1511 to build the HTTP response.""" + if parser_output.tools_called: + tool_calls = [ + ToolCall( + id=tc.get("id", "call_test"), + type="function", + function=FunctionCall( + name=tc["name"], + arguments=tc["arguments"], + ), + ) + for tc in parser_output.tool_calls + ] + content = parser_output.content if parser_output.content else None + finish_reason = "tool_calls" + else: + tool_calls = None + content = parser_output.content + finish_reason = "stop" + + return ChatCompletionResponse( + model=model_name, + choices=[ + ChatCompletionChoice( + message=AssistantMessage( + content=content, + tool_calls=tool_calls, + ), + finish_reason=finish_reason, + ) + ], + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + +class TestGemma4OpenAIFormat: + """Verify the full response matches what Claude Code expects.""" + + def setup_method(self): + self.parser = Gemma4ToolParser() + + def test_tool_call_response_has_correct_structure(self): + """The JSON response must have the exact OpenAI structure.""" + output = '<|tool_call>call:read_file{path:<|"|>/tmp/test.py<|"|>}' + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + + # Serialize to JSON (this is what goes over the wire) + data = json.loads(response.model_dump_json(exclude_none=True)) + + choice = data["choices"][0] + msg = choice["message"] + + # finish_reason must be "tool_calls" not "stop" + assert choice["finish_reason"] == "tool_calls" + + # content must be absent or null when tool_calls present + assert msg.get("content") is None + + # tool_calls must be a list + assert isinstance(msg["tool_calls"], list) + assert len(msg["tool_calls"]) == 1 + + tc = msg["tool_calls"][0] + + # type must be "function" + assert tc["type"] == "function" + + # id must be present and non-empty + assert tc["id"] + assert isinstance(tc["id"], str) + + # function.name must be the function name + assert tc["function"]["name"] == "read_file" + + # function.arguments must be a JSON string (not a dict!) + assert isinstance(tc["function"]["arguments"], str) + args = json.loads(tc["function"]["arguments"]) + assert args == {"path": "/tmp/test.py"} + + def test_multiple_tool_calls_response(self): + """Multiple tool calls in one response.""" + output = ( + "<|tool_call>" + 'call:read_file{path:<|"|>/a.py<|"|>}' + 'call:read_file{path:<|"|>/b.py<|"|>}' + "" + ) + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + data = json.loads(response.model_dump_json(exclude_none=True)) + + tcs = data["choices"][0]["message"]["tool_calls"] + assert len(tcs) == 2 + assert tcs[0]["function"]["name"] == "read_file" + assert tcs[1]["function"]["name"] == "read_file" + # Each must have a unique id + assert tcs[0]["id"] != tcs[1]["id"] + + def test_content_before_tool_call_preserved(self): + """Text before the tool call goes in content field.""" + output = 'Let me check that.\n<|tool_call>call:read_file{path:<|"|>/tmp/x<|"|>}' + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + data = json.loads(response.model_dump_json(exclude_none=True)) + + msg = data["choices"][0]["message"] + assert msg["content"] == "Let me check that." + assert len(msg["tool_calls"]) == 1 + + def test_no_tool_call_response(self): + """Plain text response has no tool_calls field.""" + output = "The answer is 42." + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + data = json.loads(response.model_dump_json(exclude_none=True)) + + msg = data["choices"][0]["message"] + assert msg["content"] == "The answer is 42." + assert "tool_calls" not in msg # excluded by exclude_none + assert data["choices"][0]["finish_reason"] == "stop" + + def test_complex_arguments_serialize_correctly(self): + """Nested objects and arrays must survive JSON round-trip.""" + output = '<|tool_call>call:configure{settings:{enabled:true,tags:[<|"|>a<|"|>,<|"|>b<|"|>]}}' + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + data = json.loads(response.model_dump_json(exclude_none=True)) + + tc = data["choices"][0]["message"]["tool_calls"][0] + args = json.loads(tc["function"]["arguments"]) + assert args == {"settings": {"enabled": True, "tags": ["a", "b"]}} diff --git a/tests/test_gemma4_tool_parser.py b/tests/test_gemma4_tool_parser.py new file mode 100644 index 000000000..179915442 --- /dev/null +++ b/tests/test_gemma4_tool_parser.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Gemma 4 tool call parser.""" + +import json + +from vllm_mlx.tool_parsers.gemma4_tool_parser import Gemma4ToolParser + + +class TestGemma4ToolParserExtract: + """Test extract_tool_calls on complete model output.""" + + def setup_method(self): + self.parser = Gemma4ToolParser() + + def test_single_tool_call_string_arg(self): + output = '<|tool_call>call:read_file{path:<|"|>/tmp/foo.py<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc["name"] == "read_file" + args = json.loads(tc["arguments"]) + assert args == {"path": "/tmp/foo.py"} + assert result.content is None + + def test_single_tool_call_numeric_arg(self): + output = "<|tool_call>call:search{limit:10,verbose:false}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"limit": 10, "verbose": False} + + def test_mixed_types(self): + output = '<|tool_call>call:search{query:<|"|>hello world<|"|>,limit:10,verbose:false}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"query": "hello world", "limit": 10, "verbose": False} + + def test_nested_object(self): + output = '<|tool_call>call:configure{settings:{enabled:true,name:<|"|>test<|"|>}}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"settings": {"enabled": True, "name": "test"}} + + def test_array_argument(self): + output = '<|tool_call>call:tag{items:[<|"|>foo<|"|>,<|"|>bar<|"|>]}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"items": ["foo", "bar"]} + + def test_multiple_tool_calls_in_one_block(self): + output = ( + "<|tool_call>" + 'call:glob{pattern:<|"|>README*.md<|"|>}' + 'call:glob{pattern:<|"|>CONTRIBUTING.md<|"|>}' + "" + ) + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 2 + args0 = json.loads(result.tool_calls[0]["arguments"]) + args1 = json.loads(result.tool_calls[1]["arguments"]) + assert args0 == {"pattern": "README*.md"} + assert args1 == {"pattern": "CONTRIBUTING.md"} + + def test_content_before_tool_call(self): + output = 'Let me read that file for you.\n<|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert result.content == "Let me read that file for you." + assert len(result.tool_calls) == 1 + + def test_no_tool_calls(self): + output = "Hello, how can I help you today?" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is False + assert result.tool_calls == [] + assert result.content == output + + def test_empty_tool_call_block(self): + output = "<|tool_call>" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is False + assert result.tool_calls == [] + + def test_tool_call_id_generated(self): + output = '<|tool_call>call:read_file{path:<|"|>/tmp/a<|"|>}' + result = self.parser.extract_tool_calls(output) + tc = result.tool_calls[0] + assert "id" in tc + assert tc["id"].startswith("call_") + + def test_string_with_special_chars(self): + output = '<|tool_call>call:write{content:<|"|>line1\\nline2<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["content"] == "line1\\nline2" + + def test_deeply_nested_objects(self): + output = "<|tool_call>call:update{a:{b:{c:1,d:true}}}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"a": {"b": {"c": 1, "d": True}}} + + def test_null_value(self): + output = "<|tool_call>call:clear{target:null}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"target": None} + + def test_unicode_emoji_in_args(self): + output = '<|tool_call>call:search{query:<|"|>hello world \U0001f30d \u4f60\u597d<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"query": "hello world \U0001f30d \u4f60\u597d"} + + def test_braces_inside_string_value(self): + output = '<|tool_call>call:run{code:<|"|>if (x) { return y; }<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"code": "if (x) { return y; }"} + + def test_quoted_keys(self): + output = '<|tool_call>call:read{<|"|>path<|"|>:<|"|>/tmp/foo<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"path": "/tmp/foo"} + + def test_think_tags_stripped(self): + output = 'Let me think about this...<|tool_call>call:search{query:<|"|>test<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + + def test_missing_end_delimiter(self): + """Unclosed tool call block still parses (server fallback path).""" + output = '<|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"path": "/tmp/foo"} + + def test_string_with_colon(self): + """String containing colon pattern must not be corrupted by bare-key quoting.""" + output = '<|tool_call>call:connect{url:<|"|>host:8080<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"url": "host:8080"} + + def test_string_with_newline_and_quote(self): + """Real newline and double quote inside string values are JSON-escaped.""" + output = '<|tool_call>call:write{text:<|"|>line1\nline2 said "hello"<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"text": 'line1\nline2 said "hello"'} + + +class TestGemma4ToolParserStreaming: + """Test streaming tool call extraction.""" + + def setup_method(self): + self.parser = Gemma4ToolParser() + self.parser.reset() + + def test_streaming_no_tool_call(self): + """Normal text passes through as content.""" + result = self.parser.extract_tool_calls_streaming( + previous_text="", + current_text="Hello", + delta_text="Hello", + ) + assert result == {"content": "Hello"} + + def test_streaming_suppresses_during_tool_call(self): + """Returns None while inside tool call block (buffering).""" + r1 = self.parser.extract_tool_calls_streaming( + previous_text="", + current_text="Sure. ", + delta_text="Sure. ", + ) + assert r1 == {"content": "Sure. "} + + r2 = self.parser.extract_tool_calls_streaming( + previous_text="Sure. ", + current_text="Sure. <|tool_call>call:read", + delta_text="<|tool_call>call:read", + ) + assert r2 is None + + r3 = self.parser.extract_tool_calls_streaming( + previous_text="Sure. <|tool_call>call:read", + current_text='Sure. <|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}', + delta_text='_file{path:<|"|>/tmp/foo<|"|>}', + ) + assert r3 is None + + def test_streaming_emits_on_close(self): + """Emits structured tool_calls when end delimiter arrives.""" + full_text = ( + 'Sure. <|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}' + ) + result = self.parser.extract_tool_calls_streaming( + previous_text='Sure. <|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}', + current_text=full_text, + delta_text="", + ) + assert result is not None + assert "tool_calls" in result + assert len(result["tool_calls"]) == 1 + tc = result["tool_calls"][0] + assert tc["function"]["name"] == "read_file" + assert tc["type"] == "function" + assert tc["index"] == 0 + + +class TestGemma4Registration: + """Test parser registration and flags.""" + + def test_registered_in_manager(self): + from vllm_mlx.tool_parsers import ToolParserManager + + parser_cls = ToolParserManager.get_tool_parser("gemma4") + assert parser_cls is Gemma4ToolParser + + def test_native_format_false(self): + assert Gemma4ToolParser.SUPPORTS_NATIVE_TOOL_FORMAT is False + assert Gemma4ToolParser.supports_native_format() is False diff --git a/tests/test_native_tool_format.py b/tests/test_native_tool_format.py index 184116171..c4182a6f8 100644 --- a/tests/test_native_tool_format.py +++ b/tests/test_native_tool_format.py @@ -12,6 +12,7 @@ AutoToolParser, DeepSeekToolParser, FunctionaryToolParser, + Gemma4ToolParser, GraniteToolParser, HermesToolParser, KimiToolParser, @@ -53,6 +54,7 @@ def test_parsers_without_native_support(self): NemotronToolParser, xLAMToolParser, AutoToolParser, + Gemma4ToolParser, ] for parser_cls in non_native_parsers: assert ( diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index dfe2bb6a1..7caaffbf5 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -9,6 +9,7 @@ AutoToolParser, DeepSeekToolParser, FunctionaryToolParser, + Gemma4ToolParser, GraniteToolParser, HermesToolParser, KimiToolParser, @@ -39,6 +40,7 @@ def test_list_registered(self): "nemotron", "xlam", "functionary", + "gemma4", ] for p in expected: assert p in parsers, f"Parser '{p}' not found" @@ -68,6 +70,7 @@ def test_get_tool_parser_by_name(self): ("meetkai", FunctionaryToolParser), ("hermes", HermesToolParser), ("nous", HermesToolParser), + ("gemma4", Gemma4ToolParser), ] for name, expected_cls in test_cases: parser_cls = ToolParserManager.get_tool_parser(name) diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 6dea67150..8c52915ee 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -121,6 +121,7 @@ def clean_output_text(text: str) -> str: ("", ""), ("", ""), (""), + ("<|tool_call>", ""), ("[TOOL_CALL]", "[/TOOL_CALL]"), ("[Calling tool", "]\n"), # Qwen3 bracket-style: [Calling tool: func({...})]\n ] diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 9b66bbe2f..ebaefe8d8 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -2293,7 +2293,10 @@ async def stream_chat_completion( tool_parser and tool_accumulated_text and not tool_calls_detected - and "" in tool_accumulated_text + and ( + "" in tool_accumulated_text + or "<|tool_call>" in tool_accumulated_text + ) ): result = tool_parser.extract_tool_calls(tool_accumulated_text) if result.tools_called: diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py index 16f744080..685631a73 100644 --- a/vllm_mlx/tool_parsers/__init__.py +++ b/vllm_mlx/tool_parsers/__init__.py @@ -19,6 +19,7 @@ - functionary/meetkai: MeetKai Functionary models - glm47/glm4: GLM-4.7 and GLM-4.7-Flash models - harmony/gpt-oss: GPT-OSS models (Harmony format with channels) +- gemma4: Google Gemma 4 models (<|tool_call>call:name{} format) Usage: from vllm_mlx.tool_parsers import ToolParserManager @@ -57,6 +58,7 @@ from .xlam_tool_parser import xLAMToolParser from .glm47_tool_parser import Glm47ToolParser from .harmony_tool_parser import HarmonyToolParser +from .gemma4_tool_parser import Gemma4ToolParser __all__ = [ # Base classes @@ -77,4 +79,5 @@ "FunctionaryToolParser", "Glm47ToolParser", "HarmonyToolParser", + "Gemma4ToolParser", ] diff --git a/vllm_mlx/tool_parsers/auto_tool_parser.py b/vllm_mlx/tool_parsers/auto_tool_parser.py index fc02d8fc6..ac9058c88 100644 --- a/vllm_mlx/tool_parsers/auto_tool_parser.py +++ b/vllm_mlx/tool_parsers/auto_tool_parser.py @@ -16,6 +16,7 @@ ToolParser, ToolParserManager, ) +from .gemma4_tool_parser import Gemma4ToolParser def generate_tool_id() -> str: @@ -29,12 +30,13 @@ class AutoToolParser(ToolParser): Auto-detecting tool call parser. Tries multiple formats in order: - 1. Mistral: [TOOL_CALLS] ... - 2. Qwen bracket: [Calling tool: func_name({...})] - 3. Qwen/Hermes XML: {"name": "...", "arguments": {...}} - 4. Llama: {"arg": "value"} - 5. Nemotron: ... - 6. Raw JSON: {"name": "...", "arguments": {...}} + 1. Gemma 4: <|tool_call>call:name{...} + 2. Mistral: [TOOL_CALLS] ... + 3. Qwen bracket: [Calling tool: func_name({...})] + 4. Qwen/Hermes XML: {"name": "...", "arguments": {...}} + 5. Llama: {"arg": "value"} + 6. Nemotron: ... + 7. Raw JSON: {"name": "...", "arguments": {...}} This is the default parser when no specific parser is selected. """ @@ -63,7 +65,14 @@ def extract_tool_calls( tool_calls: list[dict[str, Any]] = [] cleaned_text = model_output - # 1. Try Mistral format + # 1. Try Gemma 4 format (most distinctive marker) + if "<|tool_call>" in model_output: + gemma_parser = Gemma4ToolParser() + result = gemma_parser.extract_tool_calls(model_output, request) + if result.tools_called: + return result + + # 2. Try Mistral format if self.MISTRAL_TOKEN in model_output: parts = model_output.split(self.MISTRAL_TOKEN) content = parts[0].strip() @@ -327,6 +336,7 @@ def extract_tool_calls_streaming( """ # Check for any tool call markers markers = [ + "<|tool_call>", self.MISTRAL_TOKEN, "[Calling tool:", "", @@ -339,7 +349,7 @@ def extract_tool_calls_streaming( return {"content": delta_text} # Check for completion markers - end_markers = ["", "", ")]"] + end_markers = ["", "", "", ")]"] if any(m in delta_text for m in end_markers): result = self.extract_tool_calls(current_text) if result.tools_called: diff --git a/vllm_mlx/tool_parsers/gemma4_tool_parser.py b/vllm_mlx/tool_parsers/gemma4_tool_parser.py new file mode 100644 index 000000000..a32fd90cf --- /dev/null +++ b/vllm_mlx/tool_parsers/gemma4_tool_parser.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Gemma 4 tool call parser for vllm-mlx. + +Handles Gemma 4's native tool call format: + <|tool_call>call:func_name{<|"|>key<|"|>: <|"|>value<|"|>, num: 42} + +Gemma 4 uses special tokens instead of JSON: +- <|tool_call> / delimit tool call blocks +- <|"|> replaces " for string values +- Keys are unquoted bare identifiers +- Multiple call:name{...} can appear in a single block + +Reference: mlx-lm PR #1105, vllm PR #38837 +""" + +import json +import logging +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + +logger = logging.getLogger(__name__) + +# Delimiters +TOOL_CALL_START = "<|tool_call>" +TOOL_CALL_END = "" + +# Placeholder token used during <|"|> extraction. Matches \x00 + digits + \x00. +_PLACEHOLDER_RE = re.compile(r"\x00(\d+)\x00") + +# Pattern to extract <|"|>-delimited strings (non-greedy, supports multiline) +_STRING_DELIM_RE = re.compile(r'<\|"\|>(.*?)<\|"\|>', re.DOTALL) + +# Pattern to match call:name followed by a { (we extract balanced braces manually) +_CALL_PREFIX = re.compile(r"call:(\w+)\s*\{") + +# Pattern to quote bare keys: word followed by : at start or after , or { +_BARE_KEY = re.compile(r"(?<=[{,])\s*(\w+)\s*:") + +# Max arg block length to prevent runaway parsing on malformed input (1 MB) +_MAX_ARG_BLOCK_LEN = 1_048_576 + + +def _find_balanced_brace(text: str, start: int) -> int: + """Find the index of the closing } that balances the { at `start`. + + Before counting braces, <|"|>-delimited strings are conceptually opaque -- + we skip over <|"|>...<|"|> regions so that braces inside string values + (e.g. code snippets) don't affect depth counting. + + Args: + text: The string to search (may contain <|"|> tokens) + start: Index of the opening { + + Returns: + Index of the matching } in the ORIGINAL text, or -1 if not found + """ + if len(text) - start > _MAX_ARG_BLOCK_LEN: + return -1 + + depth = 0 + i = start + in_string = False + while i < len(text): + if text.startswith('<|"|>', i): + in_string = not in_string + i += 5 + continue + if not in_string: + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + if depth == 0: + return i + i += 1 + return -1 + + +def _gemma4_args_to_json(text: str) -> str: + """Convert Gemma 4 tool call args to valid JSON. + + Three-step conversion (ORDER MATTERS): + 1. Extract <|"|>-delimited strings into numbered \\x00N\\x00 placeholders. + This protects string contents from step 2's bare-key quoting -- without + this, a string value like "key: value" would be corrupted. + 2. Quote bare keys (word: -> "word":) now that strings are safe. + 3. Restore placeholders as properly JSON-escaped strings via json.dumps(). + Uses a single re.sub pass (O(len(text))) instead of per-placeholder replace. + """ + strings: list[str] = [] + + def _capture(m: re.Match) -> str: + strings.append(m.group(1)) + return f"\x00{len(strings) - 1}\x00" + + # Step 1: Extract <|"|>-delimited strings + text = _STRING_DELIM_RE.sub(_capture, text) + + # Step 2: Quote bare keys + text = _BARE_KEY.sub(r'"\1":', text) + + # Step 3: Restore captured strings as properly escaped JSON strings + def _restore(m: re.Match) -> str: + idx = int(m.group(1)) + return json.dumps(strings[idx]) if idx < len(strings) else m.group(0) + + text = _PLACEHOLDER_RE.sub(_restore, text) + + return text + + +def generate_tool_id() -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:8]}" + + +@ToolParserManager.register_module("gemma4") +class Gemma4ToolParser(ToolParser): + """ + Tool call parser for Gemma 4 models. + + Parses: <|tool_call>call:func{<|"|>key<|"|>: <|"|>val<|"|>} + + Used when --enable-auto-tool-choice --tool-call-parser gemma4 are set. + """ + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + """Extract tool calls from a complete Gemma 4 model response.""" + cleaned = self.strip_think_tags(model_output) + + start_idx = cleaned.find(TOOL_CALL_START) + if start_idx == -1: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + content_before = cleaned[:start_idx].strip() or None + + block_start = start_idx + len(TOOL_CALL_START) + end_idx = cleaned.find(TOOL_CALL_END, block_start) + if end_idx == -1: + block = cleaned[block_start:] + else: + block = cleaned[block_start:end_idx] + + tool_calls: list[dict[str, Any]] = [] + + pos = 0 + while pos < len(block): + m = _CALL_PREFIX.search(block, pos) + if not m: + break + + func_name = m.group(1) + brace_start = m.end() - 1 + + brace_end = _find_balanced_brace(block, brace_start) + if brace_end == -1: + pos = m.end() + continue + + args_raw = block[brace_start : brace_end + 1] + try: + args_json = _gemma4_args_to_json(args_raw) + json.loads(args_json) + tool_calls.append( + { + "id": generate_tool_id(), + "name": func_name, + "arguments": args_json, + } + ) + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + f"Gemma 4 tool parser: failed to parse args for " + f"call:{func_name}: {e}" + ) + + pos = brace_end + 1 + + if tool_calls: + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content_before, + ) + else: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + """Extract tool calls from streaming Gemma 4 model output.""" + has_start = TOOL_CALL_START in current_text + + if not has_start: + return {"content": delta_text} + + if TOOL_CALL_END in delta_text: + result = self.extract_tool_calls(current_text) + if result.tools_called: + return { + "tool_calls": [ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + } + + return None From 31017b05f28dc68a39c9e39c365dbb80d1b1a2ca Mon Sep 17 00:00:00 2001 From: Jan Hilgard <89418784+janhilgard@users.noreply.github.com> Date: Sat, 11 Apr 2026 00:22:10 +0200 Subject: [PATCH 15/45] feat: add full sampling params support for MLLM continuous batching (#258) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends MLLM batch generator to support top_k, min_p, and presence_penalty alongside the existing repetition_penalty. This gives the MLLM path full parity with the LLM/SimpleEngine sampling parameter coverage. Changes: - MLLMBatchRequest: add top_k, min_p, presence_penalty fields - MLLMBatch: add per-request samplers list (filter/extend support) - _process_prompts: build per-request logits processors for presence_penalty and per-request samplers for top_k/min_p - _step: accept and apply per-request samplers - SamplingParams: add presence_penalty field - MLLMScheduler: propagate new params from kwargs to batch requests - BatchedEngine: pass new params through generate/stream_generate When a request uses default values (top_k=0, min_p=0.0, presence_penalty=0.0), no extra processors or samplers are created — zero overhead for standard requests. Co-authored-by: Claude Opus 4.6 --- vllm_mlx/engine/batched.py | 16 +++++ vllm_mlx/mllm_batch_generator.py | 118 +++++++++++++++++++++++++++++-- vllm_mlx/mllm_scheduler.py | 8 +++ vllm_mlx/request.py | 1 + 4 files changed, 139 insertions(+), 4 deletions(-) diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index e47cd4fc6..d96d14bbe 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -464,6 +464,10 @@ async def generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) return GenerationOutput( @@ -480,6 +484,10 @@ async def generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), stop=stop or [], ) @@ -536,6 +544,10 @@ async def stream_generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) async for output in self._mllm_scheduler.stream_outputs(request_id): @@ -556,6 +568,10 @@ async def stream_generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), stop=stop or [], ) diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index 1de137587..9c4ea8e44 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -47,6 +47,10 @@ class MLLMBatchRequest: max_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.9 + top_k: int = 0 + min_p: float = 0.0 + presence_penalty: float = 0.0 + repetition_penalty: float = 1.0 # Processed inputs (set after vision preprocessing) input_ids: Optional[mx.array] = None @@ -98,6 +102,8 @@ class MLLMBatch: num_tokens: List[int] # Tokens generated per request cache: List[Any] # BatchKVCache for language model requests: List[MLLMBatchRequest] # Full request data + logits_processors: Optional[List[Optional[List[Callable]]]] = None + samplers: Optional[List[Optional[Callable]]] = None def __len__(self) -> int: return len(self.uids) @@ -115,6 +121,10 @@ def filter(self, keep_idx: List[int]) -> None: self.max_tokens = [self.max_tokens[k] for k in keep_idx] self.num_tokens = [self.num_tokens[k] for k in keep_idx] self.requests = [self.requests[k] for k in keep_idx] + if self.logits_processors is not None: + self.logits_processors = [self.logits_processors[k] for k in keep_idx] + if self.samplers is not None: + self.samplers = [self.samplers[k] for k in keep_idx] keep_idx_array = mx.array(keep_idx, mx.int32) self.y = self.y[keep_idx_array] @@ -139,6 +149,20 @@ def extend(self, other: "MLLMBatch") -> None: self.max_tokens.extend(other.max_tokens) self.requests.extend(other.requests) + # Extend logits_processors + if self.logits_processors is not None or other.logits_processors is not None: + self_len = len(self.uids) - len(other.uids) + self_lp = self.logits_processors or [None] * self_len + other_lp = other.logits_processors or [None] * len(other.uids) + self.logits_processors = list(self_lp) + list(other_lp) + + # Extend samplers + if self.samplers is not None or other.samplers is not None: + self_len = len(self.uids) - len(other.uids) + self_s = self.samplers or [None] * self_len + other_s = other.samplers or [None] * len(other.uids) + self.samplers = list(self_s) + list(other_s) + # Extend cache - handle None and incompatible caches for c, o in zip(self.cache, other.cache): if c is not None and o is not None and hasattr(c, "extend"): @@ -724,6 +748,51 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: # Create initial y (first generated tokens) y = mx.array(first_tokens) + # Build per-request logits processors (repetition_penalty, presence_penalty) + from mlx_lm.sample_utils import make_logits_processors, make_sampler + + batch_logits_processors = [] + has_any_lp = False + for req in requests: + need_rep = req.repetition_penalty and req.repetition_penalty != 1.0 + need_pres = req.presence_penalty and req.presence_penalty != 0.0 + if need_rep or need_pres: + lp_kwargs = {} + if need_rep: + lp_kwargs["repetition_penalty"] = req.repetition_penalty + if need_pres: + lp_kwargs["presence_penalty"] = req.presence_penalty + lp = make_logits_processors(**lp_kwargs) + batch_logits_processors.append(lp) + has_any_lp = True + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"rep_penalty={req.repetition_penalty} " + f"pres_penalty={req.presence_penalty}" + ) + else: + batch_logits_processors.append(None) + + # Build per-request samplers for top_k/min_p + batch_samplers = [] + has_any_sampler = False + for req in requests: + if req.top_k != 0 or req.min_p != 0.0: + s = make_sampler( + temp=req.temperature, + top_p=req.top_p, + top_k=req.top_k, + min_p=req.min_p, + ) + batch_samplers.append(s) + has_any_sampler = True + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"top_k={req.top_k} min_p={req.min_p}" + ) + else: + batch_samplers.append(None) + self._stats.prompt_time += time.perf_counter() - tic return MLLMBatch( @@ -735,10 +804,17 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: num_tokens=[0] * len(requests), cache=batch_cache, requests=requests, + logits_processors=batch_logits_processors if has_any_lp else None, + samplers=batch_samplers if has_any_sampler else None, ) def _step( - self, input_tokens: mx.array, cache: List[Any] + self, + input_tokens: mx.array, + cache: List[Any], + logits_processors: Optional[List[Optional[List[Callable]]]] = None, + output_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Optional[Callable]]] = None, ) -> Tuple[mx.array, List[mx.array]]: """ Run one generation step through the language model. @@ -746,6 +822,9 @@ def _step( Args: input_tokens: Input tokens [batch_size, 1] or [batch_size] cache: BatchKVCache for the language model + logits_processors: Per-request logits processors (e.g. repetition penalty) + output_tokens: Per-request generated tokens so far (needed by processors) + samplers: Per-request sampler functions (for top_k/min_p) Returns: Tuple of (sampled tokens, logprobs list) @@ -765,9 +844,29 @@ def _step( logits = logits[:, -1, :] - # Sample + # Apply per-request logits processors (repetition penalty etc.) + if logits_processors and output_tokens and any(logits_processors): + processed_logits = [] + for e in range(logits.shape[0]): + sample_logits = logits[e : e + 1] + if logits_processors[e]: + for processor in logits_processors[e]: + sample_logits = processor( + mx.array(output_tokens[e]), sample_logits + ) + processed_logits.append(sample_logits) + logits = mx.concatenate(processed_logits, axis=0) + + # Sample — per-request samplers for top_k/min_p support logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - sampled = self.sampler(logprobs) + if samplers and any(samplers): + sampled_list = [] + for e in range(logprobs.shape[0]): + s = samplers[e] if samplers[e] else self.sampler + sampled_list.append(s(logprobs[e : e + 1])) + sampled = mx.concatenate(sampled_list, axis=0) + else: + sampled = self.sampler(logprobs) return sampled, list(logprobs) @@ -832,7 +931,18 @@ def _next(self) -> List[MLLMBatchResponse]: return error_responses y, logprobs = batch.y, batch.logprobs - batch.y, batch.logprobs = self._step(y[:, None], batch.cache) + output_tokens = ( + [req.output_tokens for req in batch.requests] + if batch.logits_processors + else None + ) + batch.y, batch.logprobs = self._step( + y[:, None], + batch.cache, + batch.logits_processors, + output_tokens, + batch.samplers, + ) mx.async_eval(batch.y, batch.logprobs) y = y.tolist() diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 945992045..1a1c0c45e 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -316,6 +316,10 @@ def add_request( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) request = MLLMRequest( @@ -422,6 +426,10 @@ def _schedule_waiting(self) -> List[MLLMRequest]: max_tokens=request.sampling_params.max_tokens, temperature=request.sampling_params.temperature, top_p=request.sampling_params.top_p, + top_k=request.sampling_params.top_k, + min_p=request.sampling_params.min_p, + presence_penalty=request.sampling_params.presence_penalty, + repetition_penalty=request.sampling_params.repetition_penalty, ) batch_requests.append(batch_req) diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py index 41679c0ba..f18b238d8 100644 --- a/vllm_mlx/request.py +++ b/vllm_mlx/request.py @@ -57,6 +57,7 @@ class SamplingParams: top_p: float = 0.9 top_k: int = 0 # 0 means disabled min_p: float = 0.0 + presence_penalty: float = 0.0 repetition_penalty: float = 1.0 stop: Optional[List[str]] = None stop_token_ids: Optional[List[int]] = None From 5d93852ec73a0d1ded5173896c79bbd9e799b0c9 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Sat, 11 Apr 2026 00:22:14 +0200 Subject: [PATCH 16/45] prefix_cache: preserve hybrid recurrent state across blocks (#217) * Fix Qwen3.5 hybrid paged cache reconstruction * fix: add deduplication safety test and remove duplicate tokenizer hunk Add test confirming deduplicated terminal blocks correctly isolate recurrent state per sequence. Remove the duplicate tokenizer return fix that already ships in PR #215. * style: format hybrid cache follow-up --- tests/test_paged_cache.py | 149 ++++++++++++++++++++++++++ vllm_mlx/prefix_cache.py | 216 ++++++++++++++++++++++++-------------- 2 files changed, 288 insertions(+), 77 deletions(-) diff --git a/tests/test_paged_cache.py b/tests/test_paged_cache.py index 8e3082c34..167b60944 100644 --- a/tests/test_paged_cache.py +++ b/tests/test_paged_cache.py @@ -725,3 +725,152 @@ def test_clear(self): stats = cache.get_stats() # After clear, null block is still allocated (vLLM style) assert stats["allocated_blocks"] == 1 # only null block + + def test_reconstructs_hybrid_cache_from_boundary_snapshot(self): + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=10) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + tokens = list(range(8)) + kv_keys = mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3) + kv_values = mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3) + linear_state = [ + mx.arange(1 * 3 * 8).reshape(1, 3, 8), + mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4), + ] + extracted = [ + { + "state": (kv_keys, kv_values), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": linear_state, + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + + block_table = cache.store_cache("req-1", tokens, extracted) + first_block = paged_manager.allocated_blocks[block_table.block_ids[0]] + last_block = paged_manager.allocated_blocks[block_table.block_ids[-1]] + + assert first_block.cache_data[0] is not None + assert first_block.cache_data[1] is None + assert last_block.cache_data[1] is not None + + reconstructed = cache.reconstruct_cache(block_table) + + assert reconstructed is not None + assert isinstance(reconstructed[0], KVCache) + assert isinstance(reconstructed[1], ArraysCache) + assert reconstructed[0].state[0].tolist() == kv_keys.tolist() + assert reconstructed[0].state[1].tolist() == kv_values.tolist() + assert reconstructed[1].state[0].tolist() == linear_state[0].tolist() + assert reconstructed[1].state[1].tolist() == linear_state[1].tolist() + + def test_rejects_hybrid_prefix_without_boundary_snapshot(self): + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import BlockTable, PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=10) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + extracted = [ + { + "state": ( + mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3), + mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3), + ), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": [ + mx.arange(1 * 3 * 8).reshape(1, 3, 8), + mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4), + ], + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + + block_table = cache.store_cache("req-1", list(range(8)), extracted) + prefix_table = BlockTable( + request_id="req-prefix", + block_ids=[block_table.block_ids[0]], + num_tokens=4, + ) + + assert cache.reconstruct_cache(prefix_table) is None + + def test_deduplicated_terminal_uses_correct_recurrent_snapshot(self): + """Deduplication must not leak recurrent state across sequences.""" + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=20) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + # Request A: 8 tokens across 2 blocks. B2 is terminal. + kv_a = mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3) + recurrent_a = [mx.ones((1, 3, 8)), mx.ones((1, 2, 4, 4))] + extracted_a = [ + { + "state": (kv_a, kv_a), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": recurrent_a, + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + bt_a = cache.store_cache("req-a", list(range(8)), extracted_a) + + # Request B: 12 tokens, first 8 identical. B1/B2 deduplicated, B3 new terminal. + kv_b = mx.arange(1 * 2 * 12 * 3).reshape(1, 2, 12, 3) + recurrent_b = [mx.full((1, 3, 8), 2.0), mx.full((1, 2, 4, 4), 2.0)] + extracted_b = [ + { + "state": (kv_b, kv_b), + "meta_state": "", + "class_ref": KVCache, + "class_name": "ArraysCache", + }, + { + "state": recurrent_b, + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + bt_b = cache.store_cache("req-b", list(range(12)), extracted_b) + + # Reconstruct A: should use A's recurrent state (ones), not B's (twos) + recon_a = cache.reconstruct_cache(bt_a) + assert recon_a is not None + assert recon_a[1].state[0].tolist() == recurrent_a[0].tolist() + + # Reconstruct B: should use B's recurrent state (twos) + recon_b = cache.reconstruct_cache(bt_b) + assert recon_b is not None + assert recon_b[1].state[0].tolist() == recurrent_b[0].tolist() diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index e8f47a324..a419f3973 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -586,7 +586,7 @@ def store_cache( # Extract and store actual tensor slices for this block if is_tensor_data and HAS_MLX: block_kv_data = self._extract_block_tensor_slice( - cache_data, global_start, global_end + cache_data, global_start, global_end, len(tokens) ) if block_kv_data: block.cache_data = block_kv_data @@ -629,56 +629,122 @@ def _extract_block_tensor_slice( cache_data: List[Dict[str, Any]], start_idx: int, end_idx: int, - ) -> Optional[List[Tuple[Any, Any]]]: + total_tokens: int, + ) -> Optional[List[Optional[Dict[str, Any]]]]: """ - Extract tensor slices for a single block from cache data. + Extract per-layer cache data for a single block. Args: - cache_data: List of layer states, each containing 'state': (keys, values) + cache_data: List of extracted layer states start_idx: Start token index in the sequence end_idx: End token index in the sequence + total_tokens: Total number of tokens covered by cache_data Returns: - List of (keys_slice, values_slice) for each layer, or None on failure + Per-layer block cache state, or None on failure """ if not HAS_MLX or not cache_data: return None try: - block_slices = [] + block_slices: List[Optional[Dict[str, Any]]] = [] for layer_state in cache_data: if "state" not in layer_state: + block_slices.append(None) continue - keys, values = layer_state["state"] + state = layer_state["state"] + meta_state = layer_state.get("meta_state") + class_ref = layer_state.get("class_ref") + class_name = layer_state.get("class_name") - # KV cache shape: (batch, n_kv_heads, seq_len, head_dim) - # Slice along seq_len dimension (axis 2) - seq_len = keys.shape[2] if hasattr(keys, "shape") else 0 + if self._can_concatenate_cache_state(state): + state_slice = self._slice_concat_cache_state( + state, start_idx, end_idx + ) + block_slices.append( + { + "state": state_slice, + "meta_state": meta_state, + "class_ref": class_ref, + "class_name": class_name, + "storage": "concat", + "seq_axis": 2, + } + ) + continue - if end_idx > seq_len: - # Requested range extends beyond available data - logger.debug( - f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" + if end_idx == total_tokens: + block_slices.append( + { + "state": state, + "meta_state": meta_state, + "class_ref": class_ref, + "class_name": class_name, + "storage": "latest", + } ) - # Use whatever is available - actual_end = min(end_idx, seq_len) - if start_idx >= actual_end: - continue - keys_slice = keys[:, :, start_idx:actual_end, :] - values_slice = values[:, :, start_idx:actual_end, :] else: - keys_slice = keys[:, :, start_idx:end_idx, :] - values_slice = values[:, :, start_idx:end_idx, :] + block_slices.append(None) - block_slices.append((keys_slice, values_slice)) - - return block_slices if block_slices else None + return ( + block_slices + if any(entry is not None for entry in block_slices) + else None + ) except Exception as e: logger.warning(f"Failed to extract block tensor slice: {e}") return None + def _can_concatenate_cache_state(self, state: Any) -> bool: + """Return True when cache state can be concatenated block-by-block.""" + if not isinstance(state, (list, tuple)) or not state: + return False + return all( + tensor is not None and hasattr(tensor, "shape") and len(tensor.shape) == 4 + for tensor in state + ) + + def _slice_concat_cache_state( + self, + state: Tuple[Any, ...] | List[Any], + start_idx: int, + end_idx: int, + ) -> Tuple[Any, ...] | List[Any]: + """Slice a sequence-backed cache state across the token axis.""" + seq_len = state[0].shape[2] + actual_end = min(end_idx, seq_len) + if start_idx >= actual_end: + raise ValueError( + f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" + ) + + def _slice_tensor(tensor: Any) -> Any: + slices = [slice(None)] * len(tensor.shape) + slices[2] = slice(start_idx, actual_end) + return tensor[tuple(slices)] + + sliced = [_slice_tensor(tensor) for tensor in state] + return tuple(sliced) if isinstance(state, tuple) else sliced + + def _concat_cache_states( + self, + states: List[Tuple[Any, ...] | List[Any]], + seq_axis: int, + ) -> Optional[Tuple[Any, ...] | List[Any]]: + """Concatenate state fragments for a sequence-backed cache layer.""" + if not states: + return None + arity = len(states[0]) + concatenated = [] + for idx in range(arity): + parts = [state[idx] for state in states] + if any(part is None for part in parts): + return None + concatenated.append(mx.concatenate(parts, axis=seq_axis)) + return tuple(concatenated) if isinstance(states[0], tuple) else concatenated + def get_cache_for_generation( self, request_id: str, @@ -763,10 +829,11 @@ def reconstruct_cache( block_table: BlockTable, ) -> Optional[List[Any]]: """ - Reconstruct KVCache objects from stored block tensor data. + Reconstruct cache objects from stored block tensor data. - This method concatenates tensor slices from all blocks and - creates new KVCache objects that can be used for inference. + Sequence-backed caches are concatenated block-by-block. Recurrent + caches such as ArraysCache are restored from the latest sequence + boundary snapshot that was actually stored. Args: block_table: BlockTable containing block IDs to reconstruct from @@ -800,67 +867,62 @@ def reconstruct_cache( if not all_block_data: return None - # Get number of layers from first block - num_layers = len(all_block_data[0]) + # Get number of layers from the richest block + num_layers = max(len(block_data) for block_data in all_block_data) if num_layers == 0: return None - # Concatenate tensors for each layer reconstructed_caches = [] - for layer_idx in range(num_layers): - layer_keys = [] - layer_values = [] + layer_entries = [ + block_data[layer_idx] + for block_data in all_block_data + if layer_idx < len(block_data) + ] + layer_entries = [entry for entry in layer_entries if entry is not None] + if not layer_entries: + return None - for block_data in all_block_data: - if layer_idx < len(block_data): - keys_slice, values_slice = block_data[layer_idx] - layer_keys.append(keys_slice) - layer_values.append(values_slice) + layer_meta = layer_entries[-1] + state = layer_meta["state"] + if layer_meta["storage"] == "concat": + state = self._concat_cache_states( + [entry["state"] for entry in layer_entries], + layer_meta["seq_axis"], + ) + elif layer_meta["storage"] == "latest": + state = layer_entries[-1]["state"] - if not layer_keys: - continue + if state is None: + return None - # Concatenate along sequence dimension (axis 2) - # Shape: (batch, n_kv_heads, seq_len, head_dim) - concat_keys = mx.concatenate(layer_keys, axis=2) - concat_values = mx.concatenate(layer_values, axis=2) + cache_cls = layer_meta.get("class_ref") + meta_state = layer_meta.get("meta_state") - # Create KVCache object - # Try to use mlx_lm's KVCache.from_state if available - try: + if cache_cls is not None and hasattr(cache_cls, "from_state"): + from mlx_lm.models.cache import ( + BatchKVCache as _BatchKVCache, + KVCache as _KVCache, + ) + + if cache_cls is _BatchKVCache: + keys, values = state[0], state[1] + cache = _KVCache() + cache.keys = keys + cache.values = values + cache.offset = keys.shape[2] + else: + cache = cache_cls.from_state(state, meta_state) + else: from mlx_lm.models.cache import KVCache - # Create new cache and set its state + if len(state) != 2: + return None cache = KVCache() - seq_len = concat_keys.shape[2] - - # Set internal state directly - # KVCache stores keys/values and offset - cache.keys = concat_keys - cache.values = concat_values - cache.offset = seq_len - - reconstructed_caches.append(cache) - - except ImportError: - # Fallback: create a simple cache-like object - class SimpleKVCache: - def __init__(self, keys, values): - self.keys = keys - self.values = values - self.offset = keys.shape[2] - - @property - def state(self): - return (self.keys, self.values) - - @property - def meta_state(self): - return (str(self.offset),) - - cache = SimpleKVCache(concat_keys, concat_values) - reconstructed_caches.append(cache) + cache.keys, cache.values = state + cache.offset = cache.keys.shape[2] + + reconstructed_caches.append(cache) if not reconstructed_caches: return None From 55cff4d0749c6160f4cd4230bc00ab1f5753498b Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Sat, 11 Apr 2026 00:22:19 +0200 Subject: [PATCH 17/45] engine: keep SimpleEngine serialized across cancellation (#220) * fix: keep simple engine serialized across cancellation (#8) * fix: avoid nested simple engine generation locks * fix: catch BaseException in cancellation handler, fix async test markers _run_blocking_serialized catches CancelledError (a BaseException subclass) from the outer scope, but the inner try/except used Exception which would let a second CancelledError during await task escape unhandled. Changed to BaseException to suppress any exception from the draining await. Also fix test_simple_engine.py to use pytest.mark.anyio instead of pytest.mark.asyncio (pytest-asyncio is not configured), and add the anyio_backend fixture to conftest.py restricting to asyncio only since trio is not installed. * fix: preserve prompt token accounting after upstream refresh * fix: restore specprefill fallback helper scope --- tests/conftest.py | 6 + tests/test_simple_engine.py | 12 +- ...test_simple_engine_cancel_serialization.py | 143 ++++ vllm_mlx/engine/simple.py | 745 +++++++++--------- 4 files changed, 531 insertions(+), 375 deletions(-) create mode 100644 tests/test_simple_engine_cancel_serialization.py diff --git a/tests/conftest.py b/tests/conftest.py index d0c7f026b..f699c08bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,3 +50,9 @@ def pytest_collection_modifyitems(config, items): def server_url(request): """Get server URL from command line.""" return request.config.getoption("--server-url") + + +@pytest.fixture(params=["asyncio"]) +def anyio_backend(request): + """Run anyio-marked tests on asyncio only (trio is not installed).""" + return request.param diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index cce42bfc3..7202f625f 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -6,6 +6,8 @@ import pytest +pytestmark = pytest.mark.anyio + class TestSimpleEngineConcurrency: """Test SimpleEngine lock behavior with concurrent requests.""" @@ -65,7 +67,7 @@ def chat_side_effect(**kwargs): model.chat = MagicMock(side_effect=chat_side_effect) return model - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_generate(self, mock_model): """Test that the lock prevents concurrent generate calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -89,7 +91,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_chat(self, mock_llm_model): """Test that the lock prevents concurrent chat calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -115,7 +117,7 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" from vllm_mlx.engine.simple import SimpleEngine @@ -178,7 +180,7 @@ async def try_stream(): result = await stream_task assert len(result) == 3, f"Expected 3 chunks, got {len(result)}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_engine_initialization_creates_lock(self): """Test that SimpleEngine creates a lock on initialization.""" from vllm_mlx.engine.simple import SimpleEngine @@ -189,7 +191,7 @@ async def test_engine_initialization_creates_lock(self): assert hasattr(engine, "_generation_lock") assert isinstance(engine._generation_lock, asyncio.Lock) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_requests_complete_in_order(self, mock_model): """Test that concurrent requests complete (may be in any order due to lock).""" from vllm_mlx.engine.simple import SimpleEngine diff --git a/tests/test_simple_engine_cancel_serialization.py b/tests/test_simple_engine_cancel_serialization.py new file mode 100644 index 000000000..28c25868e --- /dev/null +++ b/tests/test_simple_engine_cancel_serialization.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression test for cancellation-safe SimpleEngine serialization.""" + +from __future__ import annotations + +import asyncio +import threading +import unittest +from unittest.mock import MagicMock, patch + + +class SimpleEngineCancelSerializationTests(unittest.IsolatedAsyncioTestCase): + async def test_cancellation_does_not_release_lock_before_worker_finishes(self): + """A cancelled request must not let a second MLX worker overlap.""" + from vllm_mlx.engine.simple import SimpleEngine + + model = MagicMock() + model.tokenizer = MagicMock() + model.tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + model._concurrent_count = 0 + model._max_concurrent = 0 + + first_started = threading.Event() + release_workers = threading.Event() + call_count = 0 + call_lock = threading.Lock() + + def generate_side_effect(**kwargs): + nonlocal call_count + with call_lock: + call_count += 1 + current_call = call_count + model._concurrent_count += 1 + model._max_concurrent = max( + model._max_concurrent, model._concurrent_count + ) + if current_call == 1: + first_started.set() + + release_workers.wait(timeout=1.0) + + with call_lock: + model._concurrent_count -= 1 + + result = MagicMock() + result.text = f"response-{current_call}" + result.tokens = [1, 2, 3] + result.finish_reason = "stop" + return result + + model.generate = MagicMock(side_effect=generate_side_effect) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + + task1 = asyncio.create_task(engine.generate(prompt="first", max_tokens=8)) + await asyncio.to_thread(first_started.wait, 1.0) + + task1.cancel() + task2 = asyncio.create_task(engine.generate(prompt="second", max_tokens=8)) + + await asyncio.sleep(0.05) + release_workers.set() + + with self.assertRaises(asyncio.CancelledError): + await task1 + result2 = await task2 + + self.assertEqual(result2.text, "response-2") + self.assertEqual( + model._max_concurrent, + 1, + "cancellation released the generation lock before the first worker finished", + ) + + async def test_specprefill_path_does_not_prelock_serialized_runner(self): + """Specprefill streaming must let _run_blocking_serialized own the lock.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_serialized(func, *args, **kwargs): + self.assertFalse(engine._generation_lock.locked()) + return [] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._model = MagicMock() + engine._model.model = MagicMock() + engine._model.tokenizer = MagicMock() + engine._draft_model = MagicMock() + engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign] + + outputs = [] + async for chunk in engine._stream_generate_specprefill( + prompt="hello", + tokens=[1, 2, 3, 4], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0].finished) + self.assertEqual(outputs[0].completion_tokens, 0) + + async def test_text_mtp_path_does_not_prelock_serialized_runner(self): + """Text-only MTP streaming must let _run_blocking_serialized own the lock.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_serialized(func, *args, **kwargs): + self.assertFalse(engine._generation_lock.locked()) + return [] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.make_mtp_cache = MagicMock(return_value=[]) + engine._text_tokenizer = MagicMock() + engine._text_tokenizer.apply_chat_template = MagicMock(return_value="hello") + engine._text_tokenizer.bos_token = None + engine._draft_model = None + engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign] + + outputs = [] + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0].finished) + self.assertEqual(outputs[0].completion_tokens, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index da3ccfc18..8c7212ca4 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -226,6 +226,24 @@ async def stop(self) -> None: self._system_kv_token_count = 0 logger.info("SimpleEngine stopped") + async def _run_blocking_serialized(self, func, /, *args, **kwargs): + """Run a blocking MLX operation under the generation lock. + + Cancellation must not release the async lock before the worker thread + finishes, or a follow-up request can enter MLX/Metal concurrently and + corrupt the command-buffer state. + """ + async with self._generation_lock: + task = asyncio.create_task(asyncio.to_thread(func, *args, **kwargs)) + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + try: + await task + except BaseException: + pass + raise + async def generate( self, prompt: str, @@ -252,30 +270,28 @@ async def generate( if not self._loaded: await self.start() - async with self._generation_lock: - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.generate, - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - **kwargs, - ) - - # Clean output text - text = clean_output_text(output.text) + output = await self._run_blocking_serialized( + self._model.generate, + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ) - return GenerationOutput( - text=text, - tokens=getattr(output, "tokens", []), - prompt_tokens=getattr(output, "prompt_tokens", 0), - completion_tokens=getattr( - output, "completion_tokens", len(getattr(output, "tokens", [])) - ), - finish_reason=output.finish_reason, - ) + # Clean output text + text = clean_output_text(output.text) + + return GenerationOutput( + text=text, + tokens=getattr(output, "tokens", []), + prompt_tokens=getattr(output, "prompt_tokens", 0), + completion_tokens=getattr( + output, "completion_tokens", len(getattr(output, "tokens", [])) + ), + finish_reason=output.finish_reason, + ) async def stream_generate( self, @@ -440,55 +456,51 @@ async def chat( # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None - async with self._generation_lock: - if self._is_mllm: - # For MLLM, use the chat method which handles images/videos - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, - ) - text = clean_output_text(output.text) - return GenerationOutput( - text=text, - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - finish_reason=output.finish_reason, - ) - else: - # For LLM, use the chat method - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - tools=template_tools, - **kwargs, - ) - text = clean_output_text(output.text) - # Count prompt tokens from the full templated prompt - tokenizer = self._model.tokenizer - template_kwargs = { - "tokenize": True, - "add_generation_prompt": True, - } - if template_tools: - template_kwargs["tools"] = template_tools - prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs) - prompt_token_count = len(prompt_ids) - return GenerationOutput( - text=text, - tokens=output.tokens, - prompt_tokens=prompt_token_count, - completion_tokens=len(output.tokens), - finish_reason=output.finish_reason, - ) + if self._is_mllm: + output = await self._run_blocking_serialized( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, + ) + text = clean_output_text(output.text) + return GenerationOutput( + text=text, + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + finish_reason=output.finish_reason, + ) + else: + output = await self._run_blocking_serialized( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=template_tools, + **kwargs, + ) + text = clean_output_text(output.text) + # Preserve upstream prompt accounting while routing the blocking + # chat call through the cancellation-safe serialized runner. + tokenizer = self._model.tokenizer + template_kwargs = { + "tokenize": True, + "add_generation_prompt": True, + } + if template_tools: + template_kwargs["tools"] = template_tools + prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs) + prompt_token_count = len(prompt_ids) + return GenerationOutput( + text=text, + tokens=output.tokens, + prompt_tokens=prompt_token_count, + completion_tokens=len(output.tokens), + finish_reason=output.finish_reason, + ) async def stream_chat( self, @@ -548,42 +560,41 @@ async def stream_chat( # For MLLM, use stream_chat which yields tokens incrementally. # Must hold _generation_lock to prevent concurrent Metal access # (e.g. OpenCode sends title + main request simultaneously). - async with self._generation_lock: - accumulated_text = "" - token_count = 0 - - # Run stream_chat in thread pool since it's synchronous - def run_stream(): - return list( - self._model.stream_chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, - ) + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + return list( + self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, ) + ) - chunks = await asyncio.to_thread(run_stream) + chunks = await self._run_blocking_serialized(run_stream) - for chunk in chunks: - token_count += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text - finished = chunk.finish_reason is not None + finished = chunk.finish_reason is not None - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=getattr(chunk, "prompt_tokens", 0), - completion_tokens=token_count, - finished=finished, - finish_reason=chunk.finish_reason if finished else None, - ) + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, "prompt_tokens", 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) - if finished: - break + if finished: + break return # For LLM, apply chat template and stream @@ -647,129 +658,125 @@ async def _stream_generate_specprefill( tokenizer = self._model.tokenizer n_tokens = len(tokens) - async with self._generation_lock: - - def _run_all(): - try: - return _run_specprefill() - except Exception as e: - logger.error( - "SpecPrefill failed, falling back to normal path: %s", e - ) - return _run_normal() + def _run_all(): + try: + return _run_specprefill() + except Exception as e: + logger.error("SpecPrefill failed, falling back to normal path: %s", e) + return _run_normal() + + def _run_specprefill(): + """Score tokens, sparse prefill, generate autoregressively.""" + import time + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) - def _run_specprefill(): - """Score tokens, sparse prefill, generate autoregressively.""" - import time - from types import SimpleNamespace + cache = make_prompt_cache(model) - from ..specprefill import ( - cleanup_rope, - score_tokens, - select_chunks, - sparse_prefill, + try: + # Phase 1: Score with draft model + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + tokens, + prefill_step_size=self._prefill_step_size, ) + t_score = time.monotonic() - t0 - cache = make_prompt_cache(model) + # Phase 2: Select important chunks + effective_keep = specprefill_keep_pct or self._specprefill_keep_pct + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] - try: - # Phase 1: Score with draft model - t0 = time.monotonic() - importance = score_tokens( - self._draft_model, - tokens, - prefill_step_size=self._prefill_step_size, - ) - t_score = time.monotonic() - t0 - - # Phase 2: Select important chunks - effective_keep = specprefill_keep_pct or self._specprefill_keep_pct - selected = select_chunks(importance, keep_pct=effective_keep) - n_selected = selected.shape[0] - - # Phase 3: Sparse prefill on target model - t0 = time.monotonic() - logits = sparse_prefill( - model, - tokens, - selected, - cache, - step_size=self._prefill_step_size, - ) - t_prefill = time.monotonic() - t0 + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + model, + tokens, + selected, + cache, + step_size=self._prefill_step_size, + ) + t_prefill = time.monotonic() - t0 - logger.info( - "SpecPrefill: scored %d tokens in %.1fs, " - "sparse prefill %d/%d (keep=%.0f%%) in %.1fs", - n_tokens, - t_score, - n_selected, - n_tokens, - n_selected / n_tokens * 100, - t_prefill, - ) + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs", + n_tokens, + t_score, + n_selected, + n_tokens, + n_selected / n_tokens * 100, + t_prefill, + ) - # Phase 4: Generate (simple autoregressive, no MTP) - sampler = make_sampler(temp=temperature, top_p=top_p) - eos_id = tokenizer.eos_token_id - y = sampler(logits[:, -1, :]) - mx.eval(y) + # Phase 4: Generate (simple autoregressive, no MTP) + sampler = make_sampler(temp=temperature, top_p=top_p) + eos_id = tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) - results = [] - generated_ids = [] - prev_decoded = "" + results = [] + generated_ids = [] + prev_decoded = "" - for _ in range(max_tokens): - tok_id = y.item() - generated_ids.append(tok_id) + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) - decoded = tokenizer.decode(generated_ids) - new_text = decoded[len(prev_decoded) :] - prev_decoded = decoded + decoded = tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded - is_eos = tok_id == eos_id - results.append( - SimpleNamespace( - text=new_text, - finish_reason="stop" if is_eos else None, - ) + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, ) + ) - if is_eos: - break + if is_eos: + break - logits = model(y.reshape(1, -1), cache=cache) - y = sampler(logits[:, -1, :]) - mx.eval(y) + logits = model(y.reshape(1, -1), cache=cache) + y = sampler(logits[:, -1, :]) + mx.eval(y) - return results + return results - finally: - cleanup_rope(model) + finally: + cleanup_rope(model) - def _run_normal(): - """Fallback: normal generation without specprefill.""" - from types import SimpleNamespace + def _run_normal(): + """Fallback: normal generation without specprefill.""" + from types import SimpleNamespace - results = [] - for chunk in self._model.stream_generate( - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - **kwargs, - ): - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - results.append( - SimpleNamespace( - text=new_text, - finish_reason=getattr(chunk, "finish_reason", None), - ) + results = [] + for chunk in self._model.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ): + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + results.append( + SimpleNamespace( + text=new_text, + finish_reason=getattr(chunk, "finish_reason", None), ) - return results + ) + return results - all_resps = await asyncio.to_thread(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" @@ -1010,194 +1017,192 @@ async def _stream_generate_text( ) use_specprefill = False - # Run under generation lock, all Metal ops in single thread - async with self._generation_lock: + # Run all Metal ops in a single serialized thread. + def _run_all(): + nonlocal backbone_cache, prompt_to_send - def _run_all(): - nonlocal backbone_cache, prompt_to_send + model = self._text_model - model = self._text_model + # Cache MISS with valid prefix: prefill system tokens and snapshot + if ( + not cache_hit + and system_token_count > 0 + and system_tokens is not None + and suffix_tokens is not None + ): + mc = make_prompt_cache(model) + sys_arr = mx.array(system_tokens) + + # Prefill system tokens in chunks (matching generate_step) + step = self._prefill_step_size + while sys_arr.size > step: + model(sys_arr[:step][None], cache=mc) + mx.eval([c.state for c in mc]) + sys_arr = sys_arr[step:] + mx.clear_cache() + if sys_arr.size > 0: + model(sys_arr[None], cache=mc) + mx.eval([c.state for c in mc]) + + # Snapshot backbone cache (immutable mx.arrays, safe to reuse) + snapshot = [c.state for c in mc] + mx.eval([s for pair in snapshot for s in pair]) + + self._system_kv_snapshot = snapshot + self._system_kv_hash = system_hash + self._system_kv_token_count = system_token_count + + backbone_cache = mc + prompt_to_send = mx.array(suffix_tokens) + logger.info( + "System KV cache: stored %d-token snapshot (%.1f MB), " + "prefilling %d remaining", + system_token_count, + sum(c.nbytes for c in mc) / 1e6, + len(suffix_tokens), + ) - # Cache MISS with valid prefix: prefill system tokens and snapshot - if ( - not cache_hit - and system_token_count > 0 - and system_tokens is not None - and suffix_tokens is not None - ): - mc = make_prompt_cache(model) - sys_arr = mx.array(system_tokens) - - # Prefill system tokens in chunks (matching generate_step) - step = self._prefill_step_size - while sys_arr.size > step: - model(sys_arr[:step][None], cache=mc) - mx.eval([c.state for c in mc]) - sys_arr = sys_arr[step:] - mx.clear_cache() - if sys_arr.size > 0: - model(sys_arr[None], cache=mc) - mx.eval([c.state for c in mc]) - - # Snapshot backbone cache (immutable mx.arrays, safe to reuse) - snapshot = [c.state for c in mc] - mx.eval([s for pair in snapshot for s in pair]) - - self._system_kv_snapshot = snapshot - self._system_kv_hash = system_hash - self._system_kv_token_count = system_token_count - - backbone_cache = mc - prompt_to_send = mx.array(suffix_tokens) - logger.info( - "System KV cache: stored %d-token snapshot (%.1f MB), " - "prefilling %d remaining", - system_token_count, - sum(c.nbytes for c in mc) / 1e6, - len(suffix_tokens), + # --- SpecPrefill path (with fallback to normal on failure) --- + if use_specprefill: + try: + return _run_specprefill(model, backbone_cache) + except Exception as e: + logger.error( + "SpecPrefill failed, falling back to normal MTP path: %s", + e, ) + # Discard potentially corrupted cache + backbone_cache = None + prompt_to_send = full_prompt + + # --- Normal path (MTP via mlx_lm stream_generate) --- + prompt_cache = None + if backbone_cache is not None: + # Add MTP cache on top of backbone + if hasattr(model, "make_mtp_cache"): + mtp_cache = model.make_mtp_cache() + prompt_cache = backbone_cache + mtp_cache + else: + prompt_cache = backbone_cache - # --- SpecPrefill path (with fallback to normal on failure) --- - if use_specprefill: - try: - return _run_specprefill(model, backbone_cache) - except Exception as e: - logger.error( - "SpecPrefill failed, falling back to normal MTP path: %s", - e, - ) - # Discard potentially corrupted cache - backbone_cache = None - prompt_to_send = full_prompt - - # --- Normal path (MTP via mlx_lm stream_generate) --- - prompt_cache = None - if backbone_cache is not None: - # Add MTP cache on top of backbone - if hasattr(model, "make_mtp_cache"): - mtp_cache = model.make_mtp_cache() - prompt_cache = backbone_cache + mtp_cache - else: - prompt_cache = backbone_cache + results = [] + gen_kwargs = dict( + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + prefill_step_size=self._prefill_step_size, + ) + if prompt_cache is not None: + gen_kwargs["prompt_cache"] = prompt_cache + + for resp in mlx_stream_generate( + model, + self._text_tokenizer, + prompt=prompt_to_send, + **gen_kwargs, + ): + results.append(resp) + return results + + def _run_specprefill(model, bc): + """Score tokens, sparse prefill, generate without MTP.""" + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) - results = [] - gen_kwargs = dict( - max_tokens=max_tokens, - sampler=sampler, - mtp=True, + # Create backbone cache if not already from system KV + if bc is None: + bc = make_prompt_cache(model) + + try: + # Phase 1: Score with draft model + import time + + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + specprefill_tokens, prefill_step_size=self._prefill_step_size, ) - if prompt_cache is not None: - gen_kwargs["prompt_cache"] = prompt_cache + t_score = time.monotonic() - t0 - for resp in mlx_stream_generate( - model, - self._text_tokenizer, - prompt=prompt_to_send, - **gen_kwargs, - ): - results.append(resp) - return results + # Phase 2: Select important chunks + effective_keep = specprefill_keep_pct or self._specprefill_keep_pct + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] + n_total = len(specprefill_tokens) - def _run_specprefill(model, bc): - """Score tokens, sparse prefill, generate without MTP.""" - from types import SimpleNamespace + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + model, + specprefill_tokens, + selected, + bc, + step_size=self._prefill_step_size, + position_offset=specprefill_offset, + ) + t_prefill = time.monotonic() - t0 - from ..specprefill import ( - cleanup_rope, - score_tokens, - select_chunks, - sparse_prefill, + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " + "(offset=%d, effective_keep=%.2f)", + n_total, + t_score, + n_selected, + n_total, + n_selected / n_total * 100, + t_prefill, + specprefill_offset, + effective_keep, ) - # Create backbone cache if not already from system KV - if bc is None: - bc = make_prompt_cache(model) + # Phase 4: Generate (simple autoregressive, no MTP) + eos_id = self._text_tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) - try: - # Phase 1: Score with draft model - import time - - t0 = time.monotonic() - importance = score_tokens( - self._draft_model, - specprefill_tokens, - prefill_step_size=self._prefill_step_size, - ) - t_score = time.monotonic() - t0 - - # Phase 2: Select important chunks - effective_keep = specprefill_keep_pct or self._specprefill_keep_pct - selected = select_chunks(importance, keep_pct=effective_keep) - n_selected = selected.shape[0] - n_total = len(specprefill_tokens) - - # Phase 3: Sparse prefill on target model - t0 = time.monotonic() - logits = sparse_prefill( - model, - specprefill_tokens, - selected, - bc, - step_size=self._prefill_step_size, - position_offset=specprefill_offset, - ) - t_prefill = time.monotonic() - t0 + results = [] + generated_ids = [] + prev_decoded = "" - logger.info( - "SpecPrefill: scored %d tokens in %.1fs, " - "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " - "(offset=%d, effective_keep=%.2f)", - n_total, - t_score, - n_selected, - n_total, - n_selected / n_total * 100, - t_prefill, - specprefill_offset, - effective_keep, - ) + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) - # Phase 4: Generate (simple autoregressive, no MTP) - eos_id = self._text_tokenizer.eos_token_id - y = sampler(logits[:, -1, :]) - mx.eval(y) + # Incremental text decode + decoded = self._text_tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded - results = [] - generated_ids = [] - prev_decoded = "" - - for _ in range(max_tokens): - tok_id = y.item() - generated_ids.append(tok_id) - - # Incremental text decode - decoded = self._text_tokenizer.decode(generated_ids) - new_text = decoded[len(prev_decoded) :] - prev_decoded = decoded - - is_eos = tok_id == eos_id - results.append( - SimpleNamespace( - text=new_text, - finish_reason="stop" if is_eos else None, - ) + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, ) + ) - if is_eos: - break + if is_eos: + break - # Next token - logits = model(y.reshape(1, -1), cache=bc) - y = sampler(logits[:, -1, :]) - mx.eval(y) + # Next token + logits = model(y.reshape(1, -1), cache=bc) + y = sampler(logits[:, -1, :]) + mx.eval(y) - return results + return results - finally: - cleanup_rope(model) + finally: + cleanup_rope(model) - all_resps = await asyncio.to_thread(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" From 8767fe6151cf5b38a1cbcc8b8e12b1190809d2a8 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Sat, 11 Apr 2026 00:22:24 +0200 Subject: [PATCH 18/45] scheduler: preserve prompt checkpoints in chunked prefill resume path (#221) * Fix chunked prefill for mlx-lm prompt checkpoints * fix: invoke prompt_checkpoint_callback in chunked-prefill path The upstream BatchGenerator contract requires prompt_checkpoint_callback to fire after cache finalization, before the checkpoint tail model call. The chunked-prefill monkeypatch preserved the checkpoint field but never invoked the callback, breaking the upstream checkpoint contract. Wire _lazy_extract_cache from mlx-lm and invoke the callback at the correct semantic boundary. Add regression test verifying the callback fires with the correct uid and checkpoint offset. * test: cover checkpoint tail replay on upstream refresh * style: format prompt checkpoint refresh * fix: tolerate mlx-lm Batch export drift in chunked prefill --- tests/test_batching.py | 358 +++++++++++++++++++++++++++++++++++++++++ vllm_mlx/scheduler.py | 116 +++++++++++-- 2 files changed, 463 insertions(+), 11 deletions(-) diff --git a/tests/test_batching.py b/tests/test_batching.py index 7dc050ee3..fc2eefcde 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -7,8 +7,10 @@ """ import asyncio +import importlib import pytest from unittest.mock import MagicMock +import mlx.core as mx from vllm_mlx.request import ( Request, @@ -20,8 +22,11 @@ Scheduler, SchedulerConfig, SchedulingPolicy, + _install_chunked_prefill, ) +mlx_generate = importlib.import_module("mlx_lm.generate") + class TestRequest: """Tests for Request class.""" @@ -211,6 +216,359 @@ def mock_model(self): """Create a mock model.""" return MagicMock() + def test_chunked_prefill_accepts_prompt_checkpoints(self, monkeypatch): + """Chunked prefill must match mlx-lm's 7-field prompt tuples.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + + class FakeBatchGenerator: + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3, 4, 5], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = set() + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = None + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + self.model = lambda _inputs, cache=None: None + + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + _install_chunked_prefill(batch_gen, budget=4) + + responses = batch_gen._next() + + assert responses == [] + assert batch_gen._partial is not None + assert batch_gen._partial["prompt_checkpoint"] == 3 + assert batch_gen._partial["processed"] == 2 + + def test_chunked_prefill_invokes_checkpoint_callback(self, monkeypatch): + """prompt_checkpoint_callback must fire after finalization.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + def extract(self, idx): + return self + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + generation_tokens = 0 + + callback_payloads = [] + + from collections import namedtuple + + _Response = namedtuple( + "Response", ["uid", "token", "logprobs", "finish_reason", "cache"] + ) + + class FakeBatchGenerator: + Response = _Response + + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = set() + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = ( + lambda entries: callback_payloads.extend(entries) + ) + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + self.model = lambda _inputs, cache=None: None + + def _step(self, inputs, cache, samplers, logits_processors, tokens): + return mx.array([99]), mx.array([-1.0]) + + def _generation_step(self): + if self.active_batch is not None: + self.active_batch = None + return [] + + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + batch_gen.stop_tokens = {99} + _install_chunked_prefill(batch_gen, budget=1) + + # First _next: starts partial prefill (processes 1 token) + batch_gen._next() + assert batch_gen._partial is not None + + # Second _next: finishes prefill, fires checkpoint callback, + # then runs generation step which completes (stop token). + batch_gen._next() + + assert len(callback_payloads) == 1 + uid, checkpoint, _cache_gen = callback_payloads[0] + assert uid == 7 + assert checkpoint == 1 + + def test_chunked_prefill_replays_checkpoint_tail_before_step(self, monkeypatch): + """checkpoint tails >1 must be replayed after finalize before _step.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + def extract(self, idx): + return self + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + generation_tokens = 0 + + callback_payloads = [] + model_calls = [] + step_inputs = [] + + from collections import namedtuple + + _Response = namedtuple( + "Response", ["uid", "token", "logprobs", "finish_reason", "cache"] + ) + + class FakeBatchGenerator: + Response = _Response + + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3, 4, 5], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = {99} + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = ( + lambda entries: callback_payloads.extend(entries) + ) + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + + def model(self, inputs, cache=None): + model_calls.append(inputs.tolist()) + + def _step(self, inputs, cache, samplers, logits_processors, tokens): + step_inputs.append(inputs.tolist()) + return mx.array([99]), mx.array([-1.0]) + + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + _install_chunked_prefill(batch_gen, budget=2) + + # First _next: process the first chunk and leave a 3-token checkpoint tail. + batch_gen._next() + assert batch_gen._partial is not None + assert batch_gen._partial["prompt_checkpoint"] == 3 + + # Second _next: finalize, fire callback, replay the checkpoint tail, step. + batch_gen._next() + + assert model_calls == [[[1, 2]], [[3, 4]]] + assert step_inputs[0] == [[3, 4, 5]] + assert len(callback_payloads) == 1 + uid, checkpoint, _cache_gen = callback_payloads[0] + assert uid == 7 + assert checkpoint == 3 + + def test_chunked_prefill_works_without_private_mlx_generate_exports( + self, monkeypatch + ): + """Chunked prefill should tolerate missing private mlx_lm.generate exports.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + def extract(self, idx): + return self + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + generation_tokens = 0 + + from collections import namedtuple + + _Response = namedtuple( + "Response", ["uid", "token", "logprobs", "finish_reason", "cache"] + ) + + class FakeBatchGenerator: + Response = _Response + + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = {99} + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = None + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + self.model = lambda _inputs, cache=None: None + + def _step(self, inputs, cache, samplers, logits_processors, tokens): + return mx.array([99]), mx.array([-1.0]) + + def _generation_step(self): + if self.active_batch is not None: + self.active_batch = None + return [] + + monkeypatch.delattr(mlx_generate, "Batch", raising=False) + monkeypatch.delattr(mlx_generate, "_lazy_extract_cache", raising=False) + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + _install_chunked_prefill(batch_gen, budget=1) + + batch_gen._next() + assert batch_gen._partial is not None + batch_gen._next() + assert batch_gen.active_batch is None + def test_scheduler_creation(self, mock_model, mock_tokenizer): """Test scheduler creation.""" scheduler = Scheduler( diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index ec4684049..7b71ada22 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -148,13 +148,66 @@ def _install_chunked_prefill( import time as _time from mlx_lm.generate import ( - Batch, _left_pad_prompts, _make_cache, _merge_caches, _right_pad_prompts, ) + try: + from mlx_lm.generate import _lazy_extract_cache + except ImportError: + + def _lazy_extract_cache(cache, idx): + return (c.extract(idx) for c in cache) + + try: + from mlx_lm.generate import Batch as _batch_cls + except ImportError: + + @dataclass + class _batch_cls: + uids: List[int] + y: Any + logprobs: List[Any] + max_tokens: List[int] + num_tokens: List[int] + cache: List[Any] + samplers: List[Any] + logits_processors: List[Any] + tokens: List[Any] + + def __len__(self): + return len(self.uids) + + def filter(self, keep_idx: List[int]): + self.uids = [self.uids[k] for k in keep_idx] + self.logprobs = [self.logprobs[k] for k in keep_idx] + self.max_tokens = [self.max_tokens[k] for k in keep_idx] + self.num_tokens = [self.num_tokens[k] for k in keep_idx] + self.samplers = [self.samplers[k] for k in keep_idx] + self.logits_processors = [self.logits_processors[k] for k in keep_idx] + self.tokens = [self.tokens[k] for k in keep_idx] + keep_idx_mx = mx.array(keep_idx, mx.int32) + self.y = self.y[keep_idx_mx] + for c in self.cache: + c.filter(keep_idx_mx) + + def extend(self, other): + self.uids.extend(other.uids) + self.y = mx.concatenate([self.y, other.y]) + self.logprobs.extend(other.logprobs) + self.num_tokens.extend(other.num_tokens) + self.max_tokens.extend(other.max_tokens) + self.samplers.extend(other.samplers) + self.logits_processors.extend(other.logits_processors) + self.tokens.extend(other.tokens) + for c, o in zip(self.cache, other.cache): + c.extend(o) + + def extract_cache(self, idx): + return [c.extract(idx) for c in self.cache] + # Keep references to originals _orig_next = batch_gen._next _orig_remove = batch_gen.remove @@ -268,8 +321,13 @@ def _chunked_next(self=batch_gen): # noqa: C901 inputs = partial["inputs"] prompt_cache = partial["cache"] remaining = inputs.shape[1] + prompt_checkpoint = max(1, int(partial.get("prompt_checkpoint", 1))) - n_to_process = min(budget, remaining - 1) if remaining > 1 else 0 + n_to_process = ( + min(budget, remaining - prompt_checkpoint) + if remaining > prompt_checkpoint + else 0 + ) if n_to_process > 0: self.model(mx.contiguous(inputs[:, :n_to_process]), cache=prompt_cache) @@ -294,8 +352,8 @@ def _chunked_next(self=batch_gen): # noqa: C901 if partial.get("is_cached"): mx.clear_cache() - # Check if prefill is done (only 1 token left or 0) - if inputs.shape[1] <= 1: + # Check if prefill is done once only the checkpoint tail remains. + if inputs.shape[1] <= prompt_checkpoint: # Finalize if partial.get("is_cached"): mx.eval([c.state for c in prompt_cache]) @@ -303,8 +361,31 @@ def _chunked_next(self=batch_gen): # noqa: C901 for c in prompt_cache: c.finalize() + + if self.prompt_checkpoint_callback is not None: + self.prompt_checkpoint_callback( + [ + ( + uid, + prompt_checkpoint, + _lazy_extract_cache(prompt_cache, i), + ) + for i, uid in enumerate(partial["uids"]) + ] + ) mx.clear_cache() + # Mirror upstream BatchGenerator semantics: after finalize() and + # the checkpoint callback, replay the remaining checkpoint tail + # except for the final token, which _step() consumes. + if prompt_checkpoint > 1: + self.model( + mx.contiguous(inputs[:, : prompt_checkpoint - 1]), + cache=prompt_cache, + ) + mx.eval([c.state for c in prompt_cache]) + mx.clear_cache() + y, logprobs = self._step( inputs, prompt_cache, @@ -314,10 +395,10 @@ def _chunked_next(self=batch_gen): # noqa: C901 ) mx.async_eval(y, logprobs) - new_batch = Batch( + new_batch = _batch_cls( list(partial["uids"]), y, - logprobs, + list(logprobs), list(partial["max_tokens"]), [0] * len(partial["uids"]), prompt_cache, @@ -393,12 +474,20 @@ def _chunked_next(self=batch_gen): # noqa: C901 caches, samplers, logits_processors, - _prompt_checkpoints, + prompt_checkpoints, ) = zip(*batch_prompts) lengths = [len(p) for p in inputs_raw] max_length = max(lengths) padding = [max_length - ln for ln in lengths] tokens = [mx.array(inp) for inp in inputs_raw] + # Match mlx-lm's prompt_checkpoint contract: positive values + # name the checkpoint token position in the prompt, while + # non-positive values already encode an offset from the end. + checkpoint_offsets = [ + (ln - pc if pc > 0 else -pc) + for ln, pc in zip(lengths, prompt_checkpoints) + ] + prompt_checkpoint = max(1, max(checkpoint_offsets)) is_cached = not all(c[0].empty() for c in caches) self._stats.prompt_tokens += sum(lengths) @@ -409,12 +498,14 @@ def _chunked_next(self=batch_gen): # noqa: C901 self.model, padding, self.max_kv_size ) else: - last_inputs = mx.array([p[-1:] for p in inputs_raw]) + last_inputs = mx.array( + [p[-prompt_checkpoint:] for p in inputs_raw] + ) padded = _right_pad_prompts(inputs_raw, max_length=max_length) prompt_cache = _merge_caches(caches) for c in prompt_cache: c.prepare( - lengths=[ln - 1 for ln in lengths], + lengths=[ln - prompt_checkpoint for ln in lengths], right_padding=padding, ) @@ -437,9 +528,11 @@ def _chunked_next(self=batch_gen): # noqa: C901 _pb = getattr(_req0, "prefix_boundary", 0) if _req0 else 0 _cached = getattr(_req0, "cached_tokens", 0) if _req0 else 0 _adjusted_pb = _pb - _cached - if 0 < _adjusted_pb < padded.shape[1]: + if 0 < _adjusted_pb < padded.shape[1] - prompt_checkpoint + 1: _first_chunk = _adjusted_pb - n_to_process = min(_first_chunk, padded.shape[1] - 1) + n_to_process = min( + _first_chunk, padded.shape[1] - prompt_checkpoint + ) if n_to_process > 0: self.model( mx.contiguous(padded[:, :n_to_process]), @@ -458,6 +551,7 @@ def _chunked_next(self=batch_gen): # noqa: C901 "max_tokens": list(max_tokens_list), "samplers": list(samplers), "logits_processors": list(logits_processors), + "prompt_checkpoint": prompt_checkpoint, "processed": n_to_process, "total": max_length, "is_cached": is_cached, From fc25c8c54dd051795d213bd04325910f0b3801ba Mon Sep 17 00:00:00 2001 From: Manus McAuliffe Date: Fri, 10 Apr 2026 23:43:10 +0100 Subject: [PATCH 19/45] fix: include tokens in GenerationOutput for BatchedEngine when model is mllm --- vllm_mlx/engine/batched.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 49fae5439..d49eb7c1a 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -468,6 +468,7 @@ async def generate( return GenerationOutput( text=clean_output_text(output.output_text), + tokens=output.output_token_ids, prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, finish_reason=output.finish_reason, From b274f276c40be1aa7973f61ce61583efb1aae6d2 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 10 Apr 2026 19:57:26 -0500 Subject: [PATCH 20/45] fix(tests): apply black to batched engine generate test (#279) --- tests/test_batched_engine.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_batched_engine.py b/tests/test_batched_engine.py index 65abbf09e..73a7e8ffa 100644 --- a/tests/test_batched_engine.py +++ b/tests/test_batched_engine.py @@ -31,7 +31,9 @@ def _make_mock_request_output( """Build a mock RequestOutput (as returned by AsyncEngineCore).""" mock = MagicMock() mock.output_text = output_text - mock.output_token_ids = output_token_ids if output_token_ids is not None else [3681, 374, 279] + mock.output_token_ids = ( + output_token_ids if output_token_ids is not None else [3681, 374, 279] + ) mock.prompt_tokens = prompt_tokens mock.completion_tokens = completion_tokens mock.finish_reason = finish_reason @@ -48,7 +50,9 @@ async def test_tokens_field_is_populated(self): mock_engine.generate = AsyncMock(return_value=mock_output) engine._engine = mock_engine - result = await engine.generate(prompt="What is the capital of France?", max_tokens=10) + result = await engine.generate( + prompt="What is the capital of France?", max_tokens=10 + ) assert result.tokens == token_ids From 3c00c44c1e117b46bcade0afb18581b29e6c0a5e Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 10 Apr 2026 21:36:57 -0500 Subject: [PATCH 21/45] fix(reasoning): detect split think tags across deltas --- vllm_mlx/reasoning/think_parser.py | 64 +++++++++++++++++------------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/vllm_mlx/reasoning/think_parser.py b/vllm_mlx/reasoning/think_parser.py index dad085f15..a2e9cb727 100644 --- a/vllm_mlx/reasoning/think_parser.py +++ b/vllm_mlx/reasoning/think_parser.py @@ -11,9 +11,10 @@ 3. No tags: pure content Performance: The streaming parser uses a simple state machine to track the -current phase (pre-think / thinking / content). Each token is classified in -O(1) by checking only the delta text — the accumulated output is never -rescanned. This keeps per-token overhead constant regardless of output length. +current phase (pre-think / thinking / content). Tag completion is detected +against the accumulated text for correctness when `` / `` are +split across delta boundaries, but phase tracking still avoids the old +whole-output rescanning behavior. """ from abc import abstractmethod @@ -36,8 +37,8 @@ class BaseThinkingReasoningParser(ReasoningParser): pre_think -> thinking -> content - Transitions happen when start/end tokens are detected in the delta text. - No accumulated text scanning is performed — each token is O(1). + Transitions are tracked by parser state. Accumulated text is consulted only + to detect when a start/end tag has completed across delta boundaries. """ @property @@ -109,12 +110,8 @@ def extract_reasoning_streaming( Instead of rescanning the full accumulated text on every token, this method tracks the current phase (pre_think / thinking / content) and - only inspects the delta for tag transitions. This makes each call O(1) - regardless of how much text has been generated. - - The method signature is kept compatible with the base class — previous_text - and current_text are accepted but not used for phase detection (they remain - available for subclasses that need them). + only consults accumulated text to detect completed start/end tags that + were split across delta boundaries. Handles three scenarios: 1. Explicit ... in model output @@ -122,8 +119,8 @@ def extract_reasoning_streaming( 3. No tags at all (pure content after first token with no reasoning) Args: - previous_text: Text accumulated before this delta (unused by state machine). - current_text: Text including this delta (unused by state machine). + previous_text: Text accumulated before this delta. + current_text: Text including this delta. delta_text: Just the new text in this chunk. Returns: @@ -136,34 +133,41 @@ def extract_reasoning_streaming( end_tok = self.end_token # ── Phase: pre_think ────────────────────────────────────── - # Haven't seen any tags yet. Could be: + # Haven't seen a completed tag yet. Could be: # - About to see (explicit reasoning) # - Already inside implicit reasoning (think was in prompt) # - No reasoning at all (pure content model) if self._phase == "pre_think": - # Check for start tag in this delta - if start_tok in delta_text: + if start_tok in current_text: self._phase = "thinking" - idx = delta_text.find(start_tok) + len(start_tok) - after = delta_text[idx:] - # Edge case: both tags in same delta + idx = delta_text.find(start_tok) + after = delta_text[idx + len(start_tok) :] if idx >= 0 else delta_text + if end_tok in after: self._phase = "content" eidx = after.find(end_tok) reasoning = after[:eidx] - content = after[eidx + len(end_tok):] + content = after[eidx + len(end_tok) :] + if not reasoning and not content: + return None return DeltaMessage( reasoning=reasoning or None, content=content or None, ) return DeltaMessage(reasoning=after) if after else None - # Check for end tag (implicit mode — think was in prompt) - if end_tok in delta_text: + # Implicit mode: completed without an explicit . + if end_tok in current_text: self._phase = "content" idx = delta_text.find(end_tok) - reasoning = delta_text[:idx] - content = delta_text[idx + len(end_tok):] + if idx >= 0: + reasoning = delta_text[:idx] + content = delta_text[idx + len(end_tok) :] + else: + reasoning = None + content = delta_text + if not reasoning and not content: + return None return DeltaMessage( reasoning=reasoning or None, content=content or None, @@ -178,11 +182,17 @@ def extract_reasoning_streaming( # ── Phase: thinking ─────────────────────────────────────── # Inside a reasoning block, waiting for end tag. if self._phase == "thinking": - if end_tok in delta_text: + if end_tok in current_text and end_tok not in previous_text: self._phase = "content" idx = delta_text.find(end_tok) - reasoning = delta_text[:idx] - content = delta_text[idx + len(end_tok):] + if idx >= 0: + reasoning = delta_text[:idx] + content = delta_text[idx + len(end_tok) :] + else: + reasoning = delta_text + content = None + if not reasoning and not content: + return None return DeltaMessage( reasoning=reasoning or None, content=content or None, From 660552e7c7d0d571461d150766ae1a8132c1f0dc Mon Sep 17 00:00:00 2001 From: Michael Ledin Date: Sat, 11 Apr 2026 06:23:23 +0300 Subject: [PATCH 22/45] integrate tool call parser into reasoning parser streaming path (#253) * fix(server): integrate tool call parser into reasoning parser streaming path * use _model_name instead of request.model in reasoning tool chunk --------- Co-authored-by: Wayner Barrios --- tests/test_server.py | 222 ++++++++++++++++++++++++++++++++++++++++++- vllm_mlx/server.py | 44 +++++++++ 2 files changed, 265 insertions(+), 1 deletion(-) diff --git a/tests/test_server.py b/tests/test_server.py index 9fb86a3e5..ad8e0a9b9 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the OpenAI-compatible API server.""" +import json import platform import sys @@ -304,7 +305,7 @@ def test_rate_limiter_enforces_limit(self): # First 3 requests should be allowed for i in range(3): allowed, retry_after = limiter.is_allowed("client1") - assert allowed is True, f"Request {i+1} should be allowed" + assert allowed is True, f"Request {i + 1} should be allowed" assert retry_after == 0 # 4th request should be blocked @@ -677,6 +678,225 @@ def test_rate_limiter_window_cleanup(self): assert allowed is True +class TestStreamChatCompletion: + """Tests for streaming chat completion behavior.""" + + @pytest.mark.asyncio + async def test_reasoning_stream_emits_structured_tool_calls(self, monkeypatch): + """Tool markup after should emit tool_calls chunks.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.reasoning import DeltaMessage + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput(text="", new_text="reasoning", finished=False), + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput( + text="", new_text='{"name":"search"}', finished=False + ), + GenerationOutput( + text="", + new_text="", + finished=True, + finish_reason="stop", + prompt_tokens=7, + completion_tokens=3, + ), + ] + for chunk in chunks: + yield chunk + + class FakeReasoningParser: + def reset_state(self): + self._in_reasoning = False + + def extract_reasoning_streaming( + self, previous_text, current_text, delta_text + ): + if delta_text == "": + self._in_reasoning = True + return None + if delta_text == "": + self._in_reasoning = False + return None + if self._in_reasoning: + return DeltaMessage(reasoning=delta_text) + return DeltaMessage(content=delta_text) + + class FakeToolParser: + def reset(self): + pass + + def extract_tool_calls_streaming( + self, previous_text, current_text, delta_text + ): + if "" in current_text: + return { + "tool_calls": [ + { + "index": 0, + "id": "call_123", + "type": "function", + "function": { + "name": "search", + "arguments": '{"q":"weather"}', + }, + } + ] + } + return None + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", FakeReasoningParser()) + monkeypatch.setattr(server, "_enable_auto_tool_choice", True) + monkeypatch.setattr(server, "_tool_call_parser", "fake") + monkeypatch.setattr(server, "_tool_parser_instance", FakeToolParser()) + + request = ChatCompletionRequest( + model="request-model", + messages=[Message(role="user", content="hi")], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + + tool_payloads = [ + payload + for payload in payloads + if payload["choices"] and payload["choices"][0]["delta"].get("tool_calls") + ] + + assert payloads[0]["choices"][0]["delta"]["role"] == "assistant" + assert payloads[1]["choices"][0]["delta"]["reasoning"] == "reasoning" + assert len(tool_payloads) == 1 + assert ( + tool_payloads[0]["choices"][0]["delta"]["tool_calls"][0]["function"]["name"] + == "search" + ) + assert tool_payloads[0]["choices"][0]["finish_reason"] == "tool_calls" + assert tool_payloads[0]["usage"] == { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + } + + @pytest.mark.asyncio + async def test_reasoning_stream_skips_tool_parser_until_markup_appears( + self, monkeypatch + ): + """Plain post-reasoning content should stream normally on the fast path.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.reasoning import DeltaMessage + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput(text="", new_text="reasoning", finished=False), + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput( + text="", + new_text="final answer", + finished=True, + finish_reason="stop", + ), + ] + for chunk in chunks: + yield chunk + + class FakeReasoningParser: + def reset_state(self): + self._in_reasoning = False + + def extract_reasoning_streaming( + self, previous_text, current_text, delta_text + ): + if delta_text == "": + self._in_reasoning = True + return None + if delta_text == "": + self._in_reasoning = False + return None + if self._in_reasoning: + return DeltaMessage(reasoning=delta_text) + return DeltaMessage(content=delta_text) + + class TrackingToolParser: + def __init__(self): + self.calls = [] + + def reset(self): + self.calls.clear() + + def extract_tool_calls_streaming( + self, previous_text, current_text, delta_text + ): + self.calls.append((previous_text, current_text, delta_text)) + return {"content": delta_text} + + tool_parser = TrackingToolParser() + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", FakeReasoningParser()) + monkeypatch.setattr(server, "_enable_auto_tool_choice", True) + monkeypatch.setattr(server, "_tool_call_parser", "fake") + monkeypatch.setattr(server, "_tool_parser_instance", tool_parser) + + request = ChatCompletionRequest( + model="request-model", + messages=[Message(role="user", content="hi")], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + + assert tool_parser.calls == [] + assert payloads[1]["choices"][0]["delta"]["reasoning"] == "reasoning" + assert payloads[2]["choices"][0]["delta"]["content"] == "final answer" + assert payloads[2]["choices"][0]["finish_reason"] == "stop" + + # ============================================================================= # Integration Tests (require running server) # ============================================================================= diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index ebaefe8d8..adb63d4ea 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -2195,6 +2195,50 @@ async def stream_chat_completion( # Skip this chunk (e.g., token itself) continue + # Run tool parser on content delta (post-reasoning text only). + # The reasoning parser suppresses tokens, so tool calls + # only appear in delta_msg.content after is emitted. + if tool_parser and delta_msg.content: + content_delta = delta_msg.content + if not tool_markup_possible and "<" not in content_delta: + tool_accumulated_text += content_delta + else: + if not tool_markup_possible: + tool_markup_possible = True + + tool_previous = tool_accumulated_text + tool_accumulated_text += content_delta + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, content_delta + ) + + if tool_result is None: + # Inside tool markup - suppress content + continue + + if "tool_calls" in tool_result: + tool_calls_detected = True + tool_chunk = ChatCompletionChunk( + id=response_id, + model=_model_name, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=tool_result["tool_calls"] + ), + finish_reason=( + "tool_calls" if output.finished else None + ), + ) + ], + usage=get_usage(output) if output.finished else None, + ) + yield f"data: {tool_chunk.model_dump_json()}\n\n" + continue + + # Update content with what the tool parser returned + delta_msg.content = tool_result.get("content", "") or None + chunk = ChatCompletionChunk( id=response_id, model=_model_name, From 7b1bfd49312ec7bbcd682bb3744f45b6a6144762 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Fri, 10 Apr 2026 22:29:19 -0500 Subject: [PATCH 23/45] honor tool_choice=none by stripping tools and suppressing parsing (#173) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When tool_choice='none', models should never return tool calls. Two fixes: 1. Strip tools from chat template context — prevents templates from activating tool-call token generation. 2. Suppress tool call parsing — _parse_tool_calls_with_parser() returns early with no tools, streaming parser skips initialization. Applied across all server paths: chat completions (streaming + non-streaming), Anthropic adapter (streaming + non-streaming). Fixes #162 --- tests/test_tool_choice_none.py | 65 ++++++++++++++++++++++++++++++++++ vllm_mlx/server.py | 17 ++++++--- 2 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 tests/test_tool_choice_none.py diff --git a/tests/test_tool_choice_none.py b/tests/test_tool_choice_none.py new file mode 100644 index 000000000..d4af223fe --- /dev/null +++ b/tests/test_tool_choice_none.py @@ -0,0 +1,65 @@ +"""Tests for tool_choice='none' handling.""" + + +class TestToolChoiceNoneParserSuppression: + """Verify tool call parsing is suppressed when tool_choice='none'.""" + + def test_parse_tool_calls_skipped_when_tool_choice_none(self): + """_parse_tool_calls_with_parser should return no tools when tool_choice='none'.""" + from vllm_mlx.api.models import ChatCompletionRequest + from vllm_mlx.server import _parse_tool_calls_with_parser + + # Text that looks like a tool call + text = '{"name": "get_weather", "arguments": {"city": "London"}}' + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "Hello"}], + tool_choice="none", + ) + cleaned, tool_calls = _parse_tool_calls_with_parser(text, request) + # With tool_choice="none", parser should be suppressed + assert tool_calls is None + assert cleaned == text # text returned unchanged + + def test_parse_tool_calls_works_when_tool_choice_auto(self): + """Tool parsing should work normally when tool_choice is not 'none'.""" + from vllm_mlx.api.models import ChatCompletionRequest + from vllm_mlx.server import _parse_tool_calls_with_parser + + text = "Hello, how can I help?" + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "Hello"}], + tool_choice="auto", + ) + cleaned, tool_calls = _parse_tool_calls_with_parser(text, request) + # No tool markup in text, so no tools found — but parser was NOT skipped + assert tool_calls is None + + def test_parse_tool_calls_works_when_tool_choice_absent(self): + """Tool parsing should work when tool_choice is not set.""" + from vllm_mlx.api.models import ChatCompletionRequest + from vllm_mlx.server import _parse_tool_calls_with_parser + + text = "Hello, how can I help?" + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "Hello"}], + ) + cleaned, tool_calls = _parse_tool_calls_with_parser(text, request) + assert tool_calls is None + + def test_tool_markup_ignored_when_tool_choice_none(self): + """Even Qwen bracket-style tool calls should be suppressed.""" + from vllm_mlx.api.models import ChatCompletionRequest + from vllm_mlx.server import _parse_tool_calls_with_parser + + text = '[Calling tool: get_weather({"city": "London"})]' + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "weather?"}], + tool_choice="none", + ) + cleaned, tool_calls = _parse_tool_calls_with_parser(text, request) + assert tool_calls is None + assert cleaned == text diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index adb63d4ea..1c0e06d17 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -379,6 +379,14 @@ def _parse_tool_calls_with_parser( request_dict = request.model_dump() if request else None + # tool_choice="none" means never return tool calls — skip all parsing + if request is not None: + tool_choice = getattr(request, "tool_choice", None) + if tool_choice is None and request_dict: + tool_choice = request_dict.get("tool_choice") + if tool_choice == "none": + return output_text, None + # If auto tool choice is not enabled, use the generic parser if not _enable_auto_tool_choice or not _tool_call_parser: return parse_tool_calls(output_text, request_dict) @@ -1430,7 +1438,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct # Add tools if provided - if request.tools: + if request.tools and request.tool_choice != "none": chat_kwargs["tools"] = convert_tools_for_template(request.tools) if request.stream: @@ -1630,7 +1638,7 @@ async def create_anthropic_message( "top_p": openai_request.top_p, } - if openai_request.tools: + if openai_request.tools and openai_request.tool_choice != "none": chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) start_time = time.perf_counter() @@ -1868,7 +1876,7 @@ async def _stream_anthropic_messages( "top_p": openai_request.top_p, } - if openai_request.tools: + if openai_request.tools and openai_request.tool_choice != "none": chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) # Emit message_start @@ -2156,7 +2164,8 @@ async def stream_chat_completion( tool_accumulated_text = "" tool_calls_detected = False tool_markup_possible = False # Fast path: skip parsing until '<' seen - if _enable_auto_tool_choice and _tool_call_parser: + tool_choice = getattr(request, "tool_choice", None) + if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none": # Initialize parser if needed (same as _parse_tool_calls_with_parser) if _tool_parser_instance is None: try: From b9f2a5f349e90a9eb120c6925ae3d44b4e1583de Mon Sep 17 00:00:00 2001 From: Jan Hilgard <89418784+janhilgard@users.noreply.github.com> Date: Sat, 11 Apr 2026 06:12:26 +0200 Subject: [PATCH 24/45] strip billing header from Anthropic system prompt for prefix cache (#277) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Claude Code injects `x-anthropic-billing-header: cc_version=...; cch=HASH;` into the system prompt. The `cch=` hash changes with every request, causing token sequences to diverge at position ~40 and completely defeating prefix cache reuse across turn boundaries. Strip this header before tokenization so consecutive requests from the same conversation share 99%+ of their token prefix. Result: 50s → 3.65s per request (13.7x speedup) on Gemma 4 26B-A4B with 60K-token prompts. Co-authored-by: Claude Opus 4.6 --- vllm_mlx/api/anthropic_adapter.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm_mlx/api/anthropic_adapter.py b/vllm_mlx/api/anthropic_adapter.py index dbb94200f..62c6757b5 100644 --- a/vllm_mlx/api/anthropic_adapter.py +++ b/vllm_mlx/api/anthropic_adapter.py @@ -9,6 +9,7 @@ """ import json +import re import uuid from .anthropic_models import ( @@ -60,6 +61,10 @@ def anthropic_to_openai(request: AnthropicRequest) -> ChatCompletionRequest: system_text = "\n".join(parts) else: system_text = str(request.system) + # Strip per-request billing/tracking headers injected by some + # clients (e.g. Claude Code). These contain a per-request hash + # that prevents prefix-cache reuse across turn boundaries. + system_text = re.sub(r"x-anthropic-billing-header:[^\n]*\n?", "", system_text) messages.append(Message(role="system", content=system_text)) # Convert each message From f61d34e120f4eb6aab162403ac3b09abf5fe099f Mon Sep 17 00:00:00 2001 From: Jan Hilgard Date: Sat, 11 Apr 2026 09:35:07 +0200 Subject: [PATCH 25/45] =?UTF-8?q?feat:=20backport=20production=20features?= =?UTF-8?q?=20=E2=80=94=20MTP,=20tool=20parsers,=20sampling,=20prefill?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New files: - patches/qwen3_5_mllm.py: BatchKVCache offset fix for Qwen3.5 - patches/qwen3_5_mtp.py: Runtime MTP injection for Qwen3.5 - tool_parsers/minimax_tool_parser.py: MiniMax-M2 tool parser - scripts/add_mtp_weights_qwen35.py: Extract MTP weights from BF16 Key changes: - mllm_batch_generator: chunked prefill, mid-batch extend, MTP hooks, patch registration, repetition penalty, prefill abort, think-suffix stripping for prefix cache - mllm_scheduler: request status, cache config, prefill abort - server: enable_thinking, tool_choice=none, tool argument coercion - engines: MTP injection, enable_thinking, gpu_memory_utilization - memory_cache: block LCP for hybrid models (SSM can't be rewound) Prefix cache fix: enable_thinking=True adds \n to generation prompt, breaking PREFIX match across conversation turns. Strip these tokens from cache keys in both store and fetch paths so stored entries match as clean prefixes. Tested: 3.12s → 0.39s (8x) for 1400-token prompts on Qwen3.5-122B hybrid model. Co-Authored-By: Claude Opus 4.6 --- scripts/add_mtp_weights_qwen35.py | 470 ++++++++ vllm_mlx/api/models.py | 6 + vllm_mlx/api/utils.py | 6 +- vllm_mlx/cli.py | 22 +- vllm_mlx/engine/batched.py | 140 ++- vllm_mlx/engine/simple.py | 15 +- vllm_mlx/engine_core.py | 17 +- vllm_mlx/memory_cache.py | 10 +- vllm_mlx/mllm_batch_generator.py | 1012 +++++++++++++++++- vllm_mlx/mllm_scheduler.py | 268 ++++- vllm_mlx/models/mllm.py | 12 +- vllm_mlx/multimodal_processor.py | 2 +- vllm_mlx/patches/qwen3_5_mllm.py | 120 +++ vllm_mlx/patches/qwen3_5_mtp.py | 399 +++++++ vllm_mlx/scheduler.py | 46 +- vllm_mlx/server.py | 574 ++++++---- vllm_mlx/text_model_from_vlm.py | 22 +- vllm_mlx/tool_parsers/__init__.py | 9 +- vllm_mlx/tool_parsers/auto_tool_parser.py | 10 +- vllm_mlx/tool_parsers/minimax_tool_parser.py | 172 +++ vllm_mlx/utils/tokenizer.py | 119 +- 21 files changed, 3094 insertions(+), 357 deletions(-) create mode 100644 scripts/add_mtp_weights_qwen35.py create mode 100644 vllm_mlx/patches/qwen3_5_mllm.py create mode 100644 vllm_mlx/patches/qwen3_5_mtp.py create mode 100644 vllm_mlx/tool_parsers/minimax_tool_parser.py diff --git a/scripts/add_mtp_weights_qwen35.py b/scripts/add_mtp_weights_qwen35.py new file mode 100644 index 000000000..1044dc894 --- /dev/null +++ b/scripts/add_mtp_weights_qwen35.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 +""" +Add MTP (Multi-Token Prediction) weights to an existing MLX Qwen3.5 model. + +This script: +1. Fetches the safetensors index from the original BF16 HuggingFace model +2. Identifies shards containing MTP weights (mtp.* keys) +3. Downloads only those shards via curl -C - +4. Extracts MTP weights +5. For MoE models: stacks expert weights (256×) into switch_mlp format +6. Applies norm shift (HF weight → MLX weight+1.0) for RMSNorm keys +7. Quantizes to match the MLX model's quantization scheme +8. Saves as mtp/weights.safetensors (subdirectory avoids mlx_vlm glob) + +Supports both: +- MoE models (Qwen3.5-122B-A10B, 35B-A3B): 256 experts, sparse MTP attention +- Dense models (Qwen3.5-27B): full MTP with k/v projections and norms + +Usage: + python add_mtp_weights_qwen35.py --mlx-model-path PATH --source-model MODEL + +Requirements: + pip install mlx +""" + +import argparse +import json +import subprocess +import sys +import tempfile +from pathlib import Path + +# Known model configurations +MODEL_CONFIGS = { + "Qwen/Qwen3.5-122B-A10B": { + "num_experts": 256, + "hidden_size": 3072, + "is_moe": True, + }, + "Qwen/Qwen3.5-35B-A3B": { + "num_experts": 256, + "hidden_size": 2048, + "is_moe": True, + }, + "Qwen/Qwen3.5-27B": { + "num_experts": 0, + "hidden_size": 5120, + "is_moe": False, + }, +} + + +def find_snapshot_dir(model_path: str) -> Path: + """Find the latest snapshot directory in HF cache structure.""" + snapshots_dir = Path(model_path) / "snapshots" + if not snapshots_dir.exists(): + if (Path(model_path) / "config.json").exists(): + return Path(model_path) + raise FileNotFoundError(f"No snapshots found in {model_path}") + snapshots = sorted(snapshots_dir.iterdir(), key=lambda p: p.stat().st_mtime) + if not snapshots: + raise FileNotFoundError(f"No snapshots in {snapshots_dir}") + return snapshots[-1] + + +def fetch_shard_index(source_model: str, download_dir: Path) -> dict: + """Fetch model.safetensors.index.json from HuggingFace.""" + index_url = f"https://huggingface.co/{source_model}/resolve/main/model.safetensors.index.json" + index_path = download_dir / "source_index.json" + + print(f"Fetching shard index from {source_model}...") + result = subprocess.run( + ["curl", "-L", "-C", "-", "-o", str(index_path), index_url], + check=False, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to fetch index: return code {result.returncode}") + + with open(index_path) as f: + return json.load(f) + + +def identify_mtp_shards(index: dict) -> tuple[dict[str, str], set[str]]: + """Identify which shards contain MTP weights. + + Returns: + Tuple of (mtp_key_to_shard mapping, set of shard filenames to download) + """ + weight_map = index.get("weight_map", {}) + mtp_keys = {} + shards_needed = set() + + for key, shard in weight_map.items(): + if key.startswith("mtp."): + mtp_keys[key] = shard + shards_needed.add(shard) + + return mtp_keys, shards_needed + + +def download_shards( + shards: set[str], source_model: str, download_dir: Path +) -> dict[str, Path]: + """Download required shards using curl with resume support.""" + shard_paths = {} + for shard_name in sorted(shards): + shard_url = f"https://huggingface.co/{source_model}/resolve/main/{shard_name}" + shard_path = download_dir / shard_name + + if shard_path.exists(): + size_gb = shard_path.stat().st_size / 1e9 + print(f" {shard_name}: exists ({size_gb:.2f} GB)") + shard_paths[shard_name] = shard_path + continue + + print(f" Downloading {shard_name}...") + result = subprocess.run( + ["curl", "-L", "-C", "-", "-o", str(shard_path), shard_url], + check=False, + ) + if result.returncode != 0: + raise RuntimeError( + f"Download failed for {shard_name}: code {result.returncode}" + ) + + size_gb = shard_path.stat().st_size / 1e9 + print(f" {shard_name}: {size_gb:.2f} GB") + shard_paths[shard_name] = shard_path + + return shard_paths + + +def extract_and_quantize_mtp_weights( + mtp_keys: dict[str, str], + shard_paths: dict[str, Path], + snapshot_dir: Path, + is_moe: bool, + num_experts: int, + no_quantize: bool = False, +): + """Extract MTP weights from BF16 shards, optionally quantize, and save.""" + import mlx.core as mx + + mx.set_default_device(mx.cpu) + + # Read MLX model's quantization config + config_path = snapshot_dir / "config.json" + with open(config_path) as f: + config = json.load(f) + + text_config = config.get("text_config", config) + quant_config = text_config.get("quantization", config.get("quantization", {})) + bits = quant_config.get("bits", 4) + group_size = quant_config.get("group_size", 64) + if no_quantize: + print("MTP weights will be saved in BF16 (no quantization)") + else: + print(f"Target quantization: {bits}-bit, group_size={group_size}") + + # Group MTP keys by shard for efficient I/O + shard_to_keys: dict[str, list[str]] = {} + for key, shard in mtp_keys.items(): + shard_to_keys.setdefault(shard, []).append(key) + + # Load all MTP weights + print(f"\nExtracting MTP weights from {len(shard_paths)} shards...") + all_mtp_weights: dict[str, mx.array] = {} + + for shard_name, keys in sorted(shard_to_keys.items()): + shard_path = shard_paths[shard_name] + print(f" Loading {shard_name} ({len(keys)} MTP keys)...") + shard_data = mx.load(str(shard_path)) + for key in keys: + if key in shard_data: + all_mtp_weights[key] = shard_data[key] + del shard_data + + print(f"Loaded {len(all_mtp_weights)} MTP weight tensors") + + # Norm keys that need +1.0 shift (HF centered ~0 → MLX centered ~1) + norm_suffixes = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + ".q_norm.weight", + ".k_norm.weight", + ".pre_fc_norm_hidden.weight", + ".pre_fc_norm_embedding.weight", + "mtp.norm.weight", + ) + + # Keys to keep as FP (not quantize) + skip_quantize_suffixes = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + ".q_norm.weight", + ".k_norm.weight", + "mtp.fc.weight", + "mtp.norm.weight", + "mtp.pre_fc_norm_hidden.weight", + "mtp.pre_fc_norm_embedding.weight", + ".shared_expert_gate.weight", + ) + + def _quantize_one(key: str, weight: mx.array) -> dict[str, mx.array]: + """Quantize a single weight, apply norm adjustment.""" + # Norm shift: +1.0 for RMSNorm weights + if any(key.endswith(s) for s in norm_suffixes) and weight.ndim == 1: + weight = weight + 1.0 + mx.eval(weight) + print(f" Norm shift: {key}") + + if no_quantize: + print(f" BF16: {key} {weight.shape}") + return {key: weight} + elif any(key.endswith(s) for s in skip_quantize_suffixes): + print(f" Keep FP: {key} {weight.shape}") + return {key: weight} + elif weight.ndim >= 2 and weight.shape[-1] >= group_size: + q_w, q_s, q_b = mx.quantize(weight, group_size=group_size, bits=bits) + mx.eval(q_w, q_s, q_b) + print(f" Quantize {bits}-bit: {key} {q_w.shape}") + return { + key: q_w, + key.replace(".weight", ".scales"): q_s, + key.replace(".weight", ".biases"): q_b, + } + else: + print(f" Keep FP (small): {key} {weight.shape}") + return {key: weight} + + quantized_weights: dict[str, mx.array] = {} + + if is_moe and num_experts > 0: + # Stack expert weights ONE PROJECTION AT A TIME to minimize peak memory + for proj in ["up_proj", "down_proj", "gate_proj"]: + expert_keys = [ + f"mtp.layers.0.mlp.experts.{e}.{proj}.weight" + for e in range(num_experts) + ] + if all(k in all_mtp_weights for k in expert_keys): + stacked = mx.stack([all_mtp_weights.pop(k) for k in expert_keys]) + mx.eval(stacked) + stacked_key = f"mtp.layers.0.mlp.switch_mlp.{proj}.weight" + print(f" Stacked {num_experts} experts for {proj}: {stacked.shape}") + quantized_weights.update(_quantize_one(stacked_key, stacked)) + del stacked + else: + present = sum(1 for k in expert_keys if k in all_mtp_weights) + if present > 0: + print(f" WARNING: Only {present}/{num_experts} experts for {proj}") + + # Quantize remaining non-expert weights + for key in sorted(all_mtp_weights.keys()): + weight = all_mtp_weights.pop(key) + quantized_weights.update(_quantize_one(key, weight)) + del weight + del all_mtp_weights + + # Save to mtp/ subdirectory (avoids mlx_vlm glob loading all *.safetensors) + mtp_output_dir = snapshot_dir / "mtp" + mtp_output_dir.mkdir(exist_ok=True) + mtp_output_file = mtp_output_dir / "weights.safetensors" + mode_str = "BF16" if no_quantize else "quantized" + print( + f"\nSaving {len(quantized_weights)} {mode_str} MTP weights to {mtp_output_file}" + ) + mx.save_safetensors(str(mtp_output_file), quantized_weights) + + total_bytes = sum(v.nbytes for v in quantized_weights.values()) + print(f"MTP weights size: {total_bytes / 1e6:.1f} MB ({mode_str})") + + return mtp_output_file, list(quantized_weights.keys()) + + +def update_model_index(snapshot_dir: Path, mtp_keys: list[str]): + """Update model.safetensors.index.json to include MTP weight keys.""" + index_path = snapshot_dir / "model.safetensors.index.json" + if not index_path.exists(): + print(f"WARNING: No index file found at {index_path}, skipping index update") + return + + with open(index_path) as f: + index = json.load(f) + + weight_map = index.get("weight_map", {}) + for key in mtp_keys: + weight_map[key] = "model-mtp.safetensors" + + index["weight_map"] = weight_map + + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + + print(f"Updated {index_path} with {len(mtp_keys)} MTP weight entries") + + +def update_config(snapshot_dir: Path): + """Update config.json to signal MTP availability. + + For Qwen3.5, mtp_num_hidden_layers already exists in text_config. + We add num_nextn_predict_layers at top level for vllm-mlx compatibility. + """ + config_path = snapshot_dir / "config.json" + with open(config_path) as f: + config = json.load(f) + + text_config = config.get("text_config", config) + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + + if num_mtp > 0: + # Set num_nextn_predict_layers for vllm-mlx MTP detection + config["num_nextn_predict_layers"] = num_mtp + text_config["num_nextn_predict_layers"] = num_mtp + if "text_config" in config: + config["text_config"] = text_config + + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"Updated config: num_nextn_predict_layers={num_mtp}") + else: + print("WARNING: mtp_num_hidden_layers not found in config") + + +def main(): + parser = argparse.ArgumentParser(description="Add MTP weights to MLX Qwen3.5 model") + parser.add_argument( + "--mlx-model-path", + type=str, + required=True, + help="Path to MLX model directory (HF cache or direct path)", + ) + parser.add_argument( + "--source-model", + type=str, + required=True, + help="HuggingFace BF16 model to download MTP shards from (e.g., Qwen/Qwen3.5-122B-A10B)", + ) + parser.add_argument( + "--download-dir", + type=str, + default=None, + help="Directory to download shards to (default: temp dir)", + ) + parser.add_argument( + "--skip-download", + action="store_true", + help="Skip download (use existing shards in download-dir)", + ) + parser.add_argument( + "--keep-shards", + action="store_true", + help="Don't delete downloaded BF16 shards after extraction", + ) + parser.add_argument( + "--no-quantize", + action="store_true", + help="Save MTP weights in BF16 (no quantization). Required for correct MTP predictions.", + ) + args = parser.parse_args() + + print("=" * 60) + print("MTP Weight Addition for Qwen3.5 MLX Model") + print("=" * 60) + + # Find snapshot directory + snapshot_dir = find_snapshot_dir(args.mlx_model_path) + print(f"\nMLX model snapshot: {snapshot_dir}") + + # Read config + config_path = snapshot_dir / "config.json" + if not config_path.exists(): + print(f"ERROR: No config.json found in {snapshot_dir}") + sys.exit(1) + + with open(config_path) as f: + config = json.load(f) + + text_config = config.get("text_config", config) + model_type = text_config.get("model_type", config.get("model_type", "unknown")) + hidden_size = text_config.get("hidden_size", "?") + num_experts = text_config.get("num_experts", 0) + is_moe = num_experts > 0 + mtp_layers = text_config.get("mtp_num_hidden_layers", 0) + + print(f"Model type: {model_type}") + print(f"Hidden size: {hidden_size}") + print(f"Num experts: {num_experts} ({'MoE' if is_moe else 'Dense'})") + print(f"MTP layers: {mtp_layers}") + + if mtp_layers == 0: + print("ERROR: Model has no MTP layers configured (mtp_num_hidden_layers=0)") + sys.exit(1) + + # Check if MTP weights already exist + mtp_file = snapshot_dir / "mtp" / "weights.safetensors" + if mtp_file.exists(): + size_mb = mtp_file.stat().st_size / 1e6 + print(f"\nWARNING: mtp/weights.safetensors already exists ({size_mb:.1f} MB)") + print("Delete it first if you want to regenerate.") + sys.exit(0) + + # Setup download directory + if args.download_dir: + download_dir = Path(args.download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + else: + download_dir = Path(tempfile.mkdtemp(prefix="qwen35_mtp_")) + print(f"\nDownload dir: {download_dir}") + + # Fetch shard index and identify MTP shards + source_index = fetch_shard_index(args.source_model, download_dir) + mtp_key_map, shards_needed = identify_mtp_shards(source_index) + + print( + f"\nFound {len(mtp_key_map)} MTP weight keys across {len(shards_needed)} shards:" + ) + for shard in sorted(shards_needed): + count = sum(1 for v in mtp_key_map.values() if v == shard) + print(f" {shard}: {count} keys") + + # Download shards + if not args.skip_download: + print(f"\nDownloading {len(shards_needed)} shards...") + shard_paths = download_shards(shards_needed, args.source_model, download_dir) + else: + shard_paths = {} + for shard_name in shards_needed: + p = download_dir / shard_name + if p.exists(): + shard_paths[shard_name] = p + else: + print(f"ERROR: Shard not found: {p}") + sys.exit(1) + + # Extract, optionally quantize, and save MTP weights + mtp_file, mtp_weight_keys = extract_and_quantize_mtp_weights( + mtp_key_map, + shard_paths, + snapshot_dir, + is_moe, + num_experts, + no_quantize=args.no_quantize, + ) + + # NOTE: Do NOT update model.safetensors.index.json — mlx_vlm's glob + # would try to load MTP weights and fail with strict loading. + # MTP weights are loaded separately by inject_mtp_support(). + + # Update config + update_config(snapshot_dir) + + # Cleanup downloaded shards + if not args.keep_shards and not args.skip_download: + print("\nCleaning up downloaded shards...") + for shard_path in shard_paths.values(): + shard_path.unlink(missing_ok=True) + print(f" Deleted {shard_path.name}") + + print("\n" + "=" * 60) + print("SUCCESS! MTP weights added to MLX model.") + print("=" * 60) + print(f"\nMTP weight file: {mtp_file}") + print(f"Total MTP keys: {len(mtp_weight_keys)}") + print("\nTo use MTP, start the server with --enable-mtp:") + print(f" vllm-mlx serve {args.mlx_model_path} --enable-mtp") + + +if __name__ == "__main__": + main() diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index 32b26e035..e450cd5bc 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -172,12 +172,16 @@ class ChatCompletionRequest(BaseModel): # MLLM-specific parameters video_fps: float | None = None video_max_frames: int | None = None + # Sampling penalties + repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes) # Request timeout in seconds (None = use server default) timeout: float | None = None # SpecPrefill: per-request enable/disable (None = server decides) specprefill: bool | None = None # SpecPrefill: per-request keep percentage (0.0-1.0, None = use server default) specprefill_keep_pct: float | None = None + # Enable/disable thinking mode (None = server default, typically True) + enable_thinking: bool | None = None class AssistantMessage(BaseModel): @@ -239,6 +243,8 @@ class CompletionRequest(BaseModel): max_tokens: int | None = None stream: bool = False stop: list[str] | None = None + # Sampling penalties + repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes) # Request timeout in seconds (None = use server default) timeout: float | None = None diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 8c52915ee..6218dce7d 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -20,7 +20,9 @@ r"<\|im_end\|>|<\|im_start\|>|<\|endoftext\|>|" r"<\|end\|>|<\|eot_id\|>|<\|start_header_id\|>|<\|end_header_id\|>|" r"<\|channel\|>|<\|message\|>|<\|start\|>|<\|return\|>|<\|call\|>|<\|constrain\|>|" - r"|||\[PAD\]|\[SEP\]|\[CLS\]" + r"|||\[PAD\]|\[SEP\]|\[CLS\]|" + r"\[e~\[|\]~b\][a-z]*|\]~!b\[|" + r"|" ) @@ -356,6 +358,8 @@ def flush(self) -> list[tuple[str, str]]: "InternVL", # InternVL "deepseek-vl", "DeepSeek-VL", # DeepSeek-VL + "Qwen3.5-", + "qwen3_5", # Qwen3.5 MoE (natively multimodal, hybrid ArraysCache+KVCache) ] diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 8a90bc9be..cca034ad8 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -37,6 +37,13 @@ def serve_command(args): print("Example: --enable-auto-tool-choice --tool-call-parser mistral") sys.exit(1) + # Validate gpu-memory-utilization range + if not (0.0 < args.gpu_memory_utilization <= 1.0): + print( + "Error: --gpu-memory-utilization must be between 0.0 (exclusive) and 1.0 (inclusive)" + ) + sys.exit(1) + # Configure server security settings server._api_key = args.api_key server._default_timeout = args.timeout @@ -196,7 +203,8 @@ def serve_command(args): scheduler_config=scheduler_config, stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, - force_mllm=args.mllm, + force_mllm=getattr(args, "mllm", False), + gpu_memory_utilization=args.gpu_memory_utilization, served_model_name=args.served_model_name, mtp=args.enable_mtp, prefill_step_size=args.prefill_step_size, @@ -704,6 +712,14 @@ def main(): action="store_true", help="Enable continuous batching for multiple concurrent users (slower for single user)", ) + serve_parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.90, + help="Fraction of device memory for Metal allocation limit and emergency " + "cache clear threshold (0.0-1.0, default: 0.90). Increase to 0.95 for " + "large models (200GB+) that need more memory headroom.", + ) # Paged cache options (experimental) serve_parser.add_argument( "--use-paged-cache", @@ -838,12 +854,14 @@ def main(): "nemotron", "xlam", "functionary", + "gemma4", "glm47", + "minimax", ], help=( "Select the tool call parser for the model. Options: " "auto (auto-detect), mistral, qwen, qwen3_coder, llama, hermes, " - "deepseek, kimi, granite, nemotron, xlam, functionary, glm47. " + "deepseek, gemma4, kimi, granite, nemotron, xlam, functionary, glm47, minimax. " "Required for --enable-auto-tool-choice." ), ) diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 52f3ed92d..34b39c84d 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -89,23 +89,23 @@ class MLLMModelWrapper: but MLLM models return LanguageModelOutput objects. This wrapper extracts the logits from the output. - Also handles Gemma 3/4's required pixel_values argument by injecting None + Also handles Gemma 3's required pixel_values argument by injecting None for text-only requests. """ def __init__(self, model): self._model = model - # Detect if this is a Gemma 3/4 model (requires pixel_values as positional arg) - model_type = str(getattr(model, "model_type", "")).lower() - self._is_gemma_multimodal = hasattr(model, "model_type") and ( - "gemma3" in model_type or "gemma4" in model_type + # Detect if this is a Gemma 3 model (requires pixel_values as positional arg) + self._is_gemma3 = ( + hasattr(model, "model_type") + and "gemma3" in str(getattr(model, "model_type", "")).lower() ) def __call__(self, *args, **kwargs): """Call the model and extract logits from LanguageModelOutput.""" - # Gemma 3/4 requires pixel_values as a positional argument, unlike Qwen + # Gemma 3 requires pixel_values as a positional argument, unlike Qwen # which makes it optional. Inject pixel_values=None for text-only requests. - if self._is_gemma_multimodal and "pixel_values" not in kwargs: + if self._is_gemma3 and "pixel_values" not in kwargs: kwargs["pixel_values"] = None output = self._model(*args, **kwargs) @@ -137,6 +137,7 @@ def __init__( scheduler_config: Any | None = None, stream_interval: int = 1, force_mllm: bool = False, + gpu_memory_utilization: float = 0.90, ): """ Initialize the batched engine. @@ -147,11 +148,14 @@ def __init__( scheduler_config: Optional scheduler configuration stream_interval: Tokens to batch before streaming (1=every token) force_mllm: Force loading as MLLM even if not auto-detected + gpu_memory_utilization: Fraction of device memory for Metal allocation + limit and emergency threshold (0.0-1.0, default 0.90) """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._scheduler_config = scheduler_config self._stream_interval = stream_interval + self._gpu_memory_utilization = gpu_memory_utilization self._is_mllm = force_mllm or is_mllm_model(model_name) self._model = None @@ -207,6 +211,10 @@ async def _start_mllm(self) -> None: self._model = self._mllm_instance.model self._processor = self._mllm_instance.processor + # Inject MTP support if enabled + if self._scheduler_config and self._scheduler_config.enable_mtp: + self._inject_mtp_mllm() + # Create MLLM scheduler config with batch generator support if self._scheduler_config and hasattr(self._scheduler_config, "max_num_seqs"): max_num_seqs = self._scheduler_config.max_num_seqs @@ -219,12 +227,28 @@ async def _start_mllm(self) -> None: self._scheduler_config, "completion_batch_size", 16 ) + cache_memory_mb = getattr(self._scheduler_config, "cache_memory_mb", None) + enable_mtp = ( + self._scheduler_config.enable_mtp if self._scheduler_config else False + ) + mtp_num_draft = getattr(self._scheduler_config, "mtp_num_draft_tokens", 1) + kv_quant = getattr(self._scheduler_config, "kv_cache_quantization", False) + kv_bits = getattr(self._scheduler_config, "kv_cache_quantization_bits", 8) + kv_group_size = getattr( + self._scheduler_config, "kv_cache_quantization_group_size", 64 + ) mllm_config = MLLMSchedulerConfig( max_num_seqs=max_num_seqs, prefill_batch_size=prefill_batch_size, completion_batch_size=completion_batch_size, enable_vision_cache=True, vision_cache_size=100, + cache_memory_mb=cache_memory_mb, + enable_mtp=enable_mtp, + mtp_num_draft_tokens=mtp_num_draft, + kv_cache_quantization=kv_quant, + kv_cache_quantization_bits=kv_bits, + kv_cache_quantization_group_size=kv_group_size, ) # Create and start MLLM scheduler @@ -241,6 +265,54 @@ async def _start_mllm(self) -> None: f"completion_batch={completion_batch_size}" ) + def _inject_mtp_mllm(self) -> None: + """Inject MTP weights into the MLLM model's language_model.""" + import json + from pathlib import Path + + from mlx_lm.utils import _download + + model = self._model + model_path = Path(_download(self._model_name)) + config_path = model_path / "config.json" + if not config_path.exists(): + logger.warning("[MTP-MLLM] No config.json found, skipping MTP") + return + + with open(config_path) as f: + config = json.load(f) + + text_config = config.get("text_config", config) + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + if num_mtp == 0: + num_mtp = text_config.get( + "num_nextn_predict_layers", + config.get("num_nextn_predict_layers", 0), + ) + if num_mtp == 0: + logger.info("[MTP-MLLM] No MTP layers in config, skipping") + return + + # Navigate to text model + text_model = model + if hasattr(model, "language_model"): + text_model = model.language_model + if getattr(text_model, "mtp", None) is not None: + logger.info("[MTP-MLLM] Model already has MTP, skipping injection") + return + + model_type = text_config.get("model_type", config.get("model_type", "")) + if "qwen3_5" in model_type: + from ..patches.qwen3_5_mtp import inject_mtp_support + + ok = inject_mtp_support(model, model_path, config) + if ok: + logger.info("[MTP-MLLM] Qwen3.5 MTP injected successfully") + else: + logger.warning("[MTP-MLLM] Qwen3.5 MTP injection failed") + else: + logger.info(f"[MTP-MLLM] MTP not supported for model_type={model_type}") + async def _start_llm(self) -> None: """Start the LLM engine with AsyncEngineCore.""" from ..engine_core import AsyncEngineCore, EngineConfig @@ -261,9 +333,10 @@ async def _start_llm(self) -> None: # Validate MTP support if enabled if self._scheduler_config and self._scheduler_config.enable_mtp: + from ..patches.qwen3_5_mtp import validate_mtp_support as validate_35 from ..patches.qwen3_next_mtp import validate_mtp_support - if validate_mtp_support(self._model): + if validate_mtp_support(self._model) or validate_35(self._model): logger.info("[MTP] Model validated for MTP speculative decoding") else: logger.warning( @@ -283,13 +356,14 @@ async def _start_llm(self) -> None: device_info.get("memory_size", 0), ) if max_recommended > 0: - soft_limit = int(max_recommended * 0.90) + soft_limit = int(max_recommended * self._gpu_memory_utilization) mx.set_memory_limit(soft_limit) mx.set_cache_limit(32 * 1024 * 1024 * 1024) # 32GB + pct = self._gpu_memory_utilization * 100 logger.info( f"Metal memory limits set: " f"allocation_limit={soft_limit / 1e9:.1f}GB " - f"(90% of {max_recommended / 1e9:.1f}GB), " + f"({pct:.0f}% of {max_recommended / 1e9:.1f}GB), " f"cache_limit=32GB" ) except Exception as e: @@ -301,6 +375,7 @@ async def _start_llm(self) -> None: model_name=self._model_name, scheduler_config=scheduler_config, stream_interval=self._stream_interval, + gpu_memory_utilization=self._gpu_memory_utilization, ) # Create async engine @@ -335,6 +410,7 @@ def _apply_chat_template( messages: list[dict[str, Any]], tools: list[dict] | None = None, num_images: int = 0, + enable_thinking: bool | None = None, ) -> str: """Apply chat template to messages. @@ -363,9 +439,13 @@ def _apply_chat_template( if self._is_mllm and num_images > 0: messages = self._prepare_mllm_messages(messages) + # Per-request enable_thinking override; default: True unless coder model. + if enable_thinking is None: + enable_thinking = "coder" not in self._model_name.lower() template_kwargs = { "tokenize": False, "add_generation_prompt": True, + "enable_thinking": enable_thinking, } if tools: template_kwargs["tools"] = tools @@ -375,9 +455,10 @@ def _apply_chat_template( messages, **template_kwargs ) except TypeError as e: - # Some templates don't accept 'tools'; retry without them. + # Some templates don't accept 'tools' or 'enable_thinking'; + # retry without them. logger.debug(f"Chat template TypeError, retrying without extras: {e}") - for key in ["tools"]: + for key in ["tools", "enable_thinking"]: if key in template_kwargs: del template_kwargs[key] return template_applicator.apply_chat_template( @@ -639,11 +720,15 @@ async def chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + # Per-request enable_thinking override + enable_thinking = kwargs.pop("enable_thinking", None) + # Apply chat template prompt = self._apply_chat_template( messages, template_tools, num_images=len(all_images), + enable_thinking=enable_thinking, ) return await self.generate( @@ -750,11 +835,15 @@ async def stream_chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + # Per-request enable_thinking override + enable_thinking = kwargs.pop("enable_thinking", None) + # Apply chat template prompt = self._apply_chat_template( messages, template_tools, num_images=len(all_images), + enable_thinking=enable_thinking, ) # Compute prefix boundary for cache @@ -786,14 +875,27 @@ def get_stats(self) -> dict[str, Any]: if self._mllm_scheduler: mllm_stats = self._mllm_scheduler.get_stats() stats["mllm_scheduler"] = mllm_stats - # Promote Metal memory stats to top-level for /v1/status + # Promote stats to top-level for /v1/status and monitoring for key in ( + "running", + "num_running", + "num_waiting", + "num_requests_processed", + "total_prompt_tokens", + "total_completion_tokens", "metal_active_memory_gb", "metal_peak_memory_gb", "metal_cache_memory_gb", + "memory_aware_cache", + "paged_cache", + "prefix_cache", + "requests", ): if key in mllm_stats: stats[key] = mllm_stats[key] + # MLLM engine is always "running" once loaded + if "running" not in stats: + stats["running"] = self._loaded elif self._engine: stats.update(self._engine.get_stats()) @@ -801,20 +903,28 @@ def get_stats(self) -> dict[str, Any]: def get_cache_stats(self) -> dict[str, Any] | None: """Get cache statistics.""" - if self._mllm_scheduler and self._mllm_scheduler.vision_cache: - return self._mllm_scheduler.vision_cache.get_stats() + if self._mllm_scheduler and self._mllm_scheduler.batch_generator: + return self._mllm_scheduler.batch_generator.get_vision_cache_stats() elif self._engine: return self._engine.get_cache_stats() return None def save_cache_to_disk(self, cache_dir: str) -> bool: """Save prefix cache to disk for persistence across restarts.""" + if self._mllm_scheduler and self._mllm_scheduler.batch_generator: + pc = self._mllm_scheduler.batch_generator.prefix_cache + if pc is not None: + return pc.save_to_disk(cache_dir) if self._engine: return self._engine.save_cache_to_disk(cache_dir) return False def load_cache_from_disk(self, cache_dir: str) -> int: """Load prefix cache from disk. Returns number of entries loaded.""" + if self._mllm_scheduler and self._mllm_scheduler.batch_generator: + pc = self._mllm_scheduler.batch_generator.prefix_cache + if pc is not None: + return pc.load_from_disk(cache_dir) if self._engine: return self._engine.load_cache_from_disk(cache_dir) return 0 diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 8c7212ca4..39cfa849d 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -600,9 +600,10 @@ def run_stream(): # For LLM, apply chat template and stream tokenizer = self._model.tokenizer if hasattr(tokenizer, "apply_chat_template"): - # Disable thinking mode for coder models since it interferes - # with tool call parsing (tags leak as raw text). - enable_thinking = "coder" not in self._model_name.lower() + # Per-request enable_thinking override; default: True unless coder model. + enable_thinking = kwargs.pop("enable_thinking", None) + if enable_thinking is None: + enable_thinking = "coder" not in self._model_name.lower() template_kwargs = { "tokenize": False, "add_generation_prompt": True, @@ -842,9 +843,11 @@ async def _stream_generate_text( specprefill_override = kwargs.pop("specprefill", None) specprefill_keep_pct = kwargs.pop("specprefill_keep_pct", None) - # Read enable_thinking from env (set by runtime_patches, consistent with MLLM path) - enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") - enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") + # Per-request enable_thinking override; fall back to env var / default True. + enable_thinking = kwargs.pop("enable_thinking", None) + if enable_thinking is None: + enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") + enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") # Apply chat template for full prompt template_kwargs = { diff --git a/vllm_mlx/engine_core.py b/vllm_mlx/engine_core.py index d21928824..ae75fd39e 100644 --- a/vllm_mlx/engine_core.py +++ b/vllm_mlx/engine_core.py @@ -36,6 +36,7 @@ class EngineConfig: scheduler_config: Optional[SchedulerConfig] = None step_interval: float = 0.001 # 1ms between steps stream_interval: int = 1 # Tokens to batch before streaming (1=every token) + gpu_memory_utilization: float = 0.90 # Fraction of device memory for allocation class EngineCore: @@ -150,18 +151,12 @@ async def _engine_loop(self) -> None: stream_interval = self.config.stream_interval use_simple_streaming = stream_interval == 1 - # Emergency memory pressure threshold — use 85% of Metal's - # max recommended working set so this scales with system RAM. + # Emergency memory pressure threshold — dynamic based on gpu_memory_utilization + _gpu_mem_util = self.config.gpu_memory_utilization try: - _device_info = mx.device_info() - _max_recommended = _device_info.get( - "max_recommended_working_set_size", - _device_info.get("memory_size", 0), - ) - _memory_pressure_threshold = ( - int(_max_recommended * 0.85) - if _max_recommended > 0 - else 200 * 1024 * 1024 * 1024 + _device_mem = mx.device_info().get("memory_size", 200 * 1024 * 1024 * 1024) + _memory_pressure_threshold = int( + _device_mem * min(_gpu_mem_util + 0.05, 0.99) ) except Exception: _memory_pressure_threshold = 200 * 1024 * 1024 * 1024 diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index 111439330..2668c3cec 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -771,7 +771,15 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: f"layer_types={[type(lc).__name__ for lc in best_lcp_entry.cache[:3]]}" ) - if not has_non_trimmable: + if has_non_trimmable: + # Hybrid model (SSM+Attention): SSM state can't be rewound. + # Block LCP for hybrid models — use think-suffix stripping + # in the engine layer to get clean PREFIX matches instead. + logger.debug( + "[cache_fetch] LCP skipped: non-trimmable cache layers " + "(hybrid model, SSM state can't be rewound)" + ) + else: trimmed_cache = _trim_cache_offset(best_lcp_entry.cache, excess) self._entries.move_to_end(best_lcp_entry.tokens) self._stats.hits += 1 diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index 9c4ea8e44..a6a59afba 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -24,12 +24,21 @@ import mlx.core as mx import mlx.nn as nn +from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig, _trim_cache_offset from .multimodal_processor import MultimodalProcessor from .vision_embedding_cache import VisionEmbeddingCache logger = logging.getLogger(__name__) +class PrefillAbortedError(Exception): + """Raised when a prefill is aborted due to client disconnect.""" + + def __init__(self, request_id: str): + self.request_id = request_id + super().__init__(f"Prefill aborted for request {request_id}") + + @dataclass class MLLMBatchRequest: """ @@ -59,6 +68,9 @@ class MLLMBatchRequest: image_grid_thw: Optional[mx.array] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + # Text-only flag (no images/videos — eligible for prefix cache) + is_text_only: bool = False + # Generation state num_tokens: int = 0 # Tokens generated so far output_tokens: List[int] = field(default_factory=list) @@ -151,6 +163,7 @@ def extend(self, other: "MLLMBatch") -> None: # Extend logits_processors if self.logits_processors is not None or other.logits_processors is not None: + # At this point self.uids already includes other.uids from extend above self_len = len(self.uids) - len(other.uids) self_lp = self.logits_processors or [None] * self_len other_lp = other.logits_processors or [None] * len(other.uids) @@ -163,17 +176,14 @@ def extend(self, other: "MLLMBatch") -> None: other_s = other.samplers or [None] * len(other.uids) self.samplers = list(self_s) + list(other_s) - # Extend cache - handle None and incompatible caches + # Extend cache - handle both BatchKVCache (.keys/.values) and + # ArraysCache (.cache list) from hybrid models like Qwen3.5 for c, o in zip(self.cache, other.cache): if c is not None and o is not None and hasattr(c, "extend"): try: - # Only extend if both caches have valid keys - if ( - hasattr(c, "keys") - and c.keys is not None - and hasattr(o, "keys") - and o.keys is not None - ): + has_kv = hasattr(c, "keys") and c.keys is not None + has_arrays = hasattr(c, "cache") + if has_kv or has_arrays: c.extend(o) except Exception as e: logger.warning(f"Failed to extend cache: {e}") @@ -316,6 +326,7 @@ def __init__( prefill_step_size: int = 1024, enable_vision_cache: bool = True, vision_cache_size: int = 100, + prefix_cache_config: Optional[MemoryCacheConfig] = None, ): """ Initialize MLLM batch generator. @@ -332,6 +343,7 @@ def __init__( prefill_step_size: Tokens to process per prefill step enable_vision_cache: Enable vision embedding caching vision_cache_size: Max entries in vision cache + prefix_cache_config: Config for KV prefix cache (text-only requests) """ self.model = model self.processor = processor @@ -352,8 +364,10 @@ def __init__( ) # Patch attention for BatchKVCache compatibility + from .patches.qwen3_5_mllm import patch_qwen35_attention_for_batching from .patches.gemma4_mllm import patch_gemma4_attention_for_batching + patch_qwen35_attention_for_batching() patch_gemma4_attention_for_batching() self.max_tokens = max_tokens @@ -375,6 +389,15 @@ def __init__( # Error responses for requests that failed during preprocessing self._pending_error_responses: List[MLLMBatchResponse] = [] + # Per-request prefill progress: request_id → (processed_tokens, total_tokens) + self._prefill_progress: Dict[str, Tuple[int, int]] = {} + + # Aborted request IDs — checked between prefill chunks to allow + # early termination when a client disconnects during long prefill. + # Set operations are GIL-protected, safe across event-loop and + # executor threads. + self._aborted_request_ids: set = set() + # Vision embedding cache for repeated images self.vision_cache = VisionEmbeddingCache( max_pixel_entries=vision_cache_size, @@ -386,6 +409,33 @@ def __init__( f"MLLMBatchGenerator: Vision cache enabled (size={vision_cache_size})" ) + # KV prefix cache for text-only requests + self.prefix_cache: Optional[MemoryAwarePrefixCache] = None + if prefix_cache_config is not None: + self.prefix_cache = MemoryAwarePrefixCache( + model=self.language_model, + config=prefix_cache_config, + ) + logger.info("MLLMBatchGenerator: KV prefix cache enabled") + + # Normalize chat template for prefix-cache stability. + # Qwen3.5 chat template retroactively changes formatting of earlier + # assistant messages based on last_query_index (position of last + # non-tool user message). When a user text message is appended, + # last_query_index jumps forward, removing blocks from + # earlier assistant turns — shifting tokens mid-sequence and + # breaking prefix match. Fix: always use plain format for + # historical assistant turns (thinking is still added by the + # generation prompt at the end). + self._normalize_chat_template_for_prefix_cache() + + # Compute think-suffix length for prefix cache key stripping. + # Models with enable_thinking=True add \n to the generation + # prompt. This breaks prefix cache (stored key ends with + # but next request has actual response at that position). + # Stripping the suffix from cache keys enables clean PREFIX match. + self._think_suffix_len = self._compute_think_suffix_len() + # Generation stream if MLLMBatchGenerator._stream is None: MLLMBatchGenerator._stream = mx.new_stream(mx.default_device()) @@ -397,6 +447,132 @@ def __init__( mx.device_info()["max_recommended_working_set_size"] ) + def _normalize_chat_template_for_prefix_cache(self) -> None: + """Patch chat template so historical assistant turns are prefix-stable. + + Qwen3.5's chat template computes ``last_query_index`` — the position + of the last non-tool-response user message — and conditionally wraps + assistant turns after that index in ``...\\n\\n\\n``. + When a new user text message is appended, ``last_query_index`` jumps + forward, retroactively removing these ```` wrappers from + earlier assistant turns. This shifts tokens mid-sequence and breaks + prefix cache. + + Fix: replace the conditional with the plain (ELSE) branch so ALL + historical assistant messages use ``<|im_start|>assistant\\ncontent`` + without any injected ```` block. The generation prompt still + adds ``\\n`` at the very end, so the model generates thinking. + """ + if self.prefix_cache is None: + return # No prefix cache — no need to normalize + + # Find the chat template. VLM processors (e.g. Qwen3VLProcessor) + # keep a SEPARATE copy of chat_template from their tokenizer — both + # must be patched. The processor's copy is used by + # BatchedEngine._apply_chat_template() (text rendering), while the + # tokenizer's copy is used by _compute_think_suffix_len(). + tokenizer = getattr(self.processor, "tokenizer", self.processor) + # Prefer the processor's own template (it's the one used for rendering) + template = getattr(self.processor, "chat_template", None) + if not template: + template = getattr(tokenizer, "chat_template", None) + if not template or "last_query_index" not in template: + return # Not affected + + import re + + # The pattern in Qwen3.5 template: + # {%- if loop.index0 > ns.last_query_index %} + # {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + # {%- else %} + # {{- '<|im_start|>' + message.role + '\n' + content }} + # {%- endif %} + # + # Replace with just the ELSE branch (always plain format). + pattern = ( + r"\{%-\s*if\s+loop\.index0\s*>\s*ns\.last_query_index\s*%\}" + r".*?" + r"\{%-\s*else\s*%\}" + r"\s*(\{\{-.*?content.*?\}\})" + r"\s*\{%-\s*endif\s*%\}" + ) + new_template = re.sub(pattern, r"\1", template, flags=re.DOTALL) + if new_template != template: + # Patch ALL copies: processor, tokenizer, and any dict variants. + if hasattr(self.processor, "chat_template"): + self.processor.chat_template = new_template + tokenizer.chat_template = new_template + logger.info( + "[prefix_cache] Normalized chat template: removed " + "last_query_index conditional for prefix-stable assistant turns" + ) + else: + logger.debug( + "[prefix_cache] Chat template has last_query_index but " + "regex did not match — template may use a different pattern" + ) + + def _compute_think_suffix_len(self) -> int: + """Compute how many extra tokens enable_thinking=True adds at the END. + + Compares the generation prompt suffix with and without + ``enable_thinking`` to find the think-tag suffix length + (typically ``\\n`` = 2 tokens for Qwen3/Qwen3.5). + + Returns 0 if the template doesn't support ``enable_thinking``. + """ + try: + # Find something with apply_chat_template + applicator = None + for candidate in [ + getattr(self.processor, "tokenizer", None), + self.processor, + ]: + if candidate is not None and hasattr(candidate, "apply_chat_template"): + applicator = candidate + break + + if applicator is None: + return 0 + + dummy = [{"role": "user", "content": "x"}] + + try: + text_with = applicator.apply_chat_template( + dummy, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + text_without = applicator.apply_chat_template( + dummy, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + return 0 + + # Check if enable_thinking adds a known think tag at the end. + # enable_thinking may also change the system prompt, so we can't + # simply compare lengths — we look at the ending instead. + for tag in ["\n", ""]: + if text_with.endswith(tag) and not text_without.endswith(tag): + tokenizer = getattr(self.processor, "tokenizer", self.processor) + suffix_tokens = tokenizer.encode(tag) + base_tokens = tokenizer.encode("") + suffix_len = len(suffix_tokens) - len(base_tokens) + if suffix_len > 0: + logger.info( + f"[think_suffix] Detected think tag " + f"'{tag.strip()}' = {suffix_len} token(s)" + ) + return max(0, suffix_len) + + return 0 + except Exception: + return 0 + def close(self) -> None: """Release resources and reset wired limit.""" if self._old_wired_limit is not None: @@ -404,6 +580,16 @@ def close(self) -> None: mx.set_wired_limit(self._old_wired_limit) self._old_wired_limit = None + def abort_prefill(self, request_id: str) -> None: + """Signal that a request's prefill should be aborted. + + Called from the event loop thread when a client disconnects. + The prefill loop checks this set between chunks and raises + PrefillAbortedError to exit early. + """ + self._aborted_request_ids.add(request_id) + logger.info(f"[abort_prefill] Marked {request_id} for prefill abort") + def __del__(self): try: self.close() @@ -580,12 +766,81 @@ def _preprocess_request(self, request: MLLMBatchRequest) -> None: self._stats.num_images_processed += len(all_images) self._stats.vision_encoding_time += processing_time + # Mark text-only requests (eligible for prefix cache) + request.is_text_only = not bool(all_images) + logger.debug( f"Preprocessed request {request.request_id}: " f"{len(all_images)} images, {request.input_ids.size if request.input_ids is not None else 0} tokens " f"({processing_time:.2f}s)" ) + def _run_chunked_text_prefill( + self, request: MLLMBatchRequest, cache: List[Any] + ) -> mx.array: + """ + Run prefill in chunks for text-only requests, reporting real progress. + + Processes input_ids in prefill_step_size chunks through the language + model, updating ``_prefill_progress`` after each chunk so the status + endpoint can report accurate prefill percentage. + + Returns: + Logits from the last chunk (same contract as _run_vision_encoding). + """ + input_ids = request.input_ids + if input_ids.ndim == 1: + input_ids = input_ids[None, :] + + total = input_ids.shape[1] + step = self.prefill_step_size + + # Short prompt — process in one shot (no chunking overhead) + if total <= step: + self._prefill_progress[request.request_id] = (total, total) + output = self.language_model(input_ids, cache=cache) + request.vision_encoded = True + if hasattr(output, "logits"): + return output.logits + return output + + # Process all chunks except the last + processed = 0 + chunk_count = 0 + while processed + step < total: + # Check for abort between chunks (client disconnect) + if request.request_id in self._aborted_request_ids: + self._aborted_request_ids.discard(request.request_id) + logger.info( + f"[chunked_prefill] Aborted {request.request_id} at " + f"{processed}/{total} tokens" + ) + raise PrefillAbortedError(request.request_id) + + chunk = input_ids[:, processed : processed + step] + self.language_model(chunk, cache=cache) + mx.eval([c.state for c in cache]) + processed += step + chunk_count += 1 + self._prefill_progress[request.request_id] = (processed, total) + + # Release Metal buffer pool periodically. Full-attention layers + # produce attention score buffers that grow each chunk (1024 × + # growing_context). Old smaller buffers can't be reused, so the + # pool accumulates O(N²) memory without clearing. + if chunk_count % 4 == 0: + mx.clear_cache() + + # Last chunk — return logits for sampling + last_chunk = input_ids[:, processed:] + output = self.language_model(last_chunk, cache=cache) + request.vision_encoded = True + self._prefill_progress[request.request_id] = (total, total) + + if hasattr(output, "logits"): + return output.logits + return output + def _run_vision_encoding( self, request: MLLMBatchRequest, cache: Optional[List[Any]] = None ) -> mx.array: @@ -648,75 +903,275 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: tic = time.perf_counter() - # Preprocess all requests + # Preprocess all requests (per-request error handling) + failed_requests = [] for req in requests: - self._preprocess_request(req) + try: + self._preprocess_request(req) + except Exception as e: + logger.error( + f"Failed to preprocess request {req.request_id}: " + f"{type(e).__name__}: {e}" + ) + failed_requests.append(req) + + # Remove failed requests from batch and create error responses + if failed_requests: + for req in failed_requests: + requests.remove(req) + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + + if not requests: + # All requests failed + return None total_prompt_tokens = sum( req.input_ids.size if req.input_ids is not None else 1 for req in requests ) self._stats.prompt_tokens += total_prompt_tokens - # Guard against excessive memory usage during cache merge. - # Each token in the batch requires KV entries across all layers. + # Log large prompts for monitoring (was previously a hard check that + # caused infinite retry loops when requests exceeded the limit). max_batch_tokens = self.prefill_step_size * len(requests) if total_prompt_tokens > max_batch_tokens: - raise ValueError( - f"Total prompt tokens ({total_prompt_tokens}) exceeds safe limit " - f"({max_batch_tokens}) for {len(requests)} requests. " - f"Reduce prompt length or batch size." + logger.warning( + f"Large batch prefill: {total_prompt_tokens} tokens " + f"(step_size={self.prefill_step_size}, requests={len(requests)}). " + f"Processing may be slow." ) # Run vision encoding for each request with its own KVCache. # Vision encoding cannot be batched because each request may have # different images/pixel values. We pass a per-request KVCache to # the VLM so the language model writes its KV state directly into it. + # + # For text-only requests, we check the prefix cache first. If there's + # a hit, we skip the full VLM forward and run only the language model + # on the remaining (uncached) tokens. first_tokens = [] all_logprobs = [] per_request_caches = [] + aborted_requests = [] for req in requests: - # Create a fresh KVCache for this request's language model prefill - request_cache = make_prompt_cache(self.language_model) + try: + # Check abort before starting prefill + if req.request_id in self._aborted_request_ids: + self._aborted_request_ids.discard(req.request_id) + raise PrefillAbortedError(req.request_id) + + # Try prefix cache for all requests (text-only and multimodal). + # VLM forward writes the same KV state as language model forward + # for text tokens, so cached KV from a previous VLM run is valid. + # However, if the remaining (uncached) tokens contain image + # placeholders, we must fall back to VLM forward instead of + # running them through the language model alone. + cached_kv = None + remaining_ids = None + if self.prefix_cache is not None and req.input_ids is not None: + input_ids_list = req.input_ids.reshape(-1).tolist() + # Strip think suffix from lookup key so stored entries + # (also stripped) match as clean PREFIX. + S = self._think_suffix_len + lookup_ids = input_ids_list[:-S] if S > 0 else input_ids_list + cached_kv, remaining_ids = self.prefix_cache.fetch(lookup_ids) + # Append think suffix back to remaining so the model + # sees the full generation prompt (\n). + if cached_kv is not None and S > 0: + remaining_ids = list(remaining_ids) + input_ids_list[-S:] + + # If remaining tokens contain image placeholders, the + # language-model-only path cannot handle them — clear the + # cache hit so we fall through to full VLM forward. + if cached_kv is not None and remaining_ids: + img_tok = getattr( + getattr(self.model, "config", None), + "image_token_index", + None, + ) + if img_tok is not None and img_tok in remaining_ids: + cached_kv = None + remaining_ids = None + + if cached_kv is not None and remaining_ids: + # Prefix/LCP match — run language model on remaining tokens + request_cache = cached_kv + remaining = mx.array(remaining_ids)[None, :] + cached_count = len(input_ids_list) - len(remaining_ids) + total_tokens = len(input_ids_list) + remaining_count = len(remaining_ids) + + with mx.stream(MLLMBatchGenerator._stream): + step = self.prefill_step_size + if remaining_count <= step: + # Short remaining — process in one shot + self._prefill_progress[req.request_id] = ( + total_tokens, + total_tokens, + ) + logits = self.language_model(remaining, cache=request_cache) + else: + # Chunked prefill on remaining tokens + self._prefill_progress[req.request_id] = ( + cached_count, + total_tokens, + ) + processed = 0 + chunk_count = 0 + while processed + step < remaining_count: + # Check for abort between chunks + if req.request_id in self._aborted_request_ids: + self._aborted_request_ids.discard(req.request_id) + logger.info( + f"[chunked_prefill] Aborted {req.request_id} " + f"at {cached_count + processed}/{total_tokens} tokens" + ) + raise PrefillAbortedError(req.request_id) + + chunk = remaining[:, processed : processed + step] + self.language_model(chunk, cache=request_cache) + mx.eval([c.state for c in request_cache]) + processed += step + chunk_count += 1 + self._prefill_progress[req.request_id] = ( + cached_count + processed, + total_tokens, + ) + if chunk_count % 4 == 0: + mx.clear_cache() + # Last chunk — return logits + remaining = remaining[:, processed:] + logits = self.language_model(remaining, cache=request_cache) + self._prefill_progress[req.request_id] = ( + total_tokens, + total_tokens, + ) - with mx.stream(MLLMBatchGenerator._stream): - # Run VLM forward pass — cache= flows through to language_model - logits = self._run_vision_encoding(req, cache=request_cache) + if hasattr(logits, "logits"): + logits = logits.logits - # Extract last token logits and sample - last_logits = logits[:, -1, :] - logprobs = last_logits - mx.logsumexp( - last_logits, axis=-1, keepdims=True - ) - sampled = self.sampler(logprobs) + last_logits = logits[:, -1, :] + logprobs = last_logits - mx.logsumexp( + last_logits, axis=-1, keepdims=True + ) + sampled = self.sampler(logprobs) + mx.eval(sampled, logprobs) + + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) + + per_request_caches.append(request_cache) + req.vision_encoded = True + logger.debug( + f"Prefix cache hit for {req.request_id}: " + f"cached={cached_count}, " + f"remaining={remaining_count}" + ) - mx.eval(sampled, logprobs) + elif cached_kv is not None and not remaining_ids: + # Exact/supersequence match — cache has all tokens, + # but we still need logits for the last token. + # fetch() with trim-by-1 store always returns remaining=[last_token]. + # If we get here (empty remaining), re-run on last token. + request_cache = cached_kv + last_token = req.input_ids[:, -1:] + total_tokens = len(input_ids_list) + self._prefill_progress[req.request_id] = ( + total_tokens, + total_tokens, + ) - first_tokens.append(sampled.item()) - all_logprobs.append(logprobs.squeeze(0)) + with mx.stream(MLLMBatchGenerator._stream): + logits = self.language_model(last_token, cache=request_cache) + if hasattr(logits, "logits"): + logits = logits.logits - per_request_caches.append(request_cache) + last_logits = logits[:, -1, :] + logprobs = last_logits - mx.logsumexp( + last_logits, axis=-1, keepdims=True + ) + sampled = self.sampler(logprobs) + mx.eval(sampled, logprobs) - # Merge per-request KVCaches into a single BatchKVCache. - # KVCache.merge() creates a BatchKVCache with proper left-padding - # alignment, so all requests share a single batched cache for - # subsequent generation steps. - from mlx_lm.models.cache import KVCache, RotatingKVCache + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) - sample_cache = per_request_caches[0][0] - if not isinstance(sample_cache, (KVCache, RotatingKVCache)): - raise ValueError( - f"MLLM continuous batching requires KVCache or " - f"RotatingKVCache but got {type(sample_cache).__name__}. " - f"Disable --kv-cache-quantization when using multimodal " - f"models with --continuous-batching." - ) + per_request_caches.append(request_cache) + req.vision_encoded = True + logger.debug( + f"Prefix cache exact hit for {req.request_id}: " + f"all {total_tokens} tokens cached" + ) + + else: + # Cache miss — full forward pass + request_cache = make_prompt_cache(self.language_model) + + with mx.stream(MLLMBatchGenerator._stream): + # Text-only: chunked prefill with real progress tracking + # Multimodal: atomic VLM forward (vision encoder needs full input) + if req.is_text_only: + logits = self._run_chunked_text_prefill( + req, cache=request_cache + ) + else: + logits = self._run_vision_encoding(req, cache=request_cache) + + # Extract last token logits and sample + last_logits = logits[:, -1, :] + logprobs = last_logits - mx.logsumexp( + last_logits, axis=-1, keepdims=True + ) + sampled = self.sampler(logprobs) + + mx.eval(sampled, logprobs) + + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) + per_request_caches.append(request_cache) + + except PrefillAbortedError: + aborted_requests.append(req) + self._prefill_progress.pop(req.request_id, None) + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="abort", + ) + ) + + # Remove aborted requests — they have no entries in the parallel + # lists (first_tokens, all_logprobs, per_request_caches) + if aborted_requests: + for req in aborted_requests: + requests.remove(req) + mx.clear_cache() + if not requests: + return None + + # Merge per-request caches into batched caches. + # Both KVCache.merge() and ArraysCache.merge() produce batch-aware + # caches that support filter/extend/extract for continuous batching. + # # Fix: RotatingKVCache._update_concat does NOT trim on first call — # if prompt length > max_size, the buffer grows beyond max_size. # BatchRotatingKVCache.merge() then hits a shape mismatch when # copying via _temporal_order (full buffer) into a max_size slice. # Trim buffer to max_size before merging. + from mlx_lm.models.cache import RotatingKVCache + for rc in per_request_caches: for layer_cache in rc: if isinstance(layer_cache, RotatingKVCache): @@ -731,6 +1186,22 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: trim_size, layer_cache.values ) layer_cache._idx = layer_cache.max_size + # Normalize wrapped rotating cache for merge: + # after rotation _idx wraps around but merge() + # expects _idx == actual buffer size. + # Use keys.shape[2] (actual entries) NOT size() + # which can be inconsistent after prefix cache trim + # (size() = min(offset, max_size) but buffer may + # have fewer entries when trimmed). + actual_buf = layer_cache.keys.shape[2] + if layer_cache._idx != actual_buf and actual_buf > 0: + layer_cache.keys = layer_cache._temporal_order( + layer_cache.keys + ) + layer_cache.values = layer_cache._temporal_order( + layer_cache.values + ) + layer_cache._idx = actual_buf try: batch_cache = [ @@ -740,8 +1211,10 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: for layer_idx in range(len(per_request_caches[0])) ] except Exception as e: + sample_type = type(per_request_caches[0][0]).__name__ logger.error( - f"Failed to merge per-request KV caches: {type(e).__name__}: {e}" + f"Failed to merge per-request caches ({sample_type}): " + f"{type(e).__name__}: {e}" ) raise @@ -888,6 +1361,8 @@ def _next(self) -> List[MLLMBatchResponse]: # merged into a single BatchKVCache. Merging into an active batch # mid-generation would cause shape mismatches in attention layers, # so queued requests wait until the current batch finishes. + # Exception: text-only requests can be extended into an active batch + # via the elif branch below (they skip vision encoding entirely). if num_active == 0: requests = self.unprocessed_requests[: self.completion_batch_size] @@ -896,8 +1371,11 @@ def _next(self) -> List[MLLMBatchResponse]: return [] try: + # Save count before _process_prompts which modifies + # `requests` in-place via .remove() for failed items. + num_to_consume = len(requests) new_batch = self._process_prompts(requests) - self.unprocessed_requests = self.unprocessed_requests[len(requests) :] + self.unprocessed_requests = self.unprocessed_requests[num_to_consume:] self.active_batch = new_batch prompt_processing = True except Exception as e: @@ -919,6 +1397,49 @@ def _next(self) -> List[MLLMBatchResponse]: ) ) + # Mid-batch extend: text-only requests can join an active batch + # without vision encoding (no shape mismatch risk). + elif self.unprocessed_requests: + text_only = [ + r for r in self.unprocessed_requests if not r.images and not r.videos + ][: self.completion_batch_size] + + if text_only: + try: + # Capture UIDs before _process_prompts modifies + # text_only in-place via .remove() for failed items. + all_uids = {r.uid for r in text_only} + new_batch = self._process_prompts(text_only) + # Remove ALL requested (both successful and failed) + self.unprocessed_requests = [ + r for r in self.unprocessed_requests if r.uid not in all_uids + ] + if new_batch is not None: + batch.extend(new_batch) + prompt_processing = True + except Exception as e: + logger.warning( + f"Failed to extend batch with text-only requests: " + f"{type(e).__name__}: {e}" + ) + # Remove failed requests to avoid infinite retry loop + processed_uids = {r.uid for r in text_only} + self.unprocessed_requests = [ + r + for r in self.unprocessed_requests + if r.uid not in processed_uids + ] + for req in text_only: + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + # Collect any pending error responses (from failed preprocessing) error_responses = [] if self._pending_error_responses: @@ -988,6 +1509,8 @@ def _next(self) -> List[MLLMBatchResponse]: if finish_reason is not None: # Extract cache for this request cache_fn = lambda idx=i: batch.extract_cache(idx) + # Cleanup prefill progress tracking + self._prefill_progress.pop(request_id, None) responses.append( MLLMBatchResponse( @@ -1000,6 +1523,9 @@ def _next(self) -> List[MLLMBatchResponse]: ) ) + # Store caches for finished text-only requests BEFORE filtering + self._maybe_store_prefix_cache(batch, end_idx) + # Remove finished requests from batch if end_idx: if keep_idx: @@ -1030,10 +1556,404 @@ def stats(self) -> MLLMBatchStats: self._stats.peak_memory = mx.get_peak_memory() / 1e9 return self._stats + def _maybe_store_prefix_cache( + self, batch: MLLMBatch, end_indices: List[int] + ) -> None: + """Store KV caches for finished text-only requests into prefix cache. + + Must be called BEFORE batch.filter() so that indices are still valid. + """ + if self.prefix_cache is None or not end_indices: + return + for i in end_indices: + req = batch.requests[i] + if req.input_ids is not None: + try: + extracted = batch.extract_cache(i) + input_ids_list = req.input_ids.reshape(-1).tolist() + # Store prompt-only KV (trim output tokens + 1 so next + # fetch returns remaining=[last_prompt_token] at minimum). + # Also strip think suffix from key so next request's + # (also stripped) key matches as a clean PREFIX. + output_count = batch.num_tokens[i] + S = self._think_suffix_len + total_trim = output_count + 1 + S + prompt_cache = _trim_cache_offset(extracted, total_trim) + cache_key = input_ids_list[:-S] if S > 0 else input_ids_list + self.prefix_cache.store(cache_key, prompt_cache) + except Exception as e: + logger.warning( + f"Failed to store prefix cache for {req.request_id}: {type(e).__name__}: {e}" + ) + + def get_prefill_progress(self, request_id: str) -> Optional[Tuple[int, int]]: + """Return (processed_tokens, total_tokens) or None.""" + return self._prefill_progress.get(request_id) + def get_vision_cache_stats(self) -> Dict[str, Any]: """Get vision cache statistics.""" return self.vision_cache.get_stats() + def get_prefix_cache_stats(self) -> Dict[str, Any]: + """Get KV prefix cache statistics.""" + if self.prefix_cache is not None: + return self.prefix_cache.get_stats() + return { + "hits": 0, + "misses": 0, + "hit_rate": 0.0, + "evictions": 0, + "tokens_saved": 0, + "current_memory_mb": 0.0, + "max_memory_mb": 0.0, + "memory_utilization": 0.0, + "entry_count": 0, + } + def has_pending(self) -> bool: """Check if there are pending or active requests.""" return bool(self.unprocessed_requests or self.active_batch) + + +def install_mtp_mllm( + batch_gen: "MLLMBatchGenerator", + language_model: Any, + num_draft_tokens: int = 1, +) -> None: + """Install MTP (Multi-Token Prediction) on an MLLMBatchGenerator. + + Adapts the always-advance MTP strategy from scheduler._install_mtp + for the MLLM batched generation path. Handles hybrid model caches + (BatchKVCache for attention + ArraysCache for recurrent layers). + + Flow per generation step: + 1. Use skip_state logits/hidden OR run model forward -> sample primary + 2. MTP head drafts one token + 3. Verify [primary, draft] in one model call (always advances cache) + 4. Accept: skip_state from pos 1, defer draft for next step emission + Reject: trim KV by 2 + restore RNN state + re-advance with primary + 5. Draft is emitted in the NEXT generation step after primary + """ + from .scheduler import make_sampler + + _orig_step = batch_gen._step + _draft_sampler = make_sampler(temp=0.0) + + # Skip state: stored logits + hidden from verify pass + _skip_state: list = [None] + + # Deferred drafts keyed by UID + _deferred_drafts: Dict[int, dict] = {} + + # MTP stats + _mtp_stats = {"accepted": 0, "rejected": 0, "errors": 0} + + def _mtp_step( + input_tokens: mx.array, + cache: List[Any], + logits_processors: Optional[List[Optional[List[Callable]]]] = None, + output_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Optional[Callable]]] = None, + ) -> Tuple[mx.array, List[mx.array]]: + """Extended _step with MTP always-advance strategy.""" + batch_size = input_tokens.shape[0] + + # Prefill guard: skip MTP for multi-token input or when no active batch + # Also skip MTP when batch has multiple active requests (MTP overhead + # hurts aggregate throughput in concurrent scenarios) + if ( + input_tokens.shape[1] > 1 + or batch_gen.active_batch is None + or len(batch_gen.active_batch) > 1 + ): + _skip_state[0] = None + return _orig_step( + input_tokens, cache, logits_processors, output_tokens, samplers + ) + + # Check skip state + skip = _skip_state[0] + if skip is not None and skip["logits"].shape[0] != batch_size: + skip = None + _skip_state[0] = None + + if skip is not None: + logits = skip["logits"] + hidden_states = skip["hidden"] + _skip_state[0] = None + else: + # Normal forward with return_hidden + model_output = language_model(input_tokens, cache=cache, return_hidden=True) + if isinstance(model_output, tuple): + logits, hidden_states = model_output + else: + return _orig_step( + input_tokens, cache, logits_processors, output_tokens, samplers + ) + logits = logits[:, -1, :] + + # Apply logits processors before sampling + if logits_processors and output_tokens and any(logits_processors): + processed_logits = [] + for e in range(batch_size): + sample_logits = logits[e : e + 1] + if logits_processors[e]: + for processor in logits_processors[e]: + sample_logits = processor( + mx.array(output_tokens[e]), sample_logits + ) + processed_logits.append(sample_logits) + logits = mx.concatenate(processed_logits, axis=0) + + # Sample primary (use per-request sampler if available) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + if samplers and any(samplers): + sampled_list = [] + for e in range(logprobs.shape[0]): + s = samplers[e] if samplers[e] else batch_gen.sampler + sampled_list.append(s(logprobs[e : e + 1])) + primary_tokens = mx.concatenate(sampled_list, axis=0) + else: + primary_tokens = batch_gen.sampler(logprobs) + + current_uids = list(batch_gen.active_batch.uids) + + # MTP draft + always-advance verify + try: + draft_logits = language_model.mtp_forward( + hidden_states[:, -1:, :], + primary_tokens[:, None], + mtp_cache=None, + ) + draft_logits = draft_logits[:, -1, :] + draft_logprobs = draft_logits - mx.logsumexp( + draft_logits, axis=-1, keepdims=True + ) + draft_tokens = _draft_sampler(draft_logprobs) + + # Snapshot RNN state for hybrid models + _rnn_snapshots = {} + for _ci, _c in enumerate(cache): + if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()): + if hasattr(_c, "state"): + _rnn_snapshots[_ci] = [ + mx.array(s) if s is not None else None for s in _c.state + ] + + # Verify [primary, draft] + verify_input = mx.concatenate( + [primary_tokens[:, None], draft_tokens[:, None]], axis=1 + ) + verify_output = language_model( + verify_input, cache=cache, return_hidden=True + ) + if isinstance(verify_output, tuple): + verify_logits, verify_hidden = verify_output + else: + verify_logits = verify_output + verify_hidden = None + + # Verified mode: check if draft matches verify prediction + verify_pred = mx.argmax(verify_logits[:, 0, :], axis=-1) + mx.eval(verify_pred, draft_tokens) + pred_list = verify_pred.tolist() + draft_list = draft_tokens.tolist() + all_accepted = pred_list == draft_list + + if all_accepted and verify_hidden is not None: + # ACCEPT + _skip_state[0] = { + "logits": verify_logits[:, 1, :], + "hidden": verify_hidden[:, -1:, :], + } + mx.async_eval(_skip_state[0]["logits"], _skip_state[0]["hidden"]) + verify_lp = verify_logits[:, 0, :] - mx.logsumexp( + verify_logits[:, 0, :], axis=-1, keepdims=True + ) + for e in range(batch_size): + uid = current_uids[e] + _deferred_drafts[uid] = { + "token": draft_list[e], + "logprobs": verify_lp[e], + } + _mtp_stats["accepted"] += 1 + + else: + # REJECT + if _rnn_snapshots: + # Hybrid model: undo entire verify, re-advance with primary + for c in cache: + if ( + hasattr(c, "is_trimmable") + and c.is_trimmable() + and hasattr(c, "trim") + ): + c.trim(2) + for _ci, _snap in _rnn_snapshots.items(): + cache[_ci].state = _snap + rerun_out = language_model( + primary_tokens[:, None], + cache=cache, + return_hidden=True, + ) + if isinstance(rerun_out, tuple): + rerun_logits, rerun_hidden = rerun_out + else: + rerun_logits = rerun_out + rerun_hidden = None + if rerun_hidden is not None: + _skip_state[0] = { + "logits": rerun_logits[:, -1, :], + "hidden": rerun_hidden[:, -1:, :], + } + mx.async_eval( + _skip_state[0]["logits"], + _skip_state[0]["hidden"], + ) + else: + _skip_state[0] = None + else: + # Pure attention model: simple trim + for c in cache: + if ( + hasattr(c, "is_trimmable") + and c.is_trimmable() + and hasattr(c, "trim") + ): + c.trim(1) + if verify_hidden is not None: + _skip_state[0] = { + "logits": verify_logits[:, 0, :], + "hidden": verify_hidden[:, 0:1, :], + } + mx.async_eval( + _skip_state[0]["logits"], + _skip_state[0]["hidden"], + ) + else: + _skip_state[0] = None + for uid in current_uids: + _deferred_drafts.pop(uid, None) + _mtp_stats["rejected"] += 1 + + except Exception as e: + logger.warning(f"[MTP-MLLM] draft/verify failed: {e}") + _skip_state[0] = None + _mtp_stats["errors"] += 1 + + # Log MTP stats every 50 steps + total = _mtp_stats["accepted"] + _mtp_stats["rejected"] + _mtp_stats["errors"] + if total > 0 and total % 50 == 0: + acc = _mtp_stats["accepted"] + rej = _mtp_stats["rejected"] + err = _mtp_stats["errors"] + rate = acc / (acc + rej) * 100 if (acc + rej) > 0 else 0 + logger.info( + f"[MTP-MLLM] stats: accepted={acc} rejected={rej} " + f"errors={err} acceptance={rate:.0f}%" + ) + + return primary_tokens, list(logprobs) + + # Wrap _next to emit deferred MTP drafts + batch_gen._inner_next = batch_gen._next + + def _mtp_next() -> List[MLLMBatchResponse]: + """Wrapper around _next that emits deferred MTP draft tokens.""" + if batch_gen.active_batch is None: + _skip_state[0] = None + _deferred_drafts.clear() + + # Save deferred drafts from previous step + prev_deferred: Dict[int, dict] = {} + if batch_gen.active_batch is not None: + for uid in batch_gen.active_batch.uids: + if uid in _deferred_drafts: + prev_deferred[uid] = _deferred_drafts.pop(uid) + + responses = batch_gen._inner_next() + + if not prev_deferred or not responses: + return responses + + # Augment responses with deferred drafts + augmented: List[MLLMBatchResponse] = [] + draft_end_uids: set = set() + + for r in responses: + uid = r.uid + augmented.append(r) + + if r.finish_reason is not None: + _deferred_drafts.pop(uid, None) + prev_deferred.pop(uid, None) + continue + + if uid in prev_deferred: + draft_info = prev_deferred.pop(uid) + draft_t = draft_info["token"] + draft_lp = draft_info["logprobs"] + + if draft_t in batch_gen.stop_tokens: + augmented.append( + MLLMBatchResponse( + uid=uid, + request_id=r.request_id, + token=draft_t, + logprobs=draft_lp, + finish_reason="stop", + ) + ) + draft_end_uids.add(uid) + else: + draft_finish = None + batch = batch_gen.active_batch + if batch is not None: + for e, bu in enumerate(batch.uids): + if bu == uid: + batch.num_tokens[e] += 1 + batch.requests[e].output_tokens.append(draft_t) + if batch.num_tokens[e] >= batch.max_tokens[e]: + draft_finish = "length" + draft_end_uids.add(uid) + break + + augmented.append( + MLLMBatchResponse( + uid=uid, + request_id=r.request_id, + token=draft_t, + logprobs=draft_lp, + finish_reason=draft_finish, + ) + ) + + # Store prefix caches for draft-ended sequences BEFORE filtering + if draft_end_uids and batch_gen.active_batch is not None: + end_indices = [ + e + for e, u in enumerate(batch_gen.active_batch.uids) + if u in draft_end_uids + ] + batch_gen._maybe_store_prefix_cache(batch_gen.active_batch, end_indices) + + keep = [ + e + for e, u in enumerate(batch_gen.active_batch.uids) + if u not in draft_end_uids + ] + if keep: + batch_gen.active_batch.filter(keep) + else: + batch_gen.active_batch = None + + return augmented + + batch_gen._step = _mtp_step + batch_gen._next = _mtp_next + + total = _mtp_stats + logger.info( + f"[MTP-MLLM] installed with num_draft_tokens={num_draft_tokens}, " + f"always-advance verified mode" + ) diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 1a1c0c45e..04c7cac2a 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -19,6 +19,7 @@ """ import asyncio +import concurrent.futures import logging import time import uuid @@ -35,7 +36,6 @@ MLLMBatchRequest, MLLMBatchResponse, ) -from .mllm_cache import MLLMCacheManager from .multimodal_processor import MultimodalProcessor from .request import RequestOutput, RequestStatus, SamplingParams @@ -62,8 +62,22 @@ class MLLMSchedulerConfig: default_max_tokens: int = 256 # Default video FPS for frame extraction default_video_fps: float = 2.0 + # KV cache memory limit (from --cache-memory-mb) + cache_memory_mb: Optional[int] = None # Maximum video frames max_video_frames: int = 128 + # Enable MTP speculative decoding + enable_mtp: bool = False + # Number of draft tokens for MTP + mtp_num_draft_tokens: int = 1 + # Enable KV prefix cache for text-only requests + enable_prefix_cache: bool = True + # Memory limit for prefix cache (None = auto-detect) + prefix_cache_memory_mb: Optional[int] = None + # KV cache quantization for prefix cache store/fetch + kv_cache_quantization: bool = False + kv_cache_quantization_bits: int = 8 + kv_cache_quantization_group_size: int = 64 @dataclass @@ -94,6 +108,9 @@ class MLLMRequest: num_prompt_tokens: int = 0 num_output_tokens: int = 0 + # Timing + first_token_time: Optional[float] = None + @dataclass class MLLMSchedulerOutput: @@ -176,13 +193,6 @@ def __init__( config=self.model_config, ) - # Vision cache for repeated images - self.vision_cache: Optional[MLLMCacheManager] = None - if self.config.enable_vision_cache: - self.vision_cache = MLLMCacheManager( - max_entries=self.config.vision_cache_size - ) - # Get stop tokens from tokenizer self.stop_tokens = self._get_stop_tokens() @@ -218,6 +228,10 @@ def __init__( self.total_prompt_tokens = 0 self.total_completion_tokens = 0 + # Memory management: periodic mx.clear_cache() to free Metal buffers + self._step_count = 0 + self._clear_cache_interval = 32 + def _get_stop_tokens(self) -> Set[int]: """Get stop token IDs from tokenizer and generation_config.json.""" stop_tokens = set() @@ -265,9 +279,24 @@ def _ensure_batch_generator(self) -> None: if self.batch_generator is None: from mlx_lm.sample_utils import make_sampler + from .memory_cache import MemoryCacheConfig + # Default sampler (can be overridden per-request in future) sampler = make_sampler(temp=0.7, top_p=0.9) + # Configure KV prefix cache for text-only requests + # KV cache quantization reduces prefix cache memory ~4x (BF16→Q8). + # Quantization happens on store(), dequantization on fetch() — + # the model always receives normal KVCache with plain arrays. + prefix_cache_config = None + if self.config.enable_prefix_cache: + prefix_cache_config = MemoryCacheConfig( + max_memory_mb=self.config.prefix_cache_memory_mb, + kv_quantize=self.config.kv_cache_quantization, + kv_bits=self.config.kv_cache_quantization_bits, + kv_group_size=self.config.kv_cache_quantization_group_size, + ) + self.batch_generator = MLLMBatchGenerator( model=self.model, processor=self.processor, @@ -278,8 +307,21 @@ def _ensure_batch_generator(self) -> None: prefill_batch_size=self.config.prefill_batch_size, completion_batch_size=self.config.completion_batch_size, prefill_step_size=self.config.prefill_step_size, + prefix_cache_config=prefix_cache_config, ) + # Install MTP if enabled and language model supports it + if self.config.enable_mtp: + lm = self.batch_generator.language_model + if hasattr(lm, "mtp") and lm.mtp is not None: + from .mllm_batch_generator import install_mtp_mllm + + install_mtp_mllm( + self.batch_generator, + lm, + num_draft_tokens=self.config.mtp_num_draft_tokens, + ) + # ========== Sync API (step-based) ========== def add_request( @@ -330,6 +372,19 @@ def add_request( sampling_params=sampling_params, ) + # Estimate prompt token count for monitoring (text tokens only; + # vision tokens are added during prefill but this gives a useful + # approximation for the status endpoint). + tokenizer = ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + try: + request.num_prompt_tokens = len(tokenizer.encode(prompt)) + except Exception: + pass + self.requests[request_id] = request self.waiting.append(request) @@ -354,6 +409,12 @@ def abort_request(self, request_id: str) -> bool: if request is None: return False + # Signal batch generator to abort any in-progress prefill for this + # request. The prefill loop checks _aborted_request_ids between + # chunks and raises PrefillAbortedError to exit early. + if self.batch_generator is not None: + self.batch_generator.abort_prefill(request_id) + # Remove from waiting queue if request.status == RequestStatus.WAITING: try: @@ -480,21 +541,41 @@ def _process_batch_responses( if request is None: continue + # Handle error responses from failed preprocessing + if response.finish_reason == "error": + output = RequestOutput( + request_id=request_id, + new_token_ids=[], + new_text="", + output_token_ids=[], + prompt_tokens=0, + completion_tokens=0, + finished=True, + finish_reason="error", + ) + request.status = RequestStatus.FINISHED_ABORTED + request.output_text = "" + request.finish_reason = "error" + finished_ids.add(request_id) + self.num_requests_processed += 1 + logger.warning(f"Request {request_id} failed during preprocessing") + outputs.append(output) + continue + # Append token to request request.output_tokens.append(response.token) request.num_output_tokens = len(request.output_tokens) + if request.first_token_time is None and request.num_output_tokens > 0: + request.first_token_time = time.time() + # Decode the new token using streaming detokenizer (UTF-8 safe). - # Skip stop tokens and error placeholders — they are not content. - if response.finish_reason in ("stop", "error"): + # Skip stop tokens — they are not content. + if response.finish_reason == "stop": new_text = "" else: if request_id not in self._detokenizer_pool: - if hasattr(tokenizer, "detokenizer"): - detok = tokenizer.detokenizer - else: - detok = NaiveStreamingDetokenizer(tokenizer) - detok.reset() + detok = NaiveStreamingDetokenizer(tokenizer) self._detokenizer_pool[request_id] = detok detok = self._detokenizer_pool[request_id] detok.add_token(response.token) @@ -516,15 +597,13 @@ def _process_batch_responses( request.status = RequestStatus.FINISHED_STOPPED elif response.finish_reason == "length": request.status = RequestStatus.FINISHED_LENGTH_CAPPED - elif response.finish_reason == "error": - request.status = RequestStatus.FINISHED_ABORTED output.finished = True output.finish_reason = response.finish_reason finished_ids.add(request_id) # Finalize streaming detokenizer and get full output - detok = self._detokenizer_pool.get(request_id) + detok = self._detokenizer_pool.pop(request_id, None) if detok is not None: detok.finalize() output.output_text = detok.text @@ -532,7 +611,6 @@ def _process_batch_responses( output.output_text = tokenizer.decode(request.output_tokens) request.output_text = output.output_text request.finish_reason = response.finish_reason - self._detokenizer_pool.pop(request_id, None) self.total_completion_tokens += request.num_output_tokens self.num_requests_processed += 1 @@ -553,6 +631,9 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: if request_id in self.running: del self.running[request_id] + # Drain from requests dict to prevent linear memory growth + self.requests.pop(request_id, None) + # Remove UID mappings if request_id in self.request_id_to_uid: uid = self.request_id_to_uid[request_id] @@ -560,10 +641,17 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: del self.uid_to_request_id[uid] del self.request_id_to_uid[request_id] + # Clean up detokenizer pool (handles abort/timeout cases) + self._detokenizer_pool.pop(request_id, None) + # Track as finished self.finished_req_ids.add(request_id) self.requests.pop(request_id, None) + # Clear Metal buffer pool after cleanup to release memory + if finished_ids: + mx.clear_cache() + def step(self) -> MLLMSchedulerOutput: """ Execute one scheduling step. @@ -663,14 +751,33 @@ async def stop(self) -> None: logger.info("MLLM Scheduler stopped") async def _process_loop(self) -> None: - """Main async processing loop.""" + """Main async processing loop. + + Uses a thread pool executor for steps that involve prefill + (waiting requests or partial prefill in progress) so that the + event loop stays responsive for health checks and other HTTP + endpoints. Decode-only steps are fast (<3 ms) and run inline. + """ + _executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mllm-step" + ) + loop = asyncio.get_running_loop() + while self._running: try: if self.has_requests(): - # Run one step - self.step() - # Yield to other tasks - await asyncio.sleep(0) + has_waiting = self.get_num_waiting() > 0 + has_partial = ( + self.batch_generator is not None + and getattr(self.batch_generator, "_partial", None) is not None + ) + needs_executor = has_waiting or has_partial + + if needs_executor: + await loop.run_in_executor(_executor, self.step) + else: + self.step() + await asyncio.sleep(0) else: # No work, wait a bit await asyncio.sleep(0.01) @@ -678,7 +785,7 @@ async def _process_loop(self) -> None: except asyncio.CancelledError: break except Exception as e: - logger.error(f"Error in MLLM process loop: {e}") + logger.error(f"Error in MLLM process loop: {e}", exc_info=True) await asyncio.sleep(0.1) async def add_request_async( @@ -807,6 +914,77 @@ async def generate( # ========== Stats and utilities ========== + def get_running_requests_info(self) -> List[Dict[str, Any]]: + """Per-request details for status endpoint.""" + now = time.time() + result = [] + + # Waiting requests + for req in self.waiting: + result.append( + { + "request_id": req.request_id, + "status": "waiting", + "phase": "queued", + "elapsed_s": round(now - req.arrival_time, 2), + "prompt_tokens": req.num_prompt_tokens, + "completion_tokens": 0, + "max_tokens": req.sampling_params.max_tokens, + "progress": 0.0, + "tokens_per_second": None, + "ttft_s": None, + "cache_hit_type": None, + "cached_tokens": 0, + } + ) + + # Running requests + for req in self.running.values(): + n_out = req.num_output_tokens + elapsed = now - req.arrival_time + + if n_out == 0: + phase = "prefill" + else: + phase = "generation" + + tok_s = None + ttft = None + if req.first_token_time is not None: + ttft = round(req.first_token_time - req.arrival_time, 3) + gen_elapsed = now - req.first_token_time + if gen_elapsed > 0 and n_out > 0: + tok_s = round(n_out / gen_elapsed, 1) + + max_tokens = req.sampling_params.max_tokens + if phase == "prefill" and self.batch_generator is not None: + pp = self.batch_generator.get_prefill_progress(req.request_id) + if pp is not None: + progress = round(pp[0] / pp[1], 3) if pp[1] > 0 else 0.0 + else: + progress = 0.0 + else: + progress = round(n_out / max_tokens, 3) if max_tokens > 0 else 0.0 + + result.append( + { + "request_id": req.request_id, + "status": "running", + "phase": phase, + "elapsed_s": round(elapsed, 2), + "prompt_tokens": req.num_prompt_tokens, + "completion_tokens": n_out, + "max_tokens": max_tokens, + "progress": min(progress, 1.0), + "tokens_per_second": tok_s, + "ttft_s": ttft, + "cache_hit_type": None, + "cached_tokens": 0, + } + ) + + return result + def get_stats(self) -> Dict[str, Any]: """Get scheduler statistics.""" stats = { @@ -816,27 +994,45 @@ def get_stats(self) -> Dict[str, Any]: "num_requests_processed": self.num_requests_processed, "total_prompt_tokens": self.total_prompt_tokens, "total_completion_tokens": self.total_completion_tokens, + "requests": self.get_running_requests_info(), } if self.batch_generator is not None: batch_stats = self.batch_generator.stats() stats["batch_generator"] = batch_stats.to_dict() - # Add vision embedding cache stats from batch generator - stats["vision_embedding_cache"] = ( - self.batch_generator.get_vision_cache_stats() - ) - - if self.vision_cache: - stats["vision_cache"] = self.vision_cache.get_stats() + # Vision embedding cache stats from batch generator + vec_stats = self.batch_generator.get_vision_cache_stats() + stats["vision_embedding_cache"] = vec_stats # Include Metal memory stats try: if mx.metal.is_available(): - stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2) - stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2) - stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2) + active_gb = round(mx.get_active_memory() / 1e9, 2) + peak_gb = round(mx.get_peak_memory() / 1e9, 2) + cache_gb = round(mx.get_cache_memory() / 1e9, 2) + stats["metal_active_memory_gb"] = active_gb + stats["metal_peak_memory_gb"] = peak_gb + stats["metal_cache_memory_gb"] = cache_gb except Exception: - pass + active_gb = 0 + cache_gb = 0 + + # KV prefix cache stats for /v1/status and monitoring UI. + if self.batch_generator is not None: + prefix_stats = self.batch_generator.get_prefix_cache_stats() + else: + prefix_stats = { + "hits": 0, + "misses": 0, + "hit_rate": 0.0, + "evictions": 0, + "tokens_saved": 0, + "current_memory_mb": 0.0, + "max_memory_mb": 0.0, + "memory_utilization": 0.0, + "entry_count": 0, + } + stats["memory_aware_cache"] = prefix_stats return stats diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index a6c67226e..5a3551eb1 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -465,8 +465,9 @@ def save_base64_image(base64_string: str) -> str: """Save base64 image to temp file and return path. Caches identical images.""" import hashlib - # Hash the base64 string to check cache - image_hash = hashlib.md5(base64_string.encode()).hexdigest() + # Hash the full base64 string to prevent collisions between images + # with identical headers (e.g. JPEG images sharing first 1000 chars) + image_hash = hashlib.sha256(base64_string.encode()).hexdigest() # Return cached path if available and file still exists if image_hash in _base64_image_cache: @@ -1328,6 +1329,7 @@ def chat( video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) tools = kwargs.pop("tools", None) use_cache = kwargs.pop("use_cache", True) + enable_thinking = kwargs.pop("enable_thinking", True) # Collect video inputs from messages _msg_video_inputs = self._collect_video_inputs(messages) @@ -1453,11 +1455,11 @@ def chat( template_extra_kwargs["tools"] = tools try: - # Use get_chat_template directly since messages are already properly formatted formatted_prompt = get_chat_template( self.processor, chat_messages, add_generation_prompt=True, + enable_thinking=enable_thinking, **template_extra_kwargs, ) except Exception as e: @@ -1724,6 +1726,7 @@ def stream_chat( video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) tools = kwargs.pop("tools", None) use_cache = kwargs.pop("use_cache", True) + enable_thinking = kwargs.pop("enable_thinking", True) # Collect video inputs from messages _msg_video_inputs = self._collect_video_inputs(messages) @@ -1838,6 +1841,7 @@ def stream_chat( self.processor, chat_messages, add_generation_prompt=True, + enable_thinking=enable_thinking, **template_extra_kwargs, ) except Exception as e: @@ -2091,8 +2095,6 @@ def is_mllm_model(model_name: str) -> bool: "PaliGemma", "gemma-3", "gemma3", # Gemma 3 (multimodal) - "gemma-4", - "gemma4", # Gemma 4 (multimodal: vision + audio) "medgemma", "MedGemma", # MedGemma (medical multimodal) "pixtral", diff --git a/vllm_mlx/multimodal_processor.py b/vllm_mlx/multimodal_processor.py index 2905e9abb..a5c861216 100644 --- a/vllm_mlx/multimodal_processor.py +++ b/vllm_mlx/multimodal_processor.py @@ -147,7 +147,7 @@ def process( logger.warning(f"Failed to process video: {e}") # Determine add_special_tokens based on model type - if self.config and self.config.model_type in ["gemma3", "gemma3n", "gemma4"]: + if self.config and self.config.model_type in ["gemma3", "gemma3n"]: add_special_tokens = not hasattr(self.processor, "chat_template") # Prepare inputs using mlx_vlm diff --git a/vllm_mlx/patches/qwen3_5_mllm.py b/vllm_mlx/patches/qwen3_5_mllm.py new file mode 100644 index 000000000..c592928da --- /dev/null +++ b/vllm_mlx/patches/qwen3_5_mllm.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime patch for mlx-vlm's Qwen3.5 attention to support BatchKVCache. + +mlx-vlm's Qwen3_5Attention uses cache.offset directly for kv_seq_len +computation and mask slicing. BatchKVCache stores offset as mx.array +(per-batch-item), not int, causing: + + mask = mask[..., :kv_seq_len] + ValueError: Slice indices must be integers or None. + +This patch replaces Qwen3_5Attention.__call__ with a version that +converts cache.offset to int before using it for arithmetic/slicing, +while leaving the actual cache.offset untouched so update_and_fetch +still works correctly with per-batch offsets. +""" + +import logging +from typing import Optional + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +def _cache_offset_to_int(cache) -> int: + """Extract cache offset as int, handling BatchKVCache mx.array offset.""" + if cache is None: + return 0 + off = cache.offset + if isinstance(off, int): + return off + if isinstance(off, mx.array): + return int(off.max().item()) if off.ndim > 0 else int(off.item()) + return int(off) + + +def patch_qwen35_attention_for_batching() -> bool: + """Monkey-patch Qwen3_5Attention.__call__ to handle BatchKVCache. + + Returns True if patch was applied, False if mlx-vlm is not installed + or Qwen3.5 module not available. + """ + try: + from mlx_vlm.models.qwen3_5.language import ( + Qwen3_5Attention, + apply_multimodal_rotary_pos_emb, + ) + from mlx_lm.models.base import scaled_dot_product_attention + except ImportError: + logger.debug("[Qwen3.5 patch] mlx-vlm Qwen3.5 module not available") + return False + + if getattr(Qwen3_5Attention, "_batch_patched", False): + logger.debug("[Qwen3.5 patch] Already patched") + return True + + def _patched_call( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache=None, + position_ids: Optional[mx.array] = None, + ) -> mx.array: + B, L, D = x.shape + + q_proj_output = self.q_proj(x) + queries, gate = mx.split( + q_proj_output.reshape(B, L, self.num_attention_heads, -1), + 2, + axis=-1, + ) + gate = gate.reshape(B, L, -1) + + keys, values = self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries).transpose(0, 2, 1, 3) + keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + kv_seq_len = keys.shape[-2] + + # Convert cache.offset to int for slice compatibility. + # BatchKVCache stores offset as mx.array (per-batch-item), + # but kv_seq_len must be int for mask[..., :kv_seq_len]. + _offset = _cache_offset_to_int(cache) + + if position_ids is None: + kv_seq_len += _offset + 1 + position_ids = mx.arange(_offset, _offset + L) + position_ids = mx.expand_dims(position_ids, axis=0) + position_ids = mx.tile(position_ids, (3, 1, 1)) + else: + kv_seq_len += _offset + 1 if cache is not None else 0 + + cos, sin = self.rotary_emb(values, position_ids) + + if mask is not None and isinstance(mask, mx.array): + mask = mask[..., :kv_seq_len] + + queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output * mx.sigmoid(gate)) + + Qwen3_5Attention.__call__ = _patched_call + Qwen3_5Attention._batch_patched = True + logger.info("[Qwen3.5 patch] Attention patched for BatchKVCache support") + return True diff --git a/vllm_mlx/patches/qwen3_5_mtp.py b/vllm_mlx/patches/qwen3_5_mtp.py new file mode 100644 index 000000000..3d5f3e632 --- /dev/null +++ b/vllm_mlx/patches/qwen3_5_mtp.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime MTP (Multi-Token Prediction) support for Qwen3.5 models. + +Qwen3.5 models may include a built-in MTP head that predicts token n+2 +from hidden states + token n+1. MTP weights are added to the quantized +MLX model via scripts/add_mtp_weights_qwen35.py. + +Since mlx_lm's qwen3_5.py does NOT define MTP module/methods, this +module provides: + - inject_mtp_support(): dynamically creates MTP module, loads weights, + and monkey-patches the model class with return_hidden, mtp_forward, + and make_mtp_cache + - validate_mtp_support(): checks whether a loaded model has working MTP + +Supports both Dense (27B) and MoE (122B-A10B, 35B-A3B) architectures. + +The actual MTP scheduling logic lives in: + - vllm_mlx/scheduler.py (_install_mtp, _mtp_step, _mtp_next) +""" + +import logging +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def _fixup_moe_mtp(mtp, inner_model, loaded_keys: set, mx) -> None: + """Fix missing weights in MoE MTP module. + + MoE MTP checkpoints (122B, 35B) only contain: fc, q_proj, o_proj, + shared_expert.*, and per-expert weights. Missing: + - k_proj, v_proj → zero out (attention becomes no-op) + - gate, shared_expert_gate → copy from main model's last full-attn layer + - norms → already at identity (weight=1.0), no action needed + """ + import mlx.utils + + mtp_layer = mtp.layers[0] + + # Find last full-attention layer in main model for gate weights + last_fa_layer = None + for layer in reversed(inner_model.layers): + if not layer.is_linear: + last_fa_layer = layer + break + + if last_fa_layer is None: + logger.warning("[MTP fixup] No full-attention layer found in main model") + return + + # Copy expert routing gate if not in checkpoint + if "layers.0.mlp.gate.weight" not in loaded_keys: + src = getattr(last_fa_layer.mlp, "gate", None) + dst = getattr(mtp_layer.mlp, "gate", None) + if src is not None and dst is not None: + src_params = mlx.utils.tree_flatten(src.parameters()) + dst.load_weights(src_params) + mx.eval(dst.parameters()) + logger.info("[MTP fixup] Copied mlp.gate from main model last layer") + + # Copy shared_expert_gate if not in checkpoint + if "layers.0.mlp.shared_expert_gate.weight" not in loaded_keys: + src = getattr(last_fa_layer.mlp, "shared_expert_gate", None) + dst = getattr(mtp_layer.mlp, "shared_expert_gate", None) + if src is not None and dst is not None: + src_params = mlx.utils.tree_flatten(src.parameters()) + dst.load_weights(src_params) + mx.eval(dst.parameters()) + logger.info( + "[MTP fixup] Copied shared_expert_gate from main model last layer" + ) + + # Zero out k_proj and v_proj → attention becomes no-op + attn = getattr(mtp_layer, "self_attn", None) + if attn is None: + return + + for proj_name in ("k_proj", "v_proj"): + key = f"layers.0.self_attn.{proj_name}.weight" + if key not in loaded_keys: + proj = getattr(attn, proj_name, None) + if proj is None: + continue + # For quantized layers: zero scales+biases → dequantized = 0 + if hasattr(proj, "scales"): + proj.scales = mx.zeros_like(proj.scales) + proj.biases = mx.zeros_like(proj.biases) + else: + proj.weight = mx.zeros_like(proj.weight) + mx.eval(proj.parameters()) + logger.info(f"[MTP fixup] Zeroed {proj_name} (not in checkpoint)") + + +def inject_mtp_support(model: Any, model_path, config: dict) -> bool: + """Inject MTP module into a loaded Qwen3.5 model. + + mlx_lm's qwen3_5.py does not define MTP layers, so we: + 1. Create MTP module matching the weight structure + 2. Quantize it to match the base model + 3. Load MTP weights from model-mtp.safetensors + 4. Monkey-patch Model with return_hidden, mtp_forward, make_mtp_cache + + Args: + model: A model loaded via mlx_lm (strict=False, MTP weights ignored) + model_path: Path to model directory (contains model-mtp.safetensors) + config: Parsed config.json dict + + Returns: + True if MTP was successfully injected, False otherwise. + """ + import mlx.core as mx + import mlx.nn as nn + + # Navigate nested config: text_config for VLM wrappers + text_config = config.get("text_config", config) + num_mtp_layers = text_config.get("mtp_num_hidden_layers", 0) + if num_mtp_layers == 0: + # Fallback: check flat config for num_nextn_predict_layers + num_mtp_layers = text_config.get( + "num_nextn_predict_layers", + config.get("num_nextn_predict_layers", 0), + ) + if num_mtp_layers == 0: + logger.info("[MTP inject] No MTP layers configured, skipping") + return False + + model_path = Path(model_path) + # Look for MTP weights in mtp/ subdirectory first (avoids mlx_vlm glob), + # then fall back to model-mtp.safetensors in model dir. + mtp_file = model_path / "mtp" / "weights.safetensors" + if not mtp_file.exists(): + mtp_file = model_path / "model-mtp.safetensors" + if not mtp_file.exists(): + logger.warning(f"[MTP inject] MTP weights not found in {model_path}") + return False + + # Get model args — navigate VLM wrapper if needed + # Model hierarchy: Model → language_model (TextModel) → model (Qwen3_5TextModel) + text_model = model + if hasattr(model, "language_model"): + text_model = model.language_model + + args = text_model.args + + # When loaded via mlx_vlm, args may be a TextConfig object missing fields + # that mlx_lm's TextModelArgs defines (rope_theta, partial_rotary_factor, + # rope_scaling, etc.). Build a proper TextModelArgs from the config dict. + from mlx_lm.models.qwen3_5 import TextModelArgs + + if not isinstance(args, TextModelArgs): + logger.info("[MTP inject] Building TextModelArgs from config dict") + args = TextModelArgs.from_dict(text_config) + + # Detect MoE vs Dense from args + num_experts = getattr(args, "num_experts", 0) + is_moe = num_experts > 0 + + # Import model components + from mlx_lm.models.base import create_attention_mask, create_ssm_mask + from mlx_lm.models.cache import KVCache + from mlx_lm.models.qwen3_5 import DecoderLayer + + logger.info( + f"[MTP inject] Creating MTP module ({num_mtp_layers} layers, " + f"{'MoE' if is_moe else 'Dense'})" + ) + + # MTP decoder uses full attention (not GatedDeltaNet). + # layer_idx = full_attention_interval - 1 ensures is_linear=False. + fa_idx = args.full_attention_interval - 1 + + class _MTPModule(nn.Module): + def __init__(self, args, n_layers): + super().__init__() + self.pre_fc_norm_hidden = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.pre_fc_norm_embedding = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + self.layers = [ + DecoderLayer(args, layer_idx=fa_idx) for _ in range(n_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + mtp = _MTPModule(args, num_mtp_layers) + + # --- Load MTP weights in BF16 (no quantization) --- + # MTP head is extremely sensitive to quantization — even 4-bit destroys + # prediction quality (0% acceptance). Keep MTP in full precision. + # See: https://github.com/vllm-project/vllm/issues/36331 + quant_config = text_config.get("quantization", config.get("quantization", {})) + bits = quant_config.get("bits", 4) if quant_config else 4 + group_size = quant_config.get("group_size", 64) if quant_config else 64 + + logger.info( + f"[MTP inject] Loading weights from {mtp_file.name} (BF16, no quantization)" + ) + raw = mx.load(str(mtp_file)) + raw_mtp = { + k.removeprefix("mtp."): v for k, v in raw.items() if k.startswith("mtp.") + } + del raw + + # Dequantize any quantized weight triplets (weight + scales + biases) + mtp_weights: dict[str, mx.array] = {} + processed = set() + for key in sorted(raw_mtp.keys()): + if key in processed: + continue + if key.endswith(".scales") or key.endswith(".biases"): + continue + + scales_key = key.replace(".weight", ".scales") + biases_key = key.replace(".weight", ".biases") + + if scales_key in raw_mtp and biases_key in raw_mtp: + # Quantized triplet → dequantize to BF16 + dq = mx.dequantize( + raw_mtp[key], + raw_mtp[scales_key], + raw_mtp[biases_key], + group_size=group_size, + bits=bits, + ) + mtp_weights[key] = dq + processed.update([key, scales_key, biases_key]) + else: + # Already FP (norms, fc, shared_expert_gate) + mtp_weights[key] = raw_mtp[key] + processed.add(key) + del raw_mtp + + mtp.load_weights(list(mtp_weights.items()), strict=False) + mx.eval(mtp.parameters()) + + dq_count = sum(1 for k in mtp_weights if not k.endswith((".scales", ".biases"))) + has_quantized = any(k.endswith(".scales") for k in processed) + mode = "dequantized from quantized" if has_quantized else "native BF16" + logger.info(f"[MTP inject] Loaded {dq_count} MTP weight tensors ({mode})") + + # --- Step 4: Fix missing MoE MTP weights --- + # MoE checkpoints lack: k_proj, v_proj, gate, shared_expert_gate, norms. + # Norms default to identity (weight=1.0) which is correct. + # k_proj/v_proj: zero out → attention becomes no-op, MLP does prediction. + # gate/shared_expert_gate: copy from main model's last full-attention layer. + if is_moe: + loaded_key_set = set(mtp_weights.keys()) + _fixup_moe_mtp(mtp, text_model.model, loaded_key_set, mx) + + # --- Attach MTP and monkey-patch model class --- + text_model.mtp = mtp + + original_class = text_model.__class__ + + class _Qwen3_5MTP(original_class): + """Qwen3.5 with MTP support (injected at runtime).""" + + def __call__( + self, + inputs, + cache=None, + return_hidden: bool = False, + input_embeddings=None, + **kwargs, + ): + inner = self.model + if input_embeddings is not None: + hidden_states = input_embeddings + else: + hidden_states = inner.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(inner.layers) + + fa_mask = create_attention_mask(hidden_states, cache[inner.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache[inner.ssm_idx]) + + for layer, c in zip(inner.layers, cache): + mask = ssm_mask if layer.is_linear else fa_mask + hidden_states = layer(hidden_states, mask=mask, cache=c) + + normed = inner.norm(hidden_states) + + if self.args.tie_word_embeddings: + out = inner.embed_tokens.as_linear(normed) + else: + out = self.lm_head(normed) + + if return_hidden: + return out, normed # post-norm hidden states (MTP expects post-norm) + return out + + def mtp_forward( + self, + hidden_states, + next_token_ids, + cache=None, + mtp_cache=None, + ): + """Run MTP head: predict token n+2 from hidden states + token n+1.""" + input_embeds = self.model.embed_tokens(next_token_ids) + e = self.mtp.pre_fc_norm_embedding(input_embeds) + h = self.mtp.pre_fc_norm_hidden(hidden_states) + x = self.mtp.fc(mx.concatenate([e, h], axis=-1)) + + layer = self.mtp.layers[0] + c = mtp_cache[0] if mtp_cache else None + mask = create_attention_mask(x, c) + x = layer(x, mask=mask, cache=c) + + x = self.mtp.norm(x) + + if self.args.tie_word_embeddings: + return self.model.embed_tokens.as_linear(x) + return self.lm_head(x) + + def make_mtp_cache(self): + """Create KV cache for MTP layers.""" + if self.mtp is None: + return None + return [KVCache() for _ in self.mtp.layers] + + text_model.__class__ = _Qwen3_5MTP + logger.info("[MTP inject] Model class patched with MTP support") + + # If we patched the inner language_model, also expose MTP on the outer Model + if hasattr(model, "language_model") and model.language_model is text_model: + model.mtp = mtp + + return True + + +def validate_mtp_support(model: Any) -> bool: + """Validate that a loaded model has working MTP support. + + Checks: + 1. model.mtp exists and is not None + 2. model.mtp has layers with loaded weights + 3. model has return_hidden support in __call__ + 4. model has mtp_forward method + 5. model has make_mtp_cache method + + Args: + model: A model loaded via mlx_lm.load() + + Returns: + True if MTP is fully functional, False otherwise. + """ + # Navigate to text model if VLM wrapper + text_model = model + if hasattr(model, "language_model"): + text_model = model.language_model + + mtp = getattr(text_model, "mtp", None) + if mtp is None: + args = getattr(text_model, "args", None) + if args is not None: + num_mtp = getattr(args, "mtp_num_hidden_layers", 0) + if num_mtp == 0: + num_mtp = getattr(args, "num_nextn_predict_layers", 0) + if num_mtp > 0: + logger.warning( + "[MTP] Model config has MTP layers=%d but model.mtp is None. " + "Run scripts/add_mtp_weights_qwen35.py to add weights.", + num_mtp, + ) + return False + + mtp_layers = getattr(mtp, "layers", []) + if not mtp_layers: + logger.warning("[MTP] model.mtp exists but has no layers.") + return False + + import inspect + + call_sig = inspect.signature(type(text_model).__call__) + if "return_hidden" not in call_sig.parameters: + logger.warning("[MTP] Model.__call__ does not accept return_hidden parameter.") + return False + + if not hasattr(text_model, "mtp_forward") or not callable(text_model.mtp_forward): + logger.warning("[MTP] Model does not have mtp_forward() method.") + return False + + if not hasattr(text_model, "make_mtp_cache") or not callable( + text_model.make_mtp_cache + ): + logger.warning("[MTP] Model does not have make_mtp_cache() method.") + return False + + logger.info( + "[MTP] Qwen3.5 model has working MTP support: %d MTP layer(s)", + len(mtp_layers), + ) + return True diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 7b71ada22..32057c19a 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -19,7 +19,7 @@ import mlx.core as mx from mlx_lm.generate import BatchGenerator -from mlx_lm.sample_utils import make_sampler +from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig @@ -254,6 +254,10 @@ def _generation_step(self=batch_gen): batch.tokens, ) mx.async_eval(batch.y, batch.logprobs) + # Evaluate accumulated tokens to prevent Metal buffer buildup + # from lazy mx.concatenate() chains holding AGXAllocation handles + if batch.tokens: + mx.async_eval(*batch.tokens) y = y.tolist() self._stats.generation_time += _time.perf_counter() - tic_gen @@ -742,6 +746,10 @@ def _mtp_step( # --- Apply logits processors + sample primary --- if any(logits_processors): + logger.debug( + f"[logits_proc] applying {sum(len(lp) for lp in logits_processors)} " + f"processors to batch_size={batch_size}" + ) processed_logits = [] for e in range(batch_size): sample_logits = logits[e : e + 1] @@ -1188,11 +1196,7 @@ def _decode_tokens(self, token_ids: List[int]) -> str: def _get_detokenizer(self, request_id: str) -> Any: """Get or create a streaming detokenizer for a request.""" if request_id not in self._detokenizer_pool: - if hasattr(self.tokenizer, "detokenizer"): - detok = self.tokenizer.detokenizer - else: - detok = NaiveStreamingDetokenizer(self._actual_tokenizer) - detok.reset() + detok = NaiveStreamingDetokenizer(self._actual_tokenizer) self._detokenizer_pool[request_id] = detok return self._detokenizer_pool[request_id] @@ -1885,15 +1889,30 @@ def _schedule_waiting(self) -> List[Request]: request.remaining_tokens = request.prompt_token_ids tokens_to_process = request.prompt_token_ids + # Build per-request logits_processors from repetition_penalty + rep_penalty = request.sampling_params.repetition_penalty + lp = None + if rep_penalty and rep_penalty != 1.0: + lp = make_logits_processors(repetition_penalty=rep_penalty) + logger.info( + f"[rep_penalty] request={request.request_id[:12]} " + f"penalty={rep_penalty} processors={len(lp)}" + ) + # Insert into BatchGenerator with optional cache. # Wrap in try/except: if cache shapes are incompatible # (e.g. stale entry after BatchGenerator recreation), # fall back to no-cache insert instead of crashing. + insert_kwargs = { + "max_tokens": [request.sampling_params.max_tokens], + "caches": [cache_to_use] if cache_to_use else None, + } + if lp: + insert_kwargs["logits_processors"] = [lp] try: uids = self.batch_generator.insert( [tokens_to_process], - max_tokens=[request.sampling_params.max_tokens], - caches=[cache_to_use] if cache_to_use else None, + **insert_kwargs, ) except Exception as e: if cache_to_use is not None: @@ -1906,10 +1925,10 @@ def _schedule_waiting(self) -> List[Request]: request.cached_tokens = 0 request.remaining_tokens = request.prompt_token_ids tokens_to_process = request.prompt_token_ids + insert_kwargs["caches"] = None uids = self.batch_generator.insert( [tokens_to_process], - max_tokens=[request.sampling_params.max_tokens], - caches=None, + **insert_kwargs, ) else: raise @@ -1930,11 +1949,16 @@ def _schedule_waiting(self) -> List[Request]: else "" ) tokens_to_prefill = len(tokens_to_process) + rep_info = ( + f" rep_penalty={rep_penalty}" + if rep_penalty and rep_penalty != 1.0 + else "" + ) logger.info( f"[schedule] request={request.request_id[:12]} uid={uid} " f"prompt_tokens={request.num_prompt_tokens} " f"tokens_to_prefill={tokens_to_prefill}{cache_info} " - f"max_tokens={request.sampling_params.max_tokens} " + f"max_tokens={request.sampling_params.max_tokens}{rep_info} " f"running={len(self.running)} waiting={len(self.waiting)}" ) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 1c0e06d17..531a8d69a 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -104,8 +104,6 @@ ) from .api.utils import ( SPECIAL_TOKENS_PATTERN, - StreamingThinkRouter, - StreamingToolCallFilter, clean_output_text, extract_multimodal_content, is_mllm_model, # noqa: F401 @@ -169,6 +167,11 @@ def _resolve_top_p(request_value: float | None) -> float: _tool_call_parser: str | None = None # Parser name: auto, mistral, qwen, llama, hermes _tool_parser_instance = None # Instantiated parser +# Pattern to strip leaked tool call markup from content output. +# Safety net: the tool parser should consume these, but if it doesn't +# (e.g. malformed JSON, stray closing tags), strip them before emitting. +_TOOL_MARKUP_PATTERN = re.compile(r"|") + def _load_prefix_cache_from_disk() -> None: """Load prefix cache from disk during startup.""" @@ -349,6 +352,53 @@ def get_engine() -> BaseEngine: return _engine +def _coerce_tool_arguments( + arguments_json: str, tool_name: str, tools: list[dict] | None +) -> str: + """ + Coerce tool call arguments to match the tool schema. + + If a schema field expects "string" but the model produced an object/array, + JSON-stringify the value. This fixes a common LLM failure mode where models + output raw JSON objects instead of JSON strings for file content, etc. + """ + if not tools: + return arguments_json + + # Find the schema for this tool + schema = None + for tool in tools: + if isinstance(tool, dict) and tool.get("function", {}).get("name") == tool_name: + schema = tool["function"].get("parameters", {}) + break + + if not schema or "properties" not in schema: + return arguments_json + + try: + arguments = json.loads(arguments_json) + except (json.JSONDecodeError, TypeError): + return arguments_json + + if not isinstance(arguments, dict): + return arguments_json + + properties = schema.get("properties", {}) + changed = False + + for key, value in arguments.items(): + if key in properties: + expected_type = properties[key].get("type") + if expected_type == "string" and isinstance(value, (dict, list)): + arguments[key] = json.dumps(value, ensure_ascii=False, indent=2) + changed = True + + if changed: + return json.dumps(arguments, ensure_ascii=False) + + return arguments_json + + def _validate_model_name(request_model: str) -> None: """Validate that the request model name matches the served model.""" if _model_name and request_model != _model_name: @@ -414,13 +464,16 @@ def _parse_tool_calls_with_parser( _tool_parser_instance.reset() result = _tool_parser_instance.extract_tool_calls(output_text, request_dict) if result.tools_called: + tools = request_dict.get("tools") if request_dict else None tool_calls = [ ToolCall( id=tc.get("id", f"call_{uuid.uuid4().hex[:8]}"), type="function", function=FunctionCall( name=tc["name"], - arguments=tc["arguments"], + arguments=_coerce_tool_arguments( + tc["arguments"], tc["name"], tools + ), ), ) for tc in result.tool_calls @@ -499,6 +552,7 @@ def load_model( stream_interval: int = 1, max_tokens: int = 32768, force_mllm: bool = False, + gpu_memory_utilization: float = 0.90, served_model_name: str | None = None, mtp: bool = False, prefill_step_size: int = 2048, @@ -542,6 +596,7 @@ def load_model( scheduler_config=scheduler_config, stream_interval=stream_interval, force_mllm=force_mllm, + gpu_memory_utilization=gpu_memory_utilization, ) # BatchedEngine will be started in lifespan (uvicorn's event loop) # Just log for now @@ -605,14 +660,11 @@ async def health(): "tools_available": len(_mcp_manager.get_all_tools()), } - engine_stats = _engine.get_stats() if _engine else {} - return { "status": "healthy", "model_loaded": _engine is not None, "model_name": _model_name, "model_type": "mllm" if (_engine and _engine.is_mllm) else "llm", - "engine_type": engine_stats.get("engine_type", "unknown"), "mcp": mcp_info, } @@ -1041,15 +1093,19 @@ async def _disconnect_guard( generator: AsyncIterator[str], raw_request: Request, poll_interval: float = 0.5, + heartbeat_interval: float = 5.0, ) -> AsyncIterator[str]: """Wrap streaming generator to abort on client disconnect. Uses asyncio racing: each __anext__() on the inner generator is - raced against a disconnect poller. This catches disconnects even - during prefill when no chunks are being yielded for tens of seconds. - - On disconnect, aclose() propagates down the generator chain to - engine_core.stream_outputs() finally-block → abort_request(). + raced against a disconnect poller. When neither completes within + ``heartbeat_interval`` seconds, an SSE comment is yielded as a + heartbeat. This forces an ASGI write which triggers broken-pipe + detection — without heartbeats, ``is_disconnected()`` stays False + during long prefill because no data is written to the socket. + + On disconnect, the cancellation propagates to stream_outputs() + finally-block → abort_request() → abort_prefill(). """ import time as _time @@ -1058,7 +1114,9 @@ async def _disconnect_guard( def _elapsed(): return f"{_time.monotonic() - _t0:.1f}s" - logger.info(f"[disconnect_guard] START poll_interval={poll_interval}s") + logger.info( + f"[disconnect_guard] START poll={poll_interval}s heartbeat={heartbeat_interval}s" + ) async def _wait_disconnect(): poll_count = 0 @@ -1075,21 +1133,28 @@ async def _wait_disconnect(): return chunk_count = 0 + heartbeat_count = 0 disconnect_task: asyncio.Task | None = None anext_task: asyncio.Task | None = None try: aiter = generator.__aiter__() disconnect_task = asyncio.create_task(_wait_disconnect()) + anext_task = None while True: - anext_task = asyncio.ensure_future(aiter.__anext__()) + if anext_task is None: + anext_task = asyncio.ensure_future(aiter.__anext__()) + done, _ = await asyncio.wait( [anext_task, disconnect_task], return_when=asyncio.FIRST_COMPLETED, + timeout=heartbeat_interval, ) + if disconnect_task in done: logger.info( f"[disconnect_guard] CLIENT DISCONNECTED after " - f"{chunk_count} chunks, elapsed={_elapsed()}" + f"{chunk_count} chunks, {heartbeat_count} heartbeats, " + f"elapsed={_elapsed()}" ) anext_task.cancel() try: @@ -1097,20 +1162,32 @@ async def _wait_disconnect(): except (asyncio.CancelledError, StopAsyncIteration): pass break - try: - chunk = anext_task.result() - except StopAsyncIteration: - logger.info( - f"[disconnect_guard] generator exhausted normally, " - f"{chunk_count} chunks, elapsed={_elapsed()}" - ) - break - chunk_count += 1 - if chunk_count == 1: - logger.info( - f"[disconnect_guard] first chunk arrived, elapsed={_elapsed()}" - ) - yield chunk + + if anext_task in done: + try: + chunk = anext_task.result() + except StopAsyncIteration: + logger.info( + f"[disconnect_guard] generator exhausted normally, " + f"{chunk_count} chunks, elapsed={_elapsed()}" + ) + break + chunk_count += 1 + if chunk_count == 1: + logger.info( + f"[disconnect_guard] first chunk arrived, elapsed={_elapsed()}" + ) + yield chunk + anext_task = None + continue + + # Timeout — no chunk and no disconnect detected yet. + # Send SSE comment as heartbeat to force an ASGI write. + # If the client has disconnected, this write will fail and + # the next is_disconnected() poll will return True. + heartbeat_count += 1 + yield ": heartbeat\n\n" + except GeneratorExit: logger.info( f"[disconnect_guard] GeneratorExit after {chunk_count} chunks, elapsed={_elapsed()}" @@ -1130,7 +1207,8 @@ async def _wait_disconnect(): # anext_task.cancel() → CancelledError in stream_outputs() # → finally block → abort_request() → request removed from scheduler logger.info( - f"[disconnect_guard] CLEANUP done, {chunk_count} chunks total, elapsed={_elapsed()}" + f"[disconnect_guard] CLEANUP done, {chunk_count} chunks, " + f"{heartbeat_count} heartbeats, elapsed={_elapsed()}" ) @@ -1235,10 +1313,18 @@ async def create_completion(request: CompletionRequest, raw_request: Request): f"prompt_chars={prompt_len} prompt_preview={prompt_preview!r}" ) + # Resolve repetition penalty for completions + comp_rep_penalty = request.repetition_penalty + if request.stream: return StreamingResponse( _disconnect_guard( - stream_completion(engine, prompts[0], request), + stream_completion( + engine, + prompts[0], + request, + repetition_penalty=comp_rep_penalty, + ), raw_request, ), media_type="text/event-stream", @@ -1252,14 +1338,16 @@ async def create_completion(request: CompletionRequest, raw_request: Request): total_prompt_tokens = 0 for i, prompt in enumerate(prompts): + gen_kwargs = { + "max_tokens": request.max_tokens or _default_max_tokens, + "temperature": _resolve_temperature(request.temperature), + "top_p": _resolve_top_p(request.top_p), + "stop": request.stop, + } + if comp_rep_penalty is not None: + gen_kwargs["repetition_penalty"] = comp_rep_penalty output = await _wait_with_disconnect( - engine.generate( - prompt=prompt, - max_tokens=request.max_tokens or _default_max_tokens, - temperature=_resolve_temperature(request.temperature), - top_p=_resolve_top_p(request.top_p), - stop=request.stop, - ), + engine.generate(prompt=prompt, **gen_kwargs), raw_request, timeout=timeout, ) @@ -1415,12 +1503,17 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re # Inject JSON instruction into messages messages = _inject_json_instruction(messages, json_instruction) + # Resolve repetition penalty + rep_penalty = request.repetition_penalty + # Prepare kwargs chat_kwargs = { "max_tokens": request.max_tokens or _default_max_tokens, "temperature": _resolve_temperature(request.temperature), "top_p": _resolve_top_p(request.top_p), } + if rep_penalty is not None: + chat_kwargs["repetition_penalty"] = rep_penalty # Add multimodal content if has_media: @@ -1437,6 +1530,10 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re if request.specprefill_keep_pct is not None: chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + # Enable/disable thinking mode per request + if request.enable_thinking is not None: + chat_kwargs["enable_thinking"] = request.enable_thinking + # Add tools if provided if request.tools and request.tool_choice != "none": chat_kwargs["tools"] = convert_tools_for_template(request.tools) @@ -1472,8 +1569,9 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, request) # Extract reasoning content FIRST (strips channel tokens before JSON extraction) + # Skip reasoning parser when enable_thinking=False (no think tags expected) reasoning_text = None - if _reasoning_parser and not tool_calls: + if _reasoning_parser and not tool_calls and request.enable_thinking is not False: text_to_parse = cleaned_text or output.text reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( text_to_parse @@ -1860,6 +1958,10 @@ async def _stream_anthropic_messages( Converts OpenAI streaming chunks to Anthropic event format: message_start -> content_block_start -> content_block_delta* -> content_block_stop -> message_delta -> message_stop + + When a reasoning parser is active, emits a ``thinking`` content block + (index 0) for reasoning tokens and a ``text`` content block (index 1) + for the actual response, matching the Anthropic extended thinking format. """ msg_id = f"msg_{uuid.uuid4().hex[:24]}" start_time = time.perf_counter() @@ -1898,160 +2000,171 @@ async def _stream_anthropic_messages( } yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" - # Stream pipeline: raw text → tool call filter → think router → emit - # - Tool call filter strips tool call markup (emitted as structured blocks later) - # - Think router separates reasoning from content into Anthropic blocks - # - # When a reasoning parser is configured (e.g. --reasoning-parser gemma4), - # it replaces the generic StreamingThinkRouter to handle model-specific - # reasoning formats (e.g. Gemma 4 <|channel>thought...). - accumulated_text = "" - use_reasoning_parser = _reasoning_parser is not None - tool_filter = StreamingToolCallFilter() + use_reasoning = _reasoning_parser is not None - if use_reasoning_parser: + if use_reasoning: _reasoning_parser.reset_state() - think_router = None - else: - # Detect if the model's chat template injects into the - # generation prompt. If so, the model starts in thinking mode and - # the opening tag never appears in the output stream. - _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None - _chat_template = "" - if _tokenizer and hasattr(_tokenizer, "chat_template"): - _chat_template = _tokenizer.chat_template or "" - _starts_thinking = ( - "" in _chat_template and "add_generation_prompt" in _chat_template - ) - think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking) - prompt_tokens = 0 + # Block index tracking: with reasoning parser we use index 0 for + # thinking and index 1 for text; without parser, index 0 for text. + thinking_block_started = False + text_block_started = False + thinking_index = 0 + text_index = 1 if use_reasoning else 0 + + if not use_reasoning: + # No reasoning parser — start text block immediately + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + text_block_started = True + + # Stream content deltas + accumulated_text = "" completion_tokens = 0 - # Track which content blocks we've started - current_block_type = None # "thinking" or "text" - block_index = 0 - # For reasoning parser: track accumulated text for parser context - reasoning_accumulated = "" + # Tool call streaming suppression — prevents raw tool markup from leaking + # as text_delta events. Mirrors the OpenAI streaming path logic. + global _tool_parser_instance + tool_parser = None + tool_accumulated_text = "" + tool_markup_possible = False + tool_choice = getattr(openai_request, "tool_choice", None) + if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none": + if _tool_parser_instance is None: + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = None + if _engine is not None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + _tool_parser_instance = parser_cls(tokenizer) + except Exception: + pass + if _tool_parser_instance is not None: + tool_parser = _tool_parser_instance + tool_parser.reset() async for output in engine.stream_chat(messages=messages, **chat_kwargs): delta_text = output.new_text # Track token counts - if hasattr(output, "prompt_tokens") and output.prompt_tokens: - prompt_tokens = output.prompt_tokens if hasattr(output, "completion_tokens") and output.completion_tokens: completion_tokens = output.completion_tokens - if delta_text: - # Accumulate raw text BEFORE special token cleaning for tool parsing - accumulated_text += delta_text + if not delta_text: + continue - # Filter special tokens for display - content = SPECIAL_TOKENS_PATTERN.sub("", delta_text) + # Filter special tokens + filtered = SPECIAL_TOKENS_PATTERN.sub("", delta_text) + if not filtered: + continue - if content: - # Stage 1: strip tool call markup - filtered = tool_filter.process(content) - if not filtered: - continue - - if use_reasoning_parser: - # Stage 2a: use reasoning parser for model-specific formats - prev = reasoning_accumulated - reasoning_accumulated += filtered - delta_msg = _reasoning_parser.extract_reasoning_streaming( - prev, reasoning_accumulated, filtered + if not use_reasoning: + # Simple path — no reasoning parsing + accumulated_text += filtered + content_to_emit = filtered + + # Filter tool call markup during streaming + if tool_parser and content_to_emit: + if not tool_markup_possible and "<" not in content_to_emit: + tool_accumulated_text += content_to_emit + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += content_to_emit + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, content_to_emit ) - if delta_msg is None: + if tool_result is None or "tool_calls" in tool_result: + # Inside tool markup or tool calls detected — suppress + continue + content_to_emit = tool_result.get("content", "") + if content_to_emit: + content_to_emit = _TOOL_MARKUP_PATTERN.sub("", content_to_emit) + if not content_to_emit: continue - pieces = [] - if delta_msg.reasoning: - pieces.append(("thinking", delta_msg.reasoning)) - if delta_msg.content: - pieces.append(("text", delta_msg.content)) - else: - # Stage 2b: generic tag router - pieces = think_router.process(filtered) - events, current_block_type, block_index = _emit_content_pieces( - pieces, current_block_type, block_index - ) - for event in events: - yield event - - # Flush remaining from tool filter - remaining = tool_filter.flush() - if remaining: - if use_reasoning_parser: - prev = reasoning_accumulated - reasoning_accumulated += remaining - delta_msg = _reasoning_parser.extract_reasoning_streaming( - prev, reasoning_accumulated, remaining - ) - pieces = [] - if delta_msg: - if delta_msg.reasoning: - pieces.append(("thinking", delta_msg.reasoning)) - if delta_msg.content: - pieces.append(("text", delta_msg.content)) - else: - pieces = think_router.process(remaining) - events, current_block_type, block_index = _emit_content_pieces( - pieces, current_block_type, block_index + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_to_emit}})}\n\n" + continue + + # Reasoning parser path + previous_text = accumulated_text + accumulated_text += filtered + delta_msg = _reasoning_parser.extract_reasoning_streaming( + previous_text, accumulated_text, filtered ) - for event in events: - yield event - - if not use_reasoning_parser: - flush_pieces = think_router.flush() - if flush_pieces: - events, current_block_type, block_index = _emit_content_pieces( - flush_pieces, current_block_type, block_index - ) - for event in events: - yield event - # Close final content block - if current_block_type is not None: - yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" - block_index += 1 + if delta_msg is None: + continue + + if delta_msg.reasoning: + if not thinking_block_started: + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': thinking_index, 'content_block': {'type': 'thinking', 'thinking': ''}})}\n\n" + thinking_block_started = True + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': thinking_index, 'delta': {'type': 'thinking_delta', 'thinking': delta_msg.reasoning}})}\n\n" + + if delta_msg.content: + content_to_emit = delta_msg.content + + # Filter tool call markup during streaming + if tool_parser and content_to_emit: + if not tool_markup_possible and "<" not in content_to_emit: + tool_accumulated_text += content_to_emit + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += content_to_emit + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, content_to_emit + ) + if tool_result is None or "tool_calls" in tool_result: + # Inside tool markup or tool calls detected — suppress + continue + content_to_emit = tool_result.get("content", "") + if content_to_emit: + content_to_emit = _TOOL_MARKUP_PATTERN.sub("", content_to_emit) + if not content_to_emit: + continue + + if thinking_block_started and not text_block_started: + # Close thinking block, open text block + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': thinking_index})}\n\n" + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + text_block_started = True + elif not text_block_started: + # No thinking was emitted, start text block at index 0 + text_index = 0 + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + text_block_started = True + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': text_index, 'delta': {'type': 'text_delta', 'text': content_to_emit}})}\n\n" + + # Close any open thinking block that was never followed by text + if thinking_block_started and not text_block_started: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': thinking_index})}\n\n" + # Emit empty text block so response always has text content + text_index = thinking_index + 1 + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + text_block_started = True # Check for tool calls in accumulated text _, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request) + # Close text block + if text_block_started: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': text_index})}\n\n" + # If there are tool calls, emit tool_use blocks + next_index = (text_index + 1) if text_block_started else 0 if tool_calls: for i, tc in enumerate(tool_calls): - tool_index = block_index + i + tool_index = next_index + i try: tool_input = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError): tool_input = {} - # content_block_start for tool_use - tool_block_start = { - "type": "content_block_start", - "index": tool_index, - "content_block": { - "type": "tool_use", - "id": tc.id, - "name": tc.function.name, - "input": {}, - }, - } - yield f"event: content_block_start\ndata: {json.dumps(tool_block_start)}\n\n" - - # Send input as a single delta - input_json = json.dumps(tool_input) - input_delta = { - "type": "content_block_delta", - "index": tool_index, - "delta": {"type": "input_json_delta", "partial_json": input_json}, - } - yield f"event: content_block_delta\ndata: {json.dumps(input_delta)}\n\n" - - # content_block_stop + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': tool_index, 'content_block': {'type': 'tool_use', 'id': tc.id, 'name': tc.function.name, 'input': {}}})}\n\n" + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': tool_index, 'delta': {'type': 'input_json_delta', 'partial_json': json.dumps(tool_input)}})}\n\n" yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': tool_index})}\n\n" # Determine stop reason @@ -2061,7 +2174,7 @@ async def _stream_anthropic_messages( message_delta = { "type": "message_delta", "delta": {"stop_reason": stop_reason, "stop_sequence": None}, - "usage": {"input_tokens": prompt_tokens, "output_tokens": completion_tokens}, + "usage": {"output_tokens": completion_tokens}, } yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n" @@ -2069,7 +2182,7 @@ async def _stream_anthropic_messages( elapsed = time.perf_counter() - start_time tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 logger.info( - f"Anthropic messages (stream): prompt={prompt_tokens} + completion={completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + f"Anthropic messages (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" ) # Emit message_stop @@ -2085,15 +2198,18 @@ async def stream_completion( engine: BaseEngine, prompt: str, request: CompletionRequest, + repetition_penalty: float | None = None, ) -> AsyncIterator[str]: """Stream completion response.""" - async for output in engine.stream_generate( - prompt=prompt, - max_tokens=request.max_tokens or _default_max_tokens, - temperature=_resolve_temperature(request.temperature), - top_p=_resolve_top_p(request.top_p), - stop=request.stop, - ): + gen_kwargs = { + "max_tokens": request.max_tokens or _default_max_tokens, + "temperature": _resolve_temperature(request.temperature), + "top_p": _resolve_top_p(request.top_p), + "stop": request.stop, + } + if repetition_penalty is not None: + gen_kwargs["repetition_penalty"] = repetition_penalty + async for output in engine.stream_generate(prompt=prompt, **gen_kwargs): data = { "id": f"cmpl-{uuid.uuid4().hex[:8]}", "object": "text_completion", @@ -2192,8 +2308,8 @@ async def stream_chat_completion( if hasattr(output, "completion_tokens") and output.completion_tokens: completion_tokens = output.completion_tokens - # Use reasoning parser if enabled - if _reasoning_parser and delta_text: + # Use reasoning parser if enabled (skip when enable_thinking=False) + if _reasoning_parser and delta_text and request.enable_thinking is not False: previous_text = accumulated_text accumulated_text += delta_text delta_msg = _reasoning_parser.extract_reasoning_streaming( @@ -2204,36 +2320,84 @@ async def stream_chat_completion( # Skip this chunk (e.g., token itself) continue - # Run tool parser on content delta (post-reasoning text only). - # The reasoning parser suppresses tokens, so tool calls - # only appear in delta_msg.content after is emitted. - if tool_parser and delta_msg.content: - content_delta = delta_msg.content - if not tool_markup_possible and "<" not in content_delta: - tool_accumulated_text += content_delta + content = delta_msg.content + reasoning = delta_msg.reasoning + + # Some models (e.g. MiniMax) wrap tool calls in + # blocks, so reasoning parser captures tool call XML as + # reasoning while content stays None. Redirect reasoning + # to the content stream so the tool parser can handle it. + if tool_parser and reasoning and not content: + _check = tool_accumulated_text + reasoning + if ( + "" in _check + or "" in _check + or ' 0: + from .patches.qwen3_5_mtp import inject_mtp_support + + inject_mtp_support(text_model, model_path, config) + if hasattr(text_model, "mtp") and text_model.mtp is not None: mx.eval(text_model.mtp.parameters()) - logger.info( - "TextModel built with MTP support (%d layers)", - args.mtp_num_hidden_layers, + num_mtp = text_config.get( + "mtp_num_hidden_layers", + text_config.get("num_nextn_predict_layers", 0), ) + logger.info("TextModel built with MTP support (%d layers)", num_mtp) else: - logger.info("TextModel built without MTP (mtp_num_hidden_layers=0)") + logger.info("TextModel built without MTP") return text_model diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py index 685631a73..cd76ad418 100644 --- a/vllm_mlx/tool_parsers/__init__.py +++ b/vllm_mlx/tool_parsers/__init__.py @@ -10,6 +10,7 @@ - mistral: Mistral models ([TOOL_CALLS] format) - qwen/qwen3: Qwen models ( and [Calling tool:] formats) - llama/llama3/llama4: Llama models ( format) +- gemma4/gemma_4: Google Gemma 4 models (<|tool_call>call:name{} format) - hermes/nous: Hermes/NousResearch models - deepseek/deepseek_v3/deepseek_r1: DeepSeek models (unicode tokens) - kimi/kimi_k2/moonshot: Kimi/Moonshot models @@ -19,7 +20,7 @@ - functionary/meetkai: MeetKai Functionary models - glm47/glm4: GLM-4.7 and GLM-4.7-Flash models - harmony/gpt-oss: GPT-OSS models (Harmony format with channels) -- gemma4: Google Gemma 4 models (<|tool_call>call:name{} format) +- minimax: MiniMax-M2 models Usage: from vllm_mlx.tool_parsers import ToolParserManager @@ -48,6 +49,7 @@ from .auto_tool_parser import AutoToolParser from .deepseek_tool_parser import DeepSeekToolParser from .functionary_tool_parser import FunctionaryToolParser +from .gemma4_tool_parser import Gemma4ToolParser from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import HermesToolParser from .kimi_tool_parser import KimiToolParser @@ -58,7 +60,7 @@ from .xlam_tool_parser import xLAMToolParser from .glm47_tool_parser import Glm47ToolParser from .harmony_tool_parser import HarmonyToolParser -from .gemma4_tool_parser import Gemma4ToolParser +from .minimax_tool_parser import MiniMaxToolParser __all__ = [ # Base classes @@ -67,6 +69,7 @@ "ExtractedToolCallInformation", # Specific parsers "AutoToolParser", + "Gemma4ToolParser", "MistralToolParser", "QwenToolParser", "LlamaToolParser", @@ -79,5 +82,5 @@ "FunctionaryToolParser", "Glm47ToolParser", "HarmonyToolParser", - "Gemma4ToolParser", + "MiniMaxToolParser", ] diff --git a/vllm_mlx/tool_parsers/auto_tool_parser.py b/vllm_mlx/tool_parsers/auto_tool_parser.py index ac9058c88..37ab10d74 100644 --- a/vllm_mlx/tool_parsers/auto_tool_parser.py +++ b/vllm_mlx/tool_parsers/auto_tool_parser.py @@ -122,7 +122,7 @@ def extract_tool_calls( content=content if content else None, ) - # 2. Try Qwen bracket pattern + # 3. Try Qwen bracket pattern bracket_matches = self.QWEN_BRACKET_PATTERN.findall(model_output) for name, args_str in bracket_matches: try: @@ -150,7 +150,7 @@ def extract_tool_calls( if bracket_matches: cleaned_text = self.QWEN_BRACKET_PATTERN.sub("", cleaned_text).strip() - # 3. Try Nemotron pattern (before Qwen XML as it's more specific) + # 4. Try Nemotron pattern (before Qwen XML as it's more specific) nemotron_matches = self.NEMOTRON_PATTERN.findall(cleaned_text) for name, params_block in nemotron_matches: params = self.NEMOTRON_PARAM_PATTERN.findall(params_block) @@ -166,7 +166,7 @@ def extract_tool_calls( if nemotron_matches: cleaned_text = self.NEMOTRON_PATTERN.sub("", cleaned_text).strip() - # 4. Try Qwen/Hermes XML pattern + # 5. Try Qwen/Hermes XML pattern xml_matches = self.QWEN_XML_PATTERN.findall(cleaned_text) for match in xml_matches: try: @@ -191,7 +191,7 @@ def extract_tool_calls( if xml_matches: cleaned_text = self.QWEN_XML_PATTERN.sub("", cleaned_text).strip() - # 5. Try Llama pattern + # 6. Try Llama pattern llama_matches = self.LLAMA_PATTERN.findall(cleaned_text) for name, args_str in llama_matches: try: @@ -219,7 +219,7 @@ def extract_tool_calls( if llama_matches: cleaned_text = self.LLAMA_PATTERN.sub("", cleaned_text).strip() - # 6. Fallback: Try raw JSON + # 7. Fallback: Try raw JSON if not tool_calls: raw_calls = self._parse_raw_json_tool_calls(cleaned_text) if raw_calls: diff --git a/vllm_mlx/tool_parsers/minimax_tool_parser.py b/vllm_mlx/tool_parsers/minimax_tool_parser.py new file mode 100644 index 000000000..7459fe97f --- /dev/null +++ b/vllm_mlx/tool_parsers/minimax_tool_parser.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MiniMax tool call parser for vllm-mlx. + +Parses the MiniMax-M2 native XML tool call format: + + +param-value + + +""" + +import json +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + + +def generate_tool_id() -> str: + return f"call_{uuid.uuid4().hex[:8]}" + + +@ToolParserManager.register_module(["minimax", "minimax_m2"]) +class MiniMaxToolParser(ToolParser): + """ + Parser for MiniMax-M2 tool call format. + + Format: + + + value + + + """ + + TOOL_CALL_BLOCK = re.compile( + r"(.*?)", re.DOTALL + ) + INVOKE_PATTERN = re.compile(r'(.*?)', re.DOTALL) + PARAM_PATTERN = re.compile( + r'(.*?)', re.DOTALL + ) + THINK_PATTERN = re.compile(r".*?", re.DOTALL) + + def _extract_invokes(self, text: str) -> list[dict[str, Any]]: + """Extract tool calls from invoke elements, with or without wrapper.""" + tool_calls: list[dict[str, Any]] = [] + invokes = self.INVOKE_PATTERN.findall(text) + for func_name, params_block in invokes: + params = self.PARAM_PATTERN.findall(params_block) + # Skip bare tags without parameters (hallucinated junk) + if not params: + continue + arguments = {} + for p_name, p_value in params: + p_value = p_value.strip() + try: + arguments[p_name] = json.loads(p_value) + except (json.JSONDecodeError, ValueError): + arguments[p_name] = p_value + + tool_calls.append( + { + "id": generate_tool_id(), + "name": func_name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + return tool_calls + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + # Try wrapped format first: ...... + blocks = self.TOOL_CALL_BLOCK.findall(model_output) + if blocks: + tool_calls: list[dict[str, Any]] = [] + for block in blocks: + tool_calls.extend(self._extract_invokes(block)) + + cleaned = self.TOOL_CALL_BLOCK.sub("", model_output).strip() + cleaned = self.THINK_PATTERN.sub("", cleaned).strip() + cleaned = re.sub(r"\[e~\[.*$", "", cleaned).strip() + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=cleaned if cleaned else None, + ) + + # Fallback: bare without wrapper + # (model sometimes emits tool calls inside without wrapper) + tool_calls = self._extract_invokes(model_output) + if tool_calls: + # Strip matched invoke blocks and thinking from content + cleaned = self.INVOKE_PATTERN.sub("", model_output).strip() + cleaned = self.THINK_PATTERN.sub("", cleaned).strip() + cleaned = re.sub(r"\[e~\[.*$", "", cleaned).strip() + # Remove leftover closing tags + cleaned = cleaned.replace("", "").strip() + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=cleaned if cleaned else None, + ) + + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def _has_tool_start(self, text: str) -> bool: + """Check if text contains the start of a tool call block.""" + return "" in text or ( + '" in current: + return ( + "" in current + and "" not in previous + ) + # Bare invoke: just appeared + if "" in current and "" not in previous: + return True + return False + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + # Not inside a tool call block yet — pass content through + if not self._has_tool_start(current_text): + return {"content": delta_text} + + # Tool call block just completed + if self._has_tool_end(current_text, previous_text): + result = self.extract_tool_calls(current_text) + if result.tools_called: + return { + "tool_calls": [ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + } + + # Inside tool call block but not yet complete — suppress output + return None diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index 0cb4b5d82..3f45e2d5c 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -28,6 +28,27 @@ def _needs_tokenizer_fallback(model_name: str) -> bool: return any(pattern.lower() in model_lower for pattern in FALLBACK_MODELS) +def _needs_strict_false(model_name: str) -> bool: + """Check if model needs strict=False loading (VLM models with extra weights). + + VLM models (e.g., Qwen3.5) have vision_tower weights that don't match + the text-only model class. Loading with strict=True fails and wastes + memory by loading all weights (~100 GB) before raising ValueError. + Detect these models up-front to avoid the double-load penalty. + """ + from mlx_lm.utils import _download, load_config + + try: + model_path = _download(model_name) + config = load_config(model_path) + except Exception: + return False + # VLM models have vision_config or text_config with a separate model_type + if "vision_config" in config and "text_config" in config: + return True + return False + + def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): """ Load model and tokenizer with fallback for non-standard tokenizers. @@ -50,6 +71,15 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): ) return _load_with_tokenizer_fallback(model_name) + # VLM models (e.g., Qwen3.5) have extra vision weights that cause + # strict=True to fail. Skip the first load attempt to avoid loading + # ~100 GB of weights twice (which can cause OOM on 256 GB systems). + if _needs_strict_false(model_name): + logger.info( + f"Model {model_name} detected as VLM, loading directly with strict=False" + ) + return _load_strict_false(model_name, tokenizer_config) + try: model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) except ValueError as e: @@ -59,44 +89,89 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): return _load_with_tokenizer_fallback(model_name) # Fallback for models with extra weights (e.g., vision tower, MTP layers). # Retry with strict=False to discard extra weights. - if "parameters not in model" in str(e): + elif "parameters not in model" in str(e): logger.warning( f"Extra parameters found (e.g., vision tower / MTP weights), " f"retrying with strict=False: {e}" ) + # Clear traceback references to free memory from the failed first load. + # Without this, large models (200GB+) cause OOM during retry because + # the traceback holds references to the first load's weight tensors. + e.__traceback__ = None + del e + import gc + + gc.collect() return _load_strict_false(model_name, tokenizer_config) - raise + else: + raise + # After successful load, check if MTP weights exist but were stripped by sanitize() + _try_inject_mtp_post_load(model, model_name) return model, tokenizer def _load_strict_false(model_name: str, tokenizer_config: dict = None): - """Load model with strict=False to discard extra weights (e.g., vision tower, MTP).""" - from mlx_lm.utils import load_model, load_tokenizer - - local_path = Path(model_name) - if local_path.is_dir(): - model_path = local_path - else: - from huggingface_hub import snapshot_download + """Load model with strict=False to discard extra weights. - model_path = Path(snapshot_download(model_name)) + Handles models with extra parameters that the text-only model class + doesn't define (e.g., vision tower weights in VLM models like Qwen3.5, + or MTP layers). The model's own sanitize() handles key remapping + (e.g., language_model.* prefix), and strict=False silently drops + unmatched keys. + """ + import mlx.core as mx + from mlx_lm.utils import _download, load_model, load_tokenizer + model_path = _download(model_name) model, config = load_model(model_path, strict=False) + + # Verify weights loaded correctly + from mlx.utils import tree_flatten + + params = tree_flatten(model.parameters()) + total_params = len(params) + zero_params = sum(1 for _, v in params if mx.all(v == 0).item()) + logger.info( + f"[strict=False] Loaded {total_params} parameters, " + f"{zero_params} all-zero tensors" + ) + # Spot-check embedding weights + if hasattr(model, "language_model"): + emb = model.language_model.model.embed_tokens.weight + logger.info( + f"[strict=False] embed_tokens: shape={emb.shape}, " + f"dtype={emb.dtype}, mean={mx.mean(emb.astype(mx.float32)).item():.4f}" + ) + tokenizer = load_tokenizer( model_path, tokenizer_config or {}, eos_token_ids=config.get("eos_token_id", None), ) - # Inject MTP support if model has MTP config + weights _try_inject_mtp(model, model_path, config) return model, tokenizer def _try_inject_mtp(model, model_path, config): """Inject MTP support if model has MTP config + weights.""" + # Qwen3-Next: flat num_nextn_predict_layers if config.get("num_nextn_predict_layers", 0) > 0: - from ..patches.qwen3_next_mtp import inject_mtp_support + # Detect Qwen3.5 vs Qwen3-Next by checking text_config or model_type + text_config = config.get("text_config", config) + model_type = text_config.get("model_type", config.get("model_type", "")) + if "qwen3_5" in model_type: + from ..patches.qwen3_5_mtp import inject_mtp_support + else: + from ..patches.qwen3_next_mtp import inject_mtp_support + inject_mtp_support(model, model_path, config) + return + + # Qwen3.5: mtp_num_hidden_layers in text_config + text_config = config.get("text_config", config) + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + if num_mtp > 0: + from ..patches.qwen3_5_mtp import inject_mtp_support inject_mtp_support(model, model_path, config) @@ -113,13 +188,21 @@ def _try_inject_mtp_post_load(model, model_name): return with open(config_path) as f: config = json.load(f) - # Also check text_config for nested configs + # Check for MTP in flat config and nested text_config + text_config = config.get("text_config", {}) num_mtp = config.get("num_nextn_predict_layers", 0) if num_mtp == 0: - text_config = config.get("text_config", {}) num_mtp = text_config.get("num_nextn_predict_layers", 0) - if num_mtp > 0 and getattr(model, "mtp", None) is None: - mtp_file = Path(model_path) / "model-mtp.safetensors" + if num_mtp == 0: + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + # Also check mtp attribute on language_model for VLM wrappers + check_model = model + if hasattr(model, "language_model"): + check_model = model.language_model + if num_mtp > 0 and getattr(check_model, "mtp", None) is None: + mtp_file = Path(model_path) / "mtp" / "weights.safetensors" + if not mtp_file.exists(): + mtp_file = Path(model_path) / "model-mtp.safetensors" if mtp_file.exists(): logger.info( f"[MTP] Found MTP config (layers={num_mtp}) and weights, injecting..." @@ -128,7 +211,7 @@ def _try_inject_mtp_post_load(model, model_name): else: logger.info( f"[MTP] Config has num_nextn_predict_layers={num_mtp} " - "but model-mtp.safetensors not found, skipping MTP." + "but MTP weights not found, skipping MTP." ) From 6111dd5f102379cc78b4b79f8b8cebf1bf70173c Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 11 Apr 2026 08:31:19 -0500 Subject: [PATCH 26/45] chore: add Apache 2.0 license file --- LICENSE | 176 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..d9a10c0d8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,176 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS From d67fec4e82982cc48d725102358474c68fe3fd4e Mon Sep 17 00:00:00 2001 From: Jan Hilgard Date: Sat, 11 Apr 2026 16:29:55 +0200 Subject: [PATCH 27/45] Add format support to Qwen tool parser - Parse value format natively generated by Qwen3.5 models (both parameter tags and JSON body) - Add streaming partial-marker buffering for --- tests/test_tool_parsers.py | 174 ++++++++++++++++++++++ vllm_mlx/server.py | 3 +- vllm_mlx/tool_parsers/qwen_tool_parser.py | 143 +++++++++++++++++- vllm_mlx/utils/tokenizer.py | 1 + 4 files changed, 318 insertions(+), 3 deletions(-) diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index 7caaffbf5..4f3c287d1 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -1163,3 +1163,177 @@ def test_streaming_bare_multi_function_blocks(self): assert len(emitted_calls) == 2 assert emitted_calls[0]["function"]["name"] == "func1" assert emitted_calls[1]["function"]["name"] == "func2" + + +class TestQwenFunctionFormat: + """Test Qwen parser's format support.""" + + @pytest.fixture + def parser(self): + return QwenToolParser() + + def test_function_format_with_parameters(self, parser): + """Test value.""" + text = "Prague" + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["city"] == "Prague" + + def test_function_format_with_json(self, parser): + """Test {"key": "val"}.""" + text = '{"city": "Prague"}' + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["city"] == "Prague" + + def test_function_format_multiple(self, parser): + """Test multiple blocks.""" + text = ( + '{"path": "/a.py"}' + '{"path": "/b.py", "content": "hello"}' + ) + result = parser.extract_tool_calls(text) + assert result.tools_called + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["name"] == "read_file" + assert result.tool_calls[1]["name"] == "write_file" + + def test_function_format_with_think_tags(self, parser): + """Test with think tags.""" + text = ( + "I need to check the weather.\n" + '{"city": "Prague"}' + ) + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + + +class TestQwenStreamingBuffering: + """Test Qwen parser streaming with partial-marker buffering.""" + + @pytest.fixture + def parser(self): + return QwenToolParser() + + def test_streaming_function_format_complete(self, parser): + """Test streaming with ... format.""" + chunks = [ + "", + "Prague", + "", + ] + accumulated = "" + tool_calls_found = False + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + tool_calls_found = True + assert r["tool_calls"][0]["function"]["name"] == "get_weather" + break + assert tool_calls_found + + def test_streaming_partial_marker_buffered(self, parser): + """Test that partial '" — not a tool marker + r = parser.extract_tool_calls_streaming( + previous_text="Hello<", + current_text="Hello
", + delta_text="div>", + ) + assert r is not None + assert "content" in r + assert "<" in r["content"] + assert "div>" in r["content"] + + def test_streaming_multiple_function_blocks(self, parser): + """Test streaming with multiple {"a": 1}', + "\n", + "", + "2", + "", + ] + accumulated = "" + emitted_calls = [] + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + emitted_calls.extend(r["tool_calls"]) + assert len(emitted_calls) == 2 + assert emitted_calls[0]["function"]["name"] == "func1" + assert emitted_calls[1]["function"]["name"] == "func2" diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index d4f24de22..e080be8ff 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -2590,7 +2590,7 @@ async def stream_chat_completion( yield f"data: {chunk.model_dump_json()}\n\n" # Fallback: if tool parser accumulated text but never emitted tool_calls - # (e.g., never arrived - incomplete tool call) + # (e.g., never arrived, or " in tool_accumulated_text or "<|tool_call>" in tool_accumulated_text + or "{"name": "func", "arguments": {...}} - Bracket style: [Calling tool: func_name({"arg": "value"})] +- Function style: value """ +import ast import json import re import uuid @@ -20,6 +22,24 @@ ) +def _parse_param_value(val: str) -> Any: + """Parse a parameter value, handling JSON literals and plain strings.""" + try: + return json.loads(val) + except (json.JSONDecodeError, ValueError): + pass + try: + python_val = ast.literal_eval(val) + if isinstance(python_val, set): + python_val = sorted(python_val, key=str) + if isinstance(python_val, (complex, bytes)): + return val + json.dumps(python_val) + return python_val + except (ValueError, SyntaxError, TypeError): + return val + + def generate_tool_id() -> str: """Generate a unique tool call ID.""" return f"call_{uuid.uuid4().hex[:8]}" @@ -33,6 +53,7 @@ class QwenToolParser(ToolParser): Supports multiple Qwen tool call formats: - XML: {"name": "func", "arguments": {...}} - Bracket: [Calling tool: func_name({"arg": "value"})] + - Function: value Used when --enable-auto-tool-choice --tool-call-parser qwen are set. """ @@ -43,6 +64,12 @@ class QwenToolParser(ToolParser): # Pattern for bracket-style: [Calling tool: func_name({...})] BRACKET_PATTERN = re.compile(r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]", re.DOTALL) + # Pattern for function-style: ... + FUNCTION_PATTERN = re.compile(r"]+)>(.*?)", re.DOTALL) + + # Pattern for parameter extraction: value + PARAM_PATTERN = re.compile(r"]+)>\s*(.*?)\s*", re.DOTALL) + def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None ) -> ExtractedToolCallInformation: @@ -101,6 +128,41 @@ def extract_tool_calls( if xml_matches: cleaned_text = self.XML_PATTERN.sub("", cleaned_text).strip() + # Try function-style: value + # Qwen3.5 generates this format natively. + if not tool_calls: + func_matches = self.FUNCTION_PATTERN.findall(cleaned_text) + for name, params_block in func_matches: + # Try JSON arguments first (e.g. {"key": "val"}) + params_block_stripped = params_block.strip() + if params_block_stripped.startswith("{"): + try: + arguments = json.loads(params_block_stripped) + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + continue + except json.JSONDecodeError: + pass + # Parse value tags + params = self.PARAM_PATTERN.findall(params_block) + arguments = {} + for p_name, p_value in params: + arguments[p_name.strip()] = _parse_param_value(p_value.strip()) + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + if func_matches: + cleaned_text = self.FUNCTION_PATTERN.sub("", cleaned_text).strip() + if tool_calls: return ExtractedToolCallInformation( tools_called=True, @@ -112,6 +174,30 @@ def extract_tool_calls( tools_called=False, tool_calls=[], content=model_output ) + # Partial marker prefixes — when current_text ends with one of these, + # we suppress output until the next token confirms or denies a tool call. + # These are long enough to avoid false positives on normal text. + _PARTIAL_MARKERS = (" bool: + """Check if text ends with an incomplete tool call marker prefix.""" + return self._get_partial_marker_len(text) > 0 + + def _get_partial_marker_len(self, text: str) -> int: + """Return the length of a partial tool call marker suffix at end of text.""" + tail = text[-20:] + best = 0 + for marker in self._PARTIAL_MARKERS: + for length in range(len(marker), 0, -1): + if tail.endswith(marker[:length]) and length > best: + best = length + break + return best + + def _was_buffering(self, previous_text: str) -> bool: + """Check if the previous call was buffering a partial marker.""" + return self._has_partial_marker(previous_text) + def extract_tool_calls_streaming( self, previous_text: str, @@ -125,14 +211,67 @@ def extract_tool_calls_streaming( """ Extract tool calls from streaming Qwen model output. """ - # Check for tool call markers + # Check for complete tool call markers has_tool_marker = ( - "" in current_text or "[Calling tool:" in current_text + "" in current_text + or "[Calling tool:" in current_text + or "... (Qwen3.5 native format) + if "") + prev_func_close = previous_text.count("") + + if current_text.count(" func_close_count: + # Inside an incomplete function block, suppress output + return None + + if func_close_count > prev_func_close: + # New function block(s) completed + result = self.extract_tool_calls(current_text) + if result.tools_called: + new_calls = result.tool_calls[prev_func_close:] + if new_calls: + return { + "tool_calls": [ + { + "index": prev_func_close + i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(new_calls) + ] + } + + return None + # If we're in a tool call, accumulate and parse at the end # For simplicity, return None during accumulation if "" in delta_text or ")]" in delta_text: diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index 55dec9577..9d200ab9f 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -9,6 +9,7 @@ import json import logging +from pathlib import Path from .chat_templates import DEFAULT_CHATML_TEMPLATE, NEMOTRON_CHAT_TEMPLATE From e65f0e0083e4be72b5becfcd41bead48953bb721 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 11 Apr 2026 10:17:43 -0500 Subject: [PATCH 28/45] fix: import Path in tokenizer utils --- vllm_mlx/utils/tokenizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index 55dec9577..9d200ab9f 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -9,6 +9,7 @@ import json import logging +from pathlib import Path from .chat_templates import DEFAULT_CHATML_TEMPLATE, NEMOTRON_CHAT_TEMPLATE From 78255e8198369d0eb90ab49471937d101239c2fb Mon Sep 17 00:00:00 2001 From: Wayner Barrios Date: Sat, 11 Apr 2026 11:25:44 -0400 Subject: [PATCH 29/45] add Rapid-MLX to README acknowledgments --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b77d116dd..7fc61fe46 100644 --- a/README.md +++ b/README.md @@ -398,4 +398,5 @@ If you use vLLM-MLX in your research or project, please cite: - [mlx-vlm](https://github.com/Blaizzy/mlx-vlm) - Vision-language models - [mlx-audio](https://github.com/Blaizzy/mlx-audio) - Text-to-Speech and Speech-to-Text - [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings) - Text embeddings +- [Rapid-MLX](https://github.com/raullenchai/Rapid-MLX) - Community fork of vllm-mlx - [vLLM](https://github.com/vllm-project/vllm) - High-throughput LLM serving From 380b12d5e72bfba3a038759a80fbb3fd0c1433e3 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 11 Apr 2026 10:59:22 -0500 Subject: [PATCH 30/45] fix: skip optimistic mtp rnn snapshots --- vllm_mlx/scheduler.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 32057c19a..32f024917 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -800,12 +800,13 @@ def _mtp_step( # RNN snapshot, then re-advance with just P so both cache # types end up consistent at [..., P]. _rnn_snapshots = {} - for _ci, _c in enumerate(prompt_cache): - if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()): - if hasattr(_c, "state"): - _rnn_snapshots[_ci] = [ - s.copy() if s is not None else None for s in _c.state - ] + if not optimistic: + for _ci, _c in enumerate(prompt_cache): + if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()): + if hasattr(_c, "state"): + _rnn_snapshots[_ci] = [ + s.copy() if s is not None else None for s in _c.state + ] verify_input = mx.concatenate( [primary_tokens[:, None], draft_tokens[:, None]], axis=1 From fd34746561bb3f3a9e62ced0980b28edaf2d46b6 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 11 Apr 2026 11:23:14 -0500 Subject: [PATCH 31/45] style: format MiniMax tool call tests --- tests/test_minimax_tool_calling.py | 32 +++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/test_minimax_tool_calling.py b/tests/test_minimax_tool_calling.py index ad329aa03..2b94f967b 100644 --- a/tests/test_minimax_tool_calling.py +++ b/tests/test_minimax_tool_calling.py @@ -10,12 +10,12 @@ class TestMiniMaxToolCallParsing(unittest.TestCase): """Test parsing of MiniMax-style tool calls.""" def test_single_tool_call(self): - text = ''' + text = """ Wanaka celsius -''' +""" cleaned, tool_calls = parse_tool_calls(text) self.assertIsNotNone(tool_calls) @@ -27,12 +27,12 @@ def test_single_tool_call(self): self.assertEqual(cleaned, "") def test_tool_call_with_surrounding_text(self): - text = '''Let me check the weather for you. + text = """Let me check the weather for you. Wanaka -''' +""" cleaned, tool_calls = parse_tool_calls(text) self.assertIsNotNone(tool_calls) @@ -40,7 +40,7 @@ def test_tool_call_with_surrounding_text(self): self.assertIn("Let me check", cleaned) def test_multiple_tool_calls(self): - text = ''' + text = """ MiniMax M2.5 @@ -49,7 +49,7 @@ def test_multiple_tool_calls(self): /tmp/test.txt -''' +""" cleaned, tool_calls = parse_tool_calls(text) self.assertIsNotNone(tool_calls) @@ -58,12 +58,12 @@ def test_multiple_tool_calls(self): self.assertEqual(tool_calls[1].function.name, "read_file") def test_json_parameter_value(self): - text = ''' + text = """ Meeting ["stuart", "frida"] -''' +""" cleaned, tool_calls = parse_tool_calls(text) self.assertIsNotNone(tool_calls) @@ -72,21 +72,21 @@ def test_json_parameter_value(self): self.assertEqual(args["attendees"], ["stuart", "frida"]) def test_numeric_parameter(self): - text = ''' + text = """ 42 -''' +""" cleaned, tool_calls = parse_tool_calls(text) args = json.loads(tool_calls[0].function.arguments) self.assertEqual(args["value"], 42) def test_no_parameters(self): - text = ''' + text = """ -''' +""" cleaned, tool_calls = parse_tool_calls(text) self.assertIsNotNone(tool_calls) @@ -95,14 +95,14 @@ def test_no_parameters(self): self.assertEqual(args, {}) def test_with_think_tags_preserved(self): - text = ''' + text = """ I should check the weather first. Wanaka -''' +""" cleaned, tool_calls = parse_tool_calls(text) self.assertIsNotNone(tool_calls) @@ -115,11 +115,11 @@ def test_no_minimax_tool_calls(self): self.assertEqual(cleaned, text) def test_tool_call_id_format(self): - text = ''' + text = """ 1 -''' +""" _, tool_calls = parse_tool_calls(text) self.assertTrue(tool_calls[0].id.startswith("call_")) From 89cfad4e645418c22e8f4ad2944bbc7c934437c3 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Tue, 24 Mar 2026 09:15:02 +0100 Subject: [PATCH 32/45] feat: expose harmony tool parser in serve CLI --- tests/test_server.py | 35 +++++++++++++++++++++++++++++++++++ vllm_mlx/cli.py | 14 ++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/tests/test_server.py b/tests/test_server.py index ad8e0a9b9..c0450548d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -168,6 +168,41 @@ def test_basic_completion_request(self): assert request.max_tokens is None # uses _default_max_tokens when None +class TestServeCli: + """Test serve CLI argument parsing.""" + + def test_tool_call_parser_accepts_harmony_aliases(self): + """GPT-OSS/Harmony parsers should be selectable from the serve CLI.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + [ + "serve", + "lmstudio-community/gpt-oss-20b-MLX-8bit", + "--enable-auto-tool-choice", + "--tool-call-parser", + "harmony", + ] + ) + + assert args.command == "serve" + assert args.tool_call_parser == "harmony" + assert args.enable_auto_tool_choice is True + + args = parser.parse_args( + [ + "serve", + "lmstudio-community/gpt-oss-20b-MLX-8bit", + "--enable-auto-tool-choice", + "--tool-call-parser", + "gpt-oss", + ] + ) + + assert args.tool_call_parser == "gpt-oss" + + # ============================================================================= # Helper Function Tests # ============================================================================= diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index bba5163d4..ee6d7cba5 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -633,7 +633,8 @@ def bench_kv_cache_command(args): ) -def main(): +def create_parser() -> argparse.ArgumentParser: + """Build the top-level CLI parser.""" parser = argparse.ArgumentParser( description="vllm-mlx: Apple Silicon MLX backend for vLLM", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -880,6 +881,8 @@ def main(): "qwen3_coder", "llama", "hermes", + "harmony", + "gpt-oss", "deepseek", "kimi", "granite", @@ -893,7 +896,8 @@ def main(): help=( "Select the tool call parser for the model. Options: " "auto (auto-detect), mistral, qwen, qwen3_coder, llama, hermes, " - "deepseek, gemma4, kimi, granite, nemotron, xlam, functionary, glm47, minimax. " + "harmony, gpt-oss, deepseek, gemma4, kimi, granite, nemotron, " + "xlam, functionary, glm47, minimax. " "Required for --enable-auto-tool-choice." ), ) @@ -1113,6 +1117,12 @@ def main(): action="store_true", help="Download as multimodal model (broader file patterns)", ) + + return parser + + +def main(): + parser = create_parser() args = parser.parse_args() if args.command == "serve": From f16f854282002485998db31ad039aefe96a26615 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Tue, 24 Mar 2026 19:21:09 +0100 Subject: [PATCH 33/45] simple-engine: unify tool-enabled chat on streaming path (#10) * fix: unify tool-enabled simple chat on streaming path * fix: preserve simple chat contracts on streaming path * fix: keep tool chat on the streaming execution path * fix: preserve streamed completion token counts --- tests/test_simple_engine.py | 57 +++++++++++++++++++++++++++++++++++++ vllm_mlx/engine/simple.py | 30 +++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 7202f625f..7c0956693 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -12,6 +12,10 @@ class TestSimpleEngineConcurrency: """Test SimpleEngine lock behavior with concurrent requests.""" + @pytest.fixture + def anyio_backend(self): + return "asyncio" + @pytest.fixture def mock_model(self): """Create a mock model that tracks concurrent calls.""" @@ -117,6 +121,59 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model): "The lock is not working correctly." ) + async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model): + """Tool-enabled non-stream chat should use the streaming path.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_stream_chat(*args, **kwargs): + yield MagicMock( + text="partial", + tokens=[], + prompt_tokens=11, + completion_tokens=1, + finish_reason=None, + finished=False, + ) + yield MagicMock( + text="<|im_end|>{\"name\":\"bash\",\"arguments\":{\"command\":\"pwd\"}}", + tokens=[], + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + finished=True, + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = mock_llm_model + engine._loaded = True + engine._model.tokenizer.encode = MagicMock(return_value=[7, 8, 9]) + engine.stream_chat = fake_stream_chat # type: ignore[method-assign] + + output = await engine.chat( + messages=[{"role": "user", "content": "run pwd"}], + max_tokens=16, + tools=[ + { + "type": "function", + "function": { + "name": "bash", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + assert output.text.startswith("") + assert output.tokens == [7, 8, 9] + assert output.prompt_tokens == 11 + assert output.completion_tokens == 4 + assert output.finish_reason == "stop" + mock_llm_model.chat.assert_not_called() + engine._model.tokenizer.encode.assert_called_once_with( + output.text, add_special_tokens=False + ) + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 39cfa849d..d376b101e 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -453,6 +453,36 @@ async def chat( if not self._loaded: await self.start() + # mlx-lm non-streaming chat with tools can stall indefinitely on some + # local models, while the streaming path completes normally. Reuse the + # streaming implementation and aggregate its final state so both chat + # APIs share the same tool-capable execution path. + if tools and not self._is_mllm: + final_output = GenerationOutput(text="") + async for output in self.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + **kwargs, + ): + final_output = output + text = clean_output_text(final_output.text) + try: + tokens = self._model.tokenizer.encode(text, add_special_tokens=False) + except TypeError: + tokens = self._model.tokenizer.encode(text) + return GenerationOutput( + text=text, + tokens=tokens, + prompt_tokens=final_output.prompt_tokens, + completion_tokens=final_output.completion_tokens, + finish_reason=final_output.finish_reason, + ) + # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None From 51b4f6912089ec7ef6626bae5c8498399b326916 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Tue, 24 Mar 2026 20:04:47 +0100 Subject: [PATCH 34/45] fix: preserve streamed tool-chat token ids --- tests/test_simple_engine.py | 6 +----- vllm_mlx/engine/simple.py | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 7c0956693..3db3cc92c 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -147,7 +147,6 @@ async def fake_stream_chat(*args, **kwargs): engine = SimpleEngine("test-model") engine._model = mock_llm_model engine._loaded = True - engine._model.tokenizer.encode = MagicMock(return_value=[7, 8, 9]) engine.stream_chat = fake_stream_chat # type: ignore[method-assign] output = await engine.chat( @@ -165,14 +164,11 @@ async def fake_stream_chat(*args, **kwargs): ) assert output.text.startswith("") - assert output.tokens == [7, 8, 9] + assert output.tokens == [] assert output.prompt_tokens == 11 assert output.completion_tokens == 4 assert output.finish_reason == "stop" mock_llm_model.chat.assert_not_called() - engine._model.tokenizer.encode.assert_called_once_with( - output.text, add_special_tokens=False - ) @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index d376b101e..3dfbf7b09 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -477,7 +477,7 @@ async def chat( tokens = self._model.tokenizer.encode(text) return GenerationOutput( text=text, - tokens=tokens, + tokens=list(final_output.tokens), prompt_tokens=final_output.prompt_tokens, completion_tokens=final_output.completion_tokens, finish_reason=final_output.finish_reason, From 014edebf7dfe2d5dd3bd14da8fd7a11f0a087288 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Thu, 26 Mar 2026 01:06:21 +0100 Subject: [PATCH 35/45] remove dead token-encode block in tool-chat fallback The try/except block computing `tokens` via tokenizer.encode() was unused -- the return statement already reads from final_output.tokens. --- vllm_mlx/engine/simple.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 3dfbf7b09..b93c20c0a 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -471,10 +471,6 @@ async def chat( ): final_output = output text = clean_output_text(final_output.text) - try: - tokens = self._model.tokenizer.encode(text, add_special_tokens=False) - except TypeError: - tokens = self._model.tokenizer.encode(text) return GenerationOutput( text=text, tokens=list(final_output.tokens), From 040a724118d194d77639067fc8a5dc49701b64a2 Mon Sep 17 00:00:00 2001 From: Christopher Albert Date: Thu, 9 Apr 2026 09:31:12 +0200 Subject: [PATCH 36/45] style: format simple engine tool-chat test --- tests/test_simple_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index 3db3cc92c..c507b7a90 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -135,7 +135,7 @@ async def fake_stream_chat(*args, **kwargs): finished=False, ) yield MagicMock( - text="<|im_end|>{\"name\":\"bash\",\"arguments\":{\"command\":\"pwd\"}}", + text='<|im_end|>{"name":"bash","arguments":{"command":"pwd"}}', tokens=[], prompt_tokens=11, completion_tokens=4, From 2990f7b93c7b711d84699da9757aa8668b872eae Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 11 Apr 2026 11:30:40 -0500 Subject: [PATCH 37/45] test: align tool-chat aggregation regression --- tests/test_simple_engine.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index c507b7a90..b06b48971 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -128,7 +128,7 @@ async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model): async def fake_stream_chat(*args, **kwargs): yield MagicMock( text="partial", - tokens=[], + tokens=[1], prompt_tokens=11, completion_tokens=1, finish_reason=None, @@ -136,7 +136,7 @@ async def fake_stream_chat(*args, **kwargs): ) yield MagicMock( text='<|im_end|>{"name":"bash","arguments":{"command":"pwd"}}', - tokens=[], + tokens=[7, 8, 9], prompt_tokens=11, completion_tokens=4, finish_reason="stop", @@ -163,8 +163,8 @@ async def fake_stream_chat(*args, **kwargs): ], ) - assert output.text.startswith("") - assert output.tokens == [] + assert output.text == '{"name":"bash","arguments":{"command":"pwd"}}' + assert output.tokens == [7, 8, 9] assert output.prompt_tokens == 11 assert output.completion_tokens == 4 assert output.finish_reason == "stop" From 9d69467b0df54f03685075d398181804160cffbb Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sun, 12 Apr 2026 00:34:45 -0500 Subject: [PATCH 38/45] fix(specprefill): avoid dense tail expansion within cache window (#291) --- tests/test_specprefill_rotating_cache.py | 84 ++++++++++++++++++++++++ vllm_mlx/specprefill.py | 10 +-- 2 files changed, 89 insertions(+), 5 deletions(-) create mode 100644 tests/test_specprefill_rotating_cache.py diff --git a/tests/test_specprefill_rotating_cache.py b/tests/test_specprefill_rotating_cache.py new file mode 100644 index 000000000..a944c0ee2 --- /dev/null +++ b/tests/test_specprefill_rotating_cache.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for RotatingKVCache handling in sparse_prefill.""" + +from __future__ import annotations + +import pytest + +try: + import mlx.core as mx + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +class _FakeAttention: + def __init__(self): + self.num_heads = 1 + self.q_proj = lambda x: x + + +class _FakeLayer: + def __init__(self): + self.block_type = "*" + self.mixer = _FakeAttention() + + +class _FakeModel: + def __init__(self): + self.layers = [_FakeLayer()] + self.calls: list[list[int]] = [] + + def __call__(self, x, cache=None): + self.calls.append(x.tolist()) + logits = mx.zeros((1, x.shape[1], 8), dtype=mx.float32) + return logits + + +class RotatingKVCache: + def __init__(self, max_size: int, keep: int = 0): + self.max_size = max_size + self.keep = keep + self.offset = 0 + self.state = mx.array([0], dtype=mx.float32) + + +def _run_sparse_prefill(total_tokens: int, selected_indices: list[int], max_size: int): + from vllm_mlx.specprefill import sparse_prefill + + model = _FakeModel() + tokens = list(range(total_tokens)) + cache = [RotatingKVCache(max_size=max_size, keep=0)] + sparse_prefill( + model, + tokens, + selected_indices, + cache, + step_size=64, + ) + return model.calls + + +def test_sparse_prefill_does_not_expand_tail_when_prompt_fits_window(): + calls = _run_sparse_prefill( + total_tokens=6, + selected_indices=[0, 2, 4], + max_size=8, + ) + + flattened = [token for chunk in calls for row in chunk for token in row] + assert flattened == [0, 2, 4] + + +def test_sparse_prefill_expands_tail_when_prompt_exceeds_window(): + calls = _run_sparse_prefill( + total_tokens=10, + selected_indices=[0, 2], + max_size=8, + ) + + flattened = [token for chunk in calls for row in chunk for token in row] + assert flattened == [0, 2, 3, 4, 5, 6, 7, 8, 9] diff --git a/vllm_mlx/specprefill.py b/vllm_mlx/specprefill.py index 5ea985ff0..9ebe4401a 100644 --- a/vllm_mlx/specprefill.py +++ b/vllm_mlx/specprefill.py @@ -640,15 +640,15 @@ def sparse_prefill( M = tokens.shape[0] - # Detect RotatingKVCache and ensure tail tokens are included. - # Models with sliding window attention (e.g., GPT-OSS) use RotatingKVCache - # which evicts old entries. We must include the last `max_size` positions - # so sliding window layers have valid recent context for decode. + # Detect RotatingKVCache and ensure tail tokens are included only when the + # prompt actually exceeds the live cache window. If the full prompt still + # fits inside ``max_size`` there is no eviction yet, so forcing the entire + # tail back in would collapse sparse prefill into dense work. max_rotating_size = 0 for c in cache: if type(c).__name__ == "RotatingKVCache": max_rotating_size = max(max_rotating_size, getattr(c, "max_size", 0)) - if max_rotating_size > 0: + if max_rotating_size > 0 and M > max_rotating_size: tail_start = max(0, M - max_rotating_size) tail_indices = set(range(tail_start, M)) existing = set(selected_indices.tolist()) From 4c79879c4ba0cbfa104c212474235c670aa61809 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sun, 12 Apr 2026 01:41:07 -0500 Subject: [PATCH 39/45] feat: full sampling parameter support (top_k, min_p, presence_penalty, repetition_penalty) (#213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pass all OpenAI-compatible sampling parameters through to mlx-lm's make_sampler and make_logits_processors. Previously only temperature, top_p and max_tokens reached the engine — top_k, min_p, presence_penalty and repetition_penalty were silently dropped. Changes: - api/models.py: Add fields to ChatCompletionRequest and CompletionRequest - request.py: Add presence_penalty to SamplingParams dataclass - server.py: Extract and pass all params in every code path (6 locations), log all params on request - models/llm.py: Build sampler with top_k/min_p, build logits_processors for presence_penalty/repetition_penalty - engine/simple.py: Fix enable_thinking to read VLLM_MLX_ENABLE_THINKING env var instead of hardcoding based on model name Tested with all 4 Unsloth Qwen 3.5 sampling profiles on 122B model. --- vllm_mlx/api/models.py | 8 ++++++ vllm_mlx/models/llm.py | 59 +++++++++++++++++++++++++++++++++++++----- vllm_mlx/server.py | 27 ++++++++++++++++++- 3 files changed, 87 insertions(+), 7 deletions(-) diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index e450cd5bc..2e78d772d 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -158,6 +158,10 @@ class ChatCompletionRequest(BaseModel): messages: list[Message] temperature: float | None = None top_p: float | None = None + top_k: int | None = None + min_p: float | None = None + presence_penalty: float | None = None + repetition_penalty: float | None = None max_tokens: int | None = None stream: bool = False stream_options: StreamOptions | None = ( @@ -240,6 +244,10 @@ class CompletionRequest(BaseModel): prompt: str | list[str] temperature: float | None = None top_p: float | None = None + top_k: int | None = None + min_p: float | None = None + presence_penalty: float | None = None + repetition_penalty: float | None = None max_tokens: int | None = None stream: bool = False stop: list[str] | None = None diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 75bbab852..71f0af3ca 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -111,6 +111,8 @@ def _create_sampler( self, temperature: float = 0.7, top_p: float = 0.9, + top_k: int = 0, + min_p: float = 0.0, ): """Create a sampler for text generation.""" from mlx_lm.sample_utils import make_sampler @@ -118,16 +120,38 @@ def _create_sampler( return make_sampler( temp=temperature, top_p=top_p, + top_k=top_k, + min_p=min_p, ) + def _create_logits_processors( + self, + presence_penalty: float = 0.0, + repetition_penalty: float = 1.0, + ): + """Create logits processors for penalty-based sampling.""" + from mlx_lm.sample_utils import make_logits_processors + + processors = make_logits_processors( + repetition_penalty=( + repetition_penalty if repetition_penalty != 1.0 else None + ), + presence_penalty=presence_penalty if presence_penalty != 0.0 else None, + ) + return processors if processors else None + def generate( self, prompt: str, max_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, + top_k: int = 0, + min_p: float = 0.0, + presence_penalty: float = 0.0, repetition_penalty: float = 1.0, stop: list[str] | None = None, + **kwargs, ) -> GenerationOutput: """ Generate text from a prompt. @@ -137,7 +161,10 @@ def generate( max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (0 = greedy) top_p: Top-p (nucleus) sampling parameter - repetition_penalty: Penalty for repeating tokens + top_k: Top-k sampling (0 = disabled) + min_p: Minimum probability threshold + presence_penalty: Additive penalty for token presence + repetition_penalty: Multiplicative penalty for repeating tokens stop: List of stop sequences Returns: @@ -148,8 +175,11 @@ def generate( from mlx_lm import generate - # Create sampler with parameters - sampler = self._create_sampler(temperature, top_p) + # Create sampler and logits processors with full Unsloth params + sampler = self._create_sampler(temperature, top_p, top_k, min_p) + logits_processors = self._create_logits_processors( + presence_penalty, repetition_penalty + ) # Generate text output_text = generate( @@ -158,6 +188,7 @@ def generate( prompt=prompt, max_tokens=max_tokens, sampler=sampler, + logits_processors=logits_processors, verbose=False, ) @@ -179,8 +210,13 @@ def stream_generate( max_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, + top_k: int = 0, + min_p: float = 0.0, + presence_penalty: float = 0.0, repetition_penalty: float = 1.0, stop: list[str] | None = None, + logits_processors: list | None = None, + **kwargs, ) -> Iterator[StreamingOutput]: """ Stream text generation token by token. @@ -190,7 +226,10 @@ def stream_generate( max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (0 = greedy) top_p: Top-p (nucleus) sampling parameter - repetition_penalty: Penalty for repeating tokens + top_k: Top-k sampling (0 = disabled) + min_p: Minimum probability threshold + presence_penalty: Additive penalty for token presence + repetition_penalty: Multiplicative penalty for repeating tokens stop: List of stop sequences Yields: @@ -201,8 +240,15 @@ def stream_generate( from mlx_lm import stream_generate - # Create sampler with parameters - sampler = self._create_sampler(temperature, top_p) + # Create sampler and logits processors with full Unsloth params + sampler = self._create_sampler(temperature, top_p, top_k, min_p) + penalty_processors = self._create_logits_processors( + presence_penalty, repetition_penalty + ) + # Merge any externally-provided logits_processors with penalty processors + all_processors = None + if penalty_processors or logits_processors: + all_processors = (logits_processors or []) + (penalty_processors or []) # Count prompt tokens once upfront num_prompt_tokens = len(self.tokenizer.encode(prompt)) @@ -220,6 +266,7 @@ def stream_generate( prompt=prompt, max_tokens=max_tokens, sampler=sampler, + logits_processors=all_processors, **mtp_kwargs, ): token_count += 1 diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index e080be8ff..03da23cd9 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1310,6 +1310,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request): logger.info( f"[REQUEST] POST /v1/completions stream={request.stream} " f"max_tokens={request.max_tokens} temp={request.temperature} " + f"top_p={request.top_p} top_k={request.top_k} min_p={request.min_p} " + f"presence_penalty={request.presence_penalty} " + f"repetition_penalty={request.repetition_penalty} " f"prompt_chars={prompt_len} prompt_preview={prompt_preview!r}" ) @@ -1342,6 +1345,9 @@ async def create_completion(request: CompletionRequest, raw_request: Request): "max_tokens": request.max_tokens or _default_max_tokens, "temperature": _resolve_temperature(request.temperature), "top_p": _resolve_top_p(request.top_p), + "top_k": request.top_k or 0, + "min_p": request.min_p or 0.0, + "presence_penalty": request.presence_penalty or 0.0, "stop": request.stop, } if comp_rep_penalty is not None: @@ -1447,7 +1453,11 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re logger.info( f"[REQUEST] POST /v1/chat/completions stream={request.stream} " f"model={request.model!r} max_tokens={request.max_tokens} " - f"temp={request.temperature} msgs={n_msgs} roles={msg_roles} " + f"temp={request.temperature} top_p={request.top_p} " + f"top_k={request.top_k} min_p={request.min_p} " + f"presence_penalty={request.presence_penalty} " + f"repetition_penalty={request.repetition_penalty} " + f"msgs={n_msgs} roles={msg_roles} " f"total_chars={total_chars} tools={n_tools} " f"response_format={request.response_format}" ) @@ -1513,6 +1523,10 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re "max_tokens": request.max_tokens or _default_max_tokens, "temperature": _resolve_temperature(request.temperature), "top_p": _resolve_top_p(request.top_p), + "top_k": request.top_k or 0, + "min_p": request.min_p or 0.0, + "presence_penalty": request.presence_penalty or 0.0, + "repetition_penalty": request.repetition_penalty or 1.0, } if rep_penalty is not None: chat_kwargs["repetition_penalty"] = rep_penalty @@ -1795,6 +1809,10 @@ async def create_anthropic_message( "max_tokens": openai_request.max_tokens or _default_max_tokens, "temperature": openai_request.temperature, "top_p": openai_request.top_p, + "top_k": openai_request.top_k or 0, + "min_p": openai_request.min_p or 0.0, + "presence_penalty": openai_request.presence_penalty or 0.0, + "repetition_penalty": openai_request.repetition_penalty or 1.0, } if openai_request.tools and openai_request.tool_choice != "none": @@ -2038,6 +2056,10 @@ async def _stream_anthropic_messages( "max_tokens": openai_request.max_tokens or _default_max_tokens, "temperature": openai_request.temperature, "top_p": openai_request.top_p, + "top_k": openai_request.top_k or 0, + "min_p": openai_request.min_p or 0.0, + "presence_penalty": openai_request.presence_penalty or 0.0, + "repetition_penalty": openai_request.repetition_penalty or 1.0, } if openai_request.tools and openai_request.tool_choice != "none": @@ -2267,6 +2289,9 @@ async def stream_completion( "max_tokens": request.max_tokens or _default_max_tokens, "temperature": _resolve_temperature(request.temperature), "top_p": _resolve_top_p(request.top_p), + "top_k": request.top_k or 0, + "min_p": request.min_p or 0.0, + "presence_penalty": request.presence_penalty or 0.0, "stop": request.stop, } if repetition_penalty is not None: From d2e7f8883c8f9b2f8d057e8009b6d69a5bf53439 Mon Sep 17 00:00:00 2001 From: Wayner Barrios Date: Sun, 12 Apr 2026 11:29:29 -0500 Subject: [PATCH 40/45] compatibility with mlx-lm 0.31.x (#294) * compatibility with mlx-lm 0.31.x BatchGenerator API The backport in f61d34e assumed internal BatchGenerator APIs that were refactored in mlx-lm 0.31.x. This breaks bench and serve for all users on v0.2.7. Changes: - Set prompt_progress_callback as instance attribute instead of passing it to BatchGenerator constructor (not a valid parameter) - Guard _install_chunked_prefill with hasattr check and log warning when skipped (relies on removed _process_prompts, active_batch) - Handle next() returning (prompt_responses, generation_responses) tuple instead of flat list - Add hasattr guard for active_batch in periodic cache eval Benchmark (Llama-3.2-1B-Instruct-4bit, mlx-lm 0.31.2): Total time: 2.38s Prompts: 10 Prompts/second: 4.19 Total prompt tokens: 80 Total completion tokens: 960 Total tokens: 1040 Tokens/second: 402.52 Throughput: 436.06 tok/s Closes #293 * bump to 0.2.8 --- pyproject.toml | 2 +- vllm_mlx/scheduler.py | 30 +++++++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ed5b2deff..1191954c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vllm-mlx" -version = "0.2.7" +version = "0.2.8" description = "vLLM-like inference for Apple Silicon - GPU-accelerated Text, Image, Video & Audio on Mac" readme = "README.md" license = {text = "Apache-2.0"} diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 32f024917..6c66878e7 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -1257,15 +1257,25 @@ def _prefill_progress(progress_list): prefill_batch_size=self.config.prefill_batch_size, completion_batch_size=self.config.completion_batch_size, prefill_step_size=self.config.prefill_step_size, - prompt_progress_callback=_prefill_progress, ) + # Set callback as attribute — used by _install_chunked_prefill + # monkey-patch. Not a BatchGenerator constructor parameter. + bg.prompt_progress_callback = _prefill_progress # Install chunked prefill when explicitly configured OR when # memory-aware cache is active (needed for prefix_boundary saves # in agentic multi-turn workloads with hybrid Mamba+Transformer models). chunked_budget = self.config.chunked_prefill_tokens need_chunked = chunked_budget > 0 or self.memory_aware_cache is not None - if need_chunked: + + # The chunked prefill monkey-patch relies on BatchGenerator internals + # (_process_prompts, active_batch, _step, etc.) that were refactored + # in mlx-lm 0.31.x. Skip gracefully when the required API is absent. + chunked_compatible = hasattr(bg, "_process_prompts") and hasattr( + bg, "active_batch" + ) + + if need_chunked and chunked_compatible: if chunked_budget <= 0: # No explicit budget — use a very large value so normal # prompts pass through unchanged. Prefix boundary splits @@ -1288,6 +1298,12 @@ def _prefill_progress(progress_list): uid_to_request_id=self.uid_to_request_id, requests=self.requests, ) + elif need_chunked and not chunked_compatible: + logger.warning( + "Chunked prefill disabled: mlx-lm BatchGenerator lacks required " + "internals (_process_prompts, active_batch). Upgrade mlx-lm or " + "check compatibility." + ) # Install MTP if the model supports it if self.config.enable_mtp: @@ -2335,9 +2351,16 @@ def step(self, max_retries: int = 1) -> SchedulerOutput: # Run generation step if we have running requests if self.batch_generator is not None and self.running: - responses = self.batch_generator.next() + result = self.batch_generator.next() output.has_work = True + # mlx-lm >=0.31.x returns (prompt_responses, generation_responses); + # older versions returned a flat list. + if isinstance(result, tuple): + responses = result[1] # generation_responses only + else: + responses = result + if responses: outputs, finished_ids = self._process_batch_responses(responses) output.outputs = outputs @@ -2404,6 +2427,7 @@ def step(self, max_retries: int = 1) -> SchedulerOutput: # Evaluate batch tokens to collapse lazy concatenation chains if ( self.batch_generator is not None + and hasattr(self.batch_generator, "active_batch") and self.batch_generator.active_batch is not None and hasattr(self.batch_generator.active_batch, "tokens") ): From 7cfae144bd667c292fed10bb05a1561b07b7bb1d Mon Sep 17 00:00:00 2001 From: Kolden Prue <74475667+kol22@users.noreply.github.com> Date: Sun, 12 Apr 2026 15:17:53 -0500 Subject: [PATCH 41/45] feat: add --prefill-step-size CLI flag (#105) * feat: add --prefill-step-size CLI flag Expose prefill_step_size as a CLI argument for both serve and bench commands. Default of 0 means "use engine default" (2048 for LLM, 1024 for MLLM), preserving existing behavior. Vision models routinely exceed 1024 tokens per prompt (images alone contribute 1400+), hitting the MLLM batch generator's safe limit. This flag lets users raise the limit without patching source code. * Clarify MLLM prefill step override behavior * refactor: clarify MLLM prefill CLI flag and validate override --- vllm_mlx/cli.py | 10 ++++++++++ vllm_mlx/engine/batched.py | 13 ++++++++++++- vllm_mlx/scheduler.py | 6 ++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index ee6d7cba5..07dd17fe1 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -172,6 +172,9 @@ def serve_command(args): kv_cache_quantization_bits=args.kv_cache_quantization_bits, kv_cache_quantization_group_size=args.kv_cache_quantization_group_size, kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens, + mllm_prefill_step_size=( + args.mllm_prefill_step_size if args.mllm_prefill_step_size > 0 else None + ), ) print("Mode: Continuous batching (for multiple concurrent users)") @@ -289,6 +292,7 @@ async def run_benchmark(): kv_cache_quantization_group_size=args.kv_cache_quantization_group_size, kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens, ) + engine_config = EngineConfig( model_name=args.model, scheduler_config=scheduler_config, @@ -668,6 +672,12 @@ def create_parser() -> argparse.ArgumentParser: serve_parser.add_argument( "--completion-batch-size", type=int, default=32, help="Completion batch size" ) + serve_parser.add_argument( + "--mllm-prefill-step-size", + type=int, + default=0, + help="Override MLLM prefill-step guard (0=use MLLM default: 1024)", + ) serve_parser.add_argument( "--enable-prefix-cache", action="store_true", diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 34b39c84d..2f29d66ff 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -237,6 +237,15 @@ async def _start_mllm(self) -> None: kv_group_size = getattr( self._scheduler_config, "kv_cache_quantization_group_size", 64 ) + + # Forward MLLM prefill-step override only when explicitly configured. + # This keeps default behavior unchanged for MLLM (1024) unless set. + prefill_step_size = getattr( + self._scheduler_config, "mllm_prefill_step_size", None + ) + mllm_extra = {} + if prefill_step_size is not None: + mllm_extra["prefill_step_size"] = prefill_step_size mllm_config = MLLMSchedulerConfig( max_num_seqs=max_num_seqs, prefill_batch_size=prefill_batch_size, @@ -249,6 +258,7 @@ async def _start_mllm(self) -> None: kv_cache_quantization=kv_quant, kv_cache_quantization_bits=kv_bits, kv_cache_quantization_group_size=kv_group_size, + **mllm_extra, ) # Create and start MLLM scheduler @@ -262,7 +272,8 @@ async def _start_mllm(self) -> None: logger.info( f"MLLM Scheduler started with continuous batching: " f"max_num_seqs={max_num_seqs}, prefill_batch={prefill_batch_size}, " - f"completion_batch={completion_batch_size}" + f"completion_batch={completion_batch_size}, " + f"prefill_step_size={mllm_config.prefill_step_size}" ) def _inject_mtp_mllm(self) -> None: diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 6c66878e7..c706c85b5 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -62,6 +62,8 @@ class SchedulerConfig: prefill_batch_size: int = 8 completion_batch_size: int = 32 prefill_step_size: int = 2048 + # Optional override for MLLM prefill guard (None = use MLLM default). + mllm_prefill_step_size: Optional[int] = None # Prefix cache settings enable_prefix_cache: bool = True @@ -102,6 +104,10 @@ class SchedulerConfig: mtp_num_draft_tokens: int = 1 # Number of draft tokens from MTP head mtp_optimistic: bool = False # Skip acceptance check for max speed + def __post_init__(self) -> None: + if self.mllm_prefill_step_size is not None and self.mllm_prefill_step_size <= 0: + raise ValueError("mllm_prefill_step_size must be > 0 when provided") + @dataclass class SchedulerOutput: From ab68d94c4075e060bb438bda0c87febb037f1125 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sun, 12 Apr 2026 15:23:42 -0500 Subject: [PATCH 42/45] refactor(simple): SimpleEngine.generate() becomes thin accumulator over stream_generate (#266) stream_generate() is the only code path that consumes per-request SpecPrefill overrides (`specprefill`, `specprefill_keep_pct`) and routes through _stream_generate_specprefill() when engaged. The prior direct self._model.generate() path silently dropped those overrides: server.py's create_completion() extracts them from extra_body and forwards to engine.generate(), engine.generate() forwards via **kwargs to _model.generate(), but _model.generate() (mlx_lm.generate) does not consume them. Non-streaming /v1/completions clients that sent `{"extra_body": {"specprefill": true}}` had their overrides silently no-op'd. Fix: make SimpleEngine.generate() a thin accumulator that iterates self.stream_generate() and returns the last GenerationOutput. Matches the pattern PR #222 established for tool-enabled chat(). Non-streaming clients now get: - SpecPrefill engagement when `specprefill=true` is set (top-level or extra_body fallback via whatever helper server.py uses) - Accurate `prompt_tokens` reporting (the old path returned 0 because mlx_lm.generate never populates it) - Chat-template and reasoning-parser behavior consistent with the streaming path - Same thread-safety (stream_generate holds self._generation_lock around the MLX call) Scope: only generate() changes. chat() stays on its current path; extending chat() to the full accumulator pattern is a separate follow-up on top of PR #222. Tests: - New test_generate_accumulates_over_stream_generate stubs stream_generate with an async generator, calls generate() with per-request specprefill kwargs, and asserts: * final output fields (text, tokens, prompt_tokens, completion_tokens, finish_reason, finished) match the last yielded chunk * specprefill / specprefill_keep_pct were forwarded through to stream_generate - New test_generate_empty_stream_returns_safe_default covers the empty-stream edge case (returns GenerationOutput(text="", finish_reason="stop") rather than raising) - Existing mock_model fixture extended with stream_generate tracking so test_lock_prevents_concurrent_generate still observes serialization through the new accumulator path Verified live against Qwen3.5-4B SimpleEngine + SpecPrefill on M2 Ultra with a ~6K token prompt and extra_body.specprefill=true forcing SpecPrefill below the 8192 threshold: SpecPrefill: scored 6007 tokens in 5.3s, sparse prefill 1815/6007 (keep=30%) in 1.1s prompt_tokens reporting is now 6007 (was always 0 before). Related: companion PR #265 (CompletionRequest schema + server-side extract_body -> gen_kwargs threading) which opens the wire from /v1/completions to engine.generate(). This PR closes the wire on the engine side. --- tests/test_simple_engine.py | 104 ++++++++++++++++++++++++++++++++++++ vllm_mlx/engine/simple.py | 39 +++++++++----- 2 files changed, 131 insertions(+), 12 deletions(-) diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index b06b48971..2cf4a6daf 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -42,6 +42,27 @@ def generate_side_effect(**kwargs): return result model.generate = MagicMock(side_effect=generate_side_effect) + + # stream_generate tracks concurrency the same way so tests that + # exercise SimpleEngine.generate() (which is now an accumulator + # over stream_generate) see the same serialization behavior. + def stream_generate_side_effect(**kwargs): + model._concurrent_count += 1 + model._max_concurrent = max(model._max_concurrent, model._concurrent_count) + import time + + time.sleep(0.05) + model._concurrent_count -= 1 + chunk = MagicMock() + chunk.text = "test response" + chunk.tokens = [1, 2, 3] + chunk.finished = True + chunk.finish_reason = "stop" + chunk.prompt_tokens = 3 + chunk.completion_tokens = 3 + yield chunk + + model.stream_generate = MagicMock(side_effect=stream_generate_side_effect) return model @pytest.fixture @@ -266,3 +287,86 @@ async def test_requests_complete_in_order(self, mock_model): assert len(results) == 3 for result in results: assert result.text == "test response" + + @pytest.mark.asyncio + async def test_generate_accumulates_over_stream_generate(self): + """generate() should iterate stream_generate() and return the last + yielded GenerationOutput, forwarding per-request kwargs (including + SpecPrefill overrides) through so they reach _stream_generate_specprefill. + """ + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.engine.simple import SimpleEngine + + captured_kwargs = {} + + async def fake_stream_generate(**kwargs): + captured_kwargs.update(kwargs) + # First chunk: mid-generation + yield GenerationOutput( + text="partial", + new_text="partial", + tokens=[1, 2], + prompt_tokens=11, + completion_tokens=2, + finished=False, + finish_reason=None, + ) + # Final chunk: finished + yield GenerationOutput( + text="partial final", + new_text=" final", + tokens=[1, 2, 3], + prompt_tokens=11, + completion_tokens=3, + finished=True, + finish_reason="stop", + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + output = await engine.generate( + prompt="say hi", + max_tokens=16, + temperature=0.6, + top_p=0.95, + specprefill=True, + specprefill_keep_pct=0.2, + ) + + # Accumulator returns the last GenerationOutput's fields + assert output.text == "partial final" + assert output.tokens == [1, 2, 3] + assert output.prompt_tokens == 11 + assert output.completion_tokens == 3 + assert output.finish_reason == "stop" + assert output.finished is True + + # Per-request SpecPrefill overrides reach stream_generate + assert captured_kwargs.get("prompt") == "say hi" + assert captured_kwargs.get("max_tokens") == 16 + assert captured_kwargs.get("specprefill") is True + assert captured_kwargs.get("specprefill_keep_pct") == 0.2 + + @pytest.mark.asyncio + async def test_generate_empty_stream_returns_safe_default(self): + """If stream_generate yields nothing, generate() returns an empty + stop-reason GenerationOutput rather than raising. + """ + from vllm_mlx.engine.simple import SimpleEngine + + async def empty_stream_generate(**kwargs): + return + yield # unreachable; makes this a generator + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine.stream_generate = empty_stream_generate # type: ignore[method-assign] + + output = await engine.generate(prompt="anything", max_tokens=5) + + assert output.text == "" + assert output.finish_reason == "stop" diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index b93c20c0a..30e4b71f8 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -256,13 +256,27 @@ async def generate( """ Generate a complete response (non-streaming). + Thin accumulator over stream_generate(). stream_generate() is the + only code path that consumes per-request SpecPrefill overrides + (`specprefill`, `specprefill_keep_pct`) and routes through + _stream_generate_specprefill() when engaged. The prior direct + self._model.generate() path silently dropped those overrides for + non-streaming /v1/completions callers, so extra_body.specprefill + was advertised by the server but had no effect on this route. + + By iterating stream_generate() and returning the last + GenerationOutput, non-streaming clients get the same SpecPrefill + engagement, accurate prompt_tokens reporting, and per-request + override support as streaming clients. + Args: prompt: Input text max_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Top-p sampling stop: Stop sequences - **kwargs: Additional model-specific parameters + **kwargs: Additional parameters forwarded to stream_generate, + including per-request `specprefill` / `specprefill_keep_pct` Returns: GenerationOutput with complete text @@ -270,27 +284,28 @@ async def generate( if not self._loaded: await self.start() - output = await self._run_blocking_serialized( - self._model.generate, + last_output: GenerationOutput | None = None + async for output in self.stream_generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stop=stop, **kwargs, - ) + ): + last_output = output - # Clean output text - text = clean_output_text(output.text) + if last_output is None: + return GenerationOutput(text="", finish_reason="stop") + text = clean_output_text(last_output.text) return GenerationOutput( text=text, - tokens=getattr(output, "tokens", []), - prompt_tokens=getattr(output, "prompt_tokens", 0), - completion_tokens=getattr( - output, "completion_tokens", len(getattr(output, "tokens", [])) - ), - finish_reason=output.finish_reason, + tokens=list(last_output.tokens), + prompt_tokens=last_output.prompt_tokens, + completion_tokens=last_output.completion_tokens, + finish_reason=last_output.finish_reason, + finished=True, ) async def stream_generate( From 0cafaabe7b540603adda12b50eba495fb9dfb740 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sun, 12 Apr 2026 15:28:21 -0500 Subject: [PATCH 43/45] feat(api): per-request SpecPrefill overrides on /v1/completions (#265) * feat(api): per-request SpecPrefill overrides on /v1/completions ChatCompletionRequest already accepts per-request `specprefill` and `specprefill_keep_pct` overrides, and /v1/chat/completions threads them into engine.chat(). CompletionRequest does not, so /v1/completions clients cannot opt a single request into (or out of) SpecPrefill, nor tune the keep percentage per request. Changes: - vllm_mlx/api/models.py: add specprefill and specprefill_keep_pct to CompletionRequest, matching the existing ChatCompletionRequest fields. - vllm_mlx/server.py::create_completion: extract both and thread into engine.generate(**gen_kwargs), mirroring the pattern used at server.py:1421 in create_chat_completion. - vllm_mlx/server.py::stream_completion: apply the same extraction so streaming /v1/completions clients get the same control. Both new fields default to None, so existing behavior is unchanged for clients that do not set them. No schema changes to ChatCompletionRequest. No engine-side changes needed: SimpleEngine.stream_generate already consumes these kwargs (see simple.py:307-308). * style(server): align completions kwargs handling --- vllm_mlx/api/models.py | 4 ++++ vllm_mlx/server.py | 24 ++++++++++++++++++------ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index 2e78d772d..eccde0ce0 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -255,6 +255,10 @@ class CompletionRequest(BaseModel): repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes) # Request timeout in seconds (None = use server default) timeout: float | None = None + # SpecPrefill: per-request enable/disable (None = server decides) + specprefill: bool | None = None + # SpecPrefill: per-request keep percentage (0.0-1.0, None = use server default) + specprefill_keep_pct: float | None = None class CompletionChoice(BaseModel): diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 03da23cd9..89b75f062 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1341,7 +1341,8 @@ async def create_completion(request: CompletionRequest, raw_request: Request): total_prompt_tokens = 0 for i, prompt in enumerate(prompts): - gen_kwargs = { + generate_kwargs = { + "prompt": prompt, "max_tokens": request.max_tokens or _default_max_tokens, "temperature": _resolve_temperature(request.temperature), "top_p": _resolve_top_p(request.top_p), @@ -1351,9 +1352,14 @@ async def create_completion(request: CompletionRequest, raw_request: Request): "stop": request.stop, } if comp_rep_penalty is not None: - gen_kwargs["repetition_penalty"] = comp_rep_penalty + generate_kwargs["repetition_penalty"] = comp_rep_penalty + if request.specprefill is not None: + generate_kwargs["specprefill"] = request.specprefill + if request.specprefill_keep_pct is not None: + generate_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + output = await _wait_with_disconnect( - engine.generate(prompt=prompt, **gen_kwargs), + engine.generate(**generate_kwargs), raw_request, timeout=timeout, ) @@ -2285,7 +2291,8 @@ async def stream_completion( repetition_penalty: float | None = None, ) -> AsyncIterator[str]: """Stream completion response.""" - gen_kwargs = { + generate_kwargs = { + "prompt": prompt, "max_tokens": request.max_tokens or _default_max_tokens, "temperature": _resolve_temperature(request.temperature), "top_p": _resolve_top_p(request.top_p), @@ -2295,8 +2302,13 @@ async def stream_completion( "stop": request.stop, } if repetition_penalty is not None: - gen_kwargs["repetition_penalty"] = repetition_penalty - async for output in engine.stream_generate(prompt=prompt, **gen_kwargs): + generate_kwargs["repetition_penalty"] = repetition_penalty + if request.specprefill is not None: + generate_kwargs["specprefill"] = request.specprefill + if request.specprefill_keep_pct is not None: + generate_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + + async for output in engine.stream_generate(**generate_kwargs): data = { "id": f"cmpl-{uuid.uuid4().hex[:8]}", "object": "text_completion", From 88b60b6d5ea3145964ed8faca8a2701ac14d6265 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sun, 12 Apr 2026 15:30:28 -0500 Subject: [PATCH 44/45] test: use anyio in async regression slice (#288) --- tests/test_batching.py | 2 +- tests/test_batching_deterministic.py | 22 +++++++++++----------- tests/test_continuous_batching.py | 2 +- tests/test_server.py | 12 ++++-------- tests/test_streaming_latency.py | 2 +- 5 files changed, 18 insertions(+), 22 deletions(-) diff --git a/tests/test_batching.py b/tests/test_batching.py index fc2eefcde..6cb536aa5 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -790,7 +790,7 @@ def test_multiple_concurrent_requests(self, model_and_tokenizer): assert len(finished) == len(prompts), f"Only {len(finished)} requests finished" -@pytest.mark.asyncio +@pytest.mark.anyio class TestEngineAsync: """Async tests for the engine.""" diff --git a/tests/test_batching_deterministic.py b/tests/test_batching_deterministic.py index 52b0fd49b..0e6072ce9 100644 --- a/tests/test_batching_deterministic.py +++ b/tests/test_batching_deterministic.py @@ -37,7 +37,7 @@ def sampling_params(): class TestDeterministicSingleRequest: """Test single request determinism.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_same_prompt_same_output(self, model_and_tokenizer, sampling_params): """Same prompt should produce same output with temp=0.""" from vllm_mlx import AsyncEngineCore, EngineConfig, SchedulerConfig @@ -68,7 +68,7 @@ async def test_same_prompt_same_output(self, model_and_tokenizer, sampling_param assert len(outputs) == 3 assert outputs[0] == outputs[1] == outputs[2], f"Outputs differ: {outputs}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_token_streaming_order(self, model_and_tokenizer, sampling_params): """Tokens should stream in order.""" from vllm_mlx import AsyncEngineCore @@ -94,7 +94,7 @@ async def test_token_streaming_order(self, model_and_tokenizer, sampling_params) class TestDeterministicConcurrentRequests: """Test concurrent request handling with determinism.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_concurrent_same_prompt(self, model_and_tokenizer): """Multiple concurrent requests with same prompt should get same output.""" from vllm_mlx import ( @@ -137,7 +137,7 @@ async def get_output(rid): # All should be the same assert all(r == results[0] for r in results), f"Outputs differ: {results}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_concurrent_different_prompts(self, model_and_tokenizer): """Different prompts should get different (but deterministic) outputs.""" from vllm_mlx import ( @@ -191,7 +191,7 @@ async def get_output(rid): class TestBatchingPerformance: """Test that batching improves throughput.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_batched_faster_than_sequential(self, model_and_tokenizer): """Batched requests should be faster than sequential.""" from vllm_mlx import ( @@ -274,7 +274,7 @@ async def get_output(rid): class TestRequestManagement: """Test request lifecycle management.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_abort_request(self, model_and_tokenizer): """Test aborting a request mid-generation.""" from vllm_mlx import AsyncEngineCore, SamplingParams @@ -304,7 +304,7 @@ async def test_abort_request(self, model_and_tokenizer): stats = engine.get_stats() assert stats["active_requests"] == 0 - @pytest.mark.asyncio + @pytest.mark.anyio async def test_engine_stats(self, model_and_tokenizer): """Test engine statistics tracking.""" from vllm_mlx import ( @@ -343,7 +343,7 @@ async def test_engine_stats(self, model_and_tokenizer): class TestSchedulerPolicy: """Test scheduler policies.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_fcfs_ordering(self, model_and_tokenizer): """Test that FCFS policy processes requests in order.""" from vllm_mlx import ( @@ -396,7 +396,7 @@ async def track_completion(rid, name): class TestEdgeCases: """Test edge cases and error handling.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_empty_prompt(self, model_and_tokenizer): """Test handling of empty prompt.""" from vllm_mlx import AsyncEngineCore, SamplingParams @@ -414,7 +414,7 @@ async def test_empty_prompt(self, model_and_tokenizer): assert out.finished break - @pytest.mark.asyncio + @pytest.mark.anyio async def test_very_short_max_tokens(self, model_and_tokenizer): """Test with max_tokens=1.""" from vllm_mlx import AsyncEngineCore, SamplingParams @@ -436,7 +436,7 @@ async def test_very_short_max_tokens(self, model_and_tokenizer): # Should generate exactly 1 token assert token_count == 1 - @pytest.mark.asyncio + @pytest.mark.anyio async def test_multiple_start_stop(self, model_and_tokenizer): """Test starting and stopping engine multiple times.""" from vllm_mlx import AsyncEngineCore, SamplingParams diff --git a/tests/test_continuous_batching.py b/tests/test_continuous_batching.py index fd10fe808..0e196a226 100644 --- a/tests/test_continuous_batching.py +++ b/tests/test_continuous_batching.py @@ -53,7 +53,7 @@ def test_scheduler_config_batching_params(self): assert config.completion_batch_size == 32 -@pytest.mark.asyncio +@pytest.mark.anyio class TestContinuousBatchingIntegration: """Integration tests requiring actual model loading.""" diff --git a/tests/test_server.py b/tests/test_server.py index c0450548d..c20957211 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -629,9 +629,7 @@ def test_verify_api_key_rejects_invalid(self): # Should raise HTTPException with 401 with pytest.raises(HTTPException) as exc_info: - asyncio.get_event_loop().run_until_complete( - server.verify_api_key(credentials) - ) + asyncio.run(server.verify_api_key(credentials)) assert exc_info.value.status_code == 401 assert "Invalid API key" in str(exc_info.value.detail) @@ -657,9 +655,7 @@ def test_verify_api_key_accepts_valid(self): ) # Should not raise any exception - result = asyncio.get_event_loop().run_until_complete( - server.verify_api_key(credentials) - ) + result = asyncio.run(server.verify_api_key(credentials)) # verify_api_key returns True on success (no exception raised) assert result is True or result is None finally: @@ -716,7 +712,7 @@ def test_rate_limiter_window_cleanup(self): class TestStreamChatCompletion: """Tests for streaming chat completion behavior.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_reasoning_stream_emits_structured_tool_calls(self, monkeypatch): """Tool markup after should emit tool_calls chunks.""" from vllm_mlx.engine.base import GenerationOutput @@ -837,7 +833,7 @@ def extract_tool_calls_streaming( "total_tokens": 10, } - @pytest.mark.asyncio + @pytest.mark.anyio async def test_reasoning_stream_skips_tool_parser_until_markup_appears( self, monkeypatch ): diff --git a/tests/test_streaming_latency.py b/tests/test_streaming_latency.py index cae95f5fb..116ee9dfa 100644 --- a/tests/test_streaming_latency.py +++ b/tests/test_streaming_latency.py @@ -206,7 +206,7 @@ async def run_benchmark( print(f"Throughput: {throughput:.1f} tokens/sec") -@pytest.mark.asyncio +@pytest.mark.anyio async def test_output_collector(): """Unit test for RequestOutputCollector.""" import sys From 2cce3daceb6b3763ed7672e77aab6c877dac198c Mon Sep 17 00:00:00 2001 From: Thump604 Date: Mon, 13 Apr 2026 19:44:05 -0500 Subject: [PATCH 45/45] test: add tokenizer fallback regression coverage (#287) --- tests/test_tokenizer_utils.py | 54 +++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/test_tokenizer_utils.py diff --git a/tests/test_tokenizer_utils.py b/tests/test_tokenizer_utils.py new file mode 100644 index 000000000..0b046e0c7 --- /dev/null +++ b/tests/test_tokenizer_utils.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for tokenizer utility helpers.""" + +import types +from unittest.mock import patch + + +def test_load_model_with_fallback_returns_successful_load_result(): + from vllm_mlx.utils.tokenizer import load_model_with_fallback + + fake_model = object() + fake_tokenizer = object() + fake_mlx_lm = types.SimpleNamespace( + load=lambda *args, **kwargs: (fake_model, fake_tokenizer) + ) + + with ( + patch("vllm_mlx.utils.tokenizer._needs_tokenizer_fallback", return_value=False), + patch("vllm_mlx.utils.tokenizer._needs_strict_false", return_value=False), + patch("vllm_mlx.utils.tokenizer._try_inject_mtp_post_load"), + patch.dict("sys.modules", {"mlx_lm": fake_mlx_lm}), + ): + model, tokenizer = load_model_with_fallback("mlx-community/Qwen3.5-4B") + + assert model is fake_model + assert tokenizer is fake_tokenizer + + +def test_load_model_with_fallback_uses_tokenizer_fallback_for_tokenizer_errors(): + from vllm_mlx.utils.tokenizer import load_model_with_fallback + + fake_model = object() + fake_tokenizer = object() + + def _raise(*args, **kwargs): + raise ValueError("Tokenizer class Foo does not exist") + + fake_mlx_lm = types.SimpleNamespace(load=_raise) + + with ( + patch("vllm_mlx.utils.tokenizer._needs_tokenizer_fallback", return_value=False), + patch("vllm_mlx.utils.tokenizer._needs_strict_false", return_value=False), + patch("vllm_mlx.utils.tokenizer._try_inject_mtp_post_load"), + patch( + "vllm_mlx.utils.tokenizer._load_with_tokenizer_fallback", + return_value=(fake_model, fake_tokenizer), + ) as fallback, + patch.dict("sys.modules", {"mlx_lm": fake_mlx_lm}), + ): + model, tokenizer = load_model_with_fallback("example/model") + + fallback.assert_called_once_with("example/model") + assert model is fake_model + assert tokenizer is fake_tokenizer