diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..d9a10c0d8 --- /dev/null +++ b/LICENSE @@ -0,0 +1,176 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/README.md b/README.md index b77d116dd..7fc61fe46 100644 --- a/README.md +++ b/README.md @@ -398,4 +398,5 @@ If you use vLLM-MLX in your research or project, please cite: - [mlx-vlm](https://github.com/Blaizzy/mlx-vlm) - Vision-language models - [mlx-audio](https://github.com/Blaizzy/mlx-audio) - Text-to-Speech and Speech-to-Text - [mlx-embeddings](https://github.com/Blaizzy/mlx-embeddings) - Text embeddings +- [Rapid-MLX](https://github.com/raullenchai/Rapid-MLX) - Community fork of vllm-mlx - [vLLM](https://github.com/vllm-project/vllm) - High-throughput LLM serving diff --git a/benchmarks/bench_reasoning_parser.py b/benchmarks/bench_reasoning_parser.py new file mode 100644 index 000000000..c7a2ba13f --- /dev/null +++ b/benchmarks/bench_reasoning_parser.py @@ -0,0 +1,55 @@ +"""Benchmark: reasoning parser streaming performance. + +Measures per-token overhead of extract_reasoning_streaming() at various +output lengths. Demonstrates the difference between O(N²) accumulated +text scanning and O(1) state-machine tracking. + +Usage: + python benchmarks/bench_reasoning_parser.py +""" + +import time + +from vllm_mlx.reasoning.qwen3_parser import Qwen3ReasoningParser + + +def bench_streaming(parser, n_tokens: int, label: str) -> float: + """Simulate n_tokens of streaming through the parser. Returns total ms.""" + parser.reset_state() + + # Simulate: + N reasoning tokens + + 10 content tokens + tokens = [""] + tokens += [f"word{i} " for i in range(n_tokens)] + tokens += [""] + tokens += [f"answer{i} " for i in range(10)] + + accumulated = "" + start = time.perf_counter() + for tok in tokens: + prev = accumulated + accumulated += tok + parser.extract_reasoning_streaming(prev, accumulated, tok) + elapsed = (time.perf_counter() - start) * 1000 + + print(f" {label}: {n_tokens:>6} tokens -> {elapsed:>8.2f}ms " + f"({elapsed / (n_tokens + 11):.3f}ms/tok)") + return elapsed + + +def main(): + parser = Qwen3ReasoningParser() + + print("Reasoning parser streaming benchmark") + print("=" * 60) + print() + + for n in [50, 100, 200, 500, 1000, 2000, 5000]: + bench_streaming(parser, n, f"{n} tokens") + + print() + print("At 50 tok/s, per-token budget is 20ms.") + print("Parser overhead should be <0.1ms/tok to be negligible.") + + +if __name__ == "__main__": + main() diff --git a/docs/reference/models.md b/docs/reference/models.md index a45550e4d..d378de003 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -12,7 +12,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun | Mistral / Devstral | 7B, Mixtral 8x7B | 4-bit, 8-bit | | Qwen2/Qwen3 | 0.5B to 72B | Various | | DeepSeek V3, R1 | 7B, 33B, 67B | 4-bit | -| Gemma 2, 3 | 2B, 9B, 27B | 4-bit | +| Gemma 2, 3, 4 | 2B, 9B, 27B | 4-bit | | GLM-4.7 | Flash, Base | 4-bit, 8-bit | | Kimi K2 | Various | 4-bit | | Phi-3 | 3.8B, 14B | 4-bit | @@ -35,6 +35,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun | **Qwen-VL** | `Qwen3-VL-4B-Instruct-3bit`, `Qwen3-VL-8B-Instruct-4bit`, `Qwen2-VL-2B/7B-Instruct-4bit` | | **LLaVA** | `llava-1.5-7b-4bit`, `llava-v1.6-mistral-7b-4bit`, `llava-llama-3-8b-v1_1-4bit` | | **Idefics** | `Idefics3-8B-Llama3-4bit`, `idefics2-8b-4bit` | +| **Gemma 4** | `gemma-4-e2b-it-mxfp4` (vision + audio) | | **PaliGemma** | `paligemma2-3b-mix-224-4bit`, `paligemma-3b-mix-224-8bit` | | **Pixtral** | `pixtral-12b-4bit`, `pixtral-12b-8bit` | | **Molmo** | `Molmo-7B-D-0924-4bit`, `Molmo-7B-D-0924-8bit` | @@ -72,7 +73,7 @@ vllm-mlx auto-detects multimodal models by name patterns: - Contains "VL", "Vision", "vision" - Contains "llava", "idefics", "paligemma" - Contains "pixtral", "molmo", "deepseek-vl" -- Contains "MedGemma", "Gemma-3" (vision variants) +- Contains "MedGemma", "Gemma-3", "Gemma-4" (multimodal variants) ## Using Models diff --git a/pyproject.toml b/pyproject.toml index 6ccc45282..1191954c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "vllm-mlx" -version = "0.2.7" +version = "0.2.8" description = "vLLM-like inference for Apple Silicon - GPU-accelerated Text, Image, Video & Audio on Mac" readme = "README.md" license = {text = "Apache-2.0"} @@ -30,7 +30,7 @@ classifiers = [ dependencies = [ "mlx>=0.29.0", "mlx-lm>=0.31.0", # 0.31+ required for ArraysCache native batching (hybrid models) - "mlx-vlm>=0.1.0", # VLM support + "mlx-vlm>=0.4.3", # 0.4.3+ required for Gemma 4 support "transformers>=5.0.0", # mlx-lm 0.30.5+ requires transformers 5.0 (rc3 bug fixed in stable) "tokenizers>=0.19.0", "huggingface-hub>=0.23.0", @@ -44,7 +44,7 @@ dependencies = [ # Video processing for VLM "opencv-python>=4.8.0", # Vision processor (required for transformers AutoProcessor) - "torchvision>=0.18.0", + "torchvision>=0.21.0", # Resource monitoring "psutil>=5.9.0", # Server @@ -75,7 +75,7 @@ vllm = [ ] vision = [ "torch>=2.3.0", - "torchvision>=0.18.0", + "torchvision>=0.21.0", ] # Audio dependencies for TTS/STT (mlx-audio) audio = [ diff --git a/scripts/add_mtp_weights_qwen35.py b/scripts/add_mtp_weights_qwen35.py new file mode 100644 index 000000000..1044dc894 --- /dev/null +++ b/scripts/add_mtp_weights_qwen35.py @@ -0,0 +1,470 @@ +#!/usr/bin/env python3 +""" +Add MTP (Multi-Token Prediction) weights to an existing MLX Qwen3.5 model. + +This script: +1. Fetches the safetensors index from the original BF16 HuggingFace model +2. Identifies shards containing MTP weights (mtp.* keys) +3. Downloads only those shards via curl -C - +4. Extracts MTP weights +5. For MoE models: stacks expert weights (256×) into switch_mlp format +6. Applies norm shift (HF weight → MLX weight+1.0) for RMSNorm keys +7. Quantizes to match the MLX model's quantization scheme +8. Saves as mtp/weights.safetensors (subdirectory avoids mlx_vlm glob) + +Supports both: +- MoE models (Qwen3.5-122B-A10B, 35B-A3B): 256 experts, sparse MTP attention +- Dense models (Qwen3.5-27B): full MTP with k/v projections and norms + +Usage: + python add_mtp_weights_qwen35.py --mlx-model-path PATH --source-model MODEL + +Requirements: + pip install mlx +""" + +import argparse +import json +import subprocess +import sys +import tempfile +from pathlib import Path + +# Known model configurations +MODEL_CONFIGS = { + "Qwen/Qwen3.5-122B-A10B": { + "num_experts": 256, + "hidden_size": 3072, + "is_moe": True, + }, + "Qwen/Qwen3.5-35B-A3B": { + "num_experts": 256, + "hidden_size": 2048, + "is_moe": True, + }, + "Qwen/Qwen3.5-27B": { + "num_experts": 0, + "hidden_size": 5120, + "is_moe": False, + }, +} + + +def find_snapshot_dir(model_path: str) -> Path: + """Find the latest snapshot directory in HF cache structure.""" + snapshots_dir = Path(model_path) / "snapshots" + if not snapshots_dir.exists(): + if (Path(model_path) / "config.json").exists(): + return Path(model_path) + raise FileNotFoundError(f"No snapshots found in {model_path}") + snapshots = sorted(snapshots_dir.iterdir(), key=lambda p: p.stat().st_mtime) + if not snapshots: + raise FileNotFoundError(f"No snapshots in {snapshots_dir}") + return snapshots[-1] + + +def fetch_shard_index(source_model: str, download_dir: Path) -> dict: + """Fetch model.safetensors.index.json from HuggingFace.""" + index_url = f"https://huggingface.co/{source_model}/resolve/main/model.safetensors.index.json" + index_path = download_dir / "source_index.json" + + print(f"Fetching shard index from {source_model}...") + result = subprocess.run( + ["curl", "-L", "-C", "-", "-o", str(index_path), index_url], + check=False, + ) + if result.returncode != 0: + raise RuntimeError(f"Failed to fetch index: return code {result.returncode}") + + with open(index_path) as f: + return json.load(f) + + +def identify_mtp_shards(index: dict) -> tuple[dict[str, str], set[str]]: + """Identify which shards contain MTP weights. + + Returns: + Tuple of (mtp_key_to_shard mapping, set of shard filenames to download) + """ + weight_map = index.get("weight_map", {}) + mtp_keys = {} + shards_needed = set() + + for key, shard in weight_map.items(): + if key.startswith("mtp."): + mtp_keys[key] = shard + shards_needed.add(shard) + + return mtp_keys, shards_needed + + +def download_shards( + shards: set[str], source_model: str, download_dir: Path +) -> dict[str, Path]: + """Download required shards using curl with resume support.""" + shard_paths = {} + for shard_name in sorted(shards): + shard_url = f"https://huggingface.co/{source_model}/resolve/main/{shard_name}" + shard_path = download_dir / shard_name + + if shard_path.exists(): + size_gb = shard_path.stat().st_size / 1e9 + print(f" {shard_name}: exists ({size_gb:.2f} GB)") + shard_paths[shard_name] = shard_path + continue + + print(f" Downloading {shard_name}...") + result = subprocess.run( + ["curl", "-L", "-C", "-", "-o", str(shard_path), shard_url], + check=False, + ) + if result.returncode != 0: + raise RuntimeError( + f"Download failed for {shard_name}: code {result.returncode}" + ) + + size_gb = shard_path.stat().st_size / 1e9 + print(f" {shard_name}: {size_gb:.2f} GB") + shard_paths[shard_name] = shard_path + + return shard_paths + + +def extract_and_quantize_mtp_weights( + mtp_keys: dict[str, str], + shard_paths: dict[str, Path], + snapshot_dir: Path, + is_moe: bool, + num_experts: int, + no_quantize: bool = False, +): + """Extract MTP weights from BF16 shards, optionally quantize, and save.""" + import mlx.core as mx + + mx.set_default_device(mx.cpu) + + # Read MLX model's quantization config + config_path = snapshot_dir / "config.json" + with open(config_path) as f: + config = json.load(f) + + text_config = config.get("text_config", config) + quant_config = text_config.get("quantization", config.get("quantization", {})) + bits = quant_config.get("bits", 4) + group_size = quant_config.get("group_size", 64) + if no_quantize: + print("MTP weights will be saved in BF16 (no quantization)") + else: + print(f"Target quantization: {bits}-bit, group_size={group_size}") + + # Group MTP keys by shard for efficient I/O + shard_to_keys: dict[str, list[str]] = {} + for key, shard in mtp_keys.items(): + shard_to_keys.setdefault(shard, []).append(key) + + # Load all MTP weights + print(f"\nExtracting MTP weights from {len(shard_paths)} shards...") + all_mtp_weights: dict[str, mx.array] = {} + + for shard_name, keys in sorted(shard_to_keys.items()): + shard_path = shard_paths[shard_name] + print(f" Loading {shard_name} ({len(keys)} MTP keys)...") + shard_data = mx.load(str(shard_path)) + for key in keys: + if key in shard_data: + all_mtp_weights[key] = shard_data[key] + del shard_data + + print(f"Loaded {len(all_mtp_weights)} MTP weight tensors") + + # Norm keys that need +1.0 shift (HF centered ~0 → MLX centered ~1) + norm_suffixes = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + ".q_norm.weight", + ".k_norm.weight", + ".pre_fc_norm_hidden.weight", + ".pre_fc_norm_embedding.weight", + "mtp.norm.weight", + ) + + # Keys to keep as FP (not quantize) + skip_quantize_suffixes = ( + ".input_layernorm.weight", + ".post_attention_layernorm.weight", + ".q_norm.weight", + ".k_norm.weight", + "mtp.fc.weight", + "mtp.norm.weight", + "mtp.pre_fc_norm_hidden.weight", + "mtp.pre_fc_norm_embedding.weight", + ".shared_expert_gate.weight", + ) + + def _quantize_one(key: str, weight: mx.array) -> dict[str, mx.array]: + """Quantize a single weight, apply norm adjustment.""" + # Norm shift: +1.0 for RMSNorm weights + if any(key.endswith(s) for s in norm_suffixes) and weight.ndim == 1: + weight = weight + 1.0 + mx.eval(weight) + print(f" Norm shift: {key}") + + if no_quantize: + print(f" BF16: {key} {weight.shape}") + return {key: weight} + elif any(key.endswith(s) for s in skip_quantize_suffixes): + print(f" Keep FP: {key} {weight.shape}") + return {key: weight} + elif weight.ndim >= 2 and weight.shape[-1] >= group_size: + q_w, q_s, q_b = mx.quantize(weight, group_size=group_size, bits=bits) + mx.eval(q_w, q_s, q_b) + print(f" Quantize {bits}-bit: {key} {q_w.shape}") + return { + key: q_w, + key.replace(".weight", ".scales"): q_s, + key.replace(".weight", ".biases"): q_b, + } + else: + print(f" Keep FP (small): {key} {weight.shape}") + return {key: weight} + + quantized_weights: dict[str, mx.array] = {} + + if is_moe and num_experts > 0: + # Stack expert weights ONE PROJECTION AT A TIME to minimize peak memory + for proj in ["up_proj", "down_proj", "gate_proj"]: + expert_keys = [ + f"mtp.layers.0.mlp.experts.{e}.{proj}.weight" + for e in range(num_experts) + ] + if all(k in all_mtp_weights for k in expert_keys): + stacked = mx.stack([all_mtp_weights.pop(k) for k in expert_keys]) + mx.eval(stacked) + stacked_key = f"mtp.layers.0.mlp.switch_mlp.{proj}.weight" + print(f" Stacked {num_experts} experts for {proj}: {stacked.shape}") + quantized_weights.update(_quantize_one(stacked_key, stacked)) + del stacked + else: + present = sum(1 for k in expert_keys if k in all_mtp_weights) + if present > 0: + print(f" WARNING: Only {present}/{num_experts} experts for {proj}") + + # Quantize remaining non-expert weights + for key in sorted(all_mtp_weights.keys()): + weight = all_mtp_weights.pop(key) + quantized_weights.update(_quantize_one(key, weight)) + del weight + del all_mtp_weights + + # Save to mtp/ subdirectory (avoids mlx_vlm glob loading all *.safetensors) + mtp_output_dir = snapshot_dir / "mtp" + mtp_output_dir.mkdir(exist_ok=True) + mtp_output_file = mtp_output_dir / "weights.safetensors" + mode_str = "BF16" if no_quantize else "quantized" + print( + f"\nSaving {len(quantized_weights)} {mode_str} MTP weights to {mtp_output_file}" + ) + mx.save_safetensors(str(mtp_output_file), quantized_weights) + + total_bytes = sum(v.nbytes for v in quantized_weights.values()) + print(f"MTP weights size: {total_bytes / 1e6:.1f} MB ({mode_str})") + + return mtp_output_file, list(quantized_weights.keys()) + + +def update_model_index(snapshot_dir: Path, mtp_keys: list[str]): + """Update model.safetensors.index.json to include MTP weight keys.""" + index_path = snapshot_dir / "model.safetensors.index.json" + if not index_path.exists(): + print(f"WARNING: No index file found at {index_path}, skipping index update") + return + + with open(index_path) as f: + index = json.load(f) + + weight_map = index.get("weight_map", {}) + for key in mtp_keys: + weight_map[key] = "model-mtp.safetensors" + + index["weight_map"] = weight_map + + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + + print(f"Updated {index_path} with {len(mtp_keys)} MTP weight entries") + + +def update_config(snapshot_dir: Path): + """Update config.json to signal MTP availability. + + For Qwen3.5, mtp_num_hidden_layers already exists in text_config. + We add num_nextn_predict_layers at top level for vllm-mlx compatibility. + """ + config_path = snapshot_dir / "config.json" + with open(config_path) as f: + config = json.load(f) + + text_config = config.get("text_config", config) + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + + if num_mtp > 0: + # Set num_nextn_predict_layers for vllm-mlx MTP detection + config["num_nextn_predict_layers"] = num_mtp + text_config["num_nextn_predict_layers"] = num_mtp + if "text_config" in config: + config["text_config"] = text_config + + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"Updated config: num_nextn_predict_layers={num_mtp}") + else: + print("WARNING: mtp_num_hidden_layers not found in config") + + +def main(): + parser = argparse.ArgumentParser(description="Add MTP weights to MLX Qwen3.5 model") + parser.add_argument( + "--mlx-model-path", + type=str, + required=True, + help="Path to MLX model directory (HF cache or direct path)", + ) + parser.add_argument( + "--source-model", + type=str, + required=True, + help="HuggingFace BF16 model to download MTP shards from (e.g., Qwen/Qwen3.5-122B-A10B)", + ) + parser.add_argument( + "--download-dir", + type=str, + default=None, + help="Directory to download shards to (default: temp dir)", + ) + parser.add_argument( + "--skip-download", + action="store_true", + help="Skip download (use existing shards in download-dir)", + ) + parser.add_argument( + "--keep-shards", + action="store_true", + help="Don't delete downloaded BF16 shards after extraction", + ) + parser.add_argument( + "--no-quantize", + action="store_true", + help="Save MTP weights in BF16 (no quantization). Required for correct MTP predictions.", + ) + args = parser.parse_args() + + print("=" * 60) + print("MTP Weight Addition for Qwen3.5 MLX Model") + print("=" * 60) + + # Find snapshot directory + snapshot_dir = find_snapshot_dir(args.mlx_model_path) + print(f"\nMLX model snapshot: {snapshot_dir}") + + # Read config + config_path = snapshot_dir / "config.json" + if not config_path.exists(): + print(f"ERROR: No config.json found in {snapshot_dir}") + sys.exit(1) + + with open(config_path) as f: + config = json.load(f) + + text_config = config.get("text_config", config) + model_type = text_config.get("model_type", config.get("model_type", "unknown")) + hidden_size = text_config.get("hidden_size", "?") + num_experts = text_config.get("num_experts", 0) + is_moe = num_experts > 0 + mtp_layers = text_config.get("mtp_num_hidden_layers", 0) + + print(f"Model type: {model_type}") + print(f"Hidden size: {hidden_size}") + print(f"Num experts: {num_experts} ({'MoE' if is_moe else 'Dense'})") + print(f"MTP layers: {mtp_layers}") + + if mtp_layers == 0: + print("ERROR: Model has no MTP layers configured (mtp_num_hidden_layers=0)") + sys.exit(1) + + # Check if MTP weights already exist + mtp_file = snapshot_dir / "mtp" / "weights.safetensors" + if mtp_file.exists(): + size_mb = mtp_file.stat().st_size / 1e6 + print(f"\nWARNING: mtp/weights.safetensors already exists ({size_mb:.1f} MB)") + print("Delete it first if you want to regenerate.") + sys.exit(0) + + # Setup download directory + if args.download_dir: + download_dir = Path(args.download_dir) + download_dir.mkdir(parents=True, exist_ok=True) + else: + download_dir = Path(tempfile.mkdtemp(prefix="qwen35_mtp_")) + print(f"\nDownload dir: {download_dir}") + + # Fetch shard index and identify MTP shards + source_index = fetch_shard_index(args.source_model, download_dir) + mtp_key_map, shards_needed = identify_mtp_shards(source_index) + + print( + f"\nFound {len(mtp_key_map)} MTP weight keys across {len(shards_needed)} shards:" + ) + for shard in sorted(shards_needed): + count = sum(1 for v in mtp_key_map.values() if v == shard) + print(f" {shard}: {count} keys") + + # Download shards + if not args.skip_download: + print(f"\nDownloading {len(shards_needed)} shards...") + shard_paths = download_shards(shards_needed, args.source_model, download_dir) + else: + shard_paths = {} + for shard_name in shards_needed: + p = download_dir / shard_name + if p.exists(): + shard_paths[shard_name] = p + else: + print(f"ERROR: Shard not found: {p}") + sys.exit(1) + + # Extract, optionally quantize, and save MTP weights + mtp_file, mtp_weight_keys = extract_and_quantize_mtp_weights( + mtp_key_map, + shard_paths, + snapshot_dir, + is_moe, + num_experts, + no_quantize=args.no_quantize, + ) + + # NOTE: Do NOT update model.safetensors.index.json — mlx_vlm's glob + # would try to load MTP weights and fail with strict loading. + # MTP weights are loaded separately by inject_mtp_support(). + + # Update config + update_config(snapshot_dir) + + # Cleanup downloaded shards + if not args.keep_shards and not args.skip_download: + print("\nCleaning up downloaded shards...") + for shard_path in shard_paths.values(): + shard_path.unlink(missing_ok=True) + print(f" Deleted {shard_path.name}") + + print("\n" + "=" * 60) + print("SUCCESS! MTP weights added to MLX model.") + print("=" * 60) + print(f"\nMTP weight file: {mtp_file}") + print(f"Total MTP keys: {len(mtp_weight_keys)}") + print("\nTo use MTP, start the server with --enable-mtp:") + print(f" vllm-mlx serve {args.mlx_model_path} --enable-mtp") + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index d0c7f026b..f699c08bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,3 +50,9 @@ def pytest_collection_modifyitems(config, items): def server_url(request): """Get server URL from command line.""" return request.config.getoption("--server-url") + + +@pytest.fixture(params=["asyncio"]) +def anyio_backend(request): + """Run anyio-marked tests on asyncio only (trio is not installed).""" + return request.param diff --git a/tests/test_batched_engine.py b/tests/test_batched_engine.py new file mode 100644 index 000000000..73a7e8ffa --- /dev/null +++ b/tests/test_batched_engine.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for BatchedEngine generate() output.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +class TestBatchedEngineGenerate: + """Test BatchedEngine.generate() output fields.""" + + def _make_engine(self): + """Create a BatchedEngine instance with loading bypassed.""" + from vllm_mlx.engine.batched import BatchedEngine + + with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False): + engine = BatchedEngine("test-model") + + engine._loaded = True + engine._is_mllm = False + return engine + + def _make_mock_request_output( + self, + output_text="Paris", + output_token_ids=None, + prompt_tokens=10, + completion_tokens=3, + finish_reason="stop", + ): + """Build a mock RequestOutput (as returned by AsyncEngineCore).""" + mock = MagicMock() + mock.output_text = output_text + mock.output_token_ids = ( + output_token_ids if output_token_ids is not None else [3681, 374, 279] + ) + mock.prompt_tokens = prompt_tokens + mock.completion_tokens = completion_tokens + mock.finish_reason = finish_reason + return mock + + @pytest.mark.asyncio + async def test_tokens_field_is_populated(self): + """tokens should contain the output token IDs from AsyncEngineCore.""" + engine = self._make_engine() + token_ids = [3681, 374, 279] + mock_output = self._make_mock_request_output(output_token_ids=token_ids) + + mock_engine = MagicMock() + mock_engine.generate = AsyncMock(return_value=mock_output) + engine._engine = mock_engine + + result = await engine.generate( + prompt="What is the capital of France?", max_tokens=10 + ) + + assert result.tokens == token_ids + + @pytest.mark.asyncio + async def test_tokens_field_empty_when_no_tokens_generated(self): + """tokens should be an empty list when output_token_ids is empty.""" + engine = self._make_engine() + mock_output = self._make_mock_request_output(output_token_ids=[]) + + mock_engine = MagicMock() + mock_engine.generate = AsyncMock(return_value=mock_output) + engine._engine = mock_engine + + result = await engine.generate(prompt="test", max_tokens=10) + + assert result.tokens == [] + + @pytest.mark.asyncio + async def test_other_output_fields_still_populated(self): + """Existing fields (text, prompt_tokens, etc.) must remain correct.""" + engine = self._make_engine() + mock_output = self._make_mock_request_output( + output_text="Paris", + output_token_ids=[3681], + prompt_tokens=7, + completion_tokens=1, + finish_reason="stop", + ) + + mock_engine = MagicMock() + mock_engine.generate = AsyncMock(return_value=mock_output) + engine._engine = mock_engine + + result = await engine.generate(prompt="Capital of France?", max_tokens=5) + + assert result.text == "Paris" + assert result.prompt_tokens == 7 + assert result.completion_tokens == 1 + assert result.finish_reason == "stop" diff --git a/tests/test_batching.py b/tests/test_batching.py index 7dc050ee3..6cb536aa5 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -7,8 +7,10 @@ """ import asyncio +import importlib import pytest from unittest.mock import MagicMock +import mlx.core as mx from vllm_mlx.request import ( Request, @@ -20,8 +22,11 @@ Scheduler, SchedulerConfig, SchedulingPolicy, + _install_chunked_prefill, ) +mlx_generate = importlib.import_module("mlx_lm.generate") + class TestRequest: """Tests for Request class.""" @@ -211,6 +216,359 @@ def mock_model(self): """Create a mock model.""" return MagicMock() + def test_chunked_prefill_accepts_prompt_checkpoints(self, monkeypatch): + """Chunked prefill must match mlx-lm's 7-field prompt tuples.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + + class FakeBatchGenerator: + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3, 4, 5], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = set() + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = None + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + self.model = lambda _inputs, cache=None: None + + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + _install_chunked_prefill(batch_gen, budget=4) + + responses = batch_gen._next() + + assert responses == [] + assert batch_gen._partial is not None + assert batch_gen._partial["prompt_checkpoint"] == 3 + assert batch_gen._partial["processed"] == 2 + + def test_chunked_prefill_invokes_checkpoint_callback(self, monkeypatch): + """prompt_checkpoint_callback must fire after finalization.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + def extract(self, idx): + return self + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + generation_tokens = 0 + + callback_payloads = [] + + from collections import namedtuple + + _Response = namedtuple( + "Response", ["uid", "token", "logprobs", "finish_reason", "cache"] + ) + + class FakeBatchGenerator: + Response = _Response + + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = set() + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = ( + lambda entries: callback_payloads.extend(entries) + ) + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + self.model = lambda _inputs, cache=None: None + + def _step(self, inputs, cache, samplers, logits_processors, tokens): + return mx.array([99]), mx.array([-1.0]) + + def _generation_step(self): + if self.active_batch is not None: + self.active_batch = None + return [] + + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + batch_gen.stop_tokens = {99} + _install_chunked_prefill(batch_gen, budget=1) + + # First _next: starts partial prefill (processes 1 token) + batch_gen._next() + assert batch_gen._partial is not None + + # Second _next: finishes prefill, fires checkpoint callback, + # then runs generation step which completes (stop token). + batch_gen._next() + + assert len(callback_payloads) == 1 + uid, checkpoint, _cache_gen = callback_payloads[0] + assert uid == 7 + assert checkpoint == 1 + + def test_chunked_prefill_replays_checkpoint_tail_before_step(self, monkeypatch): + """checkpoint tails >1 must be replayed after finalize before _step.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + def extract(self, idx): + return self + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + generation_tokens = 0 + + callback_payloads = [] + model_calls = [] + step_inputs = [] + + from collections import namedtuple + + _Response = namedtuple( + "Response", ["uid", "token", "logprobs", "finish_reason", "cache"] + ) + + class FakeBatchGenerator: + Response = _Response + + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3, 4, 5], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = {99} + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = ( + lambda entries: callback_payloads.extend(entries) + ) + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + + def model(self, inputs, cache=None): + model_calls.append(inputs.tolist()) + + def _step(self, inputs, cache, samplers, logits_processors, tokens): + step_inputs.append(inputs.tolist()) + return mx.array([99]), mx.array([-1.0]) + + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + _install_chunked_prefill(batch_gen, budget=2) + + # First _next: process the first chunk and leave a 3-token checkpoint tail. + batch_gen._next() + assert batch_gen._partial is not None + assert batch_gen._partial["prompt_checkpoint"] == 3 + + # Second _next: finalize, fire callback, replay the checkpoint tail, step. + batch_gen._next() + + assert model_calls == [[[1, 2]], [[3, 4]]] + assert step_inputs[0] == [[3, 4, 5]] + assert len(callback_payloads) == 1 + uid, checkpoint, _cache_gen = callback_payloads[0] + assert uid == 7 + assert checkpoint == 3 + + def test_chunked_prefill_works_without_private_mlx_generate_exports( + self, monkeypatch + ): + """Chunked prefill should tolerate missing private mlx_lm.generate exports.""" + + class FakeCacheEntry: + def empty(self): + return True + + class FakePromptCache: + def __init__(self): + self.state = mx.array([0]) + + def finalize(self): + return None + + def extract(self, idx): + return self + + class FakeStats: + prompt_tokens = 0 + prompt_time = 0.0 + generation_time = 0.0 + generation_tokens = 0 + + from collections import namedtuple + + _Response = namedtuple( + "Response", ["uid", "token", "logprobs", "finish_reason", "cache"] + ) + + class FakeBatchGenerator: + Response = _Response + + def __init__(self): + self._stats = FakeStats() + self._partial = None + self.active_batch = None + self.unprocessed_prompts = [ + ( + 7, + [1, 2, 3], + 16, + [FakeCacheEntry()], + None, + [None], + 2, + ) + ] + self.prefill_batch_size = 1 + self.completion_batch_size = 1 + self.max_kv_size = None + self.stop_tokens = {99} + self.prompt_progress_callback = lambda _progress: None + self.prompt_checkpoint_callback = None + self._next = lambda: [] + self.remove = lambda _uids: None + self._process_prompts = lambda _prompts: None + self.model = lambda _inputs, cache=None: None + + def _step(self, inputs, cache, samplers, logits_processors, tokens): + return mx.array([99]), mx.array([-1.0]) + + def _generation_step(self): + if self.active_batch is not None: + self.active_batch = None + return [] + + monkeypatch.delattr(mlx_generate, "Batch", raising=False) + monkeypatch.delattr(mlx_generate, "_lazy_extract_cache", raising=False) + monkeypatch.setattr( + mlx_generate, + "_left_pad_prompts", + lambda prompts, max_length=None: mx.array(prompts), + ) + monkeypatch.setattr( + mlx_generate, + "_make_cache", + lambda _model, _padding, _max_kv_size=None: [FakePromptCache()], + ) + + batch_gen = FakeBatchGenerator() + _install_chunked_prefill(batch_gen, budget=1) + + batch_gen._next() + assert batch_gen._partial is not None + batch_gen._next() + assert batch_gen.active_batch is None + def test_scheduler_creation(self, mock_model, mock_tokenizer): """Test scheduler creation.""" scheduler = Scheduler( @@ -432,7 +790,7 @@ def test_multiple_concurrent_requests(self, model_and_tokenizer): assert len(finished) == len(prompts), f"Only {len(finished)} requests finished" -@pytest.mark.asyncio +@pytest.mark.anyio class TestEngineAsync: """Async tests for the engine.""" diff --git a/tests/test_batching_deterministic.py b/tests/test_batching_deterministic.py index 52b0fd49b..0e6072ce9 100644 --- a/tests/test_batching_deterministic.py +++ b/tests/test_batching_deterministic.py @@ -37,7 +37,7 @@ def sampling_params(): class TestDeterministicSingleRequest: """Test single request determinism.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_same_prompt_same_output(self, model_and_tokenizer, sampling_params): """Same prompt should produce same output with temp=0.""" from vllm_mlx import AsyncEngineCore, EngineConfig, SchedulerConfig @@ -68,7 +68,7 @@ async def test_same_prompt_same_output(self, model_and_tokenizer, sampling_param assert len(outputs) == 3 assert outputs[0] == outputs[1] == outputs[2], f"Outputs differ: {outputs}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_token_streaming_order(self, model_and_tokenizer, sampling_params): """Tokens should stream in order.""" from vllm_mlx import AsyncEngineCore @@ -94,7 +94,7 @@ async def test_token_streaming_order(self, model_and_tokenizer, sampling_params) class TestDeterministicConcurrentRequests: """Test concurrent request handling with determinism.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_concurrent_same_prompt(self, model_and_tokenizer): """Multiple concurrent requests with same prompt should get same output.""" from vllm_mlx import ( @@ -137,7 +137,7 @@ async def get_output(rid): # All should be the same assert all(r == results[0] for r in results), f"Outputs differ: {results}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_concurrent_different_prompts(self, model_and_tokenizer): """Different prompts should get different (but deterministic) outputs.""" from vllm_mlx import ( @@ -191,7 +191,7 @@ async def get_output(rid): class TestBatchingPerformance: """Test that batching improves throughput.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_batched_faster_than_sequential(self, model_and_tokenizer): """Batched requests should be faster than sequential.""" from vllm_mlx import ( @@ -274,7 +274,7 @@ async def get_output(rid): class TestRequestManagement: """Test request lifecycle management.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_abort_request(self, model_and_tokenizer): """Test aborting a request mid-generation.""" from vllm_mlx import AsyncEngineCore, SamplingParams @@ -304,7 +304,7 @@ async def test_abort_request(self, model_and_tokenizer): stats = engine.get_stats() assert stats["active_requests"] == 0 - @pytest.mark.asyncio + @pytest.mark.anyio async def test_engine_stats(self, model_and_tokenizer): """Test engine statistics tracking.""" from vllm_mlx import ( @@ -343,7 +343,7 @@ async def test_engine_stats(self, model_and_tokenizer): class TestSchedulerPolicy: """Test scheduler policies.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_fcfs_ordering(self, model_and_tokenizer): """Test that FCFS policy processes requests in order.""" from vllm_mlx import ( @@ -396,7 +396,7 @@ async def track_completion(rid, name): class TestEdgeCases: """Test edge cases and error handling.""" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_empty_prompt(self, model_and_tokenizer): """Test handling of empty prompt.""" from vllm_mlx import AsyncEngineCore, SamplingParams @@ -414,7 +414,7 @@ async def test_empty_prompt(self, model_and_tokenizer): assert out.finished break - @pytest.mark.asyncio + @pytest.mark.anyio async def test_very_short_max_tokens(self, model_and_tokenizer): """Test with max_tokens=1.""" from vllm_mlx import AsyncEngineCore, SamplingParams @@ -436,7 +436,7 @@ async def test_very_short_max_tokens(self, model_and_tokenizer): # Should generate exactly 1 token assert token_count == 1 - @pytest.mark.asyncio + @pytest.mark.anyio async def test_multiple_start_stop(self, model_and_tokenizer): """Test starting and stopping engine multiple times.""" from vllm_mlx import AsyncEngineCore, SamplingParams diff --git a/tests/test_continuous_batching.py b/tests/test_continuous_batching.py index fd10fe808..0e196a226 100644 --- a/tests/test_continuous_batching.py +++ b/tests/test_continuous_batching.py @@ -53,7 +53,7 @@ def test_scheduler_config_batching_params(self): assert config.completion_batch_size == 32 -@pytest.mark.asyncio +@pytest.mark.anyio class TestContinuousBatchingIntegration: """Integration tests requiring actual model loading.""" diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 000000000..9eba711bb --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for resumable model download with retry/timeout support.""" + +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from vllm_mlx.utils.download import ( + LLM_ALLOW_PATTERNS, + MLLM_ALLOW_PATTERNS, + DownloadConfig, + ensure_model_downloaded, +) + + +class TestLocalPath: + """Tests for local path handling.""" + + def test_local_path_skips_download(self, tmp_path): + """Existing local directory is returned without downloading.""" + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + result = ensure_model_downloaded(str(tmp_path)) + mock_download.assert_not_called() + assert result == tmp_path + + +class TestRetryLogic: + """Tests for download retry behavior.""" + + def test_retry_on_failure(self): + """Failed downloads are retried up to max_retries times.""" + config = DownloadConfig(max_retries=3, retry_backoff_base=0.01) + fake_path = "/fake/cache/path" + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = [ + ConnectionError("timeout"), + ConnectionError("timeout"), + fake_path, + ] + result = ensure_model_downloaded("org/model", config=config) + assert result == Path(fake_path) + assert mock_download.call_count == 3 + + def test_retry_exhaustion(self): + """RuntimeError is raised after all retries are exhausted.""" + config = DownloadConfig(max_retries=2, retry_backoff_base=0.01) + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = ConnectionError("timeout") + with pytest.raises(RuntimeError, match="Failed to download"): + ensure_model_downloaded("org/model", config=config) + assert mock_download.call_count == 2 + + def test_keyboard_interrupt_not_retried(self): + """KeyboardInterrupt propagates immediately without retry.""" + config = DownloadConfig(max_retries=3, retry_backoff_base=0.01) + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = KeyboardInterrupt() + with pytest.raises(KeyboardInterrupt): + ensure_model_downloaded("org/model", config=config) + assert mock_download.call_count == 1 + + +class TestOfflineMode: + """Tests for offline mode behavior.""" + + def test_offline_mode_cached(self): + """Offline mode finds cached model successfully.""" + config = DownloadConfig(offline=True) + fake_path = "/fake/cache/path" + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.return_value = fake_path + result = ensure_model_downloaded("org/model", config=config) + assert result == Path(fake_path) + mock_download.assert_called_once_with("org/model", local_files_only=True) + + def test_offline_mode_missing(self): + """Offline mode raises clear error when model is not cached.""" + config = DownloadConfig(offline=True) + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = Exception("not found locally") + with pytest.raises(RuntimeError, match="not found in local cache"): + ensure_model_downloaded("org/model", config=config) + + +class TestTimeout: + """Tests for download timeout configuration.""" + + def test_hf_timeout_env_set(self): + """HF_HUB_DOWNLOAD_TIMEOUT env var is set during download.""" + config = DownloadConfig(download_timeout=600, max_retries=1) + fake_path = "/fake/cache/path" + captured_timeout = {} + + original_env = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") + + def capture_env(*args, **kwargs): + captured_timeout["value"] = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") + return fake_path + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = capture_env + ensure_model_downloaded("org/model", config=config) + + assert captured_timeout["value"] == "600" + # Env var should be restored after download + assert os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") == original_env + + def test_hf_timeout_env_restored_on_failure(self): + """HF_HUB_DOWNLOAD_TIMEOUT is restored even after failure.""" + config = DownloadConfig( + download_timeout=999, max_retries=1, retry_backoff_base=0.01 + ) + original_env = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.side_effect = ConnectionError("fail") + with pytest.raises(RuntimeError): + ensure_model_downloaded("org/model", config=config) + + assert os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") == original_env + + +class TestAllowPatterns: + """Tests for LLM vs MLLM download patterns.""" + + def test_llm_patterns_used_by_default(self): + """LLM allow patterns are used when is_mllm=False.""" + config = DownloadConfig(max_retries=1) + fake_path = "/fake/cache/path" + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.return_value = fake_path + ensure_model_downloaded("org/model", config=config, is_mllm=False) + mock_download.assert_called_once_with( + "org/model", allow_patterns=LLM_ALLOW_PATTERNS + ) + + def test_mllm_patterns_used(self): + """MLLM allow patterns are used when is_mllm=True.""" + config = DownloadConfig(max_retries=1) + fake_path = "/fake/cache/path" + + with patch("vllm_mlx.utils.download.snapshot_download") as mock_download: + mock_download.return_value = fake_path + ensure_model_downloaded("org/model", config=config, is_mllm=True) + mock_download.assert_called_once_with( + "org/model", allow_patterns=MLLM_ALLOW_PATTERNS + ) + + +class TestCLIDownloadCommand: + """Tests for CLI download subcommand argument parsing.""" + + def test_cli_download_command(self): + """Download subcommand parses arguments correctly.""" + import argparse + + # We test argparse by calling parse_args directly + # (main() would try to actually run the command) + with patch("sys.argv", ["vllm-mlx", "download", "org/model"]): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument("model") + download_parser.add_argument("--timeout", type=int, default=300) + download_parser.add_argument("--retries", type=int, default=3) + download_parser.add_argument("--mllm", action="store_true") + + args = parser.parse_args(["download", "org/model", "--timeout", "600"]) + assert args.command == "download" + assert args.model == "org/model" + assert args.timeout == 600 + assert args.retries == 3 + assert args.mllm is False + + def test_cli_download_mllm_flag(self): + """Download subcommand parses --mllm flag.""" + import argparse + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="command") + download_parser = subparsers.add_parser("download") + download_parser.add_argument("model") + download_parser.add_argument("--timeout", type=int, default=300) + download_parser.add_argument("--retries", type=int, default=3) + download_parser.add_argument("--mllm", action="store_true") + + args = parser.parse_args(["download", "org/vl-model", "--mllm"]) + assert args.mllm is True diff --git a/tests/test_gemma4_openai_format.py b/tests/test_gemma4_openai_format.py new file mode 100644 index 000000000..f680c911b --- /dev/null +++ b/tests/test_gemma4_openai_format.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Integration test: verify Gemma 4 tool calls produce valid OpenAI API responses. + +Claude Code (and other OpenAI-compatible clients) expect: +- response.choices[0].message.tool_calls[0].type == "function" +- response.choices[0].message.tool_calls[0].function.name == "read_file" +- response.choices[0].message.tool_calls[0].function.arguments == '{"path":"/tmp/test.py"}' +- response.choices[0].message.content is None (not empty string) +- response.choices[0].finish_reason == "tool_calls" + +This test verifies the FULL pipeline from parser output → server wrapping → JSON response, +not just the parser in isolation. +""" + +import json + +from vllm_mlx.api.models import ( + AssistantMessage, + ChatCompletionChoice, + ChatCompletionResponse, + FunctionCall, + ToolCall, + Usage, +) +from vllm_mlx.tool_parsers.gemma4_tool_parser import Gemma4ToolParser + + +def _build_response_from_parser(parser_output, model_name="gemma-4-27b-it"): + """Simulate what server.py does at lines 1494-1511 to build the HTTP response.""" + if parser_output.tools_called: + tool_calls = [ + ToolCall( + id=tc.get("id", "call_test"), + type="function", + function=FunctionCall( + name=tc["name"], + arguments=tc["arguments"], + ), + ) + for tc in parser_output.tool_calls + ] + content = parser_output.content if parser_output.content else None + finish_reason = "tool_calls" + else: + tool_calls = None + content = parser_output.content + finish_reason = "stop" + + return ChatCompletionResponse( + model=model_name, + choices=[ + ChatCompletionChoice( + message=AssistantMessage( + content=content, + tool_calls=tool_calls, + ), + finish_reason=finish_reason, + ) + ], + usage=Usage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ) + + +class TestGemma4OpenAIFormat: + """Verify the full response matches what Claude Code expects.""" + + def setup_method(self): + self.parser = Gemma4ToolParser() + + def test_tool_call_response_has_correct_structure(self): + """The JSON response must have the exact OpenAI structure.""" + output = '<|tool_call>call:read_file{path:<|"|>/tmp/test.py<|"|>}' + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + + # Serialize to JSON (this is what goes over the wire) + data = json.loads(response.model_dump_json(exclude_none=True)) + + choice = data["choices"][0] + msg = choice["message"] + + # finish_reason must be "tool_calls" not "stop" + assert choice["finish_reason"] == "tool_calls" + + # content must be absent or null when tool_calls present + assert msg.get("content") is None + + # tool_calls must be a list + assert isinstance(msg["tool_calls"], list) + assert len(msg["tool_calls"]) == 1 + + tc = msg["tool_calls"][0] + + # type must be "function" + assert tc["type"] == "function" + + # id must be present and non-empty + assert tc["id"] + assert isinstance(tc["id"], str) + + # function.name must be the function name + assert tc["function"]["name"] == "read_file" + + # function.arguments must be a JSON string (not a dict!) + assert isinstance(tc["function"]["arguments"], str) + args = json.loads(tc["function"]["arguments"]) + assert args == {"path": "/tmp/test.py"} + + def test_multiple_tool_calls_response(self): + """Multiple tool calls in one response.""" + output = ( + "<|tool_call>" + 'call:read_file{path:<|"|>/a.py<|"|>}' + 'call:read_file{path:<|"|>/b.py<|"|>}' + "" + ) + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + data = json.loads(response.model_dump_json(exclude_none=True)) + + tcs = data["choices"][0]["message"]["tool_calls"] + assert len(tcs) == 2 + assert tcs[0]["function"]["name"] == "read_file" + assert tcs[1]["function"]["name"] == "read_file" + # Each must have a unique id + assert tcs[0]["id"] != tcs[1]["id"] + + def test_content_before_tool_call_preserved(self): + """Text before the tool call goes in content field.""" + output = 'Let me check that.\n<|tool_call>call:read_file{path:<|"|>/tmp/x<|"|>}' + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + data = json.loads(response.model_dump_json(exclude_none=True)) + + msg = data["choices"][0]["message"] + assert msg["content"] == "Let me check that." + assert len(msg["tool_calls"]) == 1 + + def test_no_tool_call_response(self): + """Plain text response has no tool_calls field.""" + output = "The answer is 42." + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + data = json.loads(response.model_dump_json(exclude_none=True)) + + msg = data["choices"][0]["message"] + assert msg["content"] == "The answer is 42." + assert "tool_calls" not in msg # excluded by exclude_none + assert data["choices"][0]["finish_reason"] == "stop" + + def test_complex_arguments_serialize_correctly(self): + """Nested objects and arrays must survive JSON round-trip.""" + output = '<|tool_call>call:configure{settings:{enabled:true,tags:[<|"|>a<|"|>,<|"|>b<|"|>]}}' + result = self.parser.extract_tool_calls(output) + response = _build_response_from_parser(result) + data = json.loads(response.model_dump_json(exclude_none=True)) + + tc = data["choices"][0]["message"]["tool_calls"][0] + args = json.loads(tc["function"]["arguments"]) + assert args == {"settings": {"enabled": True, "tags": ["a", "b"]}} diff --git a/tests/test_gemma4_tool_parser.py b/tests/test_gemma4_tool_parser.py new file mode 100644 index 000000000..179915442 --- /dev/null +++ b/tests/test_gemma4_tool_parser.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for Gemma 4 tool call parser.""" + +import json + +from vllm_mlx.tool_parsers.gemma4_tool_parser import Gemma4ToolParser + + +class TestGemma4ToolParserExtract: + """Test extract_tool_calls on complete model output.""" + + def setup_method(self): + self.parser = Gemma4ToolParser() + + def test_single_tool_call_string_arg(self): + output = '<|tool_call>call:read_file{path:<|"|>/tmp/foo.py<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc["name"] == "read_file" + args = json.loads(tc["arguments"]) + assert args == {"path": "/tmp/foo.py"} + assert result.content is None + + def test_single_tool_call_numeric_arg(self): + output = "<|tool_call>call:search{limit:10,verbose:false}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"limit": 10, "verbose": False} + + def test_mixed_types(self): + output = '<|tool_call>call:search{query:<|"|>hello world<|"|>,limit:10,verbose:false}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"query": "hello world", "limit": 10, "verbose": False} + + def test_nested_object(self): + output = '<|tool_call>call:configure{settings:{enabled:true,name:<|"|>test<|"|>}}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"settings": {"enabled": True, "name": "test"}} + + def test_array_argument(self): + output = '<|tool_call>call:tag{items:[<|"|>foo<|"|>,<|"|>bar<|"|>]}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"items": ["foo", "bar"]} + + def test_multiple_tool_calls_in_one_block(self): + output = ( + "<|tool_call>" + 'call:glob{pattern:<|"|>README*.md<|"|>}' + 'call:glob{pattern:<|"|>CONTRIBUTING.md<|"|>}' + "" + ) + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 2 + args0 = json.loads(result.tool_calls[0]["arguments"]) + args1 = json.loads(result.tool_calls[1]["arguments"]) + assert args0 == {"pattern": "README*.md"} + assert args1 == {"pattern": "CONTRIBUTING.md"} + + def test_content_before_tool_call(self): + output = 'Let me read that file for you.\n<|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert result.content == "Let me read that file for you." + assert len(result.tool_calls) == 1 + + def test_no_tool_calls(self): + output = "Hello, how can I help you today?" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is False + assert result.tool_calls == [] + assert result.content == output + + def test_empty_tool_call_block(self): + output = "<|tool_call>" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is False + assert result.tool_calls == [] + + def test_tool_call_id_generated(self): + output = '<|tool_call>call:read_file{path:<|"|>/tmp/a<|"|>}' + result = self.parser.extract_tool_calls(output) + tc = result.tool_calls[0] + assert "id" in tc + assert tc["id"].startswith("call_") + + def test_string_with_special_chars(self): + output = '<|tool_call>call:write{content:<|"|>line1\\nline2<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["content"] == "line1\\nline2" + + def test_deeply_nested_objects(self): + output = "<|tool_call>call:update{a:{b:{c:1,d:true}}}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"a": {"b": {"c": 1, "d": True}}} + + def test_null_value(self): + output = "<|tool_call>call:clear{target:null}" + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"target": None} + + def test_unicode_emoji_in_args(self): + output = '<|tool_call>call:search{query:<|"|>hello world \U0001f30d \u4f60\u597d<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"query": "hello world \U0001f30d \u4f60\u597d"} + + def test_braces_inside_string_value(self): + output = '<|tool_call>call:run{code:<|"|>if (x) { return y; }<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"code": "if (x) { return y; }"} + + def test_quoted_keys(self): + output = '<|tool_call>call:read{<|"|>path<|"|>:<|"|>/tmp/foo<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"path": "/tmp/foo"} + + def test_think_tags_stripped(self): + output = 'Let me think about this...<|tool_call>call:search{query:<|"|>test<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + + def test_missing_end_delimiter(self): + """Unclosed tool call block still parses (server fallback path).""" + output = '<|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + assert len(result.tool_calls) == 1 + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"path": "/tmp/foo"} + + def test_string_with_colon(self): + """String containing colon pattern must not be corrupted by bare-key quoting.""" + output = '<|tool_call>call:connect{url:<|"|>host:8080<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"url": "host:8080"} + + def test_string_with_newline_and_quote(self): + """Real newline and double quote inside string values are JSON-escaped.""" + output = '<|tool_call>call:write{text:<|"|>line1\nline2 said "hello"<|"|>}' + result = self.parser.extract_tool_calls(output) + assert result.tools_called is True + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"text": 'line1\nline2 said "hello"'} + + +class TestGemma4ToolParserStreaming: + """Test streaming tool call extraction.""" + + def setup_method(self): + self.parser = Gemma4ToolParser() + self.parser.reset() + + def test_streaming_no_tool_call(self): + """Normal text passes through as content.""" + result = self.parser.extract_tool_calls_streaming( + previous_text="", + current_text="Hello", + delta_text="Hello", + ) + assert result == {"content": "Hello"} + + def test_streaming_suppresses_during_tool_call(self): + """Returns None while inside tool call block (buffering).""" + r1 = self.parser.extract_tool_calls_streaming( + previous_text="", + current_text="Sure. ", + delta_text="Sure. ", + ) + assert r1 == {"content": "Sure. "} + + r2 = self.parser.extract_tool_calls_streaming( + previous_text="Sure. ", + current_text="Sure. <|tool_call>call:read", + delta_text="<|tool_call>call:read", + ) + assert r2 is None + + r3 = self.parser.extract_tool_calls_streaming( + previous_text="Sure. <|tool_call>call:read", + current_text='Sure. <|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}', + delta_text='_file{path:<|"|>/tmp/foo<|"|>}', + ) + assert r3 is None + + def test_streaming_emits_on_close(self): + """Emits structured tool_calls when end delimiter arrives.""" + full_text = ( + 'Sure. <|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}' + ) + result = self.parser.extract_tool_calls_streaming( + previous_text='Sure. <|tool_call>call:read_file{path:<|"|>/tmp/foo<|"|>}', + current_text=full_text, + delta_text="", + ) + assert result is not None + assert "tool_calls" in result + assert len(result["tool_calls"]) == 1 + tc = result["tool_calls"][0] + assert tc["function"]["name"] == "read_file" + assert tc["type"] == "function" + assert tc["index"] == 0 + + +class TestGemma4Registration: + """Test parser registration and flags.""" + + def test_registered_in_manager(self): + from vllm_mlx.tool_parsers import ToolParserManager + + parser_cls = ToolParserManager.get_tool_parser("gemma4") + assert parser_cls is Gemma4ToolParser + + def test_native_format_false(self): + assert Gemma4ToolParser.SUPPORTS_NATIVE_TOOL_FORMAT is False + assert Gemma4ToolParser.supports_native_format() is False diff --git a/tests/test_minimax_tool_calling.py b/tests/test_minimax_tool_calling.py new file mode 100644 index 000000000..2b94f967b --- /dev/null +++ b/tests/test_minimax_tool_calling.py @@ -0,0 +1,130 @@ +"""Tests for MiniMax tool call parsing.""" + +import json +import unittest + +from vllm_mlx.api.tool_calling import parse_tool_calls + + +class TestMiniMaxToolCallParsing(unittest.TestCase): + """Test parsing of MiniMax-style tool calls.""" + + def test_single_tool_call(self): + text = """ + +Wanaka +celsius + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 1) + self.assertEqual(tool_calls[0].function.name, "get_weather") + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["city"], "Wanaka") + self.assertEqual(args["units"], "celsius") + self.assertEqual(cleaned, "") + + def test_tool_call_with_surrounding_text(self): + text = """Let me check the weather for you. + + +Wanaka + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 1) + self.assertIn("Let me check", cleaned) + + def test_multiple_tool_calls(self): + text = """ + +MiniMax M2.5 + + + + +/tmp/test.txt + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(len(tool_calls), 2) + self.assertEqual(tool_calls[0].function.name, "search") + self.assertEqual(tool_calls[1].function.name, "read_file") + + def test_json_parameter_value(self): + text = """ + +Meeting +["stuart", "frida"] + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["title"], "Meeting") + self.assertEqual(args["attendees"], ["stuart", "frida"]) + + def test_numeric_parameter(self): + text = """ + +42 + +""" + + cleaned, tool_calls = parse_tool_calls(text) + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args["value"], 42) + + def test_no_parameters(self): + text = """ + + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertEqual(tool_calls[0].function.name, "get_time") + args = json.loads(tool_calls[0].function.arguments) + self.assertEqual(args, {}) + + def test_with_think_tags_preserved(self): + text = """ +I should check the weather first. + + + +Wanaka + +""" + + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNotNone(tool_calls) + self.assertIn("", cleaned) + + def test_no_minimax_tool_calls(self): + text = "Just a regular message with no tool calls." + cleaned, tool_calls = parse_tool_calls(text) + self.assertIsNone(tool_calls) + self.assertEqual(cleaned, text) + + def test_tool_call_id_format(self): + text = """ + +1 + +""" + + _, tool_calls = parse_tool_calls(text) + self.assertTrue(tool_calls[0].id.startswith("call_")) + self.assertEqual(tool_calls[0].type, "function") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_mllm_continuous_batching.py b/tests/test_mllm_continuous_batching.py index 28b26b219..7cafb81f7 100644 --- a/tests/test_mllm_continuous_batching.py +++ b/tests/test_mllm_continuous_batching.py @@ -129,6 +129,53 @@ def test_finished_response(self): assert resp.finish_reason == "stop" + def test_error_response_skips_decoding(self): + """Error responses must not decode token=0 as content.""" + from unittest.mock import MagicMock + + from vllm_mlx.mllm_batch_generator import MLLMBatchResponse + from vllm_mlx.mllm_scheduler import MLLMScheduler + from vllm_mlx.request import RequestStatus + + # Build a minimal scheduler with mocked internals + scheduler = MLLMScheduler.__new__(MLLMScheduler) + scheduler._detokenizer_pool = {} + scheduler.uid_to_request_id = {0: "req-err"} + scheduler.total_completion_tokens = 0 + scheduler.num_requests_processed = 0 + + mock_tokenizer = MagicMock() + mock_tokenizer.decode.return_value = "" + mock_processor = MagicMock() + mock_processor.tokenizer = mock_tokenizer + scheduler.processor = mock_processor + + # Create a running request + mock_request = MagicMock() + mock_request.request_id = "req-err" + mock_request.output_tokens = [] + mock_request.num_output_tokens = 0 + mock_request.num_prompt_tokens = 10 + mock_request.status = RequestStatus.RUNNING + scheduler.running = {"req-err": mock_request} + + error_resp = MLLMBatchResponse( + uid=0, + request_id="req-err", + token=0, + logprobs=mx.array([0.0]), + finish_reason="error", + ) + + outputs, finished = scheduler._process_batch_responses([error_resp]) + + assert "req-err" in finished + assert mock_request.status == RequestStatus.FINISHED_ABORTED + # token=0 should not have been decoded through a detokenizer + assert "req-err" not in scheduler._detokenizer_pool + assert len(outputs) == 1 + assert outputs[0].new_text == "" + class TestMLLMBatch: """Tests for MLLMBatch class.""" diff --git a/tests/test_native_tool_format.py b/tests/test_native_tool_format.py index 184116171..c4182a6f8 100644 --- a/tests/test_native_tool_format.py +++ b/tests/test_native_tool_format.py @@ -12,6 +12,7 @@ AutoToolParser, DeepSeekToolParser, FunctionaryToolParser, + Gemma4ToolParser, GraniteToolParser, HermesToolParser, KimiToolParser, @@ -53,6 +54,7 @@ def test_parsers_without_native_support(self): NemotronToolParser, xLAMToolParser, AutoToolParser, + Gemma4ToolParser, ] for parser_cls in non_native_parsers: assert ( diff --git a/tests/test_normalize_messages.py b/tests/test_normalize_messages.py new file mode 100644 index 000000000..6d88437ad --- /dev/null +++ b/tests/test_normalize_messages.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for _normalize_messages() in vllm_mlx.server. + +_normalize_messages() maps non-standard roles (developer -> system) and merges +consecutive same-role messages before chat template application. This prevents +crashes from Qwen 3.5 and Llama templates that require alternating roles. +""" + + +class TestNormalizeMessages: + """Test _normalize_messages() for handling real-world client formats.""" + + def test_merge_consecutive_system_messages(self): + """Consecutive system messages are merged into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": "Always respond in JSON."}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert "helpful assistant" in result[0]["content"] + assert "JSON" in result[0]["content"] + assert result[1]["role"] == "user" + assert result[1]["content"] == "Hello" + + def test_merge_consecutive_user_messages(self): + """Consecutive user messages are merged into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "First part"}, + {"role": "user", "content": "Second part"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[1]["role"] == "user" + assert "First part" in result[1]["content"] + assert "Second part" in result[1]["content"] + + def test_opencode_format(self): + """OpenCode's system+system+user+user format is normalized.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "System prompt part 1"}, + {"role": "system", "content": "System prompt part 2"}, + {"role": "user", "content": "User instruction"}, + {"role": "user", "content": "User question"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + + def test_developer_role_mapped_to_system(self): + """OpenAI Responses API 'developer' role is mapped to 'system'.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "developer", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + + def test_developer_and_system_merged(self): + """developer + system consecutive messages are merged after role mapping.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "developer", "content": "Part 1"}, + {"role": "system", "content": "Part 2"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert "Part 1" in result[0]["content"] + assert "Part 2" in result[0]["content"] + + def test_already_alternating_unchanged(self): + """Well-formed alternating messages pass through unchanged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "Bye"}, + ] + result = _normalize_messages(messages) + assert result == messages + + def test_single_message_unchanged(self): + """Single message passes through unchanged.""" + from vllm_mlx.server import _normalize_messages + + messages = [{"role": "user", "content": "Hello"}] + result = _normalize_messages(messages) + assert result == messages + + def test_empty_messages(self): + """Empty message list passes through.""" + from vllm_mlx.server import _normalize_messages + + assert _normalize_messages([]) == [] + + def test_multimodal_content_preserved(self): + """Messages with list content (multimodal) are not merged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "user", "content": "Describe this:"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": {"url": "http://example.com/img.png"}, + }, + ], + }, + ] + result = _normalize_messages(messages) + # List content can't be trivially merged with string - kept separate + assert len(result) >= 1 + + def test_preserves_non_content_fields(self): + """Fields other than role/content are preserved on the first merged message.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "Part 1", "name": "sys1"}, + {"role": "system", "content": "Part 2"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + + def test_null_content_not_merged(self): + """Messages with None content (tool_calls pattern) are not merged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]}, + {"role": "assistant", "content": "Follow-up"}, + ] + result = _normalize_messages(messages) + # None content can't be merged with string - kept separate + assert len(result) == 2 + + def test_three_consecutive_system_messages(self): + """Three consecutive system messages merge into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "Part 1"}, + {"role": "system", "content": "Part 2"}, + {"role": "system", "content": "Part 3"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert "Part 1" in result[0]["content"] + assert "Part 3" in result[0]["content"] diff --git a/tests/test_paged_cache.py b/tests/test_paged_cache.py index 8e3082c34..167b60944 100644 --- a/tests/test_paged_cache.py +++ b/tests/test_paged_cache.py @@ -725,3 +725,152 @@ def test_clear(self): stats = cache.get_stats() # After clear, null block is still allocated (vLLM style) assert stats["allocated_blocks"] == 1 # only null block + + def test_reconstructs_hybrid_cache_from_boundary_snapshot(self): + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=10) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + tokens = list(range(8)) + kv_keys = mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3) + kv_values = mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3) + linear_state = [ + mx.arange(1 * 3 * 8).reshape(1, 3, 8), + mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4), + ] + extracted = [ + { + "state": (kv_keys, kv_values), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": linear_state, + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + + block_table = cache.store_cache("req-1", tokens, extracted) + first_block = paged_manager.allocated_blocks[block_table.block_ids[0]] + last_block = paged_manager.allocated_blocks[block_table.block_ids[-1]] + + assert first_block.cache_data[0] is not None + assert first_block.cache_data[1] is None + assert last_block.cache_data[1] is not None + + reconstructed = cache.reconstruct_cache(block_table) + + assert reconstructed is not None + assert isinstance(reconstructed[0], KVCache) + assert isinstance(reconstructed[1], ArraysCache) + assert reconstructed[0].state[0].tolist() == kv_keys.tolist() + assert reconstructed[0].state[1].tolist() == kv_values.tolist() + assert reconstructed[1].state[0].tolist() == linear_state[0].tolist() + assert reconstructed[1].state[1].tolist() == linear_state[1].tolist() + + def test_rejects_hybrid_prefix_without_boundary_snapshot(self): + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import BlockTable, PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=10) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + extracted = [ + { + "state": ( + mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3), + mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3), + ), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": [ + mx.arange(1 * 3 * 8).reshape(1, 3, 8), + mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4), + ], + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + + block_table = cache.store_cache("req-1", list(range(8)), extracted) + prefix_table = BlockTable( + request_id="req-prefix", + block_ids=[block_table.block_ids[0]], + num_tokens=4, + ) + + assert cache.reconstruct_cache(prefix_table) is None + + def test_deduplicated_terminal_uses_correct_recurrent_snapshot(self): + """Deduplication must not leak recurrent state across sequences.""" + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=20) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + # Request A: 8 tokens across 2 blocks. B2 is terminal. + kv_a = mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3) + recurrent_a = [mx.ones((1, 3, 8)), mx.ones((1, 2, 4, 4))] + extracted_a = [ + { + "state": (kv_a, kv_a), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": recurrent_a, + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + bt_a = cache.store_cache("req-a", list(range(8)), extracted_a) + + # Request B: 12 tokens, first 8 identical. B1/B2 deduplicated, B3 new terminal. + kv_b = mx.arange(1 * 2 * 12 * 3).reshape(1, 2, 12, 3) + recurrent_b = [mx.full((1, 3, 8), 2.0), mx.full((1, 2, 4, 4), 2.0)] + extracted_b = [ + { + "state": (kv_b, kv_b), + "meta_state": "", + "class_ref": KVCache, + "class_name": "ArraysCache", + }, + { + "state": recurrent_b, + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + bt_b = cache.store_cache("req-b", list(range(12)), extracted_b) + + # Reconstruct A: should use A's recurrent state (ones), not B's (twos) + recon_a = cache.reconstruct_cache(bt_a) + assert recon_a is not None + assert recon_a[1].state[0].tolist() == recurrent_a[0].tolist() + + # Reconstruct B: should use B's recurrent state (twos) + recon_b = cache.reconstruct_cache(bt_b) + assert recon_b is not None + assert recon_b[1].state[0].tolist() == recurrent_b[0].tolist() diff --git a/tests/test_reasoning_parser.py b/tests/test_reasoning_parser.py index e2d0184e7..4bcb5ab3f 100644 --- a/tests/test_reasoning_parser.py +++ b/tests/test_reasoning_parser.py @@ -6,6 +6,7 @@ - Parser registry (registration, lookup, listing) - Qwen3 parser (non-streaming and streaming) - DeepSeek-R1 parser (non-streaming and streaming) +- Gemma 4 parser (channel protocol, streaming, channel name stripping) - Edge cases (no tags, partial tags, etc.) """ @@ -28,6 +29,7 @@ def test_list_parsers_includes_builtin(self): parsers = list_parsers() assert "qwen3" in parsers assert "deepseek_r1" in parsers + assert "gemma4" in parsers def test_get_parser_qwen3(self): """Should be able to get Qwen3 parser.""" @@ -920,3 +922,267 @@ def test_constrain_tokens_stripped(self, parser): reasoning, content = parser.extract_reasoning(output) assert "<|constrain|>" not in (content or "") assert "<|channel|>" not in (content or "") + + +class TestGemma4Parser: + """Tests for the Gemma 4 reasoning parser (channel-based protocol).""" + + @pytest.fixture + def parser(self): + """Create a fresh Gemma 4 parser for each test.""" + return get_parser("gemma4")() + + # --- Non-streaming tests --- + + def test_extract_standard_format(self, parser): + """Standard format: <|channel>thought...response.""" + output = ( + "<|channel>thought\nLet me think step by step.\nThe answer is 42." + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Let me think step by step." + assert content == "The answer is 42." + + def test_extract_alternative_format(self, parser): + """Alternative format: <|channel>thought...<|channel>response...""" + output = "<|channel>thought\nAnalyzing the problem.\n<|channel>response\nThe result is 7." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Analyzing the problem." + assert content == "The result is 7." + + def test_extract_strips_thought_prefix(self, parser): + """Channel name 'thought' should be stripped from reasoning.""" + output = "<|channel>thought\nActual reasoning hereContent" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Actual reasoning here" + assert "thought" not in reasoning + + def test_extract_no_tags_pure_content(self, parser): + """No channel tags at all should return pure content.""" + output = "Just a regular response without thinking." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == output + + def test_extract_only_start_tag(self, parser): + """Only start tag means incomplete reasoning (no content yet).""" + output = "<|channel>thought\nStill thinking..." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Still thinking..." + assert content is None + + def test_extract_only_end_tag(self, parser): + """Only end tag (think injected in prompt).""" + output = "thought\nImplicit reasoningThe answer" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Implicit reasoning" + assert content == "The answer" + + def test_extract_empty_reasoning(self, parser): + """Empty reasoning should return None.""" + output = "<|channel>thought\nOnly content here." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == "Only content here." + + def test_extract_multiline_reasoning(self, parser): + """Should preserve multiline reasoning content.""" + output = ( + "<|channel>thought\n" + "Step 1: Understand the question.\n" + "Step 2: Analyze the data.\n" + "Step 3: Form conclusion.\n" + "The conclusion is clear." + ) + reasoning, content = parser.extract_reasoning(output) + assert "Step 1" in reasoning + assert "Step 2" in reasoning + assert "Step 3" in reasoning + assert content == "The conclusion is clear." + + def test_extract_unicode_reasoning(self, parser): + """Should handle Unicode in reasoning.""" + output = "<|channel>thought\n日本語テスト 🤔\n答えは42" + reasoning, content = parser.extract_reasoning(output) + assert "日本語テスト" in reasoning + assert "🤔" in reasoning + assert "42" in content + + def test_registry_includes_gemma4(self): + """gemma4 should be in the parser registry.""" + assert "gemma4" in list_parsers() + + # --- Streaming tests --- + + def test_streaming_no_tags_plain_content(self, parser): + """Streaming without any channel tags should return content.""" + parser.reset_state() + result = parser.extract_reasoning_streaming("", "Hello", "Hello") + assert result is not None + assert result.content == "Hello" + assert result.reasoning is None + + def test_streaming_standard_format(self, parser): + """Test streaming through <|channel>thought...content flow.""" + parser.reset_state() + + tokens = [ + "<|channel>", + "thought", + "\n", + "Let me ", + "think.", + "", + "The ", + "answer.", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + # "thought\n" prefix should be stripped + assert "thought" not in full_reasoning or "thought" in "Let me think." + assert "Let me think." in full_reasoning + assert "The answer." in full_content + + def test_streaming_alternative_format(self, parser): + """Test streaming with <|channel>response transition.""" + parser.reset_state() + + tokens = [ + "<|channel>", + "thought", + "\n", + "Analyzing.", + "<|channel>response", + "\n", + "Result: ", + "42", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_content = "".join(content_parts) + assert "Result: 42" in full_content + + def test_streaming_suppresses_channel_names(self, parser): + """Channel names 'thought' and 'response' should not appear in output.""" + parser.reset_state() + + # Simulate realistic Gemma 4 output + tokens = [ + "<|channel>", + "thought", + "\n", + "Real ", + "reasoning.", + "", + "Real ", + "content.", + ] + + accumulated = "" + all_output = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + all_output.append(("r", result.reasoning)) + if result.content: + all_output.append(("c", result.content)) + + # Verify no raw "thought" token leaked as reasoning + reasoning_text = "".join(t for tag, t in all_output if tag == "r") + content_text = "".join(t for tag, t in all_output if tag == "c") + + assert "Real reasoning." in reasoning_text + assert "Real content." in content_text + + def test_streaming_token_by_token(self, parser): + """Test character-by-character streaming (worst case).""" + parser.reset_state() + + output = "<|channel>thought\nStep 1: Think\nStep 2: Analyze\nFinal answer: 42." + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for char in output: + prev = accumulated + accumulated += char + result = parser.extract_reasoning_streaming(prev, accumulated, char) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + assert "Step 1: Think" in full_reasoning + assert "Step 2: Analyze" in full_reasoning + assert "Final answer: 42." in full_content + + def test_streaming_long_thinking_no_end_tag(self, parser): + """When model generates long thinking without end tag, all goes to reasoning.""" + parser.reset_state() + + # Simulate model that hits max_tokens before + tokens = [ + "<|channel>", + "thought", + "\n", + "This is a very long ", + "reasoning process ", + "that continues ", + "without ending.", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + assert "very long reasoning process" in full_reasoning + assert len(content_parts) == 0 diff --git a/tests/test_server.py b/tests/test_server.py index 9fb86a3e5..c20957211 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Tests for the OpenAI-compatible API server.""" +import json import platform import sys @@ -167,6 +168,41 @@ def test_basic_completion_request(self): assert request.max_tokens is None # uses _default_max_tokens when None +class TestServeCli: + """Test serve CLI argument parsing.""" + + def test_tool_call_parser_accepts_harmony_aliases(self): + """GPT-OSS/Harmony parsers should be selectable from the serve CLI.""" + from vllm_mlx.cli import create_parser + + parser = create_parser() + args = parser.parse_args( + [ + "serve", + "lmstudio-community/gpt-oss-20b-MLX-8bit", + "--enable-auto-tool-choice", + "--tool-call-parser", + "harmony", + ] + ) + + assert args.command == "serve" + assert args.tool_call_parser == "harmony" + assert args.enable_auto_tool_choice is True + + args = parser.parse_args( + [ + "serve", + "lmstudio-community/gpt-oss-20b-MLX-8bit", + "--enable-auto-tool-choice", + "--tool-call-parser", + "gpt-oss", + ] + ) + + assert args.tool_call_parser == "gpt-oss" + + # ============================================================================= # Helper Function Tests # ============================================================================= @@ -304,7 +340,7 @@ def test_rate_limiter_enforces_limit(self): # First 3 requests should be allowed for i in range(3): allowed, retry_after = limiter.is_allowed("client1") - assert allowed is True, f"Request {i+1} should be allowed" + assert allowed is True, f"Request {i + 1} should be allowed" assert retry_after == 0 # 4th request should be blocked @@ -593,9 +629,7 @@ def test_verify_api_key_rejects_invalid(self): # Should raise HTTPException with 401 with pytest.raises(HTTPException) as exc_info: - asyncio.get_event_loop().run_until_complete( - server.verify_api_key(credentials) - ) + asyncio.run(server.verify_api_key(credentials)) assert exc_info.value.status_code == 401 assert "Invalid API key" in str(exc_info.value.detail) @@ -621,9 +655,7 @@ def test_verify_api_key_accepts_valid(self): ) # Should not raise any exception - result = asyncio.get_event_loop().run_until_complete( - server.verify_api_key(credentials) - ) + result = asyncio.run(server.verify_api_key(credentials)) # verify_api_key returns True on success (no exception raised) assert result is True or result is None finally: @@ -677,6 +709,225 @@ def test_rate_limiter_window_cleanup(self): assert allowed is True +class TestStreamChatCompletion: + """Tests for streaming chat completion behavior.""" + + @pytest.mark.anyio + async def test_reasoning_stream_emits_structured_tool_calls(self, monkeypatch): + """Tool markup after should emit tool_calls chunks.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.reasoning import DeltaMessage + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput(text="", new_text="reasoning", finished=False), + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput( + text="", new_text='{"name":"search"}', finished=False + ), + GenerationOutput( + text="", + new_text="", + finished=True, + finish_reason="stop", + prompt_tokens=7, + completion_tokens=3, + ), + ] + for chunk in chunks: + yield chunk + + class FakeReasoningParser: + def reset_state(self): + self._in_reasoning = False + + def extract_reasoning_streaming( + self, previous_text, current_text, delta_text + ): + if delta_text == "": + self._in_reasoning = True + return None + if delta_text == "": + self._in_reasoning = False + return None + if self._in_reasoning: + return DeltaMessage(reasoning=delta_text) + return DeltaMessage(content=delta_text) + + class FakeToolParser: + def reset(self): + pass + + def extract_tool_calls_streaming( + self, previous_text, current_text, delta_text + ): + if "" in current_text: + return { + "tool_calls": [ + { + "index": 0, + "id": "call_123", + "type": "function", + "function": { + "name": "search", + "arguments": '{"q":"weather"}', + }, + } + ] + } + return None + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", FakeReasoningParser()) + monkeypatch.setattr(server, "_enable_auto_tool_choice", True) + monkeypatch.setattr(server, "_tool_call_parser", "fake") + monkeypatch.setattr(server, "_tool_parser_instance", FakeToolParser()) + + request = ChatCompletionRequest( + model="request-model", + messages=[Message(role="user", content="hi")], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + + tool_payloads = [ + payload + for payload in payloads + if payload["choices"] and payload["choices"][0]["delta"].get("tool_calls") + ] + + assert payloads[0]["choices"][0]["delta"]["role"] == "assistant" + assert payloads[1]["choices"][0]["delta"]["reasoning"] == "reasoning" + assert len(tool_payloads) == 1 + assert ( + tool_payloads[0]["choices"][0]["delta"]["tool_calls"][0]["function"]["name"] + == "search" + ) + assert tool_payloads[0]["choices"][0]["finish_reason"] == "tool_calls" + assert tool_payloads[0]["usage"] == { + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + } + + @pytest.mark.anyio + async def test_reasoning_stream_skips_tool_parser_until_markup_appears( + self, monkeypatch + ): + """Plain post-reasoning content should stream normally on the fast path.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.reasoning import DeltaMessage + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput(text="", new_text="reasoning", finished=False), + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput( + text="", + new_text="final answer", + finished=True, + finish_reason="stop", + ), + ] + for chunk in chunks: + yield chunk + + class FakeReasoningParser: + def reset_state(self): + self._in_reasoning = False + + def extract_reasoning_streaming( + self, previous_text, current_text, delta_text + ): + if delta_text == "": + self._in_reasoning = True + return None + if delta_text == "": + self._in_reasoning = False + return None + if self._in_reasoning: + return DeltaMessage(reasoning=delta_text) + return DeltaMessage(content=delta_text) + + class TrackingToolParser: + def __init__(self): + self.calls = [] + + def reset(self): + self.calls.clear() + + def extract_tool_calls_streaming( + self, previous_text, current_text, delta_text + ): + self.calls.append((previous_text, current_text, delta_text)) + return {"content": delta_text} + + tool_parser = TrackingToolParser() + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", FakeReasoningParser()) + monkeypatch.setattr(server, "_enable_auto_tool_choice", True) + monkeypatch.setattr(server, "_tool_call_parser", "fake") + monkeypatch.setattr(server, "_tool_parser_instance", tool_parser) + + request = ChatCompletionRequest( + model="request-model", + messages=[Message(role="user", content="hi")], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + + assert tool_parser.calls == [] + assert payloads[1]["choices"][0]["delta"]["reasoning"] == "reasoning" + assert payloads[2]["choices"][0]["delta"]["content"] == "final answer" + assert payloads[2]["choices"][0]["finish_reason"] == "stop" + + # ============================================================================= # Integration Tests (require running server) # ============================================================================= diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index cce42bfc3..2cf4a6daf 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -6,10 +6,16 @@ import pytest +pytestmark = pytest.mark.anyio + class TestSimpleEngineConcurrency: """Test SimpleEngine lock behavior with concurrent requests.""" + @pytest.fixture + def anyio_backend(self): + return "asyncio" + @pytest.fixture def mock_model(self): """Create a mock model that tracks concurrent calls.""" @@ -36,6 +42,27 @@ def generate_side_effect(**kwargs): return result model.generate = MagicMock(side_effect=generate_side_effect) + + # stream_generate tracks concurrency the same way so tests that + # exercise SimpleEngine.generate() (which is now an accumulator + # over stream_generate) see the same serialization behavior. + def stream_generate_side_effect(**kwargs): + model._concurrent_count += 1 + model._max_concurrent = max(model._max_concurrent, model._concurrent_count) + import time + + time.sleep(0.05) + model._concurrent_count -= 1 + chunk = MagicMock() + chunk.text = "test response" + chunk.tokens = [1, 2, 3] + chunk.finished = True + chunk.finish_reason = "stop" + chunk.prompt_tokens = 3 + chunk.completion_tokens = 3 + yield chunk + + model.stream_generate = MagicMock(side_effect=stream_generate_side_effect) return model @pytest.fixture @@ -65,7 +92,7 @@ def chat_side_effect(**kwargs): model.chat = MagicMock(side_effect=chat_side_effect) return model - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_generate(self, mock_model): """Test that the lock prevents concurrent generate calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -89,7 +116,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_chat(self, mock_llm_model): """Test that the lock prevents concurrent chat calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -115,7 +142,56 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model): + """Tool-enabled non-stream chat should use the streaming path.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_stream_chat(*args, **kwargs): + yield MagicMock( + text="partial", + tokens=[1], + prompt_tokens=11, + completion_tokens=1, + finish_reason=None, + finished=False, + ) + yield MagicMock( + text='<|im_end|>{"name":"bash","arguments":{"command":"pwd"}}', + tokens=[7, 8, 9], + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + finished=True, + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = mock_llm_model + engine._loaded = True + engine.stream_chat = fake_stream_chat # type: ignore[method-assign] + + output = await engine.chat( + messages=[{"role": "user", "content": "run pwd"}], + max_tokens=16, + tools=[ + { + "type": "function", + "function": { + "name": "bash", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + assert output.text == '{"name":"bash","arguments":{"command":"pwd"}}' + assert output.tokens == [7, 8, 9] + assert output.prompt_tokens == 11 + assert output.completion_tokens == 4 + assert output.finish_reason == "stop" + mock_llm_model.chat.assert_not_called() + + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" from vllm_mlx.engine.simple import SimpleEngine @@ -178,7 +254,7 @@ async def try_stream(): result = await stream_task assert len(result) == 3, f"Expected 3 chunks, got {len(result)}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_engine_initialization_creates_lock(self): """Test that SimpleEngine creates a lock on initialization.""" from vllm_mlx.engine.simple import SimpleEngine @@ -189,7 +265,7 @@ async def test_engine_initialization_creates_lock(self): assert hasattr(engine, "_generation_lock") assert isinstance(engine._generation_lock, asyncio.Lock) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_requests_complete_in_order(self, mock_model): """Test that concurrent requests complete (may be in any order due to lock).""" from vllm_mlx.engine.simple import SimpleEngine @@ -211,3 +287,86 @@ async def test_requests_complete_in_order(self, mock_model): assert len(results) == 3 for result in results: assert result.text == "test response" + + @pytest.mark.asyncio + async def test_generate_accumulates_over_stream_generate(self): + """generate() should iterate stream_generate() and return the last + yielded GenerationOutput, forwarding per-request kwargs (including + SpecPrefill overrides) through so they reach _stream_generate_specprefill. + """ + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.engine.simple import SimpleEngine + + captured_kwargs = {} + + async def fake_stream_generate(**kwargs): + captured_kwargs.update(kwargs) + # First chunk: mid-generation + yield GenerationOutput( + text="partial", + new_text="partial", + tokens=[1, 2], + prompt_tokens=11, + completion_tokens=2, + finished=False, + finish_reason=None, + ) + # Final chunk: finished + yield GenerationOutput( + text="partial final", + new_text=" final", + tokens=[1, 2, 3], + prompt_tokens=11, + completion_tokens=3, + finished=True, + finish_reason="stop", + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine.stream_generate = fake_stream_generate # type: ignore[method-assign] + + output = await engine.generate( + prompt="say hi", + max_tokens=16, + temperature=0.6, + top_p=0.95, + specprefill=True, + specprefill_keep_pct=0.2, + ) + + # Accumulator returns the last GenerationOutput's fields + assert output.text == "partial final" + assert output.tokens == [1, 2, 3] + assert output.prompt_tokens == 11 + assert output.completion_tokens == 3 + assert output.finish_reason == "stop" + assert output.finished is True + + # Per-request SpecPrefill overrides reach stream_generate + assert captured_kwargs.get("prompt") == "say hi" + assert captured_kwargs.get("max_tokens") == 16 + assert captured_kwargs.get("specprefill") is True + assert captured_kwargs.get("specprefill_keep_pct") == 0.2 + + @pytest.mark.asyncio + async def test_generate_empty_stream_returns_safe_default(self): + """If stream_generate yields nothing, generate() returns an empty + stop-reason GenerationOutput rather than raising. + """ + from vllm_mlx.engine.simple import SimpleEngine + + async def empty_stream_generate(**kwargs): + return + yield # unreachable; makes this a generator + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine.stream_generate = empty_stream_generate # type: ignore[method-assign] + + output = await engine.generate(prompt="anything", max_tokens=5) + + assert output.text == "" + assert output.finish_reason == "stop" diff --git a/tests/test_simple_engine_cancel_serialization.py b/tests/test_simple_engine_cancel_serialization.py new file mode 100644 index 000000000..28c25868e --- /dev/null +++ b/tests/test_simple_engine_cancel_serialization.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression test for cancellation-safe SimpleEngine serialization.""" + +from __future__ import annotations + +import asyncio +import threading +import unittest +from unittest.mock import MagicMock, patch + + +class SimpleEngineCancelSerializationTests(unittest.IsolatedAsyncioTestCase): + async def test_cancellation_does_not_release_lock_before_worker_finishes(self): + """A cancelled request must not let a second MLX worker overlap.""" + from vllm_mlx.engine.simple import SimpleEngine + + model = MagicMock() + model.tokenizer = MagicMock() + model.tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + model._concurrent_count = 0 + model._max_concurrent = 0 + + first_started = threading.Event() + release_workers = threading.Event() + call_count = 0 + call_lock = threading.Lock() + + def generate_side_effect(**kwargs): + nonlocal call_count + with call_lock: + call_count += 1 + current_call = call_count + model._concurrent_count += 1 + model._max_concurrent = max( + model._max_concurrent, model._concurrent_count + ) + if current_call == 1: + first_started.set() + + release_workers.wait(timeout=1.0) + + with call_lock: + model._concurrent_count -= 1 + + result = MagicMock() + result.text = f"response-{current_call}" + result.tokens = [1, 2, 3] + result.finish_reason = "stop" + return result + + model.generate = MagicMock(side_effect=generate_side_effect) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = model + engine._loaded = True + + task1 = asyncio.create_task(engine.generate(prompt="first", max_tokens=8)) + await asyncio.to_thread(first_started.wait, 1.0) + + task1.cancel() + task2 = asyncio.create_task(engine.generate(prompt="second", max_tokens=8)) + + await asyncio.sleep(0.05) + release_workers.set() + + with self.assertRaises(asyncio.CancelledError): + await task1 + result2 = await task2 + + self.assertEqual(result2.text, "response-2") + self.assertEqual( + model._max_concurrent, + 1, + "cancellation released the generation lock before the first worker finished", + ) + + async def test_specprefill_path_does_not_prelock_serialized_runner(self): + """Specprefill streaming must let _run_blocking_serialized own the lock.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_serialized(func, *args, **kwargs): + self.assertFalse(engine._generation_lock.locked()) + return [] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._model = MagicMock() + engine._model.model = MagicMock() + engine._model.tokenizer = MagicMock() + engine._draft_model = MagicMock() + engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign] + + outputs = [] + async for chunk in engine._stream_generate_specprefill( + prompt="hello", + tokens=[1, 2, 3, 4], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0].finished) + self.assertEqual(outputs[0].completion_tokens, 0) + + async def test_text_mtp_path_does_not_prelock_serialized_runner(self): + """Text-only MTP streaming must let _run_blocking_serialized own the lock.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_serialized(func, *args, **kwargs): + self.assertFalse(engine._generation_lock.locked()) + return [] + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=True): + engine = SimpleEngine("test-model") + engine._loaded = True + engine._text_model = MagicMock() + engine._text_model.make_mtp_cache = MagicMock(return_value=[]) + engine._text_tokenizer = MagicMock() + engine._text_tokenizer.apply_chat_template = MagicMock(return_value="hello") + engine._text_tokenizer.bos_token = None + engine._draft_model = None + engine._run_blocking_serialized = fake_serialized # type: ignore[method-assign] + + outputs = [] + async for chunk in engine._stream_generate_text( + messages=[{"role": "user", "content": "hello"}], + max_tokens=4, + temperature=0.7, + top_p=0.9, + ): + outputs.append(chunk) + + self.assertEqual(len(outputs), 1) + self.assertTrue(outputs[0].finished) + self.assertEqual(outputs[0].completion_tokens, 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_specprefill_rotating_cache.py b/tests/test_specprefill_rotating_cache.py new file mode 100644 index 000000000..a944c0ee2 --- /dev/null +++ b/tests/test_specprefill_rotating_cache.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for RotatingKVCache handling in sparse_prefill.""" + +from __future__ import annotations + +import pytest + +try: + import mlx.core as mx + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +class _FakeAttention: + def __init__(self): + self.num_heads = 1 + self.q_proj = lambda x: x + + +class _FakeLayer: + def __init__(self): + self.block_type = "*" + self.mixer = _FakeAttention() + + +class _FakeModel: + def __init__(self): + self.layers = [_FakeLayer()] + self.calls: list[list[int]] = [] + + def __call__(self, x, cache=None): + self.calls.append(x.tolist()) + logits = mx.zeros((1, x.shape[1], 8), dtype=mx.float32) + return logits + + +class RotatingKVCache: + def __init__(self, max_size: int, keep: int = 0): + self.max_size = max_size + self.keep = keep + self.offset = 0 + self.state = mx.array([0], dtype=mx.float32) + + +def _run_sparse_prefill(total_tokens: int, selected_indices: list[int], max_size: int): + from vllm_mlx.specprefill import sparse_prefill + + model = _FakeModel() + tokens = list(range(total_tokens)) + cache = [RotatingKVCache(max_size=max_size, keep=0)] + sparse_prefill( + model, + tokens, + selected_indices, + cache, + step_size=64, + ) + return model.calls + + +def test_sparse_prefill_does_not_expand_tail_when_prompt_fits_window(): + calls = _run_sparse_prefill( + total_tokens=6, + selected_indices=[0, 2, 4], + max_size=8, + ) + + flattened = [token for chunk in calls for row in chunk for token in row] + assert flattened == [0, 2, 4] + + +def test_sparse_prefill_expands_tail_when_prompt_exceeds_window(): + calls = _run_sparse_prefill( + total_tokens=10, + selected_indices=[0, 2], + max_size=8, + ) + + flattened = [token for chunk in calls for row in chunk for token in row] + assert flattened == [0, 2, 3, 4, 5, 6, 7, 8, 9] diff --git a/tests/test_streaming_latency.py b/tests/test_streaming_latency.py index cae95f5fb..116ee9dfa 100644 --- a/tests/test_streaming_latency.py +++ b/tests/test_streaming_latency.py @@ -206,7 +206,7 @@ async def run_benchmark( print(f"Throughput: {throughput:.1f} tokens/sec") -@pytest.mark.asyncio +@pytest.mark.anyio async def test_output_collector(): """Unit test for RequestOutputCollector.""" import sys diff --git a/tests/test_tokenizer_utils.py b/tests/test_tokenizer_utils.py new file mode 100644 index 000000000..0b046e0c7 --- /dev/null +++ b/tests/test_tokenizer_utils.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for tokenizer utility helpers.""" + +import types +from unittest.mock import patch + + +def test_load_model_with_fallback_returns_successful_load_result(): + from vllm_mlx.utils.tokenizer import load_model_with_fallback + + fake_model = object() + fake_tokenizer = object() + fake_mlx_lm = types.SimpleNamespace( + load=lambda *args, **kwargs: (fake_model, fake_tokenizer) + ) + + with ( + patch("vllm_mlx.utils.tokenizer._needs_tokenizer_fallback", return_value=False), + patch("vllm_mlx.utils.tokenizer._needs_strict_false", return_value=False), + patch("vllm_mlx.utils.tokenizer._try_inject_mtp_post_load"), + patch.dict("sys.modules", {"mlx_lm": fake_mlx_lm}), + ): + model, tokenizer = load_model_with_fallback("mlx-community/Qwen3.5-4B") + + assert model is fake_model + assert tokenizer is fake_tokenizer + + +def test_load_model_with_fallback_uses_tokenizer_fallback_for_tokenizer_errors(): + from vllm_mlx.utils.tokenizer import load_model_with_fallback + + fake_model = object() + fake_tokenizer = object() + + def _raise(*args, **kwargs): + raise ValueError("Tokenizer class Foo does not exist") + + fake_mlx_lm = types.SimpleNamespace(load=_raise) + + with ( + patch("vllm_mlx.utils.tokenizer._needs_tokenizer_fallback", return_value=False), + patch("vllm_mlx.utils.tokenizer._needs_strict_false", return_value=False), + patch("vllm_mlx.utils.tokenizer._try_inject_mtp_post_load"), + patch( + "vllm_mlx.utils.tokenizer._load_with_tokenizer_fallback", + return_value=(fake_model, fake_tokenizer), + ) as fallback, + patch.dict("sys.modules", {"mlx_lm": fake_mlx_lm}), + ): + model, tokenizer = load_model_with_fallback("example/model") + + fallback.assert_called_once_with("example/model") + assert model is fake_model + assert tokenizer is fake_tokenizer diff --git a/tests/test_tool_choice_none.py b/tests/test_tool_choice_none.py new file mode 100644 index 000000000..d4af223fe --- /dev/null +++ b/tests/test_tool_choice_none.py @@ -0,0 +1,65 @@ +"""Tests for tool_choice='none' handling.""" + + +class TestToolChoiceNoneParserSuppression: + """Verify tool call parsing is suppressed when tool_choice='none'.""" + + def test_parse_tool_calls_skipped_when_tool_choice_none(self): + """_parse_tool_calls_with_parser should return no tools when tool_choice='none'.""" + from vllm_mlx.api.models import ChatCompletionRequest + from vllm_mlx.server import _parse_tool_calls_with_parser + + # Text that looks like a tool call + text = '{"name": "get_weather", "arguments": {"city": "London"}}' + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "Hello"}], + tool_choice="none", + ) + cleaned, tool_calls = _parse_tool_calls_with_parser(text, request) + # With tool_choice="none", parser should be suppressed + assert tool_calls is None + assert cleaned == text # text returned unchanged + + def test_parse_tool_calls_works_when_tool_choice_auto(self): + """Tool parsing should work normally when tool_choice is not 'none'.""" + from vllm_mlx.api.models import ChatCompletionRequest + from vllm_mlx.server import _parse_tool_calls_with_parser + + text = "Hello, how can I help?" + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "Hello"}], + tool_choice="auto", + ) + cleaned, tool_calls = _parse_tool_calls_with_parser(text, request) + # No tool markup in text, so no tools found — but parser was NOT skipped + assert tool_calls is None + + def test_parse_tool_calls_works_when_tool_choice_absent(self): + """Tool parsing should work when tool_choice is not set.""" + from vllm_mlx.api.models import ChatCompletionRequest + from vllm_mlx.server import _parse_tool_calls_with_parser + + text = "Hello, how can I help?" + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "Hello"}], + ) + cleaned, tool_calls = _parse_tool_calls_with_parser(text, request) + assert tool_calls is None + + def test_tool_markup_ignored_when_tool_choice_none(self): + """Even Qwen bracket-style tool calls should be suppressed.""" + from vllm_mlx.api.models import ChatCompletionRequest + from vllm_mlx.server import _parse_tool_calls_with_parser + + text = '[Calling tool: get_weather({"city": "London"})]' + request = ChatCompletionRequest( + model="test", + messages=[{"role": "user", "content": "weather?"}], + tool_choice="none", + ) + cleaned, tool_calls = _parse_tool_calls_with_parser(text, request) + assert tool_calls is None + assert cleaned == text diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index dfe2bb6a1..4f3c287d1 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -9,6 +9,7 @@ AutoToolParser, DeepSeekToolParser, FunctionaryToolParser, + Gemma4ToolParser, GraniteToolParser, HermesToolParser, KimiToolParser, @@ -39,6 +40,7 @@ def test_list_registered(self): "nemotron", "xlam", "functionary", + "gemma4", ] for p in expected: assert p in parsers, f"Parser '{p}' not found" @@ -68,6 +70,7 @@ def test_get_tool_parser_by_name(self): ("meetkai", FunctionaryToolParser), ("hermes", HermesToolParser), ("nous", HermesToolParser), + ("gemma4", Gemma4ToolParser), ] for name, expected_cls in test_cases: parser_cls = ToolParserManager.get_tool_parser(name) @@ -1160,3 +1163,177 @@ def test_streaming_bare_multi_function_blocks(self): assert len(emitted_calls) == 2 assert emitted_calls[0]["function"]["name"] == "func1" assert emitted_calls[1]["function"]["name"] == "func2" + + +class TestQwenFunctionFormat: + """Test Qwen parser's format support.""" + + @pytest.fixture + def parser(self): + return QwenToolParser() + + def test_function_format_with_parameters(self, parser): + """Test value.""" + text = "Prague" + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["city"] == "Prague" + + def test_function_format_with_json(self, parser): + """Test {"key": "val"}.""" + text = '{"city": "Prague"}' + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args["city"] == "Prague" + + def test_function_format_multiple(self, parser): + """Test multiple blocks.""" + text = ( + '{"path": "/a.py"}' + '{"path": "/b.py", "content": "hello"}' + ) + result = parser.extract_tool_calls(text) + assert result.tools_called + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["name"] == "read_file" + assert result.tool_calls[1]["name"] == "write_file" + + def test_function_format_with_think_tags(self, parser): + """Test with think tags.""" + text = ( + "I need to check the weather.\n" + '{"city": "Prague"}' + ) + result = parser.extract_tool_calls(text) + assert result.tools_called + assert result.tool_calls[0]["name"] == "get_weather" + + +class TestQwenStreamingBuffering: + """Test Qwen parser streaming with partial-marker buffering.""" + + @pytest.fixture + def parser(self): + return QwenToolParser() + + def test_streaming_function_format_complete(self, parser): + """Test streaming with ... format.""" + chunks = [ + "", + "Prague", + "", + ] + accumulated = "" + tool_calls_found = False + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + tool_calls_found = True + assert r["tool_calls"][0]["function"]["name"] == "get_weather" + break + assert tool_calls_found + + def test_streaming_partial_marker_buffered(self, parser): + """Test that partial '" — not a tool marker + r = parser.extract_tool_calls_streaming( + previous_text="Hello<", + current_text="Hello
", + delta_text="div>", + ) + assert r is not None + assert "content" in r + assert "<" in r["content"] + assert "div>" in r["content"] + + def test_streaming_multiple_function_blocks(self, parser): + """Test streaming with multiple {"a": 1}', + "\n", + "", + "2", + "", + ] + accumulated = "" + emitted_calls = [] + for chunk in chunks: + prev = accumulated + accumulated += chunk + r = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=accumulated, + delta_text=chunk, + ) + if r is not None and "tool_calls" in r: + emitted_calls.extend(r["tool_calls"]) + assert len(emitted_calls) == 2 + assert emitted_calls[0]["function"]["name"] == "func1" + assert emitted_calls[1]["function"]["name"] == "func2" diff --git a/vllm_mlx/api/anthropic_adapter.py b/vllm_mlx/api/anthropic_adapter.py index dbb94200f..62c6757b5 100644 --- a/vllm_mlx/api/anthropic_adapter.py +++ b/vllm_mlx/api/anthropic_adapter.py @@ -9,6 +9,7 @@ """ import json +import re import uuid from .anthropic_models import ( @@ -60,6 +61,10 @@ def anthropic_to_openai(request: AnthropicRequest) -> ChatCompletionRequest: system_text = "\n".join(parts) else: system_text = str(request.system) + # Strip per-request billing/tracking headers injected by some + # clients (e.g. Claude Code). These contain a per-request hash + # that prevents prefix-cache reuse across turn boundaries. + system_text = re.sub(r"x-anthropic-billing-header:[^\n]*\n?", "", system_text) messages.append(Message(role="system", content=system_text)) # Convert each message diff --git a/vllm_mlx/api/anthropic_models.py b/vllm_mlx/api/anthropic_models.py index a5bc6f776..e8854a5e6 100644 --- a/vllm_mlx/api/anthropic_models.py +++ b/vllm_mlx/api/anthropic_models.py @@ -84,8 +84,10 @@ class AnthropicUsage(BaseModel): class AnthropicResponseContentBlock(BaseModel): """A content block in the Anthropic response.""" - type: str # "text" or "tool_use" + type: str # "text", "thinking", or "tool_use" text: str | None = None + # thinking block + thinking: str | None = None # tool_use fields id: str | None = None name: str | None = None diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index f7bcaaaa5..8af8c9dca 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -159,6 +159,10 @@ class ChatCompletionRequest(BaseModel): messages: list[Message] temperature: float | None = None top_p: float | None = None + top_k: int | None = None + min_p: float | None = None + presence_penalty: float | None = None + repetition_penalty: float | None = None max_tokens: int | None = None stream: bool = False stream_options: StreamOptions | None = ( @@ -175,12 +179,16 @@ class ChatCompletionRequest(BaseModel): # MLLM-specific parameters video_fps: float | None = None video_max_frames: int | None = None + # Sampling penalties + repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes) # Request timeout in seconds (None = use server default) timeout: float | None = None # SpecPrefill: per-request enable/disable (None = server decides) specprefill: bool | None = None # SpecPrefill: per-request keep percentage (0.0-1.0, None = use server default) specprefill_keep_pct: float | None = None + # Enable/disable thinking mode (None = server default, typically True) + enable_thinking: bool | None = None class AssistantMessage(BaseModel): @@ -239,11 +247,21 @@ class CompletionRequest(BaseModel): prompt: str | list[str] temperature: float | None = None top_p: float | None = None + top_k: int | None = None + min_p: float | None = None + presence_penalty: float | None = None + repetition_penalty: float | None = None max_tokens: int | None = None stream: bool = False stop: list[str] | None = None + # Sampling penalties + repetition_penalty: float | None = None # mlx-lm style (>1.0 penalizes) # Request timeout in seconds (None = use server default) timeout: float | None = None + # SpecPrefill: per-request enable/disable (None = server decides) + specprefill: bool | None = None + # SpecPrefill: per-request keep percentage (0.0-1.0, None = use server default) + specprefill_keep_pct: float | None = None class CompletionChoice(BaseModel): diff --git a/vllm_mlx/api/tool_calling.py b/vllm_mlx/api/tool_calling.py index 1443c1674..364b65993 100644 --- a/vllm_mlx/api/tool_calling.py +++ b/vllm_mlx/api/tool_calling.py @@ -89,6 +89,7 @@ def parse_tool_calls( Parse tool calls from model output. Supports multiple formats: + - MiniMax: v - Qwen3 bracket: [Calling tool: function_name({"arg": "value"})] - Qwen: {"name": "...", "arguments": {...}} - Llama: {"arg": "value"} @@ -106,6 +107,47 @@ def parse_tool_calls( tool_calls = [] cleaned_text = text + # Pattern for MiniMax-style: v + minimax_pattern = r"\s*(.*?)\s*" + minimax_matches = re.findall(minimax_pattern, text, re.DOTALL) + + for invoke_block in minimax_matches: + # Parse blocks within the tool_call + invoke_pattern = r'(.*?)' + invoke_matches = re.findall(invoke_pattern, invoke_block, re.DOTALL) + + for name, params_block in invoke_matches: + # Parse value pairs + param_pattern = r'\s*(.*?)\s*' + params = re.findall(param_pattern, params_block, re.DOTALL) + arguments = {} + for p_name, p_value in params: + # Try to parse value as JSON (for nested objects/arrays/numbers) + try: + arguments[p_name] = json.loads(p_value) + except (json.JSONDecodeError, ValueError): + arguments[p_name] = p_value + + tool_calls.append( + ToolCall( + id=f"call_{uuid.uuid4().hex[:8]}", + type="function", + function=FunctionCall( + name=name.strip(), + arguments=json.dumps(arguments), + ), + ) + ) + + # Remove MiniMax tool call tags from cleaned text + if minimax_matches: + cleaned_text = re.sub( + r"\s*.*?\s*", + "", + cleaned_text, + flags=re.DOTALL, + ).strip() + # Pattern for Qwen3 bracket-style: [Calling tool: function_name({...})] bracket_pattern = r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]" bracket_matches = re.findall(bracket_pattern, text, re.DOTALL) diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 9fdbfef13..6218dce7d 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -20,7 +20,9 @@ r"<\|im_end\|>|<\|im_start\|>|<\|endoftext\|>|" r"<\|end\|>|<\|eot_id\|>|<\|start_header_id\|>|<\|end_header_id\|>|" r"<\|channel\|>|<\|message\|>|<\|start\|>|<\|return\|>|<\|call\|>|<\|constrain\|>|" - r"|||\[PAD\]|\[SEP\]|\[CLS\]" + r"|||\[PAD\]|\[SEP\]|\[CLS\]|" + r"\[e~\[|\]~b\][a-z]*|\]~!b\[|" + r"|" ) @@ -121,6 +123,7 @@ def clean_output_text(text: str) -> str: ("", ""), ("", ""), (""), + ("<|tool_call>", ""), ("[TOOL_CALL]", "[/TOOL_CALL]"), ("[Calling tool", "]\n"), # Qwen3 bracket-style: [Calling tool: func({...})]\n ] @@ -339,6 +342,8 @@ def flush(self) -> list[tuple[str, str]]: "PaliGemma", # PaliGemma "gemma-3", "gemma3", # Gemma 3 (multimodal) + "gemma-4", + "gemma4", # Gemma 4 (multimodal: vision + audio) "medgemma", "MedGemma", # MedGemma (medical multimodal with SigLIP vision encoder) "pixtral", @@ -353,6 +358,8 @@ def flush(self) -> list[tuple[str, str]]: "InternVL", # InternVL "deepseek-vl", "DeepSeek-VL", # DeepSeek-VL + "Qwen3.5-", + "qwen3_5", # Qwen3.5 MoE (natively multimodal, hybrid ArraysCache+KVCache) ] diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 8a90bc9be..07dd17fe1 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -37,6 +37,13 @@ def serve_command(args): print("Example: --enable-auto-tool-choice --tool-call-parser mistral") sys.exit(1) + # Validate gpu-memory-utilization range + if not (0.0 < args.gpu_memory_utilization <= 1.0): + print( + "Error: --gpu-memory-utilization must be between 0.0 (exclusive) and 1.0 (inclusive)" + ) + sys.exit(1) + # Configure server security settings server._api_key = args.api_key server._default_timeout = args.timeout @@ -105,6 +112,21 @@ def serve_command(args): print(" Reasoning: Use --reasoning-parser to enable") print("=" * 60) + # Pre-download model with retry/timeout + from .api.utils import is_mllm_model + from .utils.download import DownloadConfig, ensure_model_downloaded + + download_config = DownloadConfig( + download_timeout=args.download_timeout, + max_retries=args.download_retries, + offline=getattr(args, "offline", False), + ) + ensure_model_downloaded( + args.model, + config=download_config, + is_mllm=is_mllm_model(args.model), + ) + print(f"Loading model: {args.model}") print(f"Default max tokens: {args.max_tokens}") @@ -150,6 +172,9 @@ def serve_command(args): kv_cache_quantization_bits=args.kv_cache_quantization_bits, kv_cache_quantization_group_size=args.kv_cache_quantization_group_size, kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens, + mllm_prefill_step_size=( + args.mllm_prefill_step_size if args.mllm_prefill_step_size > 0 else None + ), ) print("Mode: Continuous batching (for multiple concurrent users)") @@ -196,7 +221,8 @@ def serve_command(args): scheduler_config=scheduler_config, stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, - force_mllm=args.mllm, + force_mllm=getattr(args, "mllm", False), + gpu_memory_utilization=args.gpu_memory_utilization, served_model_name=args.served_model_name, mtp=args.enable_mtp, prefill_step_size=args.prefill_step_size, @@ -211,6 +237,23 @@ def serve_command(args): uvicorn.run(app, host=args.host, port=args.port, log_level="info") +def download_command(args): + """Download a model to local cache without starting a server.""" + from .utils.download import DownloadConfig, ensure_model_downloaded + + config = DownloadConfig( + download_timeout=args.timeout, + max_retries=args.retries, + ) + print(f"Downloading model: {args.model}") + path = ensure_model_downloaded( + args.model, + config=config, + is_mllm=args.mllm, + ) + print(f"Model ready at: {path}") + + def bench_command(args): """Run benchmark.""" import asyncio @@ -249,6 +292,7 @@ async def run_benchmark(): kv_cache_quantization_group_size=args.kv_cache_quantization_group_size, kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens, ) + engine_config = EngineConfig( model_name=args.model, scheduler_config=scheduler_config, @@ -593,7 +637,8 @@ def bench_kv_cache_command(args): ) -def main(): +def create_parser() -> argparse.ArgumentParser: + """Build the top-level CLI parser.""" parser = argparse.ArgumentParser( description="vllm-mlx: Apple Silicon MLX backend for vLLM", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -627,6 +672,12 @@ def main(): serve_parser.add_argument( "--completion-batch-size", type=int, default=32, help="Completion batch size" ) + serve_parser.add_argument( + "--mllm-prefill-step-size", + type=int, + default=0, + help="Override MLLM prefill-step guard (0=use MLLM default: 1024)", + ) serve_parser.add_argument( "--enable-prefix-cache", action="store_true", @@ -704,6 +755,14 @@ def main(): action="store_true", help="Enable continuous batching for multiple concurrent users (slower for single user)", ) + serve_parser.add_argument( + "--gpu-memory-utilization", + type=float, + default=0.90, + help="Fraction of device memory for Metal allocation limit and emergency " + "cache clear threshold (0.0-1.0, default: 0.90). Increase to 0.95 for " + "large models (200GB+) that need more memory headroom.", + ) # Paged cache options (experimental) serve_parser.add_argument( "--use-paged-cache", @@ -832,18 +891,23 @@ def main(): "qwen3_coder", "llama", "hermes", + "harmony", + "gpt-oss", "deepseek", "kimi", "granite", "nemotron", "xlam", "functionary", + "gemma4", "glm47", + "minimax", ], help=( "Select the tool call parser for the model. Options: " "auto (auto-detect), mistral, qwen, qwen3_coder, llama, hermes, " - "deepseek, kimi, granite, nemotron, xlam, functionary, glm47. " + "harmony, gpt-oss, deepseek, gemma4, kimi, granite, nemotron, " + "xlam, functionary, glm47, minimax. " "Required for --enable-auto-tool-choice." ), ) @@ -888,6 +952,24 @@ def main(): default=None, help="Pre-load an embedding model at startup (e.g. mlx-community/embeddinggemma-300m-6bit)", ) + # Download options + serve_parser.add_argument( + "--download-timeout", + type=int, + default=300, + help="Per-file download timeout in seconds (default: 300)", + ) + serve_parser.add_argument( + "--download-retries", + type=int, + default=3, + help="Number of download retry attempts (default: 3)", + ) + serve_parser.add_argument( + "--offline", + action="store_true", + help="Offline mode — only use locally cached models", + ) # Bench command bench_parser = subparsers.add_parser("bench", help="Run benchmark") bench_parser.add_argument("model", type=str, help="Model to benchmark") @@ -1023,6 +1105,34 @@ def main(): help="Quantization group size (default: 64)", ) + # Download command + download_parser = subparsers.add_parser( + "download", help="Download a model to local cache without starting a server" + ) + download_parser.add_argument("model", type=str, help="Model to download") + download_parser.add_argument( + "--timeout", + type=int, + default=300, + help="Per-file download timeout in seconds (default: 300)", + ) + download_parser.add_argument( + "--retries", + type=int, + default=3, + help="Number of retry attempts (default: 3)", + ) + download_parser.add_argument( + "--mllm", + action="store_true", + help="Download as multimodal model (broader file patterns)", + ) + + return parser + + +def main(): + parser = create_parser() args = parser.parse_args() if args.command == "serve": @@ -1033,6 +1143,8 @@ def main(): bench_detok_command(args) elif args.command == "bench-kv-cache": bench_kv_cache_command(args) + elif args.command == "download": + download_command(args) else: parser.print_help() sys.exit(1) diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 0f0f8f0f1..cb9c8aad8 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -137,6 +137,7 @@ def __init__( scheduler_config: Any | None = None, stream_interval: int = 1, force_mllm: bool = False, + gpu_memory_utilization: float = 0.90, ): """ Initialize the batched engine. @@ -147,11 +148,14 @@ def __init__( scheduler_config: Optional scheduler configuration stream_interval: Tokens to batch before streaming (1=every token) force_mllm: Force loading as MLLM even if not auto-detected + gpu_memory_utilization: Fraction of device memory for Metal allocation + limit and emergency threshold (0.0-1.0, default 0.90) """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._scheduler_config = scheduler_config self._stream_interval = stream_interval + self._gpu_memory_utilization = gpu_memory_utilization self._is_mllm = force_mllm or is_mllm_model(model_name) self._model = None @@ -207,6 +211,10 @@ async def _start_mllm(self) -> None: self._model = self._mllm_instance.model self._processor = self._mllm_instance.processor + # Inject MTP support if enabled + if self._scheduler_config and self._scheduler_config.enable_mtp: + self._inject_mtp_mllm() + # Create MLLM scheduler config with batch generator support if self._scheduler_config and hasattr(self._scheduler_config, "max_num_seqs"): max_num_seqs = self._scheduler_config.max_num_seqs @@ -219,12 +227,38 @@ async def _start_mllm(self) -> None: self._scheduler_config, "completion_batch_size", 16 ) + cache_memory_mb = getattr(self._scheduler_config, "cache_memory_mb", None) + enable_mtp = ( + self._scheduler_config.enable_mtp if self._scheduler_config else False + ) + mtp_num_draft = getattr(self._scheduler_config, "mtp_num_draft_tokens", 1) + kv_quant = getattr(self._scheduler_config, "kv_cache_quantization", False) + kv_bits = getattr(self._scheduler_config, "kv_cache_quantization_bits", 8) + kv_group_size = getattr( + self._scheduler_config, "kv_cache_quantization_group_size", 64 + ) + + # Forward MLLM prefill-step override only when explicitly configured. + # This keeps default behavior unchanged for MLLM (1024) unless set. + prefill_step_size = getattr( + self._scheduler_config, "mllm_prefill_step_size", None + ) + mllm_extra = {} + if prefill_step_size is not None: + mllm_extra["prefill_step_size"] = prefill_step_size mllm_config = MLLMSchedulerConfig( max_num_seqs=max_num_seqs, prefill_batch_size=prefill_batch_size, completion_batch_size=completion_batch_size, enable_vision_cache=True, vision_cache_size=100, + cache_memory_mb=cache_memory_mb, + enable_mtp=enable_mtp, + mtp_num_draft_tokens=mtp_num_draft, + kv_cache_quantization=kv_quant, + kv_cache_quantization_bits=kv_bits, + kv_cache_quantization_group_size=kv_group_size, + **mllm_extra, ) # Create and start MLLM scheduler @@ -238,9 +272,58 @@ async def _start_mllm(self) -> None: logger.info( f"MLLM Scheduler started with continuous batching: " f"max_num_seqs={max_num_seqs}, prefill_batch={prefill_batch_size}, " - f"completion_batch={completion_batch_size}" + f"completion_batch={completion_batch_size}, " + f"prefill_step_size={mllm_config.prefill_step_size}" ) + def _inject_mtp_mllm(self) -> None: + """Inject MTP weights into the MLLM model's language_model.""" + import json + from pathlib import Path + + from mlx_lm.utils import _download + + model = self._model + model_path = Path(_download(self._model_name)) + config_path = model_path / "config.json" + if not config_path.exists(): + logger.warning("[MTP-MLLM] No config.json found, skipping MTP") + return + + with open(config_path) as f: + config = json.load(f) + + text_config = config.get("text_config", config) + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + if num_mtp == 0: + num_mtp = text_config.get( + "num_nextn_predict_layers", + config.get("num_nextn_predict_layers", 0), + ) + if num_mtp == 0: + logger.info("[MTP-MLLM] No MTP layers in config, skipping") + return + + # Navigate to text model + text_model = model + if hasattr(model, "language_model"): + text_model = model.language_model + if getattr(text_model, "mtp", None) is not None: + logger.info("[MTP-MLLM] Model already has MTP, skipping injection") + return + + model_type = text_config.get("model_type", config.get("model_type", "")) + if "qwen3_5" in model_type: + from ..patches.qwen3_5_mtp import inject_mtp_support + + ok = inject_mtp_support(model, model_path, config) + if ok: + logger.info("[MTP-MLLM] Qwen3.5 MTP injected successfully") + else: + logger.warning("[MTP-MLLM] Qwen3.5 MTP injection failed") + else: + logger.info(f"[MTP-MLLM] MTP not supported for model_type={model_type}") + async def _start_llm(self) -> None: """Start the LLM engine with AsyncEngineCore.""" from ..engine_core import AsyncEngineCore, EngineConfig @@ -261,9 +344,10 @@ async def _start_llm(self) -> None: # Validate MTP support if enabled if self._scheduler_config and self._scheduler_config.enable_mtp: + from ..patches.qwen3_5_mtp import validate_mtp_support as validate_35 from ..patches.qwen3_next_mtp import validate_mtp_support - if validate_mtp_support(self._model): + if validate_mtp_support(self._model) or validate_35(self._model): logger.info("[MTP] Model validated for MTP speculative decoding") else: logger.warning( @@ -283,13 +367,14 @@ async def _start_llm(self) -> None: device_info.get("memory_size", 0), ) if max_recommended > 0: - soft_limit = int(max_recommended * 0.90) + soft_limit = int(max_recommended * self._gpu_memory_utilization) mx.set_memory_limit(soft_limit) mx.set_cache_limit(32 * 1024 * 1024 * 1024) # 32GB + pct = self._gpu_memory_utilization * 100 logger.info( f"Metal memory limits set: " f"allocation_limit={soft_limit / 1e9:.1f}GB " - f"(90% of {max_recommended / 1e9:.1f}GB), " + f"({pct:.0f}% of {max_recommended / 1e9:.1f}GB), " f"cache_limit=32GB" ) except Exception as e: @@ -301,6 +386,7 @@ async def _start_llm(self) -> None: model_name=self._model_name, scheduler_config=scheduler_config, stream_interval=self._stream_interval, + gpu_memory_utilization=self._gpu_memory_utilization, ) # Create async engine @@ -336,6 +422,7 @@ def _apply_chat_template( tools: list[dict] | None = None, num_images: int = 0, chat_template_kwargs: dict[str, Any] | None = None, + enable_thinking: bool | None = None, ) -> str: """Apply chat template to messages. @@ -364,9 +451,13 @@ def _apply_chat_template( if self._is_mllm and num_images > 0: messages = self._prepare_mllm_messages(messages) + # Per-request enable_thinking override; default: True unless coder model. + if enable_thinking is None: + enable_thinking = "coder" not in self._model_name.lower() template_kwargs = { "tokenize": False, "add_generation_prompt": True, + "enable_thinking": enable_thinking, } if chat_template_kwargs: template_kwargs.update(chat_template_kwargs) @@ -380,7 +471,7 @@ def _apply_chat_template( except TypeError as e: # Some templates don't accept extra kwargs; retry without them. logger.debug(f"Chat template TypeError, retrying without extras: {e}") - for key in ["tools", *(chat_template_kwargs or {}).keys()]: + for key in ["tools", "enable_thinking", *(chat_template_kwargs or {}).keys()]: template_kwargs.pop(key, None) return template_applicator.apply_chat_template( messages, **template_kwargs @@ -466,10 +557,15 @@ async def generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) return GenerationOutput( text=clean_output_text(output.output_text), + tokens=output.output_token_ids, prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, finish_reason=output.finish_reason, @@ -482,6 +578,10 @@ async def generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), stop=stop or [], ) @@ -494,6 +594,7 @@ async def generate( return GenerationOutput( text=text, + tokens=output.output_token_ids, prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, finish_reason=output.finish_reason, @@ -538,6 +639,10 @@ async def stream_generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) async for output in self._mllm_scheduler.stream_outputs(request_id): @@ -558,6 +663,10 @@ async def stream_generate( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), stop=stop or [], ) @@ -624,12 +733,16 @@ async def chat( template_tools = convert_tools_for_template(tools) if tools else None chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) + # Per-request enable_thinking override + enable_thinking = kwargs.pop("enable_thinking", None) + # Apply chat template prompt = self._apply_chat_template( messages, template_tools, num_images=len(all_images), chat_template_kwargs=chat_template_kwargs, + enable_thinking=enable_thinking, ) return await self.generate( @@ -748,12 +861,16 @@ async def stream_chat( template_tools = convert_tools_for_template(tools) if tools else None chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) + # Per-request enable_thinking override + enable_thinking = kwargs.pop("enable_thinking", None) + # Apply chat template prompt = self._apply_chat_template( messages, template_tools, num_images=len(all_images), chat_template_kwargs=chat_template_kwargs, + enable_thinking=enable_thinking, ) # Compute prefix boundary for cache @@ -789,14 +906,27 @@ def get_stats(self) -> dict[str, Any]: if self._mllm_scheduler: mllm_stats = self._mllm_scheduler.get_stats() stats["mllm_scheduler"] = mllm_stats - # Promote Metal memory stats to top-level for /v1/status + # Promote stats to top-level for /v1/status and monitoring for key in ( + "running", + "num_running", + "num_waiting", + "num_requests_processed", + "total_prompt_tokens", + "total_completion_tokens", "metal_active_memory_gb", "metal_peak_memory_gb", "metal_cache_memory_gb", + "memory_aware_cache", + "paged_cache", + "prefix_cache", + "requests", ): if key in mllm_stats: stats[key] = mllm_stats[key] + # MLLM engine is always "running" once loaded + if "running" not in stats: + stats["running"] = self._loaded elif self._engine: stats.update(self._engine.get_stats()) @@ -804,20 +934,28 @@ def get_stats(self) -> dict[str, Any]: def get_cache_stats(self) -> dict[str, Any] | None: """Get cache statistics.""" - if self._mllm_scheduler and self._mllm_scheduler.vision_cache: - return self._mllm_scheduler.vision_cache.get_stats() + if self._mllm_scheduler and self._mllm_scheduler.batch_generator: + return self._mllm_scheduler.batch_generator.get_vision_cache_stats() elif self._engine: return self._engine.get_cache_stats() return None def save_cache_to_disk(self, cache_dir: str) -> bool: """Save prefix cache to disk for persistence across restarts.""" + if self._mllm_scheduler and self._mllm_scheduler.batch_generator: + pc = self._mllm_scheduler.batch_generator.prefix_cache + if pc is not None: + return pc.save_to_disk(cache_dir) if self._engine: return self._engine.save_cache_to_disk(cache_dir) return False def load_cache_from_disk(self, cache_dir: str) -> int: """Load prefix cache from disk. Returns number of entries loaded.""" + if self._mllm_scheduler and self._mllm_scheduler.batch_generator: + pc = self._mllm_scheduler.batch_generator.prefix_cache + if pc is not None: + return pc.load_from_disk(cache_dir) if self._engine: return self._engine.load_cache_from_disk(cache_dir) return 0 diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 768a0f4ad..681118998 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -226,6 +226,24 @@ async def stop(self) -> None: self._system_kv_token_count = 0 logger.info("SimpleEngine stopped") + async def _run_blocking_serialized(self, func, /, *args, **kwargs): + """Run a blocking MLX operation under the generation lock. + + Cancellation must not release the async lock before the worker thread + finishes, or a follow-up request can enter MLX/Metal concurrently and + corrupt the command-buffer state. + """ + async with self._generation_lock: + task = asyncio.create_task(asyncio.to_thread(func, *args, **kwargs)) + try: + return await asyncio.shield(task) + except asyncio.CancelledError: + try: + await task + except BaseException: + pass + raise + async def generate( self, prompt: str, @@ -238,13 +256,27 @@ async def generate( """ Generate a complete response (non-streaming). + Thin accumulator over stream_generate(). stream_generate() is the + only code path that consumes per-request SpecPrefill overrides + (`specprefill`, `specprefill_keep_pct`) and routes through + _stream_generate_specprefill() when engaged. The prior direct + self._model.generate() path silently dropped those overrides for + non-streaming /v1/completions callers, so extra_body.specprefill + was advertised by the server but had no effect on this route. + + By iterating stream_generate() and returning the last + GenerationOutput, non-streaming clients get the same SpecPrefill + engagement, accurate prompt_tokens reporting, and per-request + override support as streaming clients. + Args: prompt: Input text max_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Top-p sampling stop: Stop sequences - **kwargs: Additional model-specific parameters + **kwargs: Additional parameters forwarded to stream_generate, + including per-request `specprefill` / `specprefill_keep_pct` Returns: GenerationOutput with complete text @@ -252,30 +284,29 @@ async def generate( if not self._loaded: await self.start() - async with self._generation_lock: - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.generate, - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - **kwargs, - ) - - # Clean output text - text = clean_output_text(output.text) - - return GenerationOutput( - text=text, - tokens=getattr(output, "tokens", []), - prompt_tokens=getattr(output, "prompt_tokens", 0), - completion_tokens=getattr( - output, "completion_tokens", len(getattr(output, "tokens", [])) - ), - finish_reason=output.finish_reason, - ) + last_output: GenerationOutput | None = None + async for output in self.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ): + last_output = output + + if last_output is None: + return GenerationOutput(text="", finish_reason="stop") + + text = clean_output_text(last_output.text) + return GenerationOutput( + text=text, + tokens=list(last_output.tokens), + prompt_tokens=last_output.prompt_tokens, + completion_tokens=last_output.completion_tokens, + finish_reason=last_output.finish_reason, + finished=True, + ) async def stream_generate( self, @@ -439,61 +470,84 @@ async def chat( chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) + # mlx-lm non-streaming chat with tools can stall indefinitely on some + # local models, while the streaming path completes normally. Reuse the + # streaming implementation and aggregate its final state so both chat + # APIs share the same tool-capable execution path. + if tools and not self._is_mllm: + final_output = GenerationOutput(text="") + async for output in self.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + chat_template_kwargs=chat_template_kwargs, + **kwargs, + ): + final_output = output + text = clean_output_text(final_output.text) + return GenerationOutput( + text=text, + tokens=list(final_output.tokens), + prompt_tokens=final_output.prompt_tokens, + completion_tokens=final_output.completion_tokens, + finish_reason=final_output.finish_reason, + ) + # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None - async with self._generation_lock: - if self._is_mllm: - # For MLLM, use the chat method which handles images/videos - # Run in thread pool to allow asyncio timeout to work - if chat_template_kwargs: - kwargs["chat_template_kwargs"] = chat_template_kwargs - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, - ) - text = clean_output_text(output.text) - return GenerationOutput( - text=text, - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - finish_reason=output.finish_reason, - ) - else: - # For LLM, use the chat method - # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - tools=template_tools, - chat_template_kwargs=chat_template_kwargs, - **kwargs, - ) - text = clean_output_text(output.text) - # Count prompt tokens from the full templated prompt - tokenizer = self._model.tokenizer - template_kwargs = { - "tokenize": True, - "add_generation_prompt": True, - } - if template_tools: - template_kwargs["tools"] = template_tools - prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs) - prompt_token_count = len(prompt_ids) - return GenerationOutput( - text=text, - tokens=output.tokens, - prompt_tokens=prompt_token_count, - completion_tokens=len(output.tokens), - finish_reason=output.finish_reason, - ) + if self._is_mllm: + if chat_template_kwargs: + kwargs["chat_template_kwargs"] = chat_template_kwargs + output = await self._run_blocking_serialized( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, + ) + text = clean_output_text(output.text) + return GenerationOutput( + text=text, + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + finish_reason=output.finish_reason, + ) + else: + output = await self._run_blocking_serialized( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=template_tools, + chat_template_kwargs=chat_template_kwargs, + **kwargs, + ) + text = clean_output_text(output.text) + # Preserve upstream prompt accounting while routing the blocking + # chat call through the cancellation-safe serialized runner. + tokenizer = self._model.tokenizer + template_kwargs = { + "tokenize": True, + "add_generation_prompt": True, + } + if template_tools: + template_kwargs["tools"] = template_tools + prompt_ids = tokenizer.apply_chat_template(messages, **template_kwargs) + prompt_token_count = len(prompt_ids) + return GenerationOutput( + text=text, + tokens=output.tokens, + prompt_tokens=prompt_token_count, + completion_tokens=len(output.tokens), + finish_reason=output.finish_reason, + ) async def stream_chat( self, @@ -557,53 +611,53 @@ async def stream_chat( # For MLLM, use stream_chat which yields tokens incrementally. # Must hold _generation_lock to prevent concurrent Metal access # (e.g. OpenCode sends title + main request simultaneously). - async with self._generation_lock: - accumulated_text = "" - token_count = 0 - - # Run stream_chat in thread pool since it's synchronous - def run_stream(): - local_kwargs = dict(kwargs) - if chat_template_kwargs: - local_kwargs["chat_template_kwargs"] = chat_template_kwargs - return list( - self._model.stream_chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **local_kwargs, - ) + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + local_kwargs = dict(kwargs) + if chat_template_kwargs: + local_kwargs["chat_template_kwargs"] = chat_template_kwargs + return list( + self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **local_kwargs, ) + ) - chunks = await asyncio.to_thread(run_stream) + chunks = await self._run_blocking_serialized(run_stream) - for chunk in chunks: - token_count += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text - finished = chunk.finish_reason is not None + finished = chunk.finish_reason is not None - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=getattr(chunk, "prompt_tokens", 0), - completion_tokens=token_count, - finished=finished, - finish_reason=chunk.finish_reason if finished else None, - ) + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, "prompt_tokens", 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) - if finished: - break + if finished: + break return # For LLM, apply chat template and stream tokenizer = self._model.tokenizer if hasattr(tokenizer, "apply_chat_template"): - # Disable thinking mode for coder models since it interferes - # with tool call parsing (tags leak as raw text). - enable_thinking = "coder" not in self._model_name.lower() + # Per-request enable_thinking override; default: True unless coder model. + enable_thinking = kwargs.pop("enable_thinking", None) + if enable_thinking is None: + enable_thinking = "coder" not in self._model_name.lower() template_kwargs = { "tokenize": False, "add_generation_prompt": True, @@ -661,129 +715,125 @@ async def _stream_generate_specprefill( tokenizer = self._model.tokenizer n_tokens = len(tokens) - async with self._generation_lock: - - def _run_all(): - try: - return _run_specprefill() - except Exception as e: - logger.error( - "SpecPrefill failed, falling back to normal path: %s", e - ) - return _run_normal() + def _run_all(): + try: + return _run_specprefill() + except Exception as e: + logger.error("SpecPrefill failed, falling back to normal path: %s", e) + return _run_normal() + + def _run_specprefill(): + """Score tokens, sparse prefill, generate autoregressively.""" + import time + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) - def _run_specprefill(): - """Score tokens, sparse prefill, generate autoregressively.""" - import time - from types import SimpleNamespace + cache = make_prompt_cache(model) - from ..specprefill import ( - cleanup_rope, - score_tokens, - select_chunks, - sparse_prefill, + try: + # Phase 1: Score with draft model + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + tokens, + prefill_step_size=self._prefill_step_size, ) + t_score = time.monotonic() - t0 - cache = make_prompt_cache(model) + # Phase 2: Select important chunks + effective_keep = specprefill_keep_pct or self._specprefill_keep_pct + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] - try: - # Phase 1: Score with draft model - t0 = time.monotonic() - importance = score_tokens( - self._draft_model, - tokens, - prefill_step_size=self._prefill_step_size, - ) - t_score = time.monotonic() - t0 - - # Phase 2: Select important chunks - effective_keep = specprefill_keep_pct or self._specprefill_keep_pct - selected = select_chunks(importance, keep_pct=effective_keep) - n_selected = selected.shape[0] - - # Phase 3: Sparse prefill on target model - t0 = time.monotonic() - logits = sparse_prefill( - model, - tokens, - selected, - cache, - step_size=self._prefill_step_size, - ) - t_prefill = time.monotonic() - t0 + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + model, + tokens, + selected, + cache, + step_size=self._prefill_step_size, + ) + t_prefill = time.monotonic() - t0 - logger.info( - "SpecPrefill: scored %d tokens in %.1fs, " - "sparse prefill %d/%d (keep=%.0f%%) in %.1fs", - n_tokens, - t_score, - n_selected, - n_tokens, - n_selected / n_tokens * 100, - t_prefill, - ) + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs", + n_tokens, + t_score, + n_selected, + n_tokens, + n_selected / n_tokens * 100, + t_prefill, + ) - # Phase 4: Generate (simple autoregressive, no MTP) - sampler = make_sampler(temp=temperature, top_p=top_p) - eos_id = tokenizer.eos_token_id - y = sampler(logits[:, -1, :]) - mx.eval(y) + # Phase 4: Generate (simple autoregressive, no MTP) + sampler = make_sampler(temp=temperature, top_p=top_p) + eos_id = tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) - results = [] - generated_ids = [] - prev_decoded = "" + results = [] + generated_ids = [] + prev_decoded = "" - for _ in range(max_tokens): - tok_id = y.item() - generated_ids.append(tok_id) + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) - decoded = tokenizer.decode(generated_ids) - new_text = decoded[len(prev_decoded) :] - prev_decoded = decoded + decoded = tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded - is_eos = tok_id == eos_id - results.append( - SimpleNamespace( - text=new_text, - finish_reason="stop" if is_eos else None, - ) + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, ) + ) - if is_eos: - break + if is_eos: + break - logits = model(y.reshape(1, -1), cache=cache) - y = sampler(logits[:, -1, :]) - mx.eval(y) + logits = model(y.reshape(1, -1), cache=cache) + y = sampler(logits[:, -1, :]) + mx.eval(y) - return results + return results - finally: - cleanup_rope(model) + finally: + cleanup_rope(model) - def _run_normal(): - """Fallback: normal generation without specprefill.""" - from types import SimpleNamespace + def _run_normal(): + """Fallback: normal generation without specprefill.""" + from types import SimpleNamespace - results = [] - for chunk in self._model.stream_generate( - prompt=prompt, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - **kwargs, - ): - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - results.append( - SimpleNamespace( - text=new_text, - finish_reason=getattr(chunk, "finish_reason", None), - ) + results = [] + for chunk in self._model.stream_generate( + prompt=prompt, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + **kwargs, + ): + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + results.append( + SimpleNamespace( + text=new_text, + finish_reason=getattr(chunk, "finish_reason", None), ) - return results + ) + return results - all_resps = await asyncio.to_thread(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" @@ -850,9 +900,11 @@ async def _stream_generate_text( specprefill_keep_pct = kwargs.pop("specprefill_keep_pct", None) chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) - # Read enable_thinking from env (set by runtime_patches, consistent with MLLM path) - enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") - enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") + # Per-request enable_thinking override; fall back to env var / default True. + enable_thinking = kwargs.pop("enable_thinking", None) + if enable_thinking is None: + enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") + enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") # Apply chat template for full prompt template_kwargs = { @@ -1026,194 +1078,192 @@ async def _stream_generate_text( ) use_specprefill = False - # Run under generation lock, all Metal ops in single thread - async with self._generation_lock: + # Run all Metal ops in a single serialized thread. + def _run_all(): + nonlocal backbone_cache, prompt_to_send - def _run_all(): - nonlocal backbone_cache, prompt_to_send + model = self._text_model - model = self._text_model + # Cache MISS with valid prefix: prefill system tokens and snapshot + if ( + not cache_hit + and system_token_count > 0 + and system_tokens is not None + and suffix_tokens is not None + ): + mc = make_prompt_cache(model) + sys_arr = mx.array(system_tokens) + + # Prefill system tokens in chunks (matching generate_step) + step = self._prefill_step_size + while sys_arr.size > step: + model(sys_arr[:step][None], cache=mc) + mx.eval([c.state for c in mc]) + sys_arr = sys_arr[step:] + mx.clear_cache() + if sys_arr.size > 0: + model(sys_arr[None], cache=mc) + mx.eval([c.state for c in mc]) + + # Snapshot backbone cache (immutable mx.arrays, safe to reuse) + snapshot = [c.state for c in mc] + mx.eval([s for pair in snapshot for s in pair]) + + self._system_kv_snapshot = snapshot + self._system_kv_hash = system_hash + self._system_kv_token_count = system_token_count + + backbone_cache = mc + prompt_to_send = mx.array(suffix_tokens) + logger.info( + "System KV cache: stored %d-token snapshot (%.1f MB), " + "prefilling %d remaining", + system_token_count, + sum(c.nbytes for c in mc) / 1e6, + len(suffix_tokens), + ) - # Cache MISS with valid prefix: prefill system tokens and snapshot - if ( - not cache_hit - and system_token_count > 0 - and system_tokens is not None - and suffix_tokens is not None - ): - mc = make_prompt_cache(model) - sys_arr = mx.array(system_tokens) - - # Prefill system tokens in chunks (matching generate_step) - step = self._prefill_step_size - while sys_arr.size > step: - model(sys_arr[:step][None], cache=mc) - mx.eval([c.state for c in mc]) - sys_arr = sys_arr[step:] - mx.clear_cache() - if sys_arr.size > 0: - model(sys_arr[None], cache=mc) - mx.eval([c.state for c in mc]) - - # Snapshot backbone cache (immutable mx.arrays, safe to reuse) - snapshot = [c.state for c in mc] - mx.eval([s for pair in snapshot for s in pair]) - - self._system_kv_snapshot = snapshot - self._system_kv_hash = system_hash - self._system_kv_token_count = system_token_count - - backbone_cache = mc - prompt_to_send = mx.array(suffix_tokens) - logger.info( - "System KV cache: stored %d-token snapshot (%.1f MB), " - "prefilling %d remaining", - system_token_count, - sum(c.nbytes for c in mc) / 1e6, - len(suffix_tokens), + # --- SpecPrefill path (with fallback to normal on failure) --- + if use_specprefill: + try: + return _run_specprefill(model, backbone_cache) + except Exception as e: + logger.error( + "SpecPrefill failed, falling back to normal MTP path: %s", + e, ) + # Discard potentially corrupted cache + backbone_cache = None + prompt_to_send = full_prompt + + # --- Normal path (MTP via mlx_lm stream_generate) --- + prompt_cache = None + if backbone_cache is not None: + # Add MTP cache on top of backbone + if hasattr(model, "make_mtp_cache"): + mtp_cache = model.make_mtp_cache() + prompt_cache = backbone_cache + mtp_cache + else: + prompt_cache = backbone_cache - # --- SpecPrefill path (with fallback to normal on failure) --- - if use_specprefill: - try: - return _run_specprefill(model, backbone_cache) - except Exception as e: - logger.error( - "SpecPrefill failed, falling back to normal MTP path: %s", - e, - ) - # Discard potentially corrupted cache - backbone_cache = None - prompt_to_send = full_prompt - - # --- Normal path (MTP via mlx_lm stream_generate) --- - prompt_cache = None - if backbone_cache is not None: - # Add MTP cache on top of backbone - if hasattr(model, "make_mtp_cache"): - mtp_cache = model.make_mtp_cache() - prompt_cache = backbone_cache + mtp_cache - else: - prompt_cache = backbone_cache + results = [] + gen_kwargs = dict( + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + prefill_step_size=self._prefill_step_size, + ) + if prompt_cache is not None: + gen_kwargs["prompt_cache"] = prompt_cache + + for resp in mlx_stream_generate( + model, + self._text_tokenizer, + prompt=prompt_to_send, + **gen_kwargs, + ): + results.append(resp) + return results + + def _run_specprefill(model, bc): + """Score tokens, sparse prefill, generate without MTP.""" + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) - results = [] - gen_kwargs = dict( - max_tokens=max_tokens, - sampler=sampler, - mtp=True, + # Create backbone cache if not already from system KV + if bc is None: + bc = make_prompt_cache(model) + + try: + # Phase 1: Score with draft model + import time + + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + specprefill_tokens, prefill_step_size=self._prefill_step_size, ) - if prompt_cache is not None: - gen_kwargs["prompt_cache"] = prompt_cache + t_score = time.monotonic() - t0 - for resp in mlx_stream_generate( - model, - self._text_tokenizer, - prompt=prompt_to_send, - **gen_kwargs, - ): - results.append(resp) - return results + # Phase 2: Select important chunks + effective_keep = specprefill_keep_pct or self._specprefill_keep_pct + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] + n_total = len(specprefill_tokens) - def _run_specprefill(model, bc): - """Score tokens, sparse prefill, generate without MTP.""" - from types import SimpleNamespace + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + model, + specprefill_tokens, + selected, + bc, + step_size=self._prefill_step_size, + position_offset=specprefill_offset, + ) + t_prefill = time.monotonic() - t0 - from ..specprefill import ( - cleanup_rope, - score_tokens, - select_chunks, - sparse_prefill, + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " + "(offset=%d, effective_keep=%.2f)", + n_total, + t_score, + n_selected, + n_total, + n_selected / n_total * 100, + t_prefill, + specprefill_offset, + effective_keep, ) - # Create backbone cache if not already from system KV - if bc is None: - bc = make_prompt_cache(model) + # Phase 4: Generate (simple autoregressive, no MTP) + eos_id = self._text_tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) - try: - # Phase 1: Score with draft model - import time - - t0 = time.monotonic() - importance = score_tokens( - self._draft_model, - specprefill_tokens, - prefill_step_size=self._prefill_step_size, - ) - t_score = time.monotonic() - t0 - - # Phase 2: Select important chunks - effective_keep = specprefill_keep_pct or self._specprefill_keep_pct - selected = select_chunks(importance, keep_pct=effective_keep) - n_selected = selected.shape[0] - n_total = len(specprefill_tokens) - - # Phase 3: Sparse prefill on target model - t0 = time.monotonic() - logits = sparse_prefill( - model, - specprefill_tokens, - selected, - bc, - step_size=self._prefill_step_size, - position_offset=specprefill_offset, - ) - t_prefill = time.monotonic() - t0 + results = [] + generated_ids = [] + prev_decoded = "" - logger.info( - "SpecPrefill: scored %d tokens in %.1fs, " - "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " - "(offset=%d, effective_keep=%.2f)", - n_total, - t_score, - n_selected, - n_total, - n_selected / n_total * 100, - t_prefill, - specprefill_offset, - effective_keep, - ) + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) - # Phase 4: Generate (simple autoregressive, no MTP) - eos_id = self._text_tokenizer.eos_token_id - y = sampler(logits[:, -1, :]) - mx.eval(y) + # Incremental text decode + decoded = self._text_tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded - results = [] - generated_ids = [] - prev_decoded = "" - - for _ in range(max_tokens): - tok_id = y.item() - generated_ids.append(tok_id) - - # Incremental text decode - decoded = self._text_tokenizer.decode(generated_ids) - new_text = decoded[len(prev_decoded) :] - prev_decoded = decoded - - is_eos = tok_id == eos_id - results.append( - SimpleNamespace( - text=new_text, - finish_reason="stop" if is_eos else None, - ) + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, ) + ) - if is_eos: - break + if is_eos: + break - # Next token - logits = model(y.reshape(1, -1), cache=bc) - y = sampler(logits[:, -1, :]) - mx.eval(y) + # Next token + logits = model(y.reshape(1, -1), cache=bc) + y = sampler(logits[:, -1, :]) + mx.eval(y) - return results + return results - finally: - cleanup_rope(model) + finally: + cleanup_rope(model) - all_resps = await asyncio.to_thread(_run_all) + all_resps = await self._run_blocking_serialized(_run_all) # Yield results as GenerationOutput accumulated_text = "" diff --git a/vllm_mlx/engine_core.py b/vllm_mlx/engine_core.py index d21928824..ae75fd39e 100644 --- a/vllm_mlx/engine_core.py +++ b/vllm_mlx/engine_core.py @@ -36,6 +36,7 @@ class EngineConfig: scheduler_config: Optional[SchedulerConfig] = None step_interval: float = 0.001 # 1ms between steps stream_interval: int = 1 # Tokens to batch before streaming (1=every token) + gpu_memory_utilization: float = 0.90 # Fraction of device memory for allocation class EngineCore: @@ -150,18 +151,12 @@ async def _engine_loop(self) -> None: stream_interval = self.config.stream_interval use_simple_streaming = stream_interval == 1 - # Emergency memory pressure threshold — use 85% of Metal's - # max recommended working set so this scales with system RAM. + # Emergency memory pressure threshold — dynamic based on gpu_memory_utilization + _gpu_mem_util = self.config.gpu_memory_utilization try: - _device_info = mx.device_info() - _max_recommended = _device_info.get( - "max_recommended_working_set_size", - _device_info.get("memory_size", 0), - ) - _memory_pressure_threshold = ( - int(_max_recommended * 0.85) - if _max_recommended > 0 - else 200 * 1024 * 1024 * 1024 + _device_mem = mx.device_info().get("memory_size", 200 * 1024 * 1024 * 1024) + _memory_pressure_threshold = int( + _device_mem * min(_gpu_mem_util + 0.05, 0.99) ) except Exception: _memory_pressure_threshold = 200 * 1024 * 1024 * 1024 diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index f43763541..2668c3cec 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -255,44 +255,121 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry: def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: - """Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced. + """Create copies of cache layers with the last ``trim_by`` positions removed. This is used when returning a cached KV state to the scheduler so that the last N positions are "freed" and the model will recompute them on the next forward pass (preventing duplicate KV entries). - Supports both KVCache (keys/values are arrays) and QuantizedKVCache - (keys/values are 3-tuples of arrays). - """ - from mlx_lm.models.cache import KVCache + For plain KVCache: reduces offset (surplus data beyond offset is harmless + since merge slices to ``keys[:, :, :offset, :]``). - try: - from mlx_lm.models.cache import QuantizedKVCache - except ImportError: - QuantizedKVCache = None # noqa: N806 + For RotatingKVCache: actually trims the circular buffer — reducing offset + alone breaks ``size()`` / ``_temporal_order`` invariants. + + Supports KVCache, RotatingKVCache, and _QuantizedCacheWrapper. + """ + import mlx.core as mx + from mlx_lm.models.cache import RotatingKVCache trimmed: list[Any] = [] + eval_targets: list[Any] = [] for layer_cache in cache: - if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache): - tc = QuantizedKVCache.__new__(QuantizedKVCache) + if isinstance(layer_cache, _QuantizedCacheWrapper): + # Shallow copy with reduced offset + tc = _QuantizedCacheWrapper.__new__(_QuantizedCacheWrapper) tc.keys = layer_cache.keys tc.values = layer_cache.values tc.offset = max(layer_cache.offset - trim_by, 0) - tc.group_size = layer_cache.group_size tc.bits = layer_cache.bits + tc.group_size = layer_cache.group_size + tc.orig_type = layer_cache.orig_type + tc.orig_attrs = layer_cache.orig_attrs + trimmed.append(tc) + elif isinstance(layer_cache, RotatingKVCache): + if layer_cache.keys is None or trim_by <= 0: + trimmed.append(layer_cache) + continue + # RotatingKVCache: must trim buffer, not just offset. + # The buffer stores the last min(offset, max_size) tokens in a + # circular arrangement. Trimming excess positions from the END + # means removing the newest entries (chronologically last). + old_offset = layer_cache.offset + new_offset = max(old_offset - trim_by, 0) + old_size = min(old_offset, layer_cache.max_size) + entries_to_keep = max(0, old_size - trim_by) + + orig_cls = type(layer_cache) + tc = orig_cls.__new__(orig_cls) + tc.offset = new_offset + tc.max_size = layer_cache.max_size + tc.keep = getattr(layer_cache, "keep", 0) + tc.step = getattr(layer_cache, "step", layer_cache.max_size) + + if entries_to_keep <= 0: + # All buffer content is beyond the trim point — clear + tc.keys = None + tc.values = None + tc._idx = 0 + elif entries_to_keep < old_size: + # Reorder to temporal order, keep the oldest entries + ordered_k = layer_cache._temporal_order(layer_cache.keys) + ordered_v = layer_cache._temporal_order(layer_cache.values) + kept_k = ordered_k[:, :, :entries_to_keep, :] + kept_v = ordered_v[:, :, :entries_to_keep, :] + + if new_offset >= tc.max_size: + # Invariant: when offset >= max_size, buffer must be + # full (keys.shape[2] == max_size). Left-pad with + # zeros to restore the full buffer. Zeros represent + # positions evicted long ago; _idx = max_size so + # _temporal_order returns as-is and _update_in_place + # rotates to overwrite zeros first. + pad_n = tc.max_size - entries_to_keep + pad_k = mx.zeros( + (kept_k.shape[0], kept_k.shape[1], pad_n, kept_k.shape[3]), + dtype=kept_k.dtype, + ) + pad_v = mx.zeros( + (kept_v.shape[0], kept_v.shape[1], pad_n, kept_v.shape[3]), + dtype=kept_v.dtype, + ) + tc.keys = mx.concatenate([pad_k, kept_k], axis=2) + tc.values = mx.concatenate([pad_v, kept_v], axis=2) + tc._idx = tc.max_size + else: + tc.keys = kept_k + tc.values = kept_v + tc._idx = entries_to_keep + eval_targets.extend([tc.keys, tc.values]) + else: + # No entries removed (trim_by == 0 already handled above, + # this covers entries_to_keep == old_size edge case) + tc.keys = layer_cache.keys + tc.values = layer_cache.values + tc._idx = layer_cache._idx trimmed.append(tc) elif ( hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys") and not isinstance(layer_cache.keys, (list, tuple)) ): - tc = KVCache.__new__(KVCache) + orig_cls = type(layer_cache) + tc = orig_cls.__new__(orig_cls) tc.keys = layer_cache.keys tc.values = layer_cache.values tc.offset = max(layer_cache.offset - trim_by, 0) + # Preserve type-specific attrs (max_size, keep, step, _idx) + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer_cache, attr): + setattr(tc, attr, getattr(layer_cache, attr)) trimmed.append(tc) else: trimmed.append(layer_cache) + + if eval_targets: + mx.eval(*eval_targets) + return trimmed @@ -353,28 +430,72 @@ def _trim_to_offset(cache: list[Any]) -> list[Any]: return trimmed +class _QuantizedCacheWrapper: + """Lightweight wrapper storing quantized KV arrays + original cache metadata. + + Unlike ``QuantizedKVCache``, this preserves enough info to reconstruct + the *original* cache type (KVCache, RotatingKVCache, etc.) on dequantize. + """ + + __slots__ = ( + "keys", + "values", + "offset", + "bits", + "group_size", + "orig_type", + "orig_attrs", + ) + + def __init__(self, layer: Any, bits: int, group_size: int): + import mlx.core as mx + + self.keys = mx.quantize(layer.keys, group_size=group_size, bits=bits) + self.values = mx.quantize(layer.values, group_size=group_size, bits=bits) + self.offset = layer.offset + self.bits = bits + self.group_size = group_size + self.orig_type = type(layer) + # Preserve RotatingKVCache-specific attrs + self.orig_attrs = {} + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer, attr): + self.orig_attrs[attr] = getattr(layer, attr) + + def _quantize_cache(cache: list[Any], bits: int = 8, group_size: int = 64) -> list[Any]: - """Quantize KVCache layers to reduce memory. Non-KVCache layers are kept as-is.""" + """Quantize KV cache layers to reduce memory. + + Only plain KVCache layers are quantized. RotatingKVCache (sliding window) + is left as-is because its internal _idx/rotation state is tightly coupled + with update_and_fetch logic and cannot survive quantize/dequantize roundtrip. + RotatingKVCache is typically small (max_size=1024) so skipping it is fine. + """ from mlx_lm.models.cache import KVCache quantized = [] for layer in cache: - if isinstance(layer, KVCache) and layer.keys is not None: - quantized.append(layer.to_quantized(group_size=group_size, bits=bits)) + if type(layer) is KVCache and getattr(layer, "keys", None) is not None: + quantized.append(_QuantizedCacheWrapper(layer, bits, group_size)) else: quantized.append(layer) return quantized def _dequantize_cache(cache: list[Any]) -> list[Any]: - """Dequantize QuantizedKVCache layers back to regular KVCache.""" + """Dequantize _QuantizedCacheWrapper layers and copy non-quantized layers. + + All layers are copied (never returned by reference) so that the model's + ``update_and_fetch`` mutations don't corrupt the stored cache entry. + """ import mlx.core as mx - from mlx_lm.models.cache import KVCache, QuantizedKVCache result = [] for layer in cache: - if isinstance(layer, QuantizedKVCache) and layer.keys is not None: - kv = KVCache() + if isinstance(layer, _QuantizedCacheWrapper): + # Reconstruct original cache type from quantized data + orig_cls = layer.orig_type + kv = orig_cls.__new__(orig_cls) kv.keys = mx.dequantize( *layer.keys, group_size=layer.group_size, bits=layer.bits ) @@ -382,6 +503,21 @@ def _dequantize_cache(cache: list[Any]) -> list[Any]: *layer.values, group_size=layer.group_size, bits=layer.bits ) kv.offset = layer.offset + # Restore type-specific attrs (max_size, keep, step, _idx) + for attr, val in layer.orig_attrs.items(): + setattr(kv, attr, val) + result.append(kv) + elif hasattr(layer, "keys") and hasattr(layer, "offset"): + # Deep-copy non-quantized cache layers (e.g. RotatingKVCache) + # so model's in-place mutations don't corrupt stored entries + orig_cls = type(layer) + kv = orig_cls.__new__(orig_cls) + kv.keys = mx.array(layer.keys) if layer.keys is not None else None + kv.values = mx.array(layer.values) if layer.values is not None else None + kv.offset = layer.offset + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer, attr): + setattr(kv, attr, getattr(layer, attr)) result.append(kv) else: result.append(layer) @@ -635,7 +771,15 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: f"layer_types={[type(lc).__name__ for lc in best_lcp_entry.cache[:3]]}" ) - if not has_non_trimmable: + if has_non_trimmable: + # Hybrid model (SSM+Attention): SSM state can't be rewound. + # Block LCP for hybrid models — use think-suffix stripping + # in the engine layer to get clean PREFIX matches instead. + logger.debug( + "[cache_fetch] LCP skipped: non-trimmable cache layers " + "(hybrid model, SSM state can't be rewound)" + ) + else: trimmed_cache = _trim_cache_offset(best_lcp_entry.cache, excess) self._entries.move_to_end(best_lcp_entry.tokens) self._stats.hits += 1 diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index ee8d8da7b..a6a59afba 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -24,12 +24,21 @@ import mlx.core as mx import mlx.nn as nn +from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig, _trim_cache_offset from .multimodal_processor import MultimodalProcessor from .vision_embedding_cache import VisionEmbeddingCache logger = logging.getLogger(__name__) +class PrefillAbortedError(Exception): + """Raised when a prefill is aborted due to client disconnect.""" + + def __init__(self, request_id: str): + self.request_id = request_id + super().__init__(f"Prefill aborted for request {request_id}") + + @dataclass class MLLMBatchRequest: """ @@ -47,6 +56,10 @@ class MLLMBatchRequest: max_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.9 + top_k: int = 0 + min_p: float = 0.0 + presence_penalty: float = 0.0 + repetition_penalty: float = 1.0 # Processed inputs (set after vision preprocessing) input_ids: Optional[mx.array] = None @@ -55,6 +68,9 @@ class MLLMBatchRequest: image_grid_thw: Optional[mx.array] = None extra_kwargs: Dict[str, Any] = field(default_factory=dict) + # Text-only flag (no images/videos — eligible for prefix cache) + is_text_only: bool = False + # Generation state num_tokens: int = 0 # Tokens generated so far output_tokens: List[int] = field(default_factory=list) @@ -98,6 +114,8 @@ class MLLMBatch: num_tokens: List[int] # Tokens generated per request cache: List[Any] # BatchKVCache for language model requests: List[MLLMBatchRequest] # Full request data + logits_processors: Optional[List[Optional[List[Callable]]]] = None + samplers: Optional[List[Optional[Callable]]] = None def __len__(self) -> int: return len(self.uids) @@ -115,6 +133,10 @@ def filter(self, keep_idx: List[int]) -> None: self.max_tokens = [self.max_tokens[k] for k in keep_idx] self.num_tokens = [self.num_tokens[k] for k in keep_idx] self.requests = [self.requests[k] for k in keep_idx] + if self.logits_processors is not None: + self.logits_processors = [self.logits_processors[k] for k in keep_idx] + if self.samplers is not None: + self.samplers = [self.samplers[k] for k in keep_idx] keep_idx_array = mx.array(keep_idx, mx.int32) self.y = self.y[keep_idx_array] @@ -139,32 +161,73 @@ def extend(self, other: "MLLMBatch") -> None: self.max_tokens.extend(other.max_tokens) self.requests.extend(other.requests) - # Extend cache - handle None and incompatible caches + # Extend logits_processors + if self.logits_processors is not None or other.logits_processors is not None: + # At this point self.uids already includes other.uids from extend above + self_len = len(self.uids) - len(other.uids) + self_lp = self.logits_processors or [None] * self_len + other_lp = other.logits_processors or [None] * len(other.uids) + self.logits_processors = list(self_lp) + list(other_lp) + + # Extend samplers + if self.samplers is not None or other.samplers is not None: + self_len = len(self.uids) - len(other.uids) + self_s = self.samplers or [None] * self_len + other_s = other.samplers or [None] * len(other.uids) + self.samplers = list(self_s) + list(other_s) + + # Extend cache - handle both BatchKVCache (.keys/.values) and + # ArraysCache (.cache list) from hybrid models like Qwen3.5 for c, o in zip(self.cache, other.cache): if c is not None and o is not None and hasattr(c, "extend"): try: - # Only extend if both caches have valid keys - if ( - hasattr(c, "keys") - and c.keys is not None - and hasattr(o, "keys") - and o.keys is not None - ): + has_kv = hasattr(c, "keys") and c.keys is not None + has_arrays = hasattr(c, "cache") + if has_kv or has_arrays: c.extend(o) except Exception as e: logger.warning(f"Failed to extend cache: {e}") def extract_cache(self, idx: int) -> List[Any]: """ - Extract cache for a single request (for caching). - - Args: - idx: Index of request in batch + Extract cache for a single request (for prefix caching). - Returns: - Cache state for that request + Handles BatchRotatingKVCache negative left_padding bug: + during generation with rotation, left_padding becomes negative, + causing extract() to use Python negative indexing and truncate + the buffer to only generation tokens instead of the full window. """ - return [c.extract(idx) if hasattr(c, "extract") else None for c in self.cache] + from mlx_lm.models.cache import ( + BatchRotatingKVCache, + RotatingKVCache, + ) + + result = [] + for c in self.cache: + if not hasattr(c, "extract"): + result.append(None) + elif isinstance(c, BatchRotatingKVCache): + # Custom extraction: clamp left_padding to >= 0 + cache = RotatingKVCache(c.max_size) + padding = max(0, c.left_padding[idx].item()) + offset = c.offset[idx].item() + cache.keys = c.keys[idx : idx + 1] + cache.values = c.values[idx : idx + 1] + cache._idx = c._idx + if c.rotated: + cache.keys = mx.roll(cache.keys, -c._idx, axis=2) + cache.values = mx.roll(cache.values, -c._idx, axis=2) + cache._idx = c.max_size + cache.keys = mx.contiguous(cache.keys[:, :, padding : cache._idx]) + cache.values = mx.contiguous(cache.values[:, :, padding : cache._idx]) + cache.offset = offset + cache._idx = cache.keys.shape[2] + cache.step = getattr(c, "step", c.max_size) + cache.keep = getattr(c, "keep", 0) + result.append(cache) + else: + result.append(c.extract(idx)) + return result class MLLMBatchStats: @@ -205,32 +268,6 @@ def to_dict(self) -> Dict[str, Any]: } -def _make_batch_cache(model: nn.Module, left_padding: List[int]) -> List[Any]: - """ - Create batch-aware KV cache for the language model. - - Args: - model: The language model (model.language_model from VLM) - left_padding: Padding amounts for left-padded prompts - - Returns: - List of BatchKVCache objects for each layer - """ - from mlx_lm.models.cache import BatchKVCache, KVCache - - def to_batch_cache(c): - if isinstance(c, KVCache): - return BatchKVCache(left_padding) - else: - raise ValueError(f"{type(c)} does not yet support batching") - - if hasattr(model, "make_cache"): - cache = model.make_cache() - return [to_batch_cache(c) for c in cache] - else: - return [BatchKVCache(left_padding) for _ in model.layers] - - def _left_pad_prompts( prompts: List[List[int]], max_length: Optional[int] = None ) -> mx.array: @@ -289,6 +326,7 @@ def __init__( prefill_step_size: int = 1024, enable_vision_cache: bool = True, vision_cache_size: int = 100, + prefix_cache_config: Optional[MemoryCacheConfig] = None, ): """ Initialize MLLM batch generator. @@ -305,6 +343,7 @@ def __init__( prefill_step_size: Tokens to process per prefill step enable_vision_cache: Enable vision embedding caching vision_cache_size: Max entries in vision cache + prefix_cache_config: Config for KV prefix cache (text-only requests) """ self.model = model self.processor = processor @@ -324,6 +363,13 @@ def __init__( "MLLMBatchGenerator: Model does not have language_model, using model directly" ) + # Patch attention for BatchKVCache compatibility + from .patches.qwen3_5_mllm import patch_qwen35_attention_for_batching + from .patches.gemma4_mllm import patch_gemma4_attention_for_batching + + patch_qwen35_attention_for_batching() + patch_gemma4_attention_for_batching() + self.max_tokens = max_tokens self.stop_tokens = stop_tokens or set() self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) @@ -340,6 +386,18 @@ def __init__( # Statistics self._stats = MLLMBatchStats() + # Error responses for requests that failed during preprocessing + self._pending_error_responses: List[MLLMBatchResponse] = [] + + # Per-request prefill progress: request_id → (processed_tokens, total_tokens) + self._prefill_progress: Dict[str, Tuple[int, int]] = {} + + # Aborted request IDs — checked between prefill chunks to allow + # early termination when a client disconnects during long prefill. + # Set operations are GIL-protected, safe across event-loop and + # executor threads. + self._aborted_request_ids: set = set() + # Vision embedding cache for repeated images self.vision_cache = VisionEmbeddingCache( max_pixel_entries=vision_cache_size, @@ -351,6 +409,33 @@ def __init__( f"MLLMBatchGenerator: Vision cache enabled (size={vision_cache_size})" ) + # KV prefix cache for text-only requests + self.prefix_cache: Optional[MemoryAwarePrefixCache] = None + if prefix_cache_config is not None: + self.prefix_cache = MemoryAwarePrefixCache( + model=self.language_model, + config=prefix_cache_config, + ) + logger.info("MLLMBatchGenerator: KV prefix cache enabled") + + # Normalize chat template for prefix-cache stability. + # Qwen3.5 chat template retroactively changes formatting of earlier + # assistant messages based on last_query_index (position of last + # non-tool user message). When a user text message is appended, + # last_query_index jumps forward, removing blocks from + # earlier assistant turns — shifting tokens mid-sequence and + # breaking prefix match. Fix: always use plain format for + # historical assistant turns (thinking is still added by the + # generation prompt at the end). + self._normalize_chat_template_for_prefix_cache() + + # Compute think-suffix length for prefix cache key stripping. + # Models with enable_thinking=True add \n to the generation + # prompt. This breaks prefix cache (stored key ends with + # but next request has actual response at that position). + # Stripping the suffix from cache keys enables clean PREFIX match. + self._think_suffix_len = self._compute_think_suffix_len() + # Generation stream if MLLMBatchGenerator._stream is None: MLLMBatchGenerator._stream = mx.new_stream(mx.default_device()) @@ -362,6 +447,132 @@ def __init__( mx.device_info()["max_recommended_working_set_size"] ) + def _normalize_chat_template_for_prefix_cache(self) -> None: + """Patch chat template so historical assistant turns are prefix-stable. + + Qwen3.5's chat template computes ``last_query_index`` — the position + of the last non-tool-response user message — and conditionally wraps + assistant turns after that index in ``...\\n\\n\\n``. + When a new user text message is appended, ``last_query_index`` jumps + forward, retroactively removing these ```` wrappers from + earlier assistant turns. This shifts tokens mid-sequence and breaks + prefix cache. + + Fix: replace the conditional with the plain (ELSE) branch so ALL + historical assistant messages use ``<|im_start|>assistant\\ncontent`` + without any injected ```` block. The generation prompt still + adds ``\\n`` at the very end, so the model generates thinking. + """ + if self.prefix_cache is None: + return # No prefix cache — no need to normalize + + # Find the chat template. VLM processors (e.g. Qwen3VLProcessor) + # keep a SEPARATE copy of chat_template from their tokenizer — both + # must be patched. The processor's copy is used by + # BatchedEngine._apply_chat_template() (text rendering), while the + # tokenizer's copy is used by _compute_think_suffix_len(). + tokenizer = getattr(self.processor, "tokenizer", self.processor) + # Prefer the processor's own template (it's the one used for rendering) + template = getattr(self.processor, "chat_template", None) + if not template: + template = getattr(tokenizer, "chat_template", None) + if not template or "last_query_index" not in template: + return # Not affected + + import re + + # The pattern in Qwen3.5 template: + # {%- if loop.index0 > ns.last_query_index %} + # {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + # {%- else %} + # {{- '<|im_start|>' + message.role + '\n' + content }} + # {%- endif %} + # + # Replace with just the ELSE branch (always plain format). + pattern = ( + r"\{%-\s*if\s+loop\.index0\s*>\s*ns\.last_query_index\s*%\}" + r".*?" + r"\{%-\s*else\s*%\}" + r"\s*(\{\{-.*?content.*?\}\})" + r"\s*\{%-\s*endif\s*%\}" + ) + new_template = re.sub(pattern, r"\1", template, flags=re.DOTALL) + if new_template != template: + # Patch ALL copies: processor, tokenizer, and any dict variants. + if hasattr(self.processor, "chat_template"): + self.processor.chat_template = new_template + tokenizer.chat_template = new_template + logger.info( + "[prefix_cache] Normalized chat template: removed " + "last_query_index conditional for prefix-stable assistant turns" + ) + else: + logger.debug( + "[prefix_cache] Chat template has last_query_index but " + "regex did not match — template may use a different pattern" + ) + + def _compute_think_suffix_len(self) -> int: + """Compute how many extra tokens enable_thinking=True adds at the END. + + Compares the generation prompt suffix with and without + ``enable_thinking`` to find the think-tag suffix length + (typically ``\\n`` = 2 tokens for Qwen3/Qwen3.5). + + Returns 0 if the template doesn't support ``enable_thinking``. + """ + try: + # Find something with apply_chat_template + applicator = None + for candidate in [ + getattr(self.processor, "tokenizer", None), + self.processor, + ]: + if candidate is not None and hasattr(candidate, "apply_chat_template"): + applicator = candidate + break + + if applicator is None: + return 0 + + dummy = [{"role": "user", "content": "x"}] + + try: + text_with = applicator.apply_chat_template( + dummy, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + text_without = applicator.apply_chat_template( + dummy, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + return 0 + + # Check if enable_thinking adds a known think tag at the end. + # enable_thinking may also change the system prompt, so we can't + # simply compare lengths — we look at the ending instead. + for tag in ["\n", ""]: + if text_with.endswith(tag) and not text_without.endswith(tag): + tokenizer = getattr(self.processor, "tokenizer", self.processor) + suffix_tokens = tokenizer.encode(tag) + base_tokens = tokenizer.encode("") + suffix_len = len(suffix_tokens) - len(base_tokens) + if suffix_len > 0: + logger.info( + f"[think_suffix] Detected think tag " + f"'{tag.strip()}' = {suffix_len} token(s)" + ) + return max(0, suffix_len) + + return 0 + except Exception: + return 0 + def close(self) -> None: """Release resources and reset wired limit.""" if self._old_wired_limit is not None: @@ -369,6 +580,16 @@ def close(self) -> None: mx.set_wired_limit(self._old_wired_limit) self._old_wired_limit = None + def abort_prefill(self, request_id: str) -> None: + """Signal that a request's prefill should be aborted. + + Called from the event loop thread when a client disconnects. + The prefill loop checks this set between chunks and raises + PrefillAbortedError to exit early. + """ + self._aborted_request_ids.add(request_id) + logger.info(f"[abort_prefill] Marked {request_id} for prefill abort") + def __del__(self): try: self.close() @@ -545,12 +766,81 @@ def _preprocess_request(self, request: MLLMBatchRequest) -> None: self._stats.num_images_processed += len(all_images) self._stats.vision_encoding_time += processing_time + # Mark text-only requests (eligible for prefix cache) + request.is_text_only = not bool(all_images) + logger.debug( f"Preprocessed request {request.request_id}: " f"{len(all_images)} images, {request.input_ids.size if request.input_ids is not None else 0} tokens " f"({processing_time:.2f}s)" ) + def _run_chunked_text_prefill( + self, request: MLLMBatchRequest, cache: List[Any] + ) -> mx.array: + """ + Run prefill in chunks for text-only requests, reporting real progress. + + Processes input_ids in prefill_step_size chunks through the language + model, updating ``_prefill_progress`` after each chunk so the status + endpoint can report accurate prefill percentage. + + Returns: + Logits from the last chunk (same contract as _run_vision_encoding). + """ + input_ids = request.input_ids + if input_ids.ndim == 1: + input_ids = input_ids[None, :] + + total = input_ids.shape[1] + step = self.prefill_step_size + + # Short prompt — process in one shot (no chunking overhead) + if total <= step: + self._prefill_progress[request.request_id] = (total, total) + output = self.language_model(input_ids, cache=cache) + request.vision_encoded = True + if hasattr(output, "logits"): + return output.logits + return output + + # Process all chunks except the last + processed = 0 + chunk_count = 0 + while processed + step < total: + # Check for abort between chunks (client disconnect) + if request.request_id in self._aborted_request_ids: + self._aborted_request_ids.discard(request.request_id) + logger.info( + f"[chunked_prefill] Aborted {request.request_id} at " + f"{processed}/{total} tokens" + ) + raise PrefillAbortedError(request.request_id) + + chunk = input_ids[:, processed : processed + step] + self.language_model(chunk, cache=cache) + mx.eval([c.state for c in cache]) + processed += step + chunk_count += 1 + self._prefill_progress[request.request_id] = (processed, total) + + # Release Metal buffer pool periodically. Full-attention layers + # produce attention score buffers that grow each chunk (1024 × + # growing_context). Old smaller buffers can't be reused, so the + # pool accumulates O(N²) memory without clearing. + if chunk_count % 4 == 0: + mx.clear_cache() + + # Last chunk — return logits for sampling + last_chunk = input_ids[:, processed:] + output = self.language_model(last_chunk, cache=cache) + request.vision_encoded = True + self._prefill_progress[request.request_id] = (total, total) + + if hasattr(output, "logits"): + return output.logits + return output + def _run_vision_encoding( self, request: MLLMBatchRequest, cache: Optional[List[Any]] = None ) -> mx.array: @@ -613,68 +903,305 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: tic = time.perf_counter() - # Preprocess all requests + # Preprocess all requests (per-request error handling) + failed_requests = [] for req in requests: - self._preprocess_request(req) + try: + self._preprocess_request(req) + except Exception as e: + logger.error( + f"Failed to preprocess request {req.request_id}: " + f"{type(e).__name__}: {e}" + ) + failed_requests.append(req) + + # Remove failed requests from batch and create error responses + if failed_requests: + for req in failed_requests: + requests.remove(req) + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + + if not requests: + # All requests failed + return None total_prompt_tokens = sum( req.input_ids.size if req.input_ids is not None else 1 for req in requests ) self._stats.prompt_tokens += total_prompt_tokens - # Guard against excessive memory usage during cache merge. - # Each token in the batch requires KV entries across all layers. + # Log large prompts for monitoring (was previously a hard check that + # caused infinite retry loops when requests exceeded the limit). max_batch_tokens = self.prefill_step_size * len(requests) if total_prompt_tokens > max_batch_tokens: - raise ValueError( - f"Total prompt tokens ({total_prompt_tokens}) exceeds safe limit " - f"({max_batch_tokens}) for {len(requests)} requests. " - f"Reduce prompt length or batch size." + logger.warning( + f"Large batch prefill: {total_prompt_tokens} tokens " + f"(step_size={self.prefill_step_size}, requests={len(requests)}). " + f"Processing may be slow." ) # Run vision encoding for each request with its own KVCache. # Vision encoding cannot be batched because each request may have # different images/pixel values. We pass a per-request KVCache to # the VLM so the language model writes its KV state directly into it. + # + # For text-only requests, we check the prefix cache first. If there's + # a hit, we skip the full VLM forward and run only the language model + # on the remaining (uncached) tokens. first_tokens = [] all_logprobs = [] per_request_caches = [] + aborted_requests = [] for req in requests: - # Create a fresh KVCache for this request's language model prefill - request_cache = make_prompt_cache(self.language_model) - - with mx.stream(MLLMBatchGenerator._stream): - # Run VLM forward pass — cache= flows through to language_model - logits = self._run_vision_encoding(req, cache=request_cache) - - # Extract last token logits and sample - last_logits = logits[:, -1, :] - logprobs = last_logits - mx.logsumexp( - last_logits, axis=-1, keepdims=True - ) - sampled = self.sampler(logprobs) - - mx.eval(sampled, logprobs) + try: + # Check abort before starting prefill + if req.request_id in self._aborted_request_ids: + self._aborted_request_ids.discard(req.request_id) + raise PrefillAbortedError(req.request_id) + + # Try prefix cache for all requests (text-only and multimodal). + # VLM forward writes the same KV state as language model forward + # for text tokens, so cached KV from a previous VLM run is valid. + # However, if the remaining (uncached) tokens contain image + # placeholders, we must fall back to VLM forward instead of + # running them through the language model alone. + cached_kv = None + remaining_ids = None + if self.prefix_cache is not None and req.input_ids is not None: + input_ids_list = req.input_ids.reshape(-1).tolist() + # Strip think suffix from lookup key so stored entries + # (also stripped) match as clean PREFIX. + S = self._think_suffix_len + lookup_ids = input_ids_list[:-S] if S > 0 else input_ids_list + cached_kv, remaining_ids = self.prefix_cache.fetch(lookup_ids) + # Append think suffix back to remaining so the model + # sees the full generation prompt (\n). + if cached_kv is not None and S > 0: + remaining_ids = list(remaining_ids) + input_ids_list[-S:] + + # If remaining tokens contain image placeholders, the + # language-model-only path cannot handle them — clear the + # cache hit so we fall through to full VLM forward. + if cached_kv is not None and remaining_ids: + img_tok = getattr( + getattr(self.model, "config", None), + "image_token_index", + None, + ) + if img_tok is not None and img_tok in remaining_ids: + cached_kv = None + remaining_ids = None + + if cached_kv is not None and remaining_ids: + # Prefix/LCP match — run language model on remaining tokens + request_cache = cached_kv + remaining = mx.array(remaining_ids)[None, :] + cached_count = len(input_ids_list) - len(remaining_ids) + total_tokens = len(input_ids_list) + remaining_count = len(remaining_ids) + + with mx.stream(MLLMBatchGenerator._stream): + step = self.prefill_step_size + if remaining_count <= step: + # Short remaining — process in one shot + self._prefill_progress[req.request_id] = ( + total_tokens, + total_tokens, + ) + logits = self.language_model(remaining, cache=request_cache) + else: + # Chunked prefill on remaining tokens + self._prefill_progress[req.request_id] = ( + cached_count, + total_tokens, + ) + processed = 0 + chunk_count = 0 + while processed + step < remaining_count: + # Check for abort between chunks + if req.request_id in self._aborted_request_ids: + self._aborted_request_ids.discard(req.request_id) + logger.info( + f"[chunked_prefill] Aborted {req.request_id} " + f"at {cached_count + processed}/{total_tokens} tokens" + ) + raise PrefillAbortedError(req.request_id) + + chunk = remaining[:, processed : processed + step] + self.language_model(chunk, cache=request_cache) + mx.eval([c.state for c in request_cache]) + processed += step + chunk_count += 1 + self._prefill_progress[req.request_id] = ( + cached_count + processed, + total_tokens, + ) + if chunk_count % 4 == 0: + mx.clear_cache() + # Last chunk — return logits + remaining = remaining[:, processed:] + logits = self.language_model(remaining, cache=request_cache) + self._prefill_progress[req.request_id] = ( + total_tokens, + total_tokens, + ) + + if hasattr(logits, "logits"): + logits = logits.logits + + last_logits = logits[:, -1, :] + logprobs = last_logits - mx.logsumexp( + last_logits, axis=-1, keepdims=True + ) + sampled = self.sampler(logprobs) + mx.eval(sampled, logprobs) + + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) + + per_request_caches.append(request_cache) + req.vision_encoded = True + logger.debug( + f"Prefix cache hit for {req.request_id}: " + f"cached={cached_count}, " + f"remaining={remaining_count}" + ) - first_tokens.append(sampled.item()) - all_logprobs.append(logprobs.squeeze(0)) + elif cached_kv is not None and not remaining_ids: + # Exact/supersequence match — cache has all tokens, + # but we still need logits for the last token. + # fetch() with trim-by-1 store always returns remaining=[last_token]. + # If we get here (empty remaining), re-run on last token. + request_cache = cached_kv + last_token = req.input_ids[:, -1:] + total_tokens = len(input_ids_list) + self._prefill_progress[req.request_id] = ( + total_tokens, + total_tokens, + ) - per_request_caches.append(request_cache) + with mx.stream(MLLMBatchGenerator._stream): + logits = self.language_model(last_token, cache=request_cache) + if hasattr(logits, "logits"): + logits = logits.logits + + last_logits = logits[:, -1, :] + logprobs = last_logits - mx.logsumexp( + last_logits, axis=-1, keepdims=True + ) + sampled = self.sampler(logprobs) + mx.eval(sampled, logprobs) + + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) + + per_request_caches.append(request_cache) + req.vision_encoded = True + logger.debug( + f"Prefix cache exact hit for {req.request_id}: " + f"all {total_tokens} tokens cached" + ) - # Merge per-request KVCaches into a single BatchKVCache. - # KVCache.merge() creates a BatchKVCache with proper left-padding - # alignment, so all requests share a single batched cache for - # subsequent generation steps. - from mlx_lm.models.cache import KVCache + else: + # Cache miss — full forward pass + request_cache = make_prompt_cache(self.language_model) + + with mx.stream(MLLMBatchGenerator._stream): + # Text-only: chunked prefill with real progress tracking + # Multimodal: atomic VLM forward (vision encoder needs full input) + if req.is_text_only: + logits = self._run_chunked_text_prefill( + req, cache=request_cache + ) + else: + logits = self._run_vision_encoding(req, cache=request_cache) + + # Extract last token logits and sample + last_logits = logits[:, -1, :] + logprobs = last_logits - mx.logsumexp( + last_logits, axis=-1, keepdims=True + ) + sampled = self.sampler(logprobs) + + mx.eval(sampled, logprobs) + + first_tokens.append(sampled.item()) + all_logprobs.append(logprobs.squeeze(0)) + + per_request_caches.append(request_cache) + + except PrefillAbortedError: + aborted_requests.append(req) + self._prefill_progress.pop(req.request_id, None) + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="abort", + ) + ) - sample_cache = per_request_caches[0][0] - if not isinstance(sample_cache, KVCache): - raise ValueError( - f"MLLM continuous batching requires standard KVCache but got " - f"{type(sample_cache).__name__}. Disable --kv-cache-quantization " - f"when using multimodal models with --continuous-batching." - ) + # Remove aborted requests — they have no entries in the parallel + # lists (first_tokens, all_logprobs, per_request_caches) + if aborted_requests: + for req in aborted_requests: + requests.remove(req) + mx.clear_cache() + if not requests: + return None + + # Merge per-request caches into batched caches. + # Both KVCache.merge() and ArraysCache.merge() produce batch-aware + # caches that support filter/extend/extract for continuous batching. + # + # Fix: RotatingKVCache._update_concat does NOT trim on first call — + # if prompt length > max_size, the buffer grows beyond max_size. + # BatchRotatingKVCache.merge() then hits a shape mismatch when + # copying via _temporal_order (full buffer) into a max_size slice. + # Trim buffer to max_size before merging. + from mlx_lm.models.cache import RotatingKVCache + + for rc in per_request_caches: + for layer_cache in rc: + if isinstance(layer_cache, RotatingKVCache): + if layer_cache.keys is not None: + buf_len = layer_cache.keys.shape[2] + if buf_len > layer_cache.max_size: + trim_size = buf_len - layer_cache.max_size + layer_cache.keys = layer_cache._trim( + trim_size, layer_cache.keys + ) + layer_cache.values = layer_cache._trim( + trim_size, layer_cache.values + ) + layer_cache._idx = layer_cache.max_size + # Normalize wrapped rotating cache for merge: + # after rotation _idx wraps around but merge() + # expects _idx == actual buffer size. + # Use keys.shape[2] (actual entries) NOT size() + # which can be inconsistent after prefix cache trim + # (size() = min(offset, max_size) but buffer may + # have fewer entries when trimmed). + actual_buf = layer_cache.keys.shape[2] + if layer_cache._idx != actual_buf and actual_buf > 0: + layer_cache.keys = layer_cache._temporal_order( + layer_cache.keys + ) + layer_cache.values = layer_cache._temporal_order( + layer_cache.values + ) + layer_cache._idx = actual_buf try: batch_cache = [ @@ -684,14 +1211,61 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: for layer_idx in range(len(per_request_caches[0])) ] except Exception as e: + sample_type = type(per_request_caches[0][0]).__name__ logger.error( - f"Failed to merge per-request KV caches: {type(e).__name__}: {e}" + f"Failed to merge per-request caches ({sample_type}): " + f"{type(e).__name__}: {e}" ) raise # Create initial y (first generated tokens) y = mx.array(first_tokens) + # Build per-request logits processors (repetition_penalty, presence_penalty) + from mlx_lm.sample_utils import make_logits_processors, make_sampler + + batch_logits_processors = [] + has_any_lp = False + for req in requests: + need_rep = req.repetition_penalty and req.repetition_penalty != 1.0 + need_pres = req.presence_penalty and req.presence_penalty != 0.0 + if need_rep or need_pres: + lp_kwargs = {} + if need_rep: + lp_kwargs["repetition_penalty"] = req.repetition_penalty + if need_pres: + lp_kwargs["presence_penalty"] = req.presence_penalty + lp = make_logits_processors(**lp_kwargs) + batch_logits_processors.append(lp) + has_any_lp = True + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"rep_penalty={req.repetition_penalty} " + f"pres_penalty={req.presence_penalty}" + ) + else: + batch_logits_processors.append(None) + + # Build per-request samplers for top_k/min_p + batch_samplers = [] + has_any_sampler = False + for req in requests: + if req.top_k != 0 or req.min_p != 0.0: + s = make_sampler( + temp=req.temperature, + top_p=req.top_p, + top_k=req.top_k, + min_p=req.min_p, + ) + batch_samplers.append(s) + has_any_sampler = True + logger.info( + f"[sampling] request={req.request_id[:12]} " + f"top_k={req.top_k} min_p={req.min_p}" + ) + else: + batch_samplers.append(None) + self._stats.prompt_time += time.perf_counter() - tic return MLLMBatch( @@ -703,10 +1277,17 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: num_tokens=[0] * len(requests), cache=batch_cache, requests=requests, + logits_processors=batch_logits_processors if has_any_lp else None, + samplers=batch_samplers if has_any_sampler else None, ) def _step( - self, input_tokens: mx.array, cache: List[Any] + self, + input_tokens: mx.array, + cache: List[Any], + logits_processors: Optional[List[Optional[List[Callable]]]] = None, + output_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Optional[Callable]]] = None, ) -> Tuple[mx.array, List[mx.array]]: """ Run one generation step through the language model. @@ -714,6 +1295,9 @@ def _step( Args: input_tokens: Input tokens [batch_size, 1] or [batch_size] cache: BatchKVCache for the language model + logits_processors: Per-request logits processors (e.g. repetition penalty) + output_tokens: Per-request generated tokens so far (needed by processors) + samplers: Per-request sampler functions (for top_k/min_p) Returns: Tuple of (sampled tokens, logprobs list) @@ -733,9 +1317,29 @@ def _step( logits = logits[:, -1, :] - # Sample + # Apply per-request logits processors (repetition penalty etc.) + if logits_processors and output_tokens and any(logits_processors): + processed_logits = [] + for e in range(logits.shape[0]): + sample_logits = logits[e : e + 1] + if logits_processors[e]: + for processor in logits_processors[e]: + sample_logits = processor( + mx.array(output_tokens[e]), sample_logits + ) + processed_logits.append(sample_logits) + logits = mx.concatenate(processed_logits, axis=0) + + # Sample — per-request samplers for top_k/min_p support logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - sampled = self.sampler(logprobs) + if samplers and any(samplers): + sampled_list = [] + for e in range(logprobs.shape[0]): + s = samplers[e] if samplers[e] else self.sampler + sampled_list.append(s(logprobs[e : e + 1])) + sampled = mx.concatenate(sampled_list, axis=0) + else: + sampled = self.sampler(logprobs) return sampled, list(logprobs) @@ -757,6 +1361,8 @@ def _next(self) -> List[MLLMBatchResponse]: # merged into a single BatchKVCache. Merging into an active batch # mid-generation would cause shape mismatches in attention layers, # so queued requests wait until the current batch finishes. + # Exception: text-only requests can be extended into an active batch + # via the elif branch below (they skip vision encoding entirely). if num_active == 0: requests = self.unprocessed_requests[: self.completion_batch_size] @@ -764,18 +1370,100 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None return [] - new_batch = self._process_prompts(requests) - self.unprocessed_requests = self.unprocessed_requests[len(requests) :] - self.active_batch = new_batch - prompt_processing = True + try: + # Save count before _process_prompts which modifies + # `requests` in-place via .remove() for failed items. + num_to_consume = len(requests) + new_batch = self._process_prompts(requests) + self.unprocessed_requests = self.unprocessed_requests[num_to_consume:] + self.active_batch = new_batch + prompt_processing = True + except Exception as e: + logger.error( + f"Failed to process batch of {len(requests)} prompts: " + f"{type(e).__name__}: {e}", + exc_info=True, + ) + # Remove failed requests to avoid infinite retry loop + self.unprocessed_requests = self.unprocessed_requests[len(requests) :] + for req in requests: + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + + # Mid-batch extend: text-only requests can join an active batch + # without vision encoding (no shape mismatch risk). + elif self.unprocessed_requests: + text_only = [ + r for r in self.unprocessed_requests if not r.images and not r.videos + ][: self.completion_batch_size] + + if text_only: + try: + # Capture UIDs before _process_prompts modifies + # text_only in-place via .remove() for failed items. + all_uids = {r.uid for r in text_only} + new_batch = self._process_prompts(text_only) + # Remove ALL requested (both successful and failed) + self.unprocessed_requests = [ + r for r in self.unprocessed_requests if r.uid not in all_uids + ] + if new_batch is not None: + batch.extend(new_batch) + prompt_processing = True + except Exception as e: + logger.warning( + f"Failed to extend batch with text-only requests: " + f"{type(e).__name__}: {e}" + ) + # Remove failed requests to avoid infinite retry loop + processed_uids = {r.uid for r in text_only} + self.unprocessed_requests = [ + r + for r in self.unprocessed_requests + if r.uid not in processed_uids + ] + for req in text_only: + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + + # Collect any pending error responses (from failed preprocessing) + error_responses = [] + if self._pending_error_responses: + error_responses = list(self._pending_error_responses) + self._pending_error_responses.clear() # Generate next token for active batch batch = self.active_batch if batch is None: - return [] + return error_responses y, logprobs = batch.y, batch.logprobs - batch.y, batch.logprobs = self._step(y[:, None], batch.cache) + output_tokens = ( + [req.output_tokens for req in batch.requests] + if batch.logits_processors + else None + ) + batch.y, batch.logprobs = self._step( + y[:, None], + batch.cache, + batch.logits_processors, + output_tokens, + batch.samplers, + ) mx.async_eval(batch.y, batch.logprobs) y = y.tolist() @@ -821,6 +1509,8 @@ def _next(self) -> List[MLLMBatchResponse]: if finish_reason is not None: # Extract cache for this request cache_fn = lambda idx=i: batch.extract_cache(idx) + # Cleanup prefill progress tracking + self._prefill_progress.pop(request_id, None) responses.append( MLLMBatchResponse( @@ -833,6 +1523,9 @@ def _next(self) -> List[MLLMBatchResponse]: ) ) + # Store caches for finished text-only requests BEFORE filtering + self._maybe_store_prefix_cache(batch, end_idx) + # Remove finished requests from batch if end_idx: if keep_idx: @@ -841,7 +1534,7 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None self._stats.generation_tokens += len(responses) - return responses + return error_responses + responses def next(self) -> List[MLLMBatchResponse]: """ @@ -863,10 +1556,404 @@ def stats(self) -> MLLMBatchStats: self._stats.peak_memory = mx.get_peak_memory() / 1e9 return self._stats + def _maybe_store_prefix_cache( + self, batch: MLLMBatch, end_indices: List[int] + ) -> None: + """Store KV caches for finished text-only requests into prefix cache. + + Must be called BEFORE batch.filter() so that indices are still valid. + """ + if self.prefix_cache is None or not end_indices: + return + for i in end_indices: + req = batch.requests[i] + if req.input_ids is not None: + try: + extracted = batch.extract_cache(i) + input_ids_list = req.input_ids.reshape(-1).tolist() + # Store prompt-only KV (trim output tokens + 1 so next + # fetch returns remaining=[last_prompt_token] at minimum). + # Also strip think suffix from key so next request's + # (also stripped) key matches as a clean PREFIX. + output_count = batch.num_tokens[i] + S = self._think_suffix_len + total_trim = output_count + 1 + S + prompt_cache = _trim_cache_offset(extracted, total_trim) + cache_key = input_ids_list[:-S] if S > 0 else input_ids_list + self.prefix_cache.store(cache_key, prompt_cache) + except Exception as e: + logger.warning( + f"Failed to store prefix cache for {req.request_id}: {type(e).__name__}: {e}" + ) + + def get_prefill_progress(self, request_id: str) -> Optional[Tuple[int, int]]: + """Return (processed_tokens, total_tokens) or None.""" + return self._prefill_progress.get(request_id) + def get_vision_cache_stats(self) -> Dict[str, Any]: """Get vision cache statistics.""" return self.vision_cache.get_stats() + def get_prefix_cache_stats(self) -> Dict[str, Any]: + """Get KV prefix cache statistics.""" + if self.prefix_cache is not None: + return self.prefix_cache.get_stats() + return { + "hits": 0, + "misses": 0, + "hit_rate": 0.0, + "evictions": 0, + "tokens_saved": 0, + "current_memory_mb": 0.0, + "max_memory_mb": 0.0, + "memory_utilization": 0.0, + "entry_count": 0, + } + def has_pending(self) -> bool: """Check if there are pending or active requests.""" return bool(self.unprocessed_requests or self.active_batch) + + +def install_mtp_mllm( + batch_gen: "MLLMBatchGenerator", + language_model: Any, + num_draft_tokens: int = 1, +) -> None: + """Install MTP (Multi-Token Prediction) on an MLLMBatchGenerator. + + Adapts the always-advance MTP strategy from scheduler._install_mtp + for the MLLM batched generation path. Handles hybrid model caches + (BatchKVCache for attention + ArraysCache for recurrent layers). + + Flow per generation step: + 1. Use skip_state logits/hidden OR run model forward -> sample primary + 2. MTP head drafts one token + 3. Verify [primary, draft] in one model call (always advances cache) + 4. Accept: skip_state from pos 1, defer draft for next step emission + Reject: trim KV by 2 + restore RNN state + re-advance with primary + 5. Draft is emitted in the NEXT generation step after primary + """ + from .scheduler import make_sampler + + _orig_step = batch_gen._step + _draft_sampler = make_sampler(temp=0.0) + + # Skip state: stored logits + hidden from verify pass + _skip_state: list = [None] + + # Deferred drafts keyed by UID + _deferred_drafts: Dict[int, dict] = {} + + # MTP stats + _mtp_stats = {"accepted": 0, "rejected": 0, "errors": 0} + + def _mtp_step( + input_tokens: mx.array, + cache: List[Any], + logits_processors: Optional[List[Optional[List[Callable]]]] = None, + output_tokens: Optional[List[List[int]]] = None, + samplers: Optional[List[Optional[Callable]]] = None, + ) -> Tuple[mx.array, List[mx.array]]: + """Extended _step with MTP always-advance strategy.""" + batch_size = input_tokens.shape[0] + + # Prefill guard: skip MTP for multi-token input or when no active batch + # Also skip MTP when batch has multiple active requests (MTP overhead + # hurts aggregate throughput in concurrent scenarios) + if ( + input_tokens.shape[1] > 1 + or batch_gen.active_batch is None + or len(batch_gen.active_batch) > 1 + ): + _skip_state[0] = None + return _orig_step( + input_tokens, cache, logits_processors, output_tokens, samplers + ) + + # Check skip state + skip = _skip_state[0] + if skip is not None and skip["logits"].shape[0] != batch_size: + skip = None + _skip_state[0] = None + + if skip is not None: + logits = skip["logits"] + hidden_states = skip["hidden"] + _skip_state[0] = None + else: + # Normal forward with return_hidden + model_output = language_model(input_tokens, cache=cache, return_hidden=True) + if isinstance(model_output, tuple): + logits, hidden_states = model_output + else: + return _orig_step( + input_tokens, cache, logits_processors, output_tokens, samplers + ) + logits = logits[:, -1, :] + + # Apply logits processors before sampling + if logits_processors and output_tokens and any(logits_processors): + processed_logits = [] + for e in range(batch_size): + sample_logits = logits[e : e + 1] + if logits_processors[e]: + for processor in logits_processors[e]: + sample_logits = processor( + mx.array(output_tokens[e]), sample_logits + ) + processed_logits.append(sample_logits) + logits = mx.concatenate(processed_logits, axis=0) + + # Sample primary (use per-request sampler if available) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + if samplers and any(samplers): + sampled_list = [] + for e in range(logprobs.shape[0]): + s = samplers[e] if samplers[e] else batch_gen.sampler + sampled_list.append(s(logprobs[e : e + 1])) + primary_tokens = mx.concatenate(sampled_list, axis=0) + else: + primary_tokens = batch_gen.sampler(logprobs) + + current_uids = list(batch_gen.active_batch.uids) + + # MTP draft + always-advance verify + try: + draft_logits = language_model.mtp_forward( + hidden_states[:, -1:, :], + primary_tokens[:, None], + mtp_cache=None, + ) + draft_logits = draft_logits[:, -1, :] + draft_logprobs = draft_logits - mx.logsumexp( + draft_logits, axis=-1, keepdims=True + ) + draft_tokens = _draft_sampler(draft_logprobs) + + # Snapshot RNN state for hybrid models + _rnn_snapshots = {} + for _ci, _c in enumerate(cache): + if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()): + if hasattr(_c, "state"): + _rnn_snapshots[_ci] = [ + mx.array(s) if s is not None else None for s in _c.state + ] + + # Verify [primary, draft] + verify_input = mx.concatenate( + [primary_tokens[:, None], draft_tokens[:, None]], axis=1 + ) + verify_output = language_model( + verify_input, cache=cache, return_hidden=True + ) + if isinstance(verify_output, tuple): + verify_logits, verify_hidden = verify_output + else: + verify_logits = verify_output + verify_hidden = None + + # Verified mode: check if draft matches verify prediction + verify_pred = mx.argmax(verify_logits[:, 0, :], axis=-1) + mx.eval(verify_pred, draft_tokens) + pred_list = verify_pred.tolist() + draft_list = draft_tokens.tolist() + all_accepted = pred_list == draft_list + + if all_accepted and verify_hidden is not None: + # ACCEPT + _skip_state[0] = { + "logits": verify_logits[:, 1, :], + "hidden": verify_hidden[:, -1:, :], + } + mx.async_eval(_skip_state[0]["logits"], _skip_state[0]["hidden"]) + verify_lp = verify_logits[:, 0, :] - mx.logsumexp( + verify_logits[:, 0, :], axis=-1, keepdims=True + ) + for e in range(batch_size): + uid = current_uids[e] + _deferred_drafts[uid] = { + "token": draft_list[e], + "logprobs": verify_lp[e], + } + _mtp_stats["accepted"] += 1 + + else: + # REJECT + if _rnn_snapshots: + # Hybrid model: undo entire verify, re-advance with primary + for c in cache: + if ( + hasattr(c, "is_trimmable") + and c.is_trimmable() + and hasattr(c, "trim") + ): + c.trim(2) + for _ci, _snap in _rnn_snapshots.items(): + cache[_ci].state = _snap + rerun_out = language_model( + primary_tokens[:, None], + cache=cache, + return_hidden=True, + ) + if isinstance(rerun_out, tuple): + rerun_logits, rerun_hidden = rerun_out + else: + rerun_logits = rerun_out + rerun_hidden = None + if rerun_hidden is not None: + _skip_state[0] = { + "logits": rerun_logits[:, -1, :], + "hidden": rerun_hidden[:, -1:, :], + } + mx.async_eval( + _skip_state[0]["logits"], + _skip_state[0]["hidden"], + ) + else: + _skip_state[0] = None + else: + # Pure attention model: simple trim + for c in cache: + if ( + hasattr(c, "is_trimmable") + and c.is_trimmable() + and hasattr(c, "trim") + ): + c.trim(1) + if verify_hidden is not None: + _skip_state[0] = { + "logits": verify_logits[:, 0, :], + "hidden": verify_hidden[:, 0:1, :], + } + mx.async_eval( + _skip_state[0]["logits"], + _skip_state[0]["hidden"], + ) + else: + _skip_state[0] = None + for uid in current_uids: + _deferred_drafts.pop(uid, None) + _mtp_stats["rejected"] += 1 + + except Exception as e: + logger.warning(f"[MTP-MLLM] draft/verify failed: {e}") + _skip_state[0] = None + _mtp_stats["errors"] += 1 + + # Log MTP stats every 50 steps + total = _mtp_stats["accepted"] + _mtp_stats["rejected"] + _mtp_stats["errors"] + if total > 0 and total % 50 == 0: + acc = _mtp_stats["accepted"] + rej = _mtp_stats["rejected"] + err = _mtp_stats["errors"] + rate = acc / (acc + rej) * 100 if (acc + rej) > 0 else 0 + logger.info( + f"[MTP-MLLM] stats: accepted={acc} rejected={rej} " + f"errors={err} acceptance={rate:.0f}%" + ) + + return primary_tokens, list(logprobs) + + # Wrap _next to emit deferred MTP drafts + batch_gen._inner_next = batch_gen._next + + def _mtp_next() -> List[MLLMBatchResponse]: + """Wrapper around _next that emits deferred MTP draft tokens.""" + if batch_gen.active_batch is None: + _skip_state[0] = None + _deferred_drafts.clear() + + # Save deferred drafts from previous step + prev_deferred: Dict[int, dict] = {} + if batch_gen.active_batch is not None: + for uid in batch_gen.active_batch.uids: + if uid in _deferred_drafts: + prev_deferred[uid] = _deferred_drafts.pop(uid) + + responses = batch_gen._inner_next() + + if not prev_deferred or not responses: + return responses + + # Augment responses with deferred drafts + augmented: List[MLLMBatchResponse] = [] + draft_end_uids: set = set() + + for r in responses: + uid = r.uid + augmented.append(r) + + if r.finish_reason is not None: + _deferred_drafts.pop(uid, None) + prev_deferred.pop(uid, None) + continue + + if uid in prev_deferred: + draft_info = prev_deferred.pop(uid) + draft_t = draft_info["token"] + draft_lp = draft_info["logprobs"] + + if draft_t in batch_gen.stop_tokens: + augmented.append( + MLLMBatchResponse( + uid=uid, + request_id=r.request_id, + token=draft_t, + logprobs=draft_lp, + finish_reason="stop", + ) + ) + draft_end_uids.add(uid) + else: + draft_finish = None + batch = batch_gen.active_batch + if batch is not None: + for e, bu in enumerate(batch.uids): + if bu == uid: + batch.num_tokens[e] += 1 + batch.requests[e].output_tokens.append(draft_t) + if batch.num_tokens[e] >= batch.max_tokens[e]: + draft_finish = "length" + draft_end_uids.add(uid) + break + + augmented.append( + MLLMBatchResponse( + uid=uid, + request_id=r.request_id, + token=draft_t, + logprobs=draft_lp, + finish_reason=draft_finish, + ) + ) + + # Store prefix caches for draft-ended sequences BEFORE filtering + if draft_end_uids and batch_gen.active_batch is not None: + end_indices = [ + e + for e, u in enumerate(batch_gen.active_batch.uids) + if u in draft_end_uids + ] + batch_gen._maybe_store_prefix_cache(batch_gen.active_batch, end_indices) + + keep = [ + e + for e, u in enumerate(batch_gen.active_batch.uids) + if u not in draft_end_uids + ] + if keep: + batch_gen.active_batch.filter(keep) + else: + batch_gen.active_batch = None + + return augmented + + batch_gen._step = _mtp_step + batch_gen._next = _mtp_next + + total = _mtp_stats + logger.info( + f"[MTP-MLLM] installed with num_draft_tokens={num_draft_tokens}, " + f"always-advance verified mode" + ) diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 555b230f2..04c7cac2a 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -19,6 +19,7 @@ """ import asyncio +import concurrent.futures import logging import time import uuid @@ -35,7 +36,6 @@ MLLMBatchRequest, MLLMBatchResponse, ) -from .mllm_cache import MLLMCacheManager from .multimodal_processor import MultimodalProcessor from .request import RequestOutput, RequestStatus, SamplingParams @@ -62,8 +62,22 @@ class MLLMSchedulerConfig: default_max_tokens: int = 256 # Default video FPS for frame extraction default_video_fps: float = 2.0 + # KV cache memory limit (from --cache-memory-mb) + cache_memory_mb: Optional[int] = None # Maximum video frames max_video_frames: int = 128 + # Enable MTP speculative decoding + enable_mtp: bool = False + # Number of draft tokens for MTP + mtp_num_draft_tokens: int = 1 + # Enable KV prefix cache for text-only requests + enable_prefix_cache: bool = True + # Memory limit for prefix cache (None = auto-detect) + prefix_cache_memory_mb: Optional[int] = None + # KV cache quantization for prefix cache store/fetch + kv_cache_quantization: bool = False + kv_cache_quantization_bits: int = 8 + kv_cache_quantization_group_size: int = 64 @dataclass @@ -94,6 +108,9 @@ class MLLMRequest: num_prompt_tokens: int = 0 num_output_tokens: int = 0 + # Timing + first_token_time: Optional[float] = None + @dataclass class MLLMSchedulerOutput: @@ -176,13 +193,6 @@ def __init__( config=self.model_config, ) - # Vision cache for repeated images - self.vision_cache: Optional[MLLMCacheManager] = None - if self.config.enable_vision_cache: - self.vision_cache = MLLMCacheManager( - max_entries=self.config.vision_cache_size - ) - # Get stop tokens from tokenizer self.stop_tokens = self._get_stop_tokens() @@ -218,8 +228,12 @@ def __init__( self.total_prompt_tokens = 0 self.total_completion_tokens = 0 + # Memory management: periodic mx.clear_cache() to free Metal buffers + self._step_count = 0 + self._clear_cache_interval = 32 + def _get_stop_tokens(self) -> Set[int]: - """Get stop token IDs from tokenizer.""" + """Get stop token IDs from tokenizer and generation_config.json.""" stop_tokens = set() tokenizer = ( self.processor.tokenizer @@ -239,6 +253,25 @@ def _get_stop_tokens(self) -> Set[int]: else: stop_tokens.add(tokenizer.eos_token_ids) + # Also read generation_config.json which may have additional EOS tokens + # (e.g., Gemma 4 has =106, <|tool_response>=50 as EOS) + model_path = getattr(tokenizer, "name_or_path", None) + if model_path: + import json + from pathlib import Path + + gc_path = Path(model_path) / "generation_config.json" + if gc_path.exists(): + try: + gc = json.loads(gc_path.read_text()) + gc_eos = gc.get("eos_token_id") + if isinstance(gc_eos, list): + stop_tokens.update(gc_eos) + elif gc_eos is not None: + stop_tokens.add(gc_eos) + except Exception: + pass + return stop_tokens def _ensure_batch_generator(self) -> None: @@ -246,9 +279,24 @@ def _ensure_batch_generator(self) -> None: if self.batch_generator is None: from mlx_lm.sample_utils import make_sampler + from .memory_cache import MemoryCacheConfig + # Default sampler (can be overridden per-request in future) sampler = make_sampler(temp=0.7, top_p=0.9) + # Configure KV prefix cache for text-only requests + # KV cache quantization reduces prefix cache memory ~4x (BF16→Q8). + # Quantization happens on store(), dequantization on fetch() — + # the model always receives normal KVCache with plain arrays. + prefix_cache_config = None + if self.config.enable_prefix_cache: + prefix_cache_config = MemoryCacheConfig( + max_memory_mb=self.config.prefix_cache_memory_mb, + kv_quantize=self.config.kv_cache_quantization, + kv_bits=self.config.kv_cache_quantization_bits, + kv_group_size=self.config.kv_cache_quantization_group_size, + ) + self.batch_generator = MLLMBatchGenerator( model=self.model, processor=self.processor, @@ -259,8 +307,21 @@ def _ensure_batch_generator(self) -> None: prefill_batch_size=self.config.prefill_batch_size, completion_batch_size=self.config.completion_batch_size, prefill_step_size=self.config.prefill_step_size, + prefix_cache_config=prefix_cache_config, ) + # Install MTP if enabled and language model supports it + if self.config.enable_mtp: + lm = self.batch_generator.language_model + if hasattr(lm, "mtp") and lm.mtp is not None: + from .mllm_batch_generator import install_mtp_mllm + + install_mtp_mllm( + self.batch_generator, + lm, + num_draft_tokens=self.config.mtp_num_draft_tokens, + ) + # ========== Sync API (step-based) ========== def add_request( @@ -297,6 +358,10 @@ def add_request( max_tokens=max_tokens, temperature=temperature, top_p=top_p, + top_k=kwargs.pop("top_k", 0), + min_p=kwargs.pop("min_p", 0.0), + presence_penalty=kwargs.pop("presence_penalty", 0.0), + repetition_penalty=kwargs.pop("repetition_penalty", 1.0), ) request = MLLMRequest( @@ -307,6 +372,19 @@ def add_request( sampling_params=sampling_params, ) + # Estimate prompt token count for monitoring (text tokens only; + # vision tokens are added during prefill but this gives a useful + # approximation for the status endpoint). + tokenizer = ( + self.processor.tokenizer + if hasattr(self.processor, "tokenizer") + else self.processor + ) + try: + request.num_prompt_tokens = len(tokenizer.encode(prompt)) + except Exception: + pass + self.requests[request_id] = request self.waiting.append(request) @@ -331,6 +409,12 @@ def abort_request(self, request_id: str) -> bool: if request is None: return False + # Signal batch generator to abort any in-progress prefill for this + # request. The prefill loop checks _aborted_request_ids between + # chunks and raises PrefillAbortedError to exit early. + if self.batch_generator is not None: + self.batch_generator.abort_prefill(request_id) + # Remove from waiting queue if request.status == RequestStatus.WAITING: try: @@ -403,6 +487,10 @@ def _schedule_waiting(self) -> List[MLLMRequest]: max_tokens=request.sampling_params.max_tokens, temperature=request.sampling_params.temperature, top_p=request.sampling_params.top_p, + top_k=request.sampling_params.top_k, + min_p=request.sampling_params.min_p, + presence_penalty=request.sampling_params.presence_penalty, + repetition_penalty=request.sampling_params.repetition_penalty, ) batch_requests.append(batch_req) @@ -453,21 +541,41 @@ def _process_batch_responses( if request is None: continue + # Handle error responses from failed preprocessing + if response.finish_reason == "error": + output = RequestOutput( + request_id=request_id, + new_token_ids=[], + new_text="", + output_token_ids=[], + prompt_tokens=0, + completion_tokens=0, + finished=True, + finish_reason="error", + ) + request.status = RequestStatus.FINISHED_ABORTED + request.output_text = "" + request.finish_reason = "error" + finished_ids.add(request_id) + self.num_requests_processed += 1 + logger.warning(f"Request {request_id} failed during preprocessing") + outputs.append(output) + continue + # Append token to request request.output_tokens.append(response.token) request.num_output_tokens = len(request.output_tokens) + if request.first_token_time is None and request.num_output_tokens > 0: + request.first_token_time = time.time() + # Decode the new token using streaming detokenizer (UTF-8 safe). # Skip stop tokens — they are not content. if response.finish_reason == "stop": new_text = "" else: if request_id not in self._detokenizer_pool: - if hasattr(tokenizer, "detokenizer"): - detok = tokenizer.detokenizer - else: - detok = NaiveStreamingDetokenizer(tokenizer) - detok.reset() + detok = NaiveStreamingDetokenizer(tokenizer) self._detokenizer_pool[request_id] = detok detok = self._detokenizer_pool[request_id] detok.add_token(response.token) @@ -495,7 +603,7 @@ def _process_batch_responses( finished_ids.add(request_id) # Finalize streaming detokenizer and get full output - detok = self._detokenizer_pool.get(request_id) + detok = self._detokenizer_pool.pop(request_id, None) if detok is not None: detok.finalize() output.output_text = detok.text @@ -503,7 +611,6 @@ def _process_batch_responses( output.output_text = tokenizer.decode(request.output_tokens) request.output_text = output.output_text request.finish_reason = response.finish_reason - self._detokenizer_pool.pop(request_id, None) self.total_completion_tokens += request.num_output_tokens self.num_requests_processed += 1 @@ -524,6 +631,9 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: if request_id in self.running: del self.running[request_id] + # Drain from requests dict to prevent linear memory growth + self.requests.pop(request_id, None) + # Remove UID mappings if request_id in self.request_id_to_uid: uid = self.request_id_to_uid[request_id] @@ -531,10 +641,17 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: del self.uid_to_request_id[uid] del self.request_id_to_uid[request_id] + # Clean up detokenizer pool (handles abort/timeout cases) + self._detokenizer_pool.pop(request_id, None) + # Track as finished self.finished_req_ids.add(request_id) self.requests.pop(request_id, None) + # Clear Metal buffer pool after cleanup to release memory + if finished_ids: + mx.clear_cache() + def step(self) -> MLLMSchedulerOutput: """ Execute one scheduling step. @@ -634,14 +751,33 @@ async def stop(self) -> None: logger.info("MLLM Scheduler stopped") async def _process_loop(self) -> None: - """Main async processing loop.""" + """Main async processing loop. + + Uses a thread pool executor for steps that involve prefill + (waiting requests or partial prefill in progress) so that the + event loop stays responsive for health checks and other HTTP + endpoints. Decode-only steps are fast (<3 ms) and run inline. + """ + _executor = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mllm-step" + ) + loop = asyncio.get_running_loop() + while self._running: try: if self.has_requests(): - # Run one step - self.step() - # Yield to other tasks - await asyncio.sleep(0) + has_waiting = self.get_num_waiting() > 0 + has_partial = ( + self.batch_generator is not None + and getattr(self.batch_generator, "_partial", None) is not None + ) + needs_executor = has_waiting or has_partial + + if needs_executor: + await loop.run_in_executor(_executor, self.step) + else: + self.step() + await asyncio.sleep(0) else: # No work, wait a bit await asyncio.sleep(0.01) @@ -649,7 +785,7 @@ async def _process_loop(self) -> None: except asyncio.CancelledError: break except Exception as e: - logger.error(f"Error in MLLM process loop: {e}") + logger.error(f"Error in MLLM process loop: {e}", exc_info=True) await asyncio.sleep(0.1) async def add_request_async( @@ -778,6 +914,77 @@ async def generate( # ========== Stats and utilities ========== + def get_running_requests_info(self) -> List[Dict[str, Any]]: + """Per-request details for status endpoint.""" + now = time.time() + result = [] + + # Waiting requests + for req in self.waiting: + result.append( + { + "request_id": req.request_id, + "status": "waiting", + "phase": "queued", + "elapsed_s": round(now - req.arrival_time, 2), + "prompt_tokens": req.num_prompt_tokens, + "completion_tokens": 0, + "max_tokens": req.sampling_params.max_tokens, + "progress": 0.0, + "tokens_per_second": None, + "ttft_s": None, + "cache_hit_type": None, + "cached_tokens": 0, + } + ) + + # Running requests + for req in self.running.values(): + n_out = req.num_output_tokens + elapsed = now - req.arrival_time + + if n_out == 0: + phase = "prefill" + else: + phase = "generation" + + tok_s = None + ttft = None + if req.first_token_time is not None: + ttft = round(req.first_token_time - req.arrival_time, 3) + gen_elapsed = now - req.first_token_time + if gen_elapsed > 0 and n_out > 0: + tok_s = round(n_out / gen_elapsed, 1) + + max_tokens = req.sampling_params.max_tokens + if phase == "prefill" and self.batch_generator is not None: + pp = self.batch_generator.get_prefill_progress(req.request_id) + if pp is not None: + progress = round(pp[0] / pp[1], 3) if pp[1] > 0 else 0.0 + else: + progress = 0.0 + else: + progress = round(n_out / max_tokens, 3) if max_tokens > 0 else 0.0 + + result.append( + { + "request_id": req.request_id, + "status": "running", + "phase": phase, + "elapsed_s": round(elapsed, 2), + "prompt_tokens": req.num_prompt_tokens, + "completion_tokens": n_out, + "max_tokens": max_tokens, + "progress": min(progress, 1.0), + "tokens_per_second": tok_s, + "ttft_s": ttft, + "cache_hit_type": None, + "cached_tokens": 0, + } + ) + + return result + def get_stats(self) -> Dict[str, Any]: """Get scheduler statistics.""" stats = { @@ -787,27 +994,45 @@ def get_stats(self) -> Dict[str, Any]: "num_requests_processed": self.num_requests_processed, "total_prompt_tokens": self.total_prompt_tokens, "total_completion_tokens": self.total_completion_tokens, + "requests": self.get_running_requests_info(), } if self.batch_generator is not None: batch_stats = self.batch_generator.stats() stats["batch_generator"] = batch_stats.to_dict() - # Add vision embedding cache stats from batch generator - stats["vision_embedding_cache"] = ( - self.batch_generator.get_vision_cache_stats() - ) - - if self.vision_cache: - stats["vision_cache"] = self.vision_cache.get_stats() + # Vision embedding cache stats from batch generator + vec_stats = self.batch_generator.get_vision_cache_stats() + stats["vision_embedding_cache"] = vec_stats # Include Metal memory stats try: if mx.metal.is_available(): - stats["metal_active_memory_gb"] = round(mx.get_active_memory() / 1e9, 2) - stats["metal_peak_memory_gb"] = round(mx.get_peak_memory() / 1e9, 2) - stats["metal_cache_memory_gb"] = round(mx.get_cache_memory() / 1e9, 2) + active_gb = round(mx.get_active_memory() / 1e9, 2) + peak_gb = round(mx.get_peak_memory() / 1e9, 2) + cache_gb = round(mx.get_cache_memory() / 1e9, 2) + stats["metal_active_memory_gb"] = active_gb + stats["metal_peak_memory_gb"] = peak_gb + stats["metal_cache_memory_gb"] = cache_gb except Exception: - pass + active_gb = 0 + cache_gb = 0 + + # KV prefix cache stats for /v1/status and monitoring UI. + if self.batch_generator is not None: + prefix_stats = self.batch_generator.get_prefix_cache_stats() + else: + prefix_stats = { + "hits": 0, + "misses": 0, + "hit_rate": 0.0, + "evictions": 0, + "tokens_saved": 0, + "current_memory_mb": 0.0, + "max_memory_mb": 0.0, + "memory_utilization": 0.0, + "entry_count": 0, + } + stats["memory_aware_cache"] = prefix_stats return stats diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 811a6d4da..46e26a744 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -111,6 +111,8 @@ def _create_sampler( self, temperature: float = 0.7, top_p: float = 0.9, + top_k: int = 0, + min_p: float = 0.0, ): """Create a sampler for text generation.""" from mlx_lm.sample_utils import make_sampler @@ -118,16 +120,38 @@ def _create_sampler( return make_sampler( temp=temperature, top_p=top_p, + top_k=top_k, + min_p=min_p, ) + def _create_logits_processors( + self, + presence_penalty: float = 0.0, + repetition_penalty: float = 1.0, + ): + """Create logits processors for penalty-based sampling.""" + from mlx_lm.sample_utils import make_logits_processors + + processors = make_logits_processors( + repetition_penalty=( + repetition_penalty if repetition_penalty != 1.0 else None + ), + presence_penalty=presence_penalty if presence_penalty != 0.0 else None, + ) + return processors if processors else None + def generate( self, prompt: str, max_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, + top_k: int = 0, + min_p: float = 0.0, + presence_penalty: float = 0.0, repetition_penalty: float = 1.0, stop: list[str] | None = None, + **kwargs, ) -> GenerationOutput: """ Generate text from a prompt. @@ -137,7 +161,10 @@ def generate( max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (0 = greedy) top_p: Top-p (nucleus) sampling parameter - repetition_penalty: Penalty for repeating tokens + top_k: Top-k sampling (0 = disabled) + min_p: Minimum probability threshold + presence_penalty: Additive penalty for token presence + repetition_penalty: Multiplicative penalty for repeating tokens stop: List of stop sequences Returns: @@ -148,8 +175,11 @@ def generate( from mlx_lm import generate - # Create sampler with parameters - sampler = self._create_sampler(temperature, top_p) + # Create sampler and logits processors with full Unsloth params + sampler = self._create_sampler(temperature, top_p, top_k, min_p) + logits_processors = self._create_logits_processors( + presence_penalty, repetition_penalty + ) # Generate text output_text = generate( @@ -158,6 +188,7 @@ def generate( prompt=prompt, max_tokens=max_tokens, sampler=sampler, + logits_processors=logits_processors, verbose=False, ) @@ -179,8 +210,13 @@ def stream_generate( max_tokens: int = 256, temperature: float = 0.7, top_p: float = 0.9, + top_k: int = 0, + min_p: float = 0.0, + presence_penalty: float = 0.0, repetition_penalty: float = 1.0, stop: list[str] | None = None, + logits_processors: list | None = None, + **kwargs, ) -> Iterator[StreamingOutput]: """ Stream text generation token by token. @@ -190,7 +226,10 @@ def stream_generate( max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (0 = greedy) top_p: Top-p (nucleus) sampling parameter - repetition_penalty: Penalty for repeating tokens + top_k: Top-k sampling (0 = disabled) + min_p: Minimum probability threshold + presence_penalty: Additive penalty for token presence + repetition_penalty: Multiplicative penalty for repeating tokens stop: List of stop sequences Yields: @@ -201,8 +240,15 @@ def stream_generate( from mlx_lm import stream_generate - # Create sampler with parameters - sampler = self._create_sampler(temperature, top_p) + # Create sampler and logits processors with full Unsloth params + sampler = self._create_sampler(temperature, top_p, top_k, min_p) + penalty_processors = self._create_logits_processors( + presence_penalty, repetition_penalty + ) + # Merge any externally-provided logits_processors with penalty processors + all_processors = None + if penalty_processors or logits_processors: + all_processors = (logits_processors or []) + (penalty_processors or []) # Count prompt tokens once upfront num_prompt_tokens = len(self.tokenizer.encode(prompt)) @@ -220,6 +266,7 @@ def stream_generate( prompt=prompt, max_tokens=max_tokens, sampler=sampler, + logits_processors=all_processors, **mtp_kwargs, ): token_count += 1 diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index fcf3537f4..5a3551eb1 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -465,8 +465,9 @@ def save_base64_image(base64_string: str) -> str: """Save base64 image to temp file and return path. Caches identical images.""" import hashlib - # Hash the base64 string to check cache - image_hash = hashlib.md5(base64_string.encode()).hexdigest() + # Hash the full base64 string to prevent collisions between images + # with identical headers (e.g. JPEG images sharing first 1000 chars) + image_hash = hashlib.sha256(base64_string.encode()).hexdigest() # Return cached path if available and file still exists if image_hash in _base64_image_cache: @@ -1328,6 +1329,7 @@ def chat( video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) tools = kwargs.pop("tools", None) use_cache = kwargs.pop("use_cache", True) + enable_thinking = kwargs.pop("enable_thinking", True) # Collect video inputs from messages _msg_video_inputs = self._collect_video_inputs(messages) @@ -1453,11 +1455,11 @@ def chat( template_extra_kwargs["tools"] = tools try: - # Use get_chat_template directly since messages are already properly formatted formatted_prompt = get_chat_template( self.processor, chat_messages, add_generation_prompt=True, + enable_thinking=enable_thinking, **template_extra_kwargs, ) except Exception as e: @@ -1724,6 +1726,7 @@ def stream_chat( video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) tools = kwargs.pop("tools", None) use_cache = kwargs.pop("use_cache", True) + enable_thinking = kwargs.pop("enable_thinking", True) # Collect video inputs from messages _msg_video_inputs = self._collect_video_inputs(messages) @@ -1838,6 +1841,7 @@ def stream_chat( self.processor, chat_messages, add_generation_prompt=True, + enable_thinking=enable_thinking, **template_extra_kwargs, ) except Exception as e: diff --git a/vllm_mlx/patches/gemma4_mllm.py b/vllm_mlx/patches/gemma4_mllm.py new file mode 100644 index 000000000..dc041cf31 --- /dev/null +++ b/vllm_mlx/patches/gemma4_mllm.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime patch for mlx-vlm's Gemma 4 attention to support BatchKVCache. + +Gemma 4 Attention reads cache.offset into a local variable before calling +update_and_fetch, then uses the same variable later for RoPE on queries: + + offset = cache.offset # reference to mx.array([22]) + keys = self.rope(keys, offset=offset) + keys, values = cache.update_and_fetch(keys, values) + # ^^^ self.offset += 1 mutates the SAME mx.array in-place! + queries = self.rope(queries, offset=offset) # offset is now 23! + +For KVCache, cache.offset is a Python int (immutable), so the local copy +is unaffected. For BatchKVCache, cache.offset is an mx.array and +mx.array.__iadd__ is *in-place*, so the local reference is silently +mutated by update_and_fetch, giving queries the wrong RoPE position. + +This patch replaces Gemma4 Attention.__call__ with a version that +snapshots cache.offset as a defensive copy before any mutation can occur. +The mx.array copy preserves per-sequence offsets needed for correct RoPE +in continuous batching (unlike int conversion which would lose this info). +""" + +import logging +from typing import Any, Optional + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +def _snapshot_cache_offset(cache): + """Snapshot cache offset, making a defensive copy if it's an mx.array. + + BatchKVCache stores offset as mx.array (per-batch-item). + mx.array.__iadd__ is in-place, so update_and_fetch mutates the original. + We return a copy to preserve the pre-update value for RoPE on queries. + """ + if cache is None: + return 0 + off = cache.offset + if isinstance(off, int): + return off + if isinstance(off, mx.array): + return off + 0 # defensive copy — new array, same values + return off + + +def patch_gemma4_attention_for_batching() -> bool: + """Monkey-patch Gemma4 Attention.__call__ to snapshot offset before update. + + Returns True if patch was applied, False if mlx-vlm is not installed + or Gemma 4 module not available. + """ + try: + from mlx_vlm.models.gemma4.language import Attention as Gemma4Attention + from mlx_vlm.models.base import scaled_dot_product_attention + except ImportError: + logger.debug("[Gemma4 patch] mlx-vlm Gemma4 module not available") + return False + + if getattr(Gemma4Attention, "_batch_patched", False): + logger.debug("[Gemma4 patch] Already patched") + return True + + _orig_call = Gemma4Attention.__call__ + + def _patched_call( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, _ = x.shape + + queries = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim) + queries = self.q_norm(queries) + + # Snapshot offset BEFORE update_and_fetch can mutate it in-place. + # Preserves per-sequence mx.array offsets for correct batched RoPE. + offset = _snapshot_cache_offset(cache) + + if self.is_kv_shared_layer and cache is not None: + state = cache.state + keys, values = state[0], state[1] + else: + keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + + if self.use_k_eq_v: + values = keys + else: + values = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + + keys = self.k_norm(keys) + values = self.v_norm(values) + values = values.transpose(0, 2, 1, 3) + + keys = keys.transpose(0, 2, 1, 3) + keys = self.rope(keys, offset=offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + queries = queries.transpose(0, 2, 1, 3) + queries = self.rope(queries, offset=offset) + + if mask is not None and isinstance(mask, mx.array): + if mask.shape[-1] != keys.shape[-2]: + mask = mask[..., -keys.shape[-2] :] + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + Gemma4Attention.__call__ = _patched_call + Gemma4Attention._batch_patched = True + logger.info("[Gemma4 patch] Attention patched for BatchKVCache support") + return True diff --git a/vllm_mlx/patches/qwen3_5_mllm.py b/vllm_mlx/patches/qwen3_5_mllm.py new file mode 100644 index 000000000..c592928da --- /dev/null +++ b/vllm_mlx/patches/qwen3_5_mllm.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime patch for mlx-vlm's Qwen3.5 attention to support BatchKVCache. + +mlx-vlm's Qwen3_5Attention uses cache.offset directly for kv_seq_len +computation and mask slicing. BatchKVCache stores offset as mx.array +(per-batch-item), not int, causing: + + mask = mask[..., :kv_seq_len] + ValueError: Slice indices must be integers or None. + +This patch replaces Qwen3_5Attention.__call__ with a version that +converts cache.offset to int before using it for arithmetic/slicing, +while leaving the actual cache.offset untouched so update_and_fetch +still works correctly with per-batch offsets. +""" + +import logging +from typing import Optional + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +def _cache_offset_to_int(cache) -> int: + """Extract cache offset as int, handling BatchKVCache mx.array offset.""" + if cache is None: + return 0 + off = cache.offset + if isinstance(off, int): + return off + if isinstance(off, mx.array): + return int(off.max().item()) if off.ndim > 0 else int(off.item()) + return int(off) + + +def patch_qwen35_attention_for_batching() -> bool: + """Monkey-patch Qwen3_5Attention.__call__ to handle BatchKVCache. + + Returns True if patch was applied, False if mlx-vlm is not installed + or Qwen3.5 module not available. + """ + try: + from mlx_vlm.models.qwen3_5.language import ( + Qwen3_5Attention, + apply_multimodal_rotary_pos_emb, + ) + from mlx_lm.models.base import scaled_dot_product_attention + except ImportError: + logger.debug("[Qwen3.5 patch] mlx-vlm Qwen3.5 module not available") + return False + + if getattr(Qwen3_5Attention, "_batch_patched", False): + logger.debug("[Qwen3.5 patch] Already patched") + return True + + def _patched_call( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache=None, + position_ids: Optional[mx.array] = None, + ) -> mx.array: + B, L, D = x.shape + + q_proj_output = self.q_proj(x) + queries, gate = mx.split( + q_proj_output.reshape(B, L, self.num_attention_heads, -1), + 2, + axis=-1, + ) + gate = gate.reshape(B, L, -1) + + keys, values = self.k_proj(x), self.v_proj(x) + + queries = self.q_norm(queries).transpose(0, 2, 1, 3) + keys = self.k_norm(keys.reshape(B, L, self.num_key_value_heads, -1)).transpose( + 0, 2, 1, 3 + ) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + kv_seq_len = keys.shape[-2] + + # Convert cache.offset to int for slice compatibility. + # BatchKVCache stores offset as mx.array (per-batch-item), + # but kv_seq_len must be int for mask[..., :kv_seq_len]. + _offset = _cache_offset_to_int(cache) + + if position_ids is None: + kv_seq_len += _offset + 1 + position_ids = mx.arange(_offset, _offset + L) + position_ids = mx.expand_dims(position_ids, axis=0) + position_ids = mx.tile(position_ids, (3, 1, 1)) + else: + kv_seq_len += _offset + 1 if cache is not None else 0 + + cos, sin = self.rotary_emb(values, position_ids) + + if mask is not None and isinstance(mask, mx.array): + mask = mask[..., :kv_seq_len] + + queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output * mx.sigmoid(gate)) + + Qwen3_5Attention.__call__ = _patched_call + Qwen3_5Attention._batch_patched = True + logger.info("[Qwen3.5 patch] Attention patched for BatchKVCache support") + return True diff --git a/vllm_mlx/patches/qwen3_5_mtp.py b/vllm_mlx/patches/qwen3_5_mtp.py new file mode 100644 index 000000000..3d5f3e632 --- /dev/null +++ b/vllm_mlx/patches/qwen3_5_mtp.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime MTP (Multi-Token Prediction) support for Qwen3.5 models. + +Qwen3.5 models may include a built-in MTP head that predicts token n+2 +from hidden states + token n+1. MTP weights are added to the quantized +MLX model via scripts/add_mtp_weights_qwen35.py. + +Since mlx_lm's qwen3_5.py does NOT define MTP module/methods, this +module provides: + - inject_mtp_support(): dynamically creates MTP module, loads weights, + and monkey-patches the model class with return_hidden, mtp_forward, + and make_mtp_cache + - validate_mtp_support(): checks whether a loaded model has working MTP + +Supports both Dense (27B) and MoE (122B-A10B, 35B-A3B) architectures. + +The actual MTP scheduling logic lives in: + - vllm_mlx/scheduler.py (_install_mtp, _mtp_step, _mtp_next) +""" + +import logging +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +def _fixup_moe_mtp(mtp, inner_model, loaded_keys: set, mx) -> None: + """Fix missing weights in MoE MTP module. + + MoE MTP checkpoints (122B, 35B) only contain: fc, q_proj, o_proj, + shared_expert.*, and per-expert weights. Missing: + - k_proj, v_proj → zero out (attention becomes no-op) + - gate, shared_expert_gate → copy from main model's last full-attn layer + - norms → already at identity (weight=1.0), no action needed + """ + import mlx.utils + + mtp_layer = mtp.layers[0] + + # Find last full-attention layer in main model for gate weights + last_fa_layer = None + for layer in reversed(inner_model.layers): + if not layer.is_linear: + last_fa_layer = layer + break + + if last_fa_layer is None: + logger.warning("[MTP fixup] No full-attention layer found in main model") + return + + # Copy expert routing gate if not in checkpoint + if "layers.0.mlp.gate.weight" not in loaded_keys: + src = getattr(last_fa_layer.mlp, "gate", None) + dst = getattr(mtp_layer.mlp, "gate", None) + if src is not None and dst is not None: + src_params = mlx.utils.tree_flatten(src.parameters()) + dst.load_weights(src_params) + mx.eval(dst.parameters()) + logger.info("[MTP fixup] Copied mlp.gate from main model last layer") + + # Copy shared_expert_gate if not in checkpoint + if "layers.0.mlp.shared_expert_gate.weight" not in loaded_keys: + src = getattr(last_fa_layer.mlp, "shared_expert_gate", None) + dst = getattr(mtp_layer.mlp, "shared_expert_gate", None) + if src is not None and dst is not None: + src_params = mlx.utils.tree_flatten(src.parameters()) + dst.load_weights(src_params) + mx.eval(dst.parameters()) + logger.info( + "[MTP fixup] Copied shared_expert_gate from main model last layer" + ) + + # Zero out k_proj and v_proj → attention becomes no-op + attn = getattr(mtp_layer, "self_attn", None) + if attn is None: + return + + for proj_name in ("k_proj", "v_proj"): + key = f"layers.0.self_attn.{proj_name}.weight" + if key not in loaded_keys: + proj = getattr(attn, proj_name, None) + if proj is None: + continue + # For quantized layers: zero scales+biases → dequantized = 0 + if hasattr(proj, "scales"): + proj.scales = mx.zeros_like(proj.scales) + proj.biases = mx.zeros_like(proj.biases) + else: + proj.weight = mx.zeros_like(proj.weight) + mx.eval(proj.parameters()) + logger.info(f"[MTP fixup] Zeroed {proj_name} (not in checkpoint)") + + +def inject_mtp_support(model: Any, model_path, config: dict) -> bool: + """Inject MTP module into a loaded Qwen3.5 model. + + mlx_lm's qwen3_5.py does not define MTP layers, so we: + 1. Create MTP module matching the weight structure + 2. Quantize it to match the base model + 3. Load MTP weights from model-mtp.safetensors + 4. Monkey-patch Model with return_hidden, mtp_forward, make_mtp_cache + + Args: + model: A model loaded via mlx_lm (strict=False, MTP weights ignored) + model_path: Path to model directory (contains model-mtp.safetensors) + config: Parsed config.json dict + + Returns: + True if MTP was successfully injected, False otherwise. + """ + import mlx.core as mx + import mlx.nn as nn + + # Navigate nested config: text_config for VLM wrappers + text_config = config.get("text_config", config) + num_mtp_layers = text_config.get("mtp_num_hidden_layers", 0) + if num_mtp_layers == 0: + # Fallback: check flat config for num_nextn_predict_layers + num_mtp_layers = text_config.get( + "num_nextn_predict_layers", + config.get("num_nextn_predict_layers", 0), + ) + if num_mtp_layers == 0: + logger.info("[MTP inject] No MTP layers configured, skipping") + return False + + model_path = Path(model_path) + # Look for MTP weights in mtp/ subdirectory first (avoids mlx_vlm glob), + # then fall back to model-mtp.safetensors in model dir. + mtp_file = model_path / "mtp" / "weights.safetensors" + if not mtp_file.exists(): + mtp_file = model_path / "model-mtp.safetensors" + if not mtp_file.exists(): + logger.warning(f"[MTP inject] MTP weights not found in {model_path}") + return False + + # Get model args — navigate VLM wrapper if needed + # Model hierarchy: Model → language_model (TextModel) → model (Qwen3_5TextModel) + text_model = model + if hasattr(model, "language_model"): + text_model = model.language_model + + args = text_model.args + + # When loaded via mlx_vlm, args may be a TextConfig object missing fields + # that mlx_lm's TextModelArgs defines (rope_theta, partial_rotary_factor, + # rope_scaling, etc.). Build a proper TextModelArgs from the config dict. + from mlx_lm.models.qwen3_5 import TextModelArgs + + if not isinstance(args, TextModelArgs): + logger.info("[MTP inject] Building TextModelArgs from config dict") + args = TextModelArgs.from_dict(text_config) + + # Detect MoE vs Dense from args + num_experts = getattr(args, "num_experts", 0) + is_moe = num_experts > 0 + + # Import model components + from mlx_lm.models.base import create_attention_mask, create_ssm_mask + from mlx_lm.models.cache import KVCache + from mlx_lm.models.qwen3_5 import DecoderLayer + + logger.info( + f"[MTP inject] Creating MTP module ({num_mtp_layers} layers, " + f"{'MoE' if is_moe else 'Dense'})" + ) + + # MTP decoder uses full attention (not GatedDeltaNet). + # layer_idx = full_attention_interval - 1 ensures is_linear=False. + fa_idx = args.full_attention_interval - 1 + + class _MTPModule(nn.Module): + def __init__(self, args, n_layers): + super().__init__() + self.pre_fc_norm_hidden = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.pre_fc_norm_embedding = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + self.layers = [ + DecoderLayer(args, layer_idx=fa_idx) for _ in range(n_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + mtp = _MTPModule(args, num_mtp_layers) + + # --- Load MTP weights in BF16 (no quantization) --- + # MTP head is extremely sensitive to quantization — even 4-bit destroys + # prediction quality (0% acceptance). Keep MTP in full precision. + # See: https://github.com/vllm-project/vllm/issues/36331 + quant_config = text_config.get("quantization", config.get("quantization", {})) + bits = quant_config.get("bits", 4) if quant_config else 4 + group_size = quant_config.get("group_size", 64) if quant_config else 64 + + logger.info( + f"[MTP inject] Loading weights from {mtp_file.name} (BF16, no quantization)" + ) + raw = mx.load(str(mtp_file)) + raw_mtp = { + k.removeprefix("mtp."): v for k, v in raw.items() if k.startswith("mtp.") + } + del raw + + # Dequantize any quantized weight triplets (weight + scales + biases) + mtp_weights: dict[str, mx.array] = {} + processed = set() + for key in sorted(raw_mtp.keys()): + if key in processed: + continue + if key.endswith(".scales") or key.endswith(".biases"): + continue + + scales_key = key.replace(".weight", ".scales") + biases_key = key.replace(".weight", ".biases") + + if scales_key in raw_mtp and biases_key in raw_mtp: + # Quantized triplet → dequantize to BF16 + dq = mx.dequantize( + raw_mtp[key], + raw_mtp[scales_key], + raw_mtp[biases_key], + group_size=group_size, + bits=bits, + ) + mtp_weights[key] = dq + processed.update([key, scales_key, biases_key]) + else: + # Already FP (norms, fc, shared_expert_gate) + mtp_weights[key] = raw_mtp[key] + processed.add(key) + del raw_mtp + + mtp.load_weights(list(mtp_weights.items()), strict=False) + mx.eval(mtp.parameters()) + + dq_count = sum(1 for k in mtp_weights if not k.endswith((".scales", ".biases"))) + has_quantized = any(k.endswith(".scales") for k in processed) + mode = "dequantized from quantized" if has_quantized else "native BF16" + logger.info(f"[MTP inject] Loaded {dq_count} MTP weight tensors ({mode})") + + # --- Step 4: Fix missing MoE MTP weights --- + # MoE checkpoints lack: k_proj, v_proj, gate, shared_expert_gate, norms. + # Norms default to identity (weight=1.0) which is correct. + # k_proj/v_proj: zero out → attention becomes no-op, MLP does prediction. + # gate/shared_expert_gate: copy from main model's last full-attention layer. + if is_moe: + loaded_key_set = set(mtp_weights.keys()) + _fixup_moe_mtp(mtp, text_model.model, loaded_key_set, mx) + + # --- Attach MTP and monkey-patch model class --- + text_model.mtp = mtp + + original_class = text_model.__class__ + + class _Qwen3_5MTP(original_class): + """Qwen3.5 with MTP support (injected at runtime).""" + + def __call__( + self, + inputs, + cache=None, + return_hidden: bool = False, + input_embeddings=None, + **kwargs, + ): + inner = self.model + if input_embeddings is not None: + hidden_states = input_embeddings + else: + hidden_states = inner.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(inner.layers) + + fa_mask = create_attention_mask(hidden_states, cache[inner.fa_idx]) + ssm_mask = create_ssm_mask(hidden_states, cache[inner.ssm_idx]) + + for layer, c in zip(inner.layers, cache): + mask = ssm_mask if layer.is_linear else fa_mask + hidden_states = layer(hidden_states, mask=mask, cache=c) + + normed = inner.norm(hidden_states) + + if self.args.tie_word_embeddings: + out = inner.embed_tokens.as_linear(normed) + else: + out = self.lm_head(normed) + + if return_hidden: + return out, normed # post-norm hidden states (MTP expects post-norm) + return out + + def mtp_forward( + self, + hidden_states, + next_token_ids, + cache=None, + mtp_cache=None, + ): + """Run MTP head: predict token n+2 from hidden states + token n+1.""" + input_embeds = self.model.embed_tokens(next_token_ids) + e = self.mtp.pre_fc_norm_embedding(input_embeds) + h = self.mtp.pre_fc_norm_hidden(hidden_states) + x = self.mtp.fc(mx.concatenate([e, h], axis=-1)) + + layer = self.mtp.layers[0] + c = mtp_cache[0] if mtp_cache else None + mask = create_attention_mask(x, c) + x = layer(x, mask=mask, cache=c) + + x = self.mtp.norm(x) + + if self.args.tie_word_embeddings: + return self.model.embed_tokens.as_linear(x) + return self.lm_head(x) + + def make_mtp_cache(self): + """Create KV cache for MTP layers.""" + if self.mtp is None: + return None + return [KVCache() for _ in self.mtp.layers] + + text_model.__class__ = _Qwen3_5MTP + logger.info("[MTP inject] Model class patched with MTP support") + + # If we patched the inner language_model, also expose MTP on the outer Model + if hasattr(model, "language_model") and model.language_model is text_model: + model.mtp = mtp + + return True + + +def validate_mtp_support(model: Any) -> bool: + """Validate that a loaded model has working MTP support. + + Checks: + 1. model.mtp exists and is not None + 2. model.mtp has layers with loaded weights + 3. model has return_hidden support in __call__ + 4. model has mtp_forward method + 5. model has make_mtp_cache method + + Args: + model: A model loaded via mlx_lm.load() + + Returns: + True if MTP is fully functional, False otherwise. + """ + # Navigate to text model if VLM wrapper + text_model = model + if hasattr(model, "language_model"): + text_model = model.language_model + + mtp = getattr(text_model, "mtp", None) + if mtp is None: + args = getattr(text_model, "args", None) + if args is not None: + num_mtp = getattr(args, "mtp_num_hidden_layers", 0) + if num_mtp == 0: + num_mtp = getattr(args, "num_nextn_predict_layers", 0) + if num_mtp > 0: + logger.warning( + "[MTP] Model config has MTP layers=%d but model.mtp is None. " + "Run scripts/add_mtp_weights_qwen35.py to add weights.", + num_mtp, + ) + return False + + mtp_layers = getattr(mtp, "layers", []) + if not mtp_layers: + logger.warning("[MTP] model.mtp exists but has no layers.") + return False + + import inspect + + call_sig = inspect.signature(type(text_model).__call__) + if "return_hidden" not in call_sig.parameters: + logger.warning("[MTP] Model.__call__ does not accept return_hidden parameter.") + return False + + if not hasattr(text_model, "mtp_forward") or not callable(text_model.mtp_forward): + logger.warning("[MTP] Model does not have mtp_forward() method.") + return False + + if not hasattr(text_model, "make_mtp_cache") or not callable( + text_model.make_mtp_cache + ): + logger.warning("[MTP] Model does not have make_mtp_cache() method.") + return False + + logger.info( + "[MTP] Qwen3.5 model has working MTP support: %d MTP layer(s)", + len(mtp_layers), + ) + return True diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index e8f47a324..a419f3973 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -586,7 +586,7 @@ def store_cache( # Extract and store actual tensor slices for this block if is_tensor_data and HAS_MLX: block_kv_data = self._extract_block_tensor_slice( - cache_data, global_start, global_end + cache_data, global_start, global_end, len(tokens) ) if block_kv_data: block.cache_data = block_kv_data @@ -629,56 +629,122 @@ def _extract_block_tensor_slice( cache_data: List[Dict[str, Any]], start_idx: int, end_idx: int, - ) -> Optional[List[Tuple[Any, Any]]]: + total_tokens: int, + ) -> Optional[List[Optional[Dict[str, Any]]]]: """ - Extract tensor slices for a single block from cache data. + Extract per-layer cache data for a single block. Args: - cache_data: List of layer states, each containing 'state': (keys, values) + cache_data: List of extracted layer states start_idx: Start token index in the sequence end_idx: End token index in the sequence + total_tokens: Total number of tokens covered by cache_data Returns: - List of (keys_slice, values_slice) for each layer, or None on failure + Per-layer block cache state, or None on failure """ if not HAS_MLX or not cache_data: return None try: - block_slices = [] + block_slices: List[Optional[Dict[str, Any]]] = [] for layer_state in cache_data: if "state" not in layer_state: + block_slices.append(None) continue - keys, values = layer_state["state"] + state = layer_state["state"] + meta_state = layer_state.get("meta_state") + class_ref = layer_state.get("class_ref") + class_name = layer_state.get("class_name") - # KV cache shape: (batch, n_kv_heads, seq_len, head_dim) - # Slice along seq_len dimension (axis 2) - seq_len = keys.shape[2] if hasattr(keys, "shape") else 0 + if self._can_concatenate_cache_state(state): + state_slice = self._slice_concat_cache_state( + state, start_idx, end_idx + ) + block_slices.append( + { + "state": state_slice, + "meta_state": meta_state, + "class_ref": class_ref, + "class_name": class_name, + "storage": "concat", + "seq_axis": 2, + } + ) + continue - if end_idx > seq_len: - # Requested range extends beyond available data - logger.debug( - f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" + if end_idx == total_tokens: + block_slices.append( + { + "state": state, + "meta_state": meta_state, + "class_ref": class_ref, + "class_name": class_name, + "storage": "latest", + } ) - # Use whatever is available - actual_end = min(end_idx, seq_len) - if start_idx >= actual_end: - continue - keys_slice = keys[:, :, start_idx:actual_end, :] - values_slice = values[:, :, start_idx:actual_end, :] else: - keys_slice = keys[:, :, start_idx:end_idx, :] - values_slice = values[:, :, start_idx:end_idx, :] + block_slices.append(None) - block_slices.append((keys_slice, values_slice)) - - return block_slices if block_slices else None + return ( + block_slices + if any(entry is not None for entry in block_slices) + else None + ) except Exception as e: logger.warning(f"Failed to extract block tensor slice: {e}") return None + def _can_concatenate_cache_state(self, state: Any) -> bool: + """Return True when cache state can be concatenated block-by-block.""" + if not isinstance(state, (list, tuple)) or not state: + return False + return all( + tensor is not None and hasattr(tensor, "shape") and len(tensor.shape) == 4 + for tensor in state + ) + + def _slice_concat_cache_state( + self, + state: Tuple[Any, ...] | List[Any], + start_idx: int, + end_idx: int, + ) -> Tuple[Any, ...] | List[Any]: + """Slice a sequence-backed cache state across the token axis.""" + seq_len = state[0].shape[2] + actual_end = min(end_idx, seq_len) + if start_idx >= actual_end: + raise ValueError( + f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" + ) + + def _slice_tensor(tensor: Any) -> Any: + slices = [slice(None)] * len(tensor.shape) + slices[2] = slice(start_idx, actual_end) + return tensor[tuple(slices)] + + sliced = [_slice_tensor(tensor) for tensor in state] + return tuple(sliced) if isinstance(state, tuple) else sliced + + def _concat_cache_states( + self, + states: List[Tuple[Any, ...] | List[Any]], + seq_axis: int, + ) -> Optional[Tuple[Any, ...] | List[Any]]: + """Concatenate state fragments for a sequence-backed cache layer.""" + if not states: + return None + arity = len(states[0]) + concatenated = [] + for idx in range(arity): + parts = [state[idx] for state in states] + if any(part is None for part in parts): + return None + concatenated.append(mx.concatenate(parts, axis=seq_axis)) + return tuple(concatenated) if isinstance(states[0], tuple) else concatenated + def get_cache_for_generation( self, request_id: str, @@ -763,10 +829,11 @@ def reconstruct_cache( block_table: BlockTable, ) -> Optional[List[Any]]: """ - Reconstruct KVCache objects from stored block tensor data. + Reconstruct cache objects from stored block tensor data. - This method concatenates tensor slices from all blocks and - creates new KVCache objects that can be used for inference. + Sequence-backed caches are concatenated block-by-block. Recurrent + caches such as ArraysCache are restored from the latest sequence + boundary snapshot that was actually stored. Args: block_table: BlockTable containing block IDs to reconstruct from @@ -800,67 +867,62 @@ def reconstruct_cache( if not all_block_data: return None - # Get number of layers from first block - num_layers = len(all_block_data[0]) + # Get number of layers from the richest block + num_layers = max(len(block_data) for block_data in all_block_data) if num_layers == 0: return None - # Concatenate tensors for each layer reconstructed_caches = [] - for layer_idx in range(num_layers): - layer_keys = [] - layer_values = [] + layer_entries = [ + block_data[layer_idx] + for block_data in all_block_data + if layer_idx < len(block_data) + ] + layer_entries = [entry for entry in layer_entries if entry is not None] + if not layer_entries: + return None - for block_data in all_block_data: - if layer_idx < len(block_data): - keys_slice, values_slice = block_data[layer_idx] - layer_keys.append(keys_slice) - layer_values.append(values_slice) + layer_meta = layer_entries[-1] + state = layer_meta["state"] + if layer_meta["storage"] == "concat": + state = self._concat_cache_states( + [entry["state"] for entry in layer_entries], + layer_meta["seq_axis"], + ) + elif layer_meta["storage"] == "latest": + state = layer_entries[-1]["state"] - if not layer_keys: - continue + if state is None: + return None - # Concatenate along sequence dimension (axis 2) - # Shape: (batch, n_kv_heads, seq_len, head_dim) - concat_keys = mx.concatenate(layer_keys, axis=2) - concat_values = mx.concatenate(layer_values, axis=2) + cache_cls = layer_meta.get("class_ref") + meta_state = layer_meta.get("meta_state") - # Create KVCache object - # Try to use mlx_lm's KVCache.from_state if available - try: + if cache_cls is not None and hasattr(cache_cls, "from_state"): + from mlx_lm.models.cache import ( + BatchKVCache as _BatchKVCache, + KVCache as _KVCache, + ) + + if cache_cls is _BatchKVCache: + keys, values = state[0], state[1] + cache = _KVCache() + cache.keys = keys + cache.values = values + cache.offset = keys.shape[2] + else: + cache = cache_cls.from_state(state, meta_state) + else: from mlx_lm.models.cache import KVCache - # Create new cache and set its state + if len(state) != 2: + return None cache = KVCache() - seq_len = concat_keys.shape[2] - - # Set internal state directly - # KVCache stores keys/values and offset - cache.keys = concat_keys - cache.values = concat_values - cache.offset = seq_len - - reconstructed_caches.append(cache) - - except ImportError: - # Fallback: create a simple cache-like object - class SimpleKVCache: - def __init__(self, keys, values): - self.keys = keys - self.values = values - self.offset = keys.shape[2] - - @property - def state(self): - return (self.keys, self.values) - - @property - def meta_state(self): - return (str(self.offset),) - - cache = SimpleKVCache(concat_keys, concat_values) - reconstructed_caches.append(cache) + cache.keys, cache.values = state + cache.offset = cache.keys.shape[2] + + reconstructed_caches.append(cache) if not reconstructed_caches: return None diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py index f138796ff..49d13a26b 100644 --- a/vllm_mlx/reasoning/__init__.py +++ b/vllm_mlx/reasoning/__init__.py @@ -76,6 +76,7 @@ def list_parsers() -> list[str]: def _register_builtin_parsers(): """Register built-in parsers.""" from .deepseek_r1_parser import DeepSeekR1ReasoningParser + from .gemma4_parser import Gemma4ReasoningParser from .gpt_oss_parser import GptOssReasoningParser from .harmony_parser import HarmonyReasoningParser from .qwen3_parser import Qwen3ReasoningParser @@ -84,6 +85,7 @@ def _register_builtin_parsers(): register_parser("deepseek_r1", DeepSeekR1ReasoningParser) register_parser("gpt_oss", GptOssReasoningParser) register_parser("harmony", HarmonyReasoningParser) + register_parser("gemma4", Gemma4ReasoningParser) # Register built-in parsers on module load diff --git a/vllm_mlx/reasoning/gemma4_parser.py b/vllm_mlx/reasoning/gemma4_parser.py new file mode 100644 index 000000000..8b6dd8149 --- /dev/null +++ b/vllm_mlx/reasoning/gemma4_parser.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for Gemma 4 models. + +Gemma 4 uses a channel-based protocol for reasoning: + + <|channel>thought + ...thinking content... + + ...response content... + +Where: + <|channel> = token 100 (channel switch marker) + = token 101 (end-of-channel marker) + +The channel names "thought" and "response" appear as text after the +special tokens and should be stripped from the output. + +Some model variants may use <|channel>response instead of +to transition from thinking to response mode. This parser handles both. + +When thinking is disabled or not triggered, output contains no tags. +""" + +from .base import DeltaMessage +from .think_parser import BaseThinkingReasoningParser + +# Channel names that follow <|channel> — stripped from output +_THOUGHT_PREFIX = "thought" +_RESPONSE_MARKER = "<|channel>response" + + +def _strip_channel_name(text: str, prefix: str) -> str: + """Strip channel name and leading whitespace/newline from text start.""" + if text.startswith(prefix): + text = text[len(prefix) :] + return text.lstrip("\n") + + +class Gemma4ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Gemma 4 models. + + Handles two transition formats: + 1. <|channel>thought...response (standard: token 100 + 101) + 2. <|channel>thought...<|channel>response (alternative: token 100 + 100) + + Channel names ("thought", "response") are stripped from output. + + Example: + Input: "<|channel>thought\\nLet me think...The answer is 42." + Output: reasoning="Let me think...", content="The answer is 42." + + When no tags are present, the entire output is treated as content. + """ + + @property + def start_token(self) -> str: + return "<|channel>" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from complete output. + + Handles both and <|channel>response as transition markers. + Strips channel names ("thought", "response") from output. + """ + text = model_output + + # Try standard format first: <|channel>thought...response + if self.start_token in text and self.end_token in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(self.end_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.strip() + return reasoning or None, content or None + + # Try alternative format: <|channel>thought...<|channel>response... + if text.count(self.start_token) >= 2 and _RESPONSE_MARKER in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(_RESPONSE_MARKER) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.lstrip("\n").strip() + return reasoning or None, content or None + + # Only closing tag (think injected in prompt) + if self.end_token in text: + reasoning, _, content = text.partition(self.end_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.strip() + return reasoning or None, content or None + + # Only start tag (incomplete reasoning, no end yet) + if self.start_token in text: + _, _, reasoning = text.partition(self.start_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + return reasoning or None, None + + # No tags at all — pure content + return None, model_output + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming delta. + + Handles: + - No tags: treat as content (Gemma 4 doesn't inject tags in prompt) + - <|channel>thought: enter reasoning mode, strip channel name + - or <|channel>response: transition to content mode + """ + # No channel tokens at all — plain content + if self.start_token not in current_text and self.end_token not in current_text: + return DeltaMessage(content=delta_text) + + # Check for alternative transition: <|channel>response + if _RESPONSE_MARKER in current_text: + if _RESPONSE_MARKER not in previous_text: + # Transition happening in this delta + # Find what (if any) content comes after the marker + marker_pos = current_text.find(_RESPONSE_MARKER) + after_marker = current_text[marker_pos + len(_RESPONSE_MARKER) :] + after_marker = after_marker.lstrip("\n") + if after_marker: + return DeltaMessage(content=after_marker) + return None # Suppress the marker itself + else: + # Already past transition — pure content + # But we need to only emit the NEW text (delta) + return DeltaMessage(content=delta_text) + + # Delegate to base class for standard <|channel>/ handling + result = super().extract_reasoning_streaming( + previous_text, current_text, delta_text + ) + + # Strip "thought" channel name from initial reasoning + if result is not None and result.reasoning is not None: + r = result.reasoning + # First reasoning delta after <|channel> will be "thought" or "thought\n" + if self.start_token in current_text: + # Check if this is the very first reasoning content + after_channel = current_text.split(self.start_token, 1)[1] + if after_channel.startswith(_THOUGHT_PREFIX): + # Remove "thought" prefix from the accumulated reasoning so far + clean = after_channel[len(_THOUGHT_PREFIX) :].lstrip("\n") + # Compute what portion of clean text is in this delta + prev_after = "" + if self.start_token in previous_text: + prev_after = previous_text.split(self.start_token, 1)[1] + if prev_after.startswith(_THOUGHT_PREFIX): + prev_after = prev_after[len(_THOUGHT_PREFIX) :].lstrip("\n") + # The new reasoning text is clean minus what was already emitted + new_reasoning = clean[len(prev_after) :] + if new_reasoning: + return DeltaMessage(reasoning=new_reasoning) + return None # Suppress channel name token + + return result diff --git a/vllm_mlx/reasoning/think_parser.py b/vllm_mlx/reasoning/think_parser.py index 136348206..a2e9cb727 100644 --- a/vllm_mlx/reasoning/think_parser.py +++ b/vllm_mlx/reasoning/think_parser.py @@ -9,6 +9,12 @@ 1. Both tags in output: reasoningcontent 2. Only closing tag (think injected in prompt): reasoningcontent 3. No tags: pure content + +Performance: The streaming parser uses a simple state machine to track the +current phase (pre-think / thinking / content). Tag completion is detected +against the accumulated text for correctness when `` / `` are +split across delta boundaries, but phase tracking still avoids the old +whole-output rescanning behavior. """ from abc import abstractmethod @@ -27,8 +33,12 @@ class BaseThinkingReasoningParser(ReasoningParser): and only appears in the model output. This is common with AI agents like OpenCode that force models to reason by injecting thinking tags. - The parser tracks state during streaming to correctly separate reasoning - from content as tokens arrive incrementally. + The streaming parser uses a state machine with three phases: + + pre_think -> thinking -> content + + Transitions are tracked by parser state. Accumulated text is consulted only + to detect when a start/end tag has completed across delta boundaries. """ @property @@ -43,6 +53,12 @@ def end_token(self) -> str: def __init__(self, tokenizer=None): super().__init__(tokenizer) + # Streaming state — reset per request via reset_state() + self._phase: str = "pre_think" # "pre_think" | "thinking" | "content" + + def reset_state(self): + """Reset state machine for a new streaming request.""" + self._phase = "pre_think" def extract_reasoning( self, @@ -66,14 +82,11 @@ def extract_reasoning( # Case 1: Both tags present (normal case) if self.start_token in text and self.end_token in text: - # Get everything after start token _, _, after_start = text.partition(self.start_token) - # Split on end token reasoning, _, content = after_start.partition(self.end_token) return reasoning.strip() or None, content.strip() or None # Case 2: Only closing tag (think was injected in prompt) - # Everything before is reasoning if self.end_token in text: reasoning, _, content = text.partition(self.end_token) return reasoning.strip() or None, content.strip() or None @@ -83,7 +96,7 @@ def extract_reasoning( _, _, reasoning = text.partition(self.start_token) return reasoning.strip() or None, None - # Case 4: No tags at all - pure content + # Case 4: No tags at all — pure content return None, model_output def extract_reasoning_streaming( @@ -93,123 +106,99 @@ def extract_reasoning_streaming( delta_text: str, ) -> DeltaMessage | None: """ - Extract reasoning from streaming delta using text-based detection. + Extract reasoning from a streaming delta using state-machine tracking. + + Instead of rescanning the full accumulated text on every token, this + method tracks the current phase (pre_think / thinking / content) and + only consults accumulated text to detect completed start/end tags that + were split across delta boundaries. - Handles implicit reasoning mode where was in the prompt - and only appears in the output. + Handles three scenarios: + 1. Explicit ... in model output + 2. Implicit mode ( in prompt, only in output) + 3. No tags at all (pure content after first token with no reasoning) Args: previous_text: Text accumulated before this delta. current_text: Text including this delta. - delta_text: Just the new text. + delta_text: Just the new text in this chunk. Returns: - DeltaMessage with reasoning/content, or None to skip. + DeltaMessage with reasoning and/or content, or None to skip. """ - # Skip if delta is just the special tokens themselves - stripped_delta = delta_text.strip() - if stripped_delta == self.start_token: - return None - if stripped_delta == self.end_token: + if not delta_text: return None - # Check token positions in text (stateless text-based detection) - start_in_prev = self.start_token in previous_text - start_in_current = self.start_token in current_text - end_in_prev = self.end_token in previous_text - end_in_delta = self.end_token in delta_text - - # Case 1: Explicit found in text - standard behavior - if start_in_current: - return self._handle_explicit_think( - previous_text, delta_text, start_in_prev, end_in_prev, end_in_delta - ) - - # Case 2: No but found - implicit reasoning mode - # This handles when was injected in the prompt - if self.end_token in current_text: - return self._handle_implicit_think(delta_text, end_in_prev, end_in_delta) - - # Case 3: No think tags seen yet - # We can't know if was in the prompt, so we must make a choice: - # - Treat as content (safe, but loses reasoning if think was in prompt) - # - Treat as reasoning (risky, wrong if no thinking at all) - # We choose to treat as reasoning IF we haven't seen yet, - # because if think was in prompt, we want to capture the reasoning. - # This will be corrected once is seen. - return DeltaMessage(reasoning=delta_text) - - def _handle_explicit_think( - self, - previous_text: str, - delta_text: str, - start_in_prev: bool, - end_in_prev: bool, - end_in_delta: bool, - ) -> DeltaMessage | None: - """Handle case where tag is explicitly in the output.""" - start_in_delta = self.start_token in delta_text - - if start_in_prev: - # We're after the start token - if end_in_delta: - # Transition: end token in this delta - idx = delta_text.find(self.end_token) - reasoning_part = delta_text[:idx] - content_part = delta_text[idx + len(self.end_token) :] + start_tok = self.start_token + end_tok = self.end_token + + # ── Phase: pre_think ────────────────────────────────────── + # Haven't seen a completed tag yet. Could be: + # - About to see (explicit reasoning) + # - Already inside implicit reasoning (think was in prompt) + # - No reasoning at all (pure content model) + if self._phase == "pre_think": + if start_tok in current_text: + self._phase = "thinking" + idx = delta_text.find(start_tok) + after = delta_text[idx + len(start_tok) :] if idx >= 0 else delta_text + + if end_tok in after: + self._phase = "content" + eidx = after.find(end_tok) + reasoning = after[:eidx] + content = after[eidx + len(end_tok) :] + if not reasoning and not content: + return None + return DeltaMessage( + reasoning=reasoning or None, + content=content or None, + ) + return DeltaMessage(reasoning=after) if after else None + + # Implicit mode: completed without an explicit . + if end_tok in current_text: + self._phase = "content" + idx = delta_text.find(end_tok) + if idx >= 0: + reasoning = delta_text[:idx] + content = delta_text[idx + len(end_tok) :] + else: + reasoning = None + content = delta_text + if not reasoning and not content: + return None return DeltaMessage( - reasoning=reasoning_part if reasoning_part else None, - content=content_part if content_part else None, + reasoning=reasoning or None, + content=content or None, ) - elif end_in_prev: - # Already past reasoning phase - pure content - return DeltaMessage(content=delta_text) - else: - # Still in reasoning phase - return DeltaMessage(reasoning=delta_text) - - elif start_in_delta: - # Start token is in this delta - start_idx = delta_text.find(self.start_token) - - if end_in_delta: - # Both tokens in this delta - end_idx = delta_text.find(self.end_token) - reasoning_part = delta_text[start_idx + len(self.start_token) : end_idx] - content_part = delta_text[end_idx + len(self.end_token) :] - return DeltaMessage( - reasoning=reasoning_part if reasoning_part else None, - content=content_part if content_part else None, - ) - else: - # Only start token - beginning of reasoning - reasoning_part = delta_text[start_idx + len(self.start_token) :] + + # No tags — default to reasoning (implicit mode assumption). + # If the model doesn't use thinking at all, the server's + # non-parser path handles it. This path only activates when + # a reasoning parser is explicitly configured. + return DeltaMessage(reasoning=delta_text) + + # ── Phase: thinking ─────────────────────────────────────── + # Inside a reasoning block, waiting for end tag. + if self._phase == "thinking": + if end_tok in current_text and end_tok not in previous_text: + self._phase = "content" + idx = delta_text.find(end_tok) + if idx >= 0: + reasoning = delta_text[:idx] + content = delta_text[idx + len(end_tok) :] + else: + reasoning = delta_text + content = None + if not reasoning and not content: + return None return DeltaMessage( - reasoning=reasoning_part if reasoning_part else None + reasoning=reasoning or None, + content=content or None, ) + return DeltaMessage(reasoning=delta_text) - # Fallback - treat as content + # ── Phase: content ──────────────────────────────────────── + # Past the reasoning block — everything is content. return DeltaMessage(content=delta_text) - - def _handle_implicit_think( - self, - delta_text: str, - end_in_prev: bool, - end_in_delta: bool, - ) -> DeltaMessage | None: - """Handle case where was in prompt (only in output).""" - if end_in_delta: - # Transition: end token in this delta - idx = delta_text.find(self.end_token) - reasoning_part = delta_text[:idx] - content_part = delta_text[idx + len(self.end_token) :] - return DeltaMessage( - reasoning=reasoning_part if reasoning_part else None, - content=content_part if content_part else None, - ) - elif end_in_prev: - # Already past reasoning phase - pure content - return DeltaMessage(content=delta_text) - else: - # Still in implicit reasoning phase - return DeltaMessage(reasoning=delta_text) diff --git a/vllm_mlx/request.py b/vllm_mlx/request.py index 41679c0ba..f18b238d8 100644 --- a/vllm_mlx/request.py +++ b/vllm_mlx/request.py @@ -57,6 +57,7 @@ class SamplingParams: top_p: float = 0.9 top_k: int = 0 # 0 means disabled min_p: float = 0.0 + presence_penalty: float = 0.0 repetition_penalty: float = 1.0 stop: Optional[List[str]] = None stop_token_ids: Optional[List[int]] = None diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index ec4684049..c706c85b5 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -19,7 +19,7 @@ import mlx.core as mx from mlx_lm.generate import BatchGenerator -from mlx_lm.sample_utils import make_sampler +from mlx_lm.sample_utils import make_logits_processors, make_sampler from mlx_lm.tokenizer_utils import NaiveStreamingDetokenizer from .memory_cache import MemoryAwarePrefixCache, MemoryCacheConfig @@ -62,6 +62,8 @@ class SchedulerConfig: prefill_batch_size: int = 8 completion_batch_size: int = 32 prefill_step_size: int = 2048 + # Optional override for MLLM prefill guard (None = use MLLM default). + mllm_prefill_step_size: Optional[int] = None # Prefix cache settings enable_prefix_cache: bool = True @@ -102,6 +104,10 @@ class SchedulerConfig: mtp_num_draft_tokens: int = 1 # Number of draft tokens from MTP head mtp_optimistic: bool = False # Skip acceptance check for max speed + def __post_init__(self) -> None: + if self.mllm_prefill_step_size is not None and self.mllm_prefill_step_size <= 0: + raise ValueError("mllm_prefill_step_size must be > 0 when provided") + @dataclass class SchedulerOutput: @@ -148,13 +154,66 @@ def _install_chunked_prefill( import time as _time from mlx_lm.generate import ( - Batch, _left_pad_prompts, _make_cache, _merge_caches, _right_pad_prompts, ) + try: + from mlx_lm.generate import _lazy_extract_cache + except ImportError: + + def _lazy_extract_cache(cache, idx): + return (c.extract(idx) for c in cache) + + try: + from mlx_lm.generate import Batch as _batch_cls + except ImportError: + + @dataclass + class _batch_cls: + uids: List[int] + y: Any + logprobs: List[Any] + max_tokens: List[int] + num_tokens: List[int] + cache: List[Any] + samplers: List[Any] + logits_processors: List[Any] + tokens: List[Any] + + def __len__(self): + return len(self.uids) + + def filter(self, keep_idx: List[int]): + self.uids = [self.uids[k] for k in keep_idx] + self.logprobs = [self.logprobs[k] for k in keep_idx] + self.max_tokens = [self.max_tokens[k] for k in keep_idx] + self.num_tokens = [self.num_tokens[k] for k in keep_idx] + self.samplers = [self.samplers[k] for k in keep_idx] + self.logits_processors = [self.logits_processors[k] for k in keep_idx] + self.tokens = [self.tokens[k] for k in keep_idx] + keep_idx_mx = mx.array(keep_idx, mx.int32) + self.y = self.y[keep_idx_mx] + for c in self.cache: + c.filter(keep_idx_mx) + + def extend(self, other): + self.uids.extend(other.uids) + self.y = mx.concatenate([self.y, other.y]) + self.logprobs.extend(other.logprobs) + self.num_tokens.extend(other.num_tokens) + self.max_tokens.extend(other.max_tokens) + self.samplers.extend(other.samplers) + self.logits_processors.extend(other.logits_processors) + self.tokens.extend(other.tokens) + for c, o in zip(self.cache, other.cache): + c.extend(o) + + def extract_cache(self, idx): + return [c.extract(idx) for c in self.cache] + # Keep references to originals _orig_next = batch_gen._next _orig_remove = batch_gen.remove @@ -201,6 +260,10 @@ def _generation_step(self=batch_gen): batch.tokens, ) mx.async_eval(batch.y, batch.logprobs) + # Evaluate accumulated tokens to prevent Metal buffer buildup + # from lazy mx.concatenate() chains holding AGXAllocation handles + if batch.tokens: + mx.async_eval(*batch.tokens) y = y.tolist() self._stats.generation_time += _time.perf_counter() - tic_gen @@ -268,8 +331,13 @@ def _chunked_next(self=batch_gen): # noqa: C901 inputs = partial["inputs"] prompt_cache = partial["cache"] remaining = inputs.shape[1] + prompt_checkpoint = max(1, int(partial.get("prompt_checkpoint", 1))) - n_to_process = min(budget, remaining - 1) if remaining > 1 else 0 + n_to_process = ( + min(budget, remaining - prompt_checkpoint) + if remaining > prompt_checkpoint + else 0 + ) if n_to_process > 0: self.model(mx.contiguous(inputs[:, :n_to_process]), cache=prompt_cache) @@ -294,8 +362,8 @@ def _chunked_next(self=batch_gen): # noqa: C901 if partial.get("is_cached"): mx.clear_cache() - # Check if prefill is done (only 1 token left or 0) - if inputs.shape[1] <= 1: + # Check if prefill is done once only the checkpoint tail remains. + if inputs.shape[1] <= prompt_checkpoint: # Finalize if partial.get("is_cached"): mx.eval([c.state for c in prompt_cache]) @@ -303,8 +371,31 @@ def _chunked_next(self=batch_gen): # noqa: C901 for c in prompt_cache: c.finalize() + + if self.prompt_checkpoint_callback is not None: + self.prompt_checkpoint_callback( + [ + ( + uid, + prompt_checkpoint, + _lazy_extract_cache(prompt_cache, i), + ) + for i, uid in enumerate(partial["uids"]) + ] + ) mx.clear_cache() + # Mirror upstream BatchGenerator semantics: after finalize() and + # the checkpoint callback, replay the remaining checkpoint tail + # except for the final token, which _step() consumes. + if prompt_checkpoint > 1: + self.model( + mx.contiguous(inputs[:, : prompt_checkpoint - 1]), + cache=prompt_cache, + ) + mx.eval([c.state for c in prompt_cache]) + mx.clear_cache() + y, logprobs = self._step( inputs, prompt_cache, @@ -314,10 +405,10 @@ def _chunked_next(self=batch_gen): # noqa: C901 ) mx.async_eval(y, logprobs) - new_batch = Batch( + new_batch = _batch_cls( list(partial["uids"]), y, - logprobs, + list(logprobs), list(partial["max_tokens"]), [0] * len(partial["uids"]), prompt_cache, @@ -393,12 +484,20 @@ def _chunked_next(self=batch_gen): # noqa: C901 caches, samplers, logits_processors, - _prompt_checkpoints, + prompt_checkpoints, ) = zip(*batch_prompts) lengths = [len(p) for p in inputs_raw] max_length = max(lengths) padding = [max_length - ln for ln in lengths] tokens = [mx.array(inp) for inp in inputs_raw] + # Match mlx-lm's prompt_checkpoint contract: positive values + # name the checkpoint token position in the prompt, while + # non-positive values already encode an offset from the end. + checkpoint_offsets = [ + (ln - pc if pc > 0 else -pc) + for ln, pc in zip(lengths, prompt_checkpoints) + ] + prompt_checkpoint = max(1, max(checkpoint_offsets)) is_cached = not all(c[0].empty() for c in caches) self._stats.prompt_tokens += sum(lengths) @@ -409,12 +508,14 @@ def _chunked_next(self=batch_gen): # noqa: C901 self.model, padding, self.max_kv_size ) else: - last_inputs = mx.array([p[-1:] for p in inputs_raw]) + last_inputs = mx.array( + [p[-prompt_checkpoint:] for p in inputs_raw] + ) padded = _right_pad_prompts(inputs_raw, max_length=max_length) prompt_cache = _merge_caches(caches) for c in prompt_cache: c.prepare( - lengths=[ln - 1 for ln in lengths], + lengths=[ln - prompt_checkpoint for ln in lengths], right_padding=padding, ) @@ -437,9 +538,11 @@ def _chunked_next(self=batch_gen): # noqa: C901 _pb = getattr(_req0, "prefix_boundary", 0) if _req0 else 0 _cached = getattr(_req0, "cached_tokens", 0) if _req0 else 0 _adjusted_pb = _pb - _cached - if 0 < _adjusted_pb < padded.shape[1]: + if 0 < _adjusted_pb < padded.shape[1] - prompt_checkpoint + 1: _first_chunk = _adjusted_pb - n_to_process = min(_first_chunk, padded.shape[1] - 1) + n_to_process = min( + _first_chunk, padded.shape[1] - prompt_checkpoint + ) if n_to_process > 0: self.model( mx.contiguous(padded[:, :n_to_process]), @@ -458,6 +561,7 @@ def _chunked_next(self=batch_gen): # noqa: C901 "max_tokens": list(max_tokens_list), "samplers": list(samplers), "logits_processors": list(logits_processors), + "prompt_checkpoint": prompt_checkpoint, "processed": n_to_process, "total": max_length, "is_cached": is_cached, @@ -648,6 +752,10 @@ def _mtp_step( # --- Apply logits processors + sample primary --- if any(logits_processors): + logger.debug( + f"[logits_proc] applying {sum(len(lp) for lp in logits_processors)} " + f"processors to batch_size={batch_size}" + ) processed_logits = [] for e in range(batch_size): sample_logits = logits[e : e + 1] @@ -698,12 +806,13 @@ def _mtp_step( # RNN snapshot, then re-advance with just P so both cache # types end up consistent at [..., P]. _rnn_snapshots = {} - for _ci, _c in enumerate(prompt_cache): - if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()): - if hasattr(_c, "state"): - _rnn_snapshots[_ci] = [ - s.copy() if s is not None else None for s in _c.state - ] + if not optimistic: + for _ci, _c in enumerate(prompt_cache): + if not (hasattr(_c, "is_trimmable") and _c.is_trimmable()): + if hasattr(_c, "state"): + _rnn_snapshots[_ci] = [ + s.copy() if s is not None else None for s in _c.state + ] verify_input = mx.concatenate( [primary_tokens[:, None], draft_tokens[:, None]], axis=1 @@ -1094,11 +1203,7 @@ def _decode_tokens(self, token_ids: List[int]) -> str: def _get_detokenizer(self, request_id: str) -> Any: """Get or create a streaming detokenizer for a request.""" if request_id not in self._detokenizer_pool: - if hasattr(self.tokenizer, "detokenizer"): - detok = self.tokenizer.detokenizer - else: - detok = NaiveStreamingDetokenizer(self._actual_tokenizer) - detok.reset() + detok = NaiveStreamingDetokenizer(self._actual_tokenizer) self._detokenizer_pool[request_id] = detok return self._detokenizer_pool[request_id] @@ -1158,15 +1263,25 @@ def _prefill_progress(progress_list): prefill_batch_size=self.config.prefill_batch_size, completion_batch_size=self.config.completion_batch_size, prefill_step_size=self.config.prefill_step_size, - prompt_progress_callback=_prefill_progress, ) + # Set callback as attribute — used by _install_chunked_prefill + # monkey-patch. Not a BatchGenerator constructor parameter. + bg.prompt_progress_callback = _prefill_progress # Install chunked prefill when explicitly configured OR when # memory-aware cache is active (needed for prefix_boundary saves # in agentic multi-turn workloads with hybrid Mamba+Transformer models). chunked_budget = self.config.chunked_prefill_tokens need_chunked = chunked_budget > 0 or self.memory_aware_cache is not None - if need_chunked: + + # The chunked prefill monkey-patch relies on BatchGenerator internals + # (_process_prompts, active_batch, _step, etc.) that were refactored + # in mlx-lm 0.31.x. Skip gracefully when the required API is absent. + chunked_compatible = hasattr(bg, "_process_prompts") and hasattr( + bg, "active_batch" + ) + + if need_chunked and chunked_compatible: if chunked_budget <= 0: # No explicit budget — use a very large value so normal # prompts pass through unchanged. Prefix boundary splits @@ -1189,6 +1304,12 @@ def _prefill_progress(progress_list): uid_to_request_id=self.uid_to_request_id, requests=self.requests, ) + elif need_chunked and not chunked_compatible: + logger.warning( + "Chunked prefill disabled: mlx-lm BatchGenerator lacks required " + "internals (_process_prompts, active_batch). Upgrade mlx-lm or " + "check compatibility." + ) # Install MTP if the model supports it if self.config.enable_mtp: @@ -1791,15 +1912,30 @@ def _schedule_waiting(self) -> List[Request]: request.remaining_tokens = request.prompt_token_ids tokens_to_process = request.prompt_token_ids + # Build per-request logits_processors from repetition_penalty + rep_penalty = request.sampling_params.repetition_penalty + lp = None + if rep_penalty and rep_penalty != 1.0: + lp = make_logits_processors(repetition_penalty=rep_penalty) + logger.info( + f"[rep_penalty] request={request.request_id[:12]} " + f"penalty={rep_penalty} processors={len(lp)}" + ) + # Insert into BatchGenerator with optional cache. # Wrap in try/except: if cache shapes are incompatible # (e.g. stale entry after BatchGenerator recreation), # fall back to no-cache insert instead of crashing. + insert_kwargs = { + "max_tokens": [request.sampling_params.max_tokens], + "caches": [cache_to_use] if cache_to_use else None, + } + if lp: + insert_kwargs["logits_processors"] = [lp] try: uids = self.batch_generator.insert( [tokens_to_process], - max_tokens=[request.sampling_params.max_tokens], - caches=[cache_to_use] if cache_to_use else None, + **insert_kwargs, ) except Exception as e: if cache_to_use is not None: @@ -1812,10 +1948,10 @@ def _schedule_waiting(self) -> List[Request]: request.cached_tokens = 0 request.remaining_tokens = request.prompt_token_ids tokens_to_process = request.prompt_token_ids + insert_kwargs["caches"] = None uids = self.batch_generator.insert( [tokens_to_process], - max_tokens=[request.sampling_params.max_tokens], - caches=None, + **insert_kwargs, ) else: raise @@ -1836,11 +1972,16 @@ def _schedule_waiting(self) -> List[Request]: else "" ) tokens_to_prefill = len(tokens_to_process) + rep_info = ( + f" rep_penalty={rep_penalty}" + if rep_penalty and rep_penalty != 1.0 + else "" + ) logger.info( f"[schedule] request={request.request_id[:12]} uid={uid} " f"prompt_tokens={request.num_prompt_tokens} " f"tokens_to_prefill={tokens_to_prefill}{cache_info} " - f"max_tokens={request.sampling_params.max_tokens} " + f"max_tokens={request.sampling_params.max_tokens}{rep_info} " f"running={len(self.running)} waiting={len(self.waiting)}" ) @@ -2216,9 +2357,16 @@ def step(self, max_retries: int = 1) -> SchedulerOutput: # Run generation step if we have running requests if self.batch_generator is not None and self.running: - responses = self.batch_generator.next() + result = self.batch_generator.next() output.has_work = True + # mlx-lm >=0.31.x returns (prompt_responses, generation_responses); + # older versions returned a flat list. + if isinstance(result, tuple): + responses = result[1] # generation_responses only + else: + responses = result + if responses: outputs, finished_ids = self._process_batch_responses(responses) output.outputs = outputs @@ -2285,6 +2433,7 @@ def step(self, max_retries: int = 1) -> SchedulerOutput: # Evaluate batch tokens to collapse lazy concatenation chains if ( self.batch_generator is not None + and hasattr(self.batch_generator, "active_batch") and self.batch_generator.active_batch is not None and hasattr(self.batch_generator.active_batch, "tokens") ): diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 18af96438..6a749fffb 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -42,6 +42,7 @@ import json import logging import os +import re import secrets import tempfile import threading @@ -57,8 +58,13 @@ # Import from new modular API # Re-export for backwards compatibility with tests -from .api.anthropic_adapter import anthropic_to_openai, openai_to_anthropic -from .api.anthropic_models import AnthropicRequest +from .api.anthropic_adapter import anthropic_to_openai +from .api.anthropic_models import ( + AnthropicRequest, + AnthropicResponse, + AnthropicResponseContentBlock, + AnthropicUsage, +) from .api.models import ( AssistantMessage, # noqa: F401 ChatCompletionChoice, # noqa: F401 @@ -98,8 +104,6 @@ ) from .api.utils import ( SPECIAL_TOKENS_PATTERN, - StreamingThinkRouter, - StreamingToolCallFilter, clean_output_text, extract_multimodal_content, is_mllm_model, # noqa: F401 @@ -163,6 +167,11 @@ def _resolve_top_p(request_value: float | None) -> float: _tool_call_parser: str | None = None # Parser name: auto, mistral, qwen, llama, hermes _tool_parser_instance = None # Instantiated parser +# Pattern to strip leaked tool call markup from content output. +# Safety net: the tool parser should consume these, but if it doesn't +# (e.g. malformed JSON, stray closing tags), strip them before emitting. +_TOOL_MARKUP_PATTERN = re.compile(r"|") + def _load_prefix_cache_from_disk() -> None: """Load prefix cache from disk during startup.""" @@ -343,6 +352,53 @@ def get_engine() -> BaseEngine: return _engine +def _coerce_tool_arguments( + arguments_json: str, tool_name: str, tools: list[dict] | None +) -> str: + """ + Coerce tool call arguments to match the tool schema. + + If a schema field expects "string" but the model produced an object/array, + JSON-stringify the value. This fixes a common LLM failure mode where models + output raw JSON objects instead of JSON strings for file content, etc. + """ + if not tools: + return arguments_json + + # Find the schema for this tool + schema = None + for tool in tools: + if isinstance(tool, dict) and tool.get("function", {}).get("name") == tool_name: + schema = tool["function"].get("parameters", {}) + break + + if not schema or "properties" not in schema: + return arguments_json + + try: + arguments = json.loads(arguments_json) + except (json.JSONDecodeError, TypeError): + return arguments_json + + if not isinstance(arguments, dict): + return arguments_json + + properties = schema.get("properties", {}) + changed = False + + for key, value in arguments.items(): + if key in properties: + expected_type = properties[key].get("type") + if expected_type == "string" and isinstance(value, (dict, list)): + arguments[key] = json.dumps(value, ensure_ascii=False, indent=2) + changed = True + + if changed: + return json.dumps(arguments, ensure_ascii=False) + + return arguments_json + + def _validate_model_name(request_model: str) -> None: """Validate that the request model name matches the served model.""" if _model_name and request_model != _model_name: @@ -373,6 +429,14 @@ def _parse_tool_calls_with_parser( request_dict = request.model_dump() if request else None + # tool_choice="none" means never return tool calls — skip all parsing + if request is not None: + tool_choice = getattr(request, "tool_choice", None) + if tool_choice is None and request_dict: + tool_choice = request_dict.get("tool_choice") + if tool_choice == "none": + return output_text, None + # If auto tool choice is not enabled, use the generic parser if not _enable_auto_tool_choice or not _tool_call_parser: return parse_tool_calls(output_text, request_dict) @@ -400,13 +464,16 @@ def _parse_tool_calls_with_parser( _tool_parser_instance.reset() result = _tool_parser_instance.extract_tool_calls(output_text, request_dict) if result.tools_called: + tools = request_dict.get("tools") if request_dict else None tool_calls = [ ToolCall( id=tc.get("id", f"call_{uuid.uuid4().hex[:8]}"), type="function", function=FunctionCall( name=tc["name"], - arguments=tc["arguments"], + arguments=_coerce_tool_arguments( + tc["arguments"], tc["name"], tools + ), ), ) for tc in result.tool_calls @@ -485,6 +552,7 @@ def load_model( stream_interval: int = 1, max_tokens: int = 32768, force_mllm: bool = False, + gpu_memory_utilization: float = 0.90, served_model_name: str | None = None, mtp: bool = False, prefill_step_size: int = 2048, @@ -528,6 +596,7 @@ def load_model( scheduler_config=scheduler_config, stream_interval=stream_interval, force_mllm=force_mllm, + gpu_memory_utilization=gpu_memory_utilization, ) # BatchedEngine will be started in lifespan (uvicorn's event loop) # Just log for now @@ -591,14 +660,11 @@ async def health(): "tools_available": len(_mcp_manager.get_all_tools()), } - engine_stats = _engine.get_stats() if _engine else {} - return { "status": "healthy", "model_loaded": _engine is not None, "model_name": _model_name, "model_type": "mllm" if (_engine and _engine.is_mllm) else "llm", - "engine_type": engine_stats.get("engine_type", "unknown"), "mcp": mcp_info, } @@ -1027,15 +1093,19 @@ async def _disconnect_guard( generator: AsyncIterator[str], raw_request: Request, poll_interval: float = 0.5, + heartbeat_interval: float = 5.0, ) -> AsyncIterator[str]: """Wrap streaming generator to abort on client disconnect. Uses asyncio racing: each __anext__() on the inner generator is - raced against a disconnect poller. This catches disconnects even - during prefill when no chunks are being yielded for tens of seconds. - - On disconnect, aclose() propagates down the generator chain to - engine_core.stream_outputs() finally-block → abort_request(). + raced against a disconnect poller. When neither completes within + ``heartbeat_interval`` seconds, an SSE comment is yielded as a + heartbeat. This forces an ASGI write which triggers broken-pipe + detection — without heartbeats, ``is_disconnected()`` stays False + during long prefill because no data is written to the socket. + + On disconnect, the cancellation propagates to stream_outputs() + finally-block → abort_request() → abort_prefill(). """ import time as _time @@ -1044,7 +1114,9 @@ async def _disconnect_guard( def _elapsed(): return f"{_time.monotonic() - _t0:.1f}s" - logger.info(f"[disconnect_guard] START poll_interval={poll_interval}s") + logger.info( + f"[disconnect_guard] START poll={poll_interval}s heartbeat={heartbeat_interval}s" + ) async def _wait_disconnect(): poll_count = 0 @@ -1061,21 +1133,28 @@ async def _wait_disconnect(): return chunk_count = 0 + heartbeat_count = 0 disconnect_task: asyncio.Task | None = None anext_task: asyncio.Task | None = None try: aiter = generator.__aiter__() disconnect_task = asyncio.create_task(_wait_disconnect()) + anext_task = None while True: - anext_task = asyncio.ensure_future(aiter.__anext__()) + if anext_task is None: + anext_task = asyncio.ensure_future(aiter.__anext__()) + done, _ = await asyncio.wait( [anext_task, disconnect_task], return_when=asyncio.FIRST_COMPLETED, + timeout=heartbeat_interval, ) + if disconnect_task in done: logger.info( f"[disconnect_guard] CLIENT DISCONNECTED after " - f"{chunk_count} chunks, elapsed={_elapsed()}" + f"{chunk_count} chunks, {heartbeat_count} heartbeats, " + f"elapsed={_elapsed()}" ) anext_task.cancel() try: @@ -1083,20 +1162,32 @@ async def _wait_disconnect(): except (asyncio.CancelledError, StopAsyncIteration): pass break - try: - chunk = anext_task.result() - except StopAsyncIteration: - logger.info( - f"[disconnect_guard] generator exhausted normally, " - f"{chunk_count} chunks, elapsed={_elapsed()}" - ) - break - chunk_count += 1 - if chunk_count == 1: - logger.info( - f"[disconnect_guard] first chunk arrived, elapsed={_elapsed()}" - ) - yield chunk + + if anext_task in done: + try: + chunk = anext_task.result() + except StopAsyncIteration: + logger.info( + f"[disconnect_guard] generator exhausted normally, " + f"{chunk_count} chunks, elapsed={_elapsed()}" + ) + break + chunk_count += 1 + if chunk_count == 1: + logger.info( + f"[disconnect_guard] first chunk arrived, elapsed={_elapsed()}" + ) + yield chunk + anext_task = None + continue + + # Timeout — no chunk and no disconnect detected yet. + # Send SSE comment as heartbeat to force an ASGI write. + # If the client has disconnected, this write will fail and + # the next is_disconnected() poll will return True. + heartbeat_count += 1 + yield ": heartbeat\n\n" + except GeneratorExit: logger.info( f"[disconnect_guard] GeneratorExit after {chunk_count} chunks, elapsed={_elapsed()}" @@ -1116,7 +1207,8 @@ async def _wait_disconnect(): # anext_task.cancel() → CancelledError in stream_outputs() # → finally block → abort_request() → request removed from scheduler logger.info( - f"[disconnect_guard] CLEANUP done, {chunk_count} chunks total, elapsed={_elapsed()}" + f"[disconnect_guard] CLEANUP done, {chunk_count} chunks, " + f"{heartbeat_count} heartbeats, elapsed={_elapsed()}" ) @@ -1218,13 +1310,24 @@ async def create_completion(request: CompletionRequest, raw_request: Request): logger.info( f"[REQUEST] POST /v1/completions stream={request.stream} " f"max_tokens={request.max_tokens} temp={request.temperature} " + f"top_p={request.top_p} top_k={request.top_k} min_p={request.min_p} " + f"presence_penalty={request.presence_penalty} " + f"repetition_penalty={request.repetition_penalty} " f"prompt_chars={prompt_len} prompt_preview={prompt_preview!r}" ) + # Resolve repetition penalty for completions + comp_rep_penalty = request.repetition_penalty + if request.stream: return StreamingResponse( _disconnect_guard( - stream_completion(engine, prompts[0], request), + stream_completion( + engine, + prompts[0], + request, + repetition_penalty=comp_rep_penalty, + ), raw_request, ), media_type="text/event-stream", @@ -1238,14 +1341,25 @@ async def create_completion(request: CompletionRequest, raw_request: Request): total_prompt_tokens = 0 for i, prompt in enumerate(prompts): + generate_kwargs = { + "prompt": prompt, + "max_tokens": request.max_tokens or _default_max_tokens, + "temperature": _resolve_temperature(request.temperature), + "top_p": _resolve_top_p(request.top_p), + "top_k": request.top_k or 0, + "min_p": request.min_p or 0.0, + "presence_penalty": request.presence_penalty or 0.0, + "stop": request.stop, + } + if comp_rep_penalty is not None: + generate_kwargs["repetition_penalty"] = comp_rep_penalty + if request.specprefill is not None: + generate_kwargs["specprefill"] = request.specprefill + if request.specprefill_keep_pct is not None: + generate_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + output = await _wait_with_disconnect( - engine.generate( - prompt=prompt, - max_tokens=request.max_tokens or _default_max_tokens, - temperature=_resolve_temperature(request.temperature), - top_p=_resolve_top_p(request.top_p), - stop=request.stop, - ), + engine.generate(**generate_kwargs), raw_request, timeout=timeout, ) @@ -1345,7 +1459,11 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re logger.info( f"[REQUEST] POST /v1/chat/completions stream={request.stream} " f"model={request.model!r} max_tokens={request.max_tokens} " - f"temp={request.temperature} msgs={n_msgs} roles={msg_roles} " + f"temp={request.temperature} top_p={request.top_p} " + f"top_k={request.top_k} min_p={request.min_p} " + f"presence_penalty={request.presence_penalty} " + f"repetition_penalty={request.repetition_penalty} " + f"msgs={n_msgs} roles={msg_roles} " f"total_chars={total_chars} tools={n_tools} " f"response_format={request.response_format}" ) @@ -1367,12 +1485,14 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re messages.append(msg_dict) images, videos = [], [] # MLLM extracts these from messages logger.debug(f"MLLM: Processing {len(messages)} messages") + messages = _normalize_messages(messages) else: # For LLM, extract text, images, and videos separately messages, images, videos = extract_multimodal_content( request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) has_media = bool(images or videos) if engine.is_mllm and not has_media: @@ -1401,12 +1521,21 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re # Inject JSON instruction into messages messages = _inject_json_instruction(messages, json_instruction) + # Resolve repetition penalty + rep_penalty = request.repetition_penalty + # Prepare kwargs chat_kwargs = { "max_tokens": request.max_tokens or _default_max_tokens, "temperature": _resolve_temperature(request.temperature), "top_p": _resolve_top_p(request.top_p), + "top_k": request.top_k or 0, + "min_p": request.min_p or 0.0, + "presence_penalty": request.presence_penalty or 0.0, + "repetition_penalty": request.repetition_penalty or 1.0, } + if rep_penalty is not None: + chat_kwargs["repetition_penalty"] = rep_penalty # Add multimodal content if has_media: @@ -1425,8 +1554,12 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re if request.chat_template_kwargs: chat_kwargs["chat_template_kwargs"] = dict(request.chat_template_kwargs) + # Enable/disable thinking mode per request + if request.enable_thinking is not None: + chat_kwargs["enable_thinking"] = request.enable_thinking + # Add tools if provided - if request.tools: + if request.tools and request.tool_choice != "none": chat_kwargs["tools"] = convert_tools_for_template(request.tools) if request.stream: @@ -1460,8 +1593,9 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re cleaned_text, tool_calls = _parse_tool_calls_with_parser(output.text, request) # Extract reasoning content FIRST (strips channel tokens before JSON extraction) + # Skip reasoning parser when enable_thinking=False (no think tags expected) reasoning_text = None - if _reasoning_parser and not tool_calls: + if _reasoning_parser and not tool_calls and request.enable_thinking is not False: text_to_parse = cleaned_text or output.text reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( text_to_parse @@ -1500,6 +1634,64 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) +def _normalize_messages(messages: list[dict]) -> list[dict]: + """Normalize message roles and merge consecutive same-role messages. + + 1. Maps non-standard roles to standard ones (e.g. ``developer`` -> ``system``). + 2. Merges consecutive same-role messages to satisfy chat template constraints + (Qwen 3.5, Llama, etc. require alternating roles). + + Only merges when both messages have string content. Messages with list + content (multimodal) are left as-is to preserve image/video attachments. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Returns: + New list with normalized roles and consecutive same-role messages merged. + """ + # OpenAI Responses API uses "developer" instead of "system". + # Map it so chat templates don't fail and fall back to raw prefill. + _ROLE_MAP = {"developer": "system"} + + if not messages: + return messages + + merged = [messages[0].copy()] + if merged[0]["role"] in _ROLE_MAP: + merged[0]["role"] = _ROLE_MAP[merged[0]["role"]] + for msg in messages[1:]: + prev = merged[-1] + role = _ROLE_MAP.get(msg["role"], msg["role"]) + if ( + role == prev["role"] + and isinstance(prev.get("content"), str) + and isinstance(msg.get("content"), str) + ): + # Merge string content with double newline separator + prev["content"] = prev["content"] + "\n\n" + msg["content"] + logger.debug( + f"Merged consecutive {role} messages " + f"({len(prev['content'])} chars total)" + ) + else: + copy = msg.copy() + copy["role"] = role + merged.append(copy) + + mapped_roles = sum(1 for m in messages if m["role"] in _ROLE_MAP) + merged_count = len(messages) - len(merged) + if mapped_roles or merged_count: + parts = [] + if mapped_roles: + parts.append(f"mapped {mapped_roles} role(s)") + if merged_count: + parts.append(f"merged {len(messages)} -> {len(merged)}") + logger.info(f"Normalized messages: {', '.join(parts)}") + + return merged + + def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages. @@ -1537,6 +1729,17 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: # ============================================================================= +def _convert_anthropic_stop_reason(openai_reason: str | None) -> str: + """Convert OpenAI finish_reason to Anthropic stop_reason.""" + mapping = { + "stop": "end_turn", + "tool_calls": "tool_use", + "length": "max_tokens", + "content_filter": "end_turn", + } + return mapping.get(openai_reason or "", "end_turn") + + @app.post("/v1/messages") async def create_anthropic_message( request: Request, @@ -1551,8 +1754,19 @@ async def create_anthropic_message( """ engine = get_engine() - # Parse the raw body to handle Anthropic request format - body = await request.json() + # Parse the raw body to handle Anthropic request format. + # Some clients (e.g. Claude Code) may send JSON with invalid escape + # sequences like \s, \d in regex patterns within tool definitions. + # Python's json.loads is strict per RFC 8259 and rejects these. + try: + body = await request.json() + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + raw = await request.body() + # Replace lone backslashes (not valid JSON escapes) with \\ + body = json.loads(re.sub(rb'\\(?!["\\/bfnrtu])', rb"\\\\", raw)) + else: + raise anthropic_request = AnthropicRequest(**body) _validate_model_name(anthropic_request.model) @@ -1597,14 +1811,19 @@ async def create_anthropic_message( openai_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) chat_kwargs = { "max_tokens": openai_request.max_tokens or _default_max_tokens, "temperature": openai_request.temperature, "top_p": openai_request.top_p, + "top_k": openai_request.top_k or 0, + "min_p": openai_request.min_p or 0.0, + "presence_penalty": openai_request.presence_penalty or 0.0, + "repetition_penalty": openai_request.repetition_penalty or 1.0, } - if openai_request.tools: + if openai_request.tools and openai_request.tool_choice != "none": chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) start_time = time.perf_counter() @@ -1629,35 +1848,63 @@ async def create_anthropic_message( output.text, openai_request ) + # Extract reasoning if parser is configured + reasoning_text = None + if _reasoning_parser and not tool_calls: + text_to_parse = cleaned_text or output.text + reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( + text_to_parse + ) + # Clean output text final_content = None if cleaned_text: final_content = clean_output_text(cleaned_text) - # Determine finish reason - finish_reason = "tool_calls" if tool_calls else output.finish_reason + # Build Anthropic content blocks directly (with thinking support) + content_blocks = [] - # Build OpenAI response to convert - openai_response = ChatCompletionResponse( - model=_model_name, - choices=[ - ChatCompletionChoice( - message=AssistantMessage( - content=final_content, - tool_calls=tool_calls, - ), - finish_reason=finish_reason, + if reasoning_text: + content_blocks.append( + AnthropicResponseContentBlock(type="thinking", thinking=reasoning_text) + ) + + if final_content: + content_blocks.append( + AnthropicResponseContentBlock(type="text", text=final_content) + ) + + if tool_calls: + for tc in tool_calls: + try: + tool_input = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError): + tool_input = {} + content_blocks.append( + AnthropicResponseContentBlock( + type="tool_use", + id=tc.id, + name=tc.function.name, + input=tool_input, + ) ) - ], - usage=Usage( - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - total_tokens=output.prompt_tokens + output.completion_tokens, - ), + + if not content_blocks: + content_blocks.append(AnthropicResponseContentBlock(type="text", text="")) + + stop_reason = _convert_anthropic_stop_reason( + "tool_calls" if tool_calls else output.finish_reason ) - # Convert to Anthropic response - anthropic_response = openai_to_anthropic(openai_response, _model_name) + anthropic_response = AnthropicResponse( + model=_model_name, + content=content_blocks, + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=output.prompt_tokens, + output_tokens=output.completion_tokens, + ), + ) return Response( content=anthropic_response.model_dump_json(exclude_none=True), media_type="application/json", @@ -1798,6 +2045,10 @@ async def _stream_anthropic_messages( Converts OpenAI streaming chunks to Anthropic event format: message_start -> content_block_start -> content_block_delta* -> content_block_stop -> message_delta -> message_stop + + When a reasoning parser is active, emits a ``thinking`` content block + (index 0) for reasoning tokens and a ``text`` content block (index 1) + for the actual response, matching the Anthropic extended thinking format. """ msg_id = f"msg_{uuid.uuid4().hex[:24]}" start_time = time.perf_counter() @@ -1807,14 +2058,19 @@ async def _stream_anthropic_messages( openai_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) chat_kwargs = { "max_tokens": openai_request.max_tokens or _default_max_tokens, "temperature": openai_request.temperature, "top_p": openai_request.top_p, + "top_k": openai_request.top_k or 0, + "min_p": openai_request.min_p or 0.0, + "presence_penalty": openai_request.presence_penalty or 0.0, + "repetition_penalty": openai_request.repetition_penalty or 1.0, } - if openai_request.tools: + if openai_request.tools and openai_request.tool_choice != "none": chat_kwargs["tools"] = convert_tools_for_template(openai_request.tools) # Emit message_start @@ -1836,115 +2092,171 @@ async def _stream_anthropic_messages( } yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" - # Stream pipeline: raw text → tool call filter → think router → emit - # - Tool call filter strips tool call markup (emitted as structured blocks later) - # - Think router separates content into Anthropic thinking blocks + use_reasoning = _reasoning_parser is not None + + if use_reasoning: + _reasoning_parser.reset_state() + + # Block index tracking: with reasoning parser we use index 0 for + # thinking and index 1 for text; without parser, index 0 for text. + thinking_block_started = False + text_block_started = False + thinking_index = 0 + text_index = 1 if use_reasoning else 0 + + if not use_reasoning: + # No reasoning parser — start text block immediately + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + text_block_started = True + + # Stream content deltas accumulated_text = "" - tool_filter = StreamingToolCallFilter() - # Detect if the model's chat template injects into the - # generation prompt. If so, the model starts in thinking mode and - # the opening tag never appears in the output stream. - _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None - _chat_template = "" - if _tokenizer and hasattr(_tokenizer, "chat_template"): - _chat_template = _tokenizer.chat_template or "" - _starts_thinking = ( - "" in _chat_template and "add_generation_prompt" in _chat_template - ) - think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking) - prompt_tokens = 0 completion_tokens = 0 - # Track which content blocks we've started - current_block_type = None # "thinking" or "text" - block_index = 0 + # Tool call streaming suppression — prevents raw tool markup from leaking + # as text_delta events. Mirrors the OpenAI streaming path logic. + global _tool_parser_instance + tool_parser = None + tool_accumulated_text = "" + tool_markup_possible = False + tool_choice = getattr(openai_request, "tool_choice", None) + if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none": + if _tool_parser_instance is None: + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = None + if _engine is not None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + _tool_parser_instance = parser_cls(tokenizer) + except Exception: + pass + if _tool_parser_instance is not None: + tool_parser = _tool_parser_instance + tool_parser.reset() async for output in engine.stream_chat(messages=messages, **chat_kwargs): delta_text = output.new_text # Track token counts - if hasattr(output, "prompt_tokens") and output.prompt_tokens: - prompt_tokens = output.prompt_tokens if hasattr(output, "completion_tokens") and output.completion_tokens: completion_tokens = output.completion_tokens - if delta_text: - # Accumulate raw text BEFORE special token cleaning for tool parsing - accumulated_text += delta_text + if not delta_text: + continue - # Filter special tokens for display - content = SPECIAL_TOKENS_PATTERN.sub("", delta_text) + # Filter special tokens + filtered = SPECIAL_TOKENS_PATTERN.sub("", delta_text) + if not filtered: + continue - if content: - # Stage 1: strip tool call markup - filtered = tool_filter.process(content) - if not filtered: - continue - # Stage 2: route thinking vs text - pieces = think_router.process(filtered) - events, current_block_type, block_index = _emit_content_pieces( - pieces, current_block_type, block_index - ) - for event in events: - yield event - - # Flush remaining from both filters - remaining = tool_filter.flush() - if remaining: - events, current_block_type, block_index = _emit_content_pieces( - think_router.process(remaining), current_block_type, block_index - ) - for event in events: - yield event + if not use_reasoning: + # Simple path — no reasoning parsing + accumulated_text += filtered + content_to_emit = filtered - flush_pieces = think_router.flush() - if flush_pieces: - events, current_block_type, block_index = _emit_content_pieces( - flush_pieces, current_block_type, block_index + # Filter tool call markup during streaming + if tool_parser and content_to_emit: + if not tool_markup_possible and "<" not in content_to_emit: + tool_accumulated_text += content_to_emit + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += content_to_emit + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, content_to_emit + ) + if tool_result is None or "tool_calls" in tool_result: + # Inside tool markup or tool calls detected — suppress + continue + content_to_emit = tool_result.get("content", "") + if content_to_emit: + content_to_emit = _TOOL_MARKUP_PATTERN.sub("", content_to_emit) + if not content_to_emit: + continue + + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': content_to_emit}})}\n\n" + continue + + # Reasoning parser path + previous_text = accumulated_text + accumulated_text += filtered + delta_msg = _reasoning_parser.extract_reasoning_streaming( + previous_text, accumulated_text, filtered ) - for event in events: - yield event - # Close final content block - if current_block_type is not None: - yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" - block_index += 1 + if delta_msg is None: + continue + + if delta_msg.reasoning: + if not thinking_block_started: + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': thinking_index, 'content_block': {'type': 'thinking', 'thinking': ''}})}\n\n" + thinking_block_started = True + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': thinking_index, 'delta': {'type': 'thinking_delta', 'thinking': delta_msg.reasoning}})}\n\n" + + if delta_msg.content: + content_to_emit = delta_msg.content + + # Filter tool call markup during streaming + if tool_parser and content_to_emit: + if not tool_markup_possible and "<" not in content_to_emit: + tool_accumulated_text += content_to_emit + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += content_to_emit + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, content_to_emit + ) + if tool_result is None or "tool_calls" in tool_result: + # Inside tool markup or tool calls detected — suppress + continue + content_to_emit = tool_result.get("content", "") + if content_to_emit: + content_to_emit = _TOOL_MARKUP_PATTERN.sub("", content_to_emit) + if not content_to_emit: + continue + + if thinking_block_started and not text_block_started: + # Close thinking block, open text block + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': thinking_index})}\n\n" + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + text_block_started = True + elif not text_block_started: + # No thinking was emitted, start text block at index 0 + text_index = 0 + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + text_block_started = True + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': text_index, 'delta': {'type': 'text_delta', 'text': content_to_emit}})}\n\n" + + # Close any open thinking block that was never followed by text + if thinking_block_started and not text_block_started: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': thinking_index})}\n\n" + # Emit empty text block so response always has text content + text_index = thinking_index + 1 + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': text_index, 'content_block': {'type': 'text', 'text': ''}})}\n\n" + text_block_started = True # Check for tool calls in accumulated text _, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request) + # Close text block + if text_block_started: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': text_index})}\n\n" + # If there are tool calls, emit tool_use blocks + next_index = (text_index + 1) if text_block_started else 0 if tool_calls: for i, tc in enumerate(tool_calls): - tool_index = block_index + i + tool_index = next_index + i try: tool_input = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError): tool_input = {} - # content_block_start for tool_use - tool_block_start = { - "type": "content_block_start", - "index": tool_index, - "content_block": { - "type": "tool_use", - "id": tc.id, - "name": tc.function.name, - "input": {}, - }, - } - yield f"event: content_block_start\ndata: {json.dumps(tool_block_start)}\n\n" - - # Send input as a single delta - input_json = json.dumps(tool_input) - input_delta = { - "type": "content_block_delta", - "index": tool_index, - "delta": {"type": "input_json_delta", "partial_json": input_json}, - } - yield f"event: content_block_delta\ndata: {json.dumps(input_delta)}\n\n" - - # content_block_stop + yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': tool_index, 'content_block': {'type': 'tool_use', 'id': tc.id, 'name': tc.function.name, 'input': {}}})}\n\n" + yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': tool_index, 'delta': {'type': 'input_json_delta', 'partial_json': json.dumps(tool_input)}})}\n\n" yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': tool_index})}\n\n" # Determine stop reason @@ -1954,7 +2266,7 @@ async def _stream_anthropic_messages( message_delta = { "type": "message_delta", "delta": {"stop_reason": stop_reason, "stop_sequence": None}, - "usage": {"input_tokens": prompt_tokens, "output_tokens": completion_tokens}, + "usage": {"output_tokens": completion_tokens}, } yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n" @@ -1962,7 +2274,7 @@ async def _stream_anthropic_messages( elapsed = time.perf_counter() - start_time tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 logger.info( - f"Anthropic messages (stream): prompt={prompt_tokens} + completion={completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + f"Anthropic messages (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" ) # Emit message_stop @@ -1978,15 +2290,27 @@ async def stream_completion( engine: BaseEngine, prompt: str, request: CompletionRequest, + repetition_penalty: float | None = None, ) -> AsyncIterator[str]: """Stream completion response.""" - async for output in engine.stream_generate( - prompt=prompt, - max_tokens=request.max_tokens or _default_max_tokens, - temperature=_resolve_temperature(request.temperature), - top_p=_resolve_top_p(request.top_p), - stop=request.stop, - ): + generate_kwargs = { + "prompt": prompt, + "max_tokens": request.max_tokens or _default_max_tokens, + "temperature": _resolve_temperature(request.temperature), + "top_p": _resolve_top_p(request.top_p), + "top_k": request.top_k or 0, + "min_p": request.min_p or 0.0, + "presence_penalty": request.presence_penalty or 0.0, + "stop": request.stop, + } + if repetition_penalty is not None: + generate_kwargs["repetition_penalty"] = repetition_penalty + if request.specprefill is not None: + generate_kwargs["specprefill"] = request.specprefill + if request.specprefill_keep_pct is not None: + generate_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + + async for output in engine.stream_generate(**generate_kwargs): data = { "id": f"cmpl-{uuid.uuid4().hex[:8]}", "object": "text_completion", @@ -2057,7 +2381,8 @@ async def stream_chat_completion( tool_accumulated_text = "" tool_calls_detected = False tool_markup_possible = False # Fast path: skip parsing until '<' seen - if _enable_auto_tool_choice and _tool_call_parser: + tool_choice = getattr(request, "tool_choice", None) + if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none": # Initialize parser if needed (same as _parse_tool_calls_with_parser) if _tool_parser_instance is None: try: @@ -2084,8 +2409,8 @@ async def stream_chat_completion( if hasattr(output, "completion_tokens") and output.completion_tokens: completion_tokens = output.completion_tokens - # Use reasoning parser if enabled - if _reasoning_parser and delta_text: + # Use reasoning parser if enabled (skip when enable_thinking=False) + if _reasoning_parser and delta_text and request.enable_thinking is not False: previous_text = accumulated_text accumulated_text += delta_text delta_msg = _reasoning_parser.extract_reasoning_streaming( @@ -2096,16 +2421,115 @@ async def stream_chat_completion( # Skip this chunk (e.g., token itself) continue + content = delta_msg.content + reasoning = delta_msg.reasoning + + # Some models (e.g. MiniMax) wrap tool calls in + # blocks, so reasoning parser captures tool call XML as + # reasoning while content stays None. Redirect reasoning + # to the content stream so the tool parser can handle it. + if tool_parser and reasoning and not content: + _check = tool_accumulated_text + reasoning + if ( + "" in _check + or "" in _check + or ' never arrived - incomplete tool call) + # (e.g., never arrived, or " in tool_accumulated_text + and ( + "" in tool_accumulated_text + or "<|tool_call>" in tool_accumulated_text + or " 0: + if max_rotating_size > 0 and M > max_rotating_size: tail_start = max(0, M - max_rotating_size) tail_indices = set(range(tail_start, M)) existing = set(selected_indices.tolist()) diff --git a/vllm_mlx/text_model_from_vlm.py b/vllm_mlx/text_model_from_vlm.py index b1130fdc5..082ccf43b 100644 --- a/vllm_mlx/text_model_from_vlm.py +++ b/vllm_mlx/text_model_from_vlm.py @@ -94,15 +94,27 @@ def _class_predicate(path, module): else: logger.warning("No MTP weights found in %s", model_path.name) - # Verify MTP is functional + # Inject MTP if TextModel doesn't have native MTP support. + # mlx_lm's qwen3_5.TextModel strips MTP weights in sanitize(), + # so we inject MTP module + methods at runtime. + if not hasattr(text_model, "mtp") or text_model.mtp is None: + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + if num_mtp == 0: + num_mtp = text_config.get("num_nextn_predict_layers", 0) + if num_mtp > 0: + from .patches.qwen3_5_mtp import inject_mtp_support + + inject_mtp_support(text_model, model_path, config) + if hasattr(text_model, "mtp") and text_model.mtp is not None: mx.eval(text_model.mtp.parameters()) - logger.info( - "TextModel built with MTP support (%d layers)", - args.mtp_num_hidden_layers, + num_mtp = text_config.get( + "mtp_num_hidden_layers", + text_config.get("num_nextn_predict_layers", 0), ) + logger.info("TextModel built with MTP support (%d layers)", num_mtp) else: - logger.info("TextModel built without MTP (mtp_num_hidden_layers=0)") + logger.info("TextModel built without MTP") return text_model diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py index 16f744080..cd76ad418 100644 --- a/vllm_mlx/tool_parsers/__init__.py +++ b/vllm_mlx/tool_parsers/__init__.py @@ -10,6 +10,7 @@ - mistral: Mistral models ([TOOL_CALLS] format) - qwen/qwen3: Qwen models ( and [Calling tool:] formats) - llama/llama3/llama4: Llama models ( format) +- gemma4/gemma_4: Google Gemma 4 models (<|tool_call>call:name{} format) - hermes/nous: Hermes/NousResearch models - deepseek/deepseek_v3/deepseek_r1: DeepSeek models (unicode tokens) - kimi/kimi_k2/moonshot: Kimi/Moonshot models @@ -19,6 +20,7 @@ - functionary/meetkai: MeetKai Functionary models - glm47/glm4: GLM-4.7 and GLM-4.7-Flash models - harmony/gpt-oss: GPT-OSS models (Harmony format with channels) +- minimax: MiniMax-M2 models Usage: from vllm_mlx.tool_parsers import ToolParserManager @@ -47,6 +49,7 @@ from .auto_tool_parser import AutoToolParser from .deepseek_tool_parser import DeepSeekToolParser from .functionary_tool_parser import FunctionaryToolParser +from .gemma4_tool_parser import Gemma4ToolParser from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import HermesToolParser from .kimi_tool_parser import KimiToolParser @@ -57,6 +60,7 @@ from .xlam_tool_parser import xLAMToolParser from .glm47_tool_parser import Glm47ToolParser from .harmony_tool_parser import HarmonyToolParser +from .minimax_tool_parser import MiniMaxToolParser __all__ = [ # Base classes @@ -65,6 +69,7 @@ "ExtractedToolCallInformation", # Specific parsers "AutoToolParser", + "Gemma4ToolParser", "MistralToolParser", "QwenToolParser", "LlamaToolParser", @@ -77,4 +82,5 @@ "FunctionaryToolParser", "Glm47ToolParser", "HarmonyToolParser", + "MiniMaxToolParser", ] diff --git a/vllm_mlx/tool_parsers/auto_tool_parser.py b/vllm_mlx/tool_parsers/auto_tool_parser.py index fc02d8fc6..37ab10d74 100644 --- a/vllm_mlx/tool_parsers/auto_tool_parser.py +++ b/vllm_mlx/tool_parsers/auto_tool_parser.py @@ -16,6 +16,7 @@ ToolParser, ToolParserManager, ) +from .gemma4_tool_parser import Gemma4ToolParser def generate_tool_id() -> str: @@ -29,12 +30,13 @@ class AutoToolParser(ToolParser): Auto-detecting tool call parser. Tries multiple formats in order: - 1. Mistral: [TOOL_CALLS] ... - 2. Qwen bracket: [Calling tool: func_name({...})] - 3. Qwen/Hermes XML: {"name": "...", "arguments": {...}} - 4. Llama: {"arg": "value"} - 5. Nemotron: ... - 6. Raw JSON: {"name": "...", "arguments": {...}} + 1. Gemma 4: <|tool_call>call:name{...} + 2. Mistral: [TOOL_CALLS] ... + 3. Qwen bracket: [Calling tool: func_name({...})] + 4. Qwen/Hermes XML: {"name": "...", "arguments": {...}} + 5. Llama: {"arg": "value"} + 6. Nemotron: ... + 7. Raw JSON: {"name": "...", "arguments": {...}} This is the default parser when no specific parser is selected. """ @@ -63,7 +65,14 @@ def extract_tool_calls( tool_calls: list[dict[str, Any]] = [] cleaned_text = model_output - # 1. Try Mistral format + # 1. Try Gemma 4 format (most distinctive marker) + if "<|tool_call>" in model_output: + gemma_parser = Gemma4ToolParser() + result = gemma_parser.extract_tool_calls(model_output, request) + if result.tools_called: + return result + + # 2. Try Mistral format if self.MISTRAL_TOKEN in model_output: parts = model_output.split(self.MISTRAL_TOKEN) content = parts[0].strip() @@ -113,7 +122,7 @@ def extract_tool_calls( content=content if content else None, ) - # 2. Try Qwen bracket pattern + # 3. Try Qwen bracket pattern bracket_matches = self.QWEN_BRACKET_PATTERN.findall(model_output) for name, args_str in bracket_matches: try: @@ -141,7 +150,7 @@ def extract_tool_calls( if bracket_matches: cleaned_text = self.QWEN_BRACKET_PATTERN.sub("", cleaned_text).strip() - # 3. Try Nemotron pattern (before Qwen XML as it's more specific) + # 4. Try Nemotron pattern (before Qwen XML as it's more specific) nemotron_matches = self.NEMOTRON_PATTERN.findall(cleaned_text) for name, params_block in nemotron_matches: params = self.NEMOTRON_PARAM_PATTERN.findall(params_block) @@ -157,7 +166,7 @@ def extract_tool_calls( if nemotron_matches: cleaned_text = self.NEMOTRON_PATTERN.sub("", cleaned_text).strip() - # 4. Try Qwen/Hermes XML pattern + # 5. Try Qwen/Hermes XML pattern xml_matches = self.QWEN_XML_PATTERN.findall(cleaned_text) for match in xml_matches: try: @@ -182,7 +191,7 @@ def extract_tool_calls( if xml_matches: cleaned_text = self.QWEN_XML_PATTERN.sub("", cleaned_text).strip() - # 5. Try Llama pattern + # 6. Try Llama pattern llama_matches = self.LLAMA_PATTERN.findall(cleaned_text) for name, args_str in llama_matches: try: @@ -210,7 +219,7 @@ def extract_tool_calls( if llama_matches: cleaned_text = self.LLAMA_PATTERN.sub("", cleaned_text).strip() - # 6. Fallback: Try raw JSON + # 7. Fallback: Try raw JSON if not tool_calls: raw_calls = self._parse_raw_json_tool_calls(cleaned_text) if raw_calls: @@ -327,6 +336,7 @@ def extract_tool_calls_streaming( """ # Check for any tool call markers markers = [ + "<|tool_call>", self.MISTRAL_TOKEN, "[Calling tool:", "", @@ -339,7 +349,7 @@ def extract_tool_calls_streaming( return {"content": delta_text} # Check for completion markers - end_markers = ["", "", ")]"] + end_markers = ["", "", "", ")]"] if any(m in delta_text for m in end_markers): result = self.extract_tool_calls(current_text) if result.tools_called: diff --git a/vllm_mlx/tool_parsers/gemma4_tool_parser.py b/vllm_mlx/tool_parsers/gemma4_tool_parser.py new file mode 100644 index 000000000..a32fd90cf --- /dev/null +++ b/vllm_mlx/tool_parsers/gemma4_tool_parser.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Gemma 4 tool call parser for vllm-mlx. + +Handles Gemma 4's native tool call format: + <|tool_call>call:func_name{<|"|>key<|"|>: <|"|>value<|"|>, num: 42} + +Gemma 4 uses special tokens instead of JSON: +- <|tool_call> / delimit tool call blocks +- <|"|> replaces " for string values +- Keys are unquoted bare identifiers +- Multiple call:name{...} can appear in a single block + +Reference: mlx-lm PR #1105, vllm PR #38837 +""" + +import json +import logging +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + +logger = logging.getLogger(__name__) + +# Delimiters +TOOL_CALL_START = "<|tool_call>" +TOOL_CALL_END = "" + +# Placeholder token used during <|"|> extraction. Matches \x00 + digits + \x00. +_PLACEHOLDER_RE = re.compile(r"\x00(\d+)\x00") + +# Pattern to extract <|"|>-delimited strings (non-greedy, supports multiline) +_STRING_DELIM_RE = re.compile(r'<\|"\|>(.*?)<\|"\|>', re.DOTALL) + +# Pattern to match call:name followed by a { (we extract balanced braces manually) +_CALL_PREFIX = re.compile(r"call:(\w+)\s*\{") + +# Pattern to quote bare keys: word followed by : at start or after , or { +_BARE_KEY = re.compile(r"(?<=[{,])\s*(\w+)\s*:") + +# Max arg block length to prevent runaway parsing on malformed input (1 MB) +_MAX_ARG_BLOCK_LEN = 1_048_576 + + +def _find_balanced_brace(text: str, start: int) -> int: + """Find the index of the closing } that balances the { at `start`. + + Before counting braces, <|"|>-delimited strings are conceptually opaque -- + we skip over <|"|>...<|"|> regions so that braces inside string values + (e.g. code snippets) don't affect depth counting. + + Args: + text: The string to search (may contain <|"|> tokens) + start: Index of the opening { + + Returns: + Index of the matching } in the ORIGINAL text, or -1 if not found + """ + if len(text) - start > _MAX_ARG_BLOCK_LEN: + return -1 + + depth = 0 + i = start + in_string = False + while i < len(text): + if text.startswith('<|"|>', i): + in_string = not in_string + i += 5 + continue + if not in_string: + if text[i] == "{": + depth += 1 + elif text[i] == "}": + depth -= 1 + if depth == 0: + return i + i += 1 + return -1 + + +def _gemma4_args_to_json(text: str) -> str: + """Convert Gemma 4 tool call args to valid JSON. + + Three-step conversion (ORDER MATTERS): + 1. Extract <|"|>-delimited strings into numbered \\x00N\\x00 placeholders. + This protects string contents from step 2's bare-key quoting -- without + this, a string value like "key: value" would be corrupted. + 2. Quote bare keys (word: -> "word":) now that strings are safe. + 3. Restore placeholders as properly JSON-escaped strings via json.dumps(). + Uses a single re.sub pass (O(len(text))) instead of per-placeholder replace. + """ + strings: list[str] = [] + + def _capture(m: re.Match) -> str: + strings.append(m.group(1)) + return f"\x00{len(strings) - 1}\x00" + + # Step 1: Extract <|"|>-delimited strings + text = _STRING_DELIM_RE.sub(_capture, text) + + # Step 2: Quote bare keys + text = _BARE_KEY.sub(r'"\1":', text) + + # Step 3: Restore captured strings as properly escaped JSON strings + def _restore(m: re.Match) -> str: + idx = int(m.group(1)) + return json.dumps(strings[idx]) if idx < len(strings) else m.group(0) + + text = _PLACEHOLDER_RE.sub(_restore, text) + + return text + + +def generate_tool_id() -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:8]}" + + +@ToolParserManager.register_module("gemma4") +class Gemma4ToolParser(ToolParser): + """ + Tool call parser for Gemma 4 models. + + Parses: <|tool_call>call:func{<|"|>key<|"|>: <|"|>val<|"|>} + + Used when --enable-auto-tool-choice --tool-call-parser gemma4 are set. + """ + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + """Extract tool calls from a complete Gemma 4 model response.""" + cleaned = self.strip_think_tags(model_output) + + start_idx = cleaned.find(TOOL_CALL_START) + if start_idx == -1: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + content_before = cleaned[:start_idx].strip() or None + + block_start = start_idx + len(TOOL_CALL_START) + end_idx = cleaned.find(TOOL_CALL_END, block_start) + if end_idx == -1: + block = cleaned[block_start:] + else: + block = cleaned[block_start:end_idx] + + tool_calls: list[dict[str, Any]] = [] + + pos = 0 + while pos < len(block): + m = _CALL_PREFIX.search(block, pos) + if not m: + break + + func_name = m.group(1) + brace_start = m.end() - 1 + + brace_end = _find_balanced_brace(block, brace_start) + if brace_end == -1: + pos = m.end() + continue + + args_raw = block[brace_start : brace_end + 1] + try: + args_json = _gemma4_args_to_json(args_raw) + json.loads(args_json) + tool_calls.append( + { + "id": generate_tool_id(), + "name": func_name, + "arguments": args_json, + } + ) + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + f"Gemma 4 tool parser: failed to parse args for " + f"call:{func_name}: {e}" + ) + + pos = brace_end + 1 + + if tool_calls: + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content_before, + ) + else: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + """Extract tool calls from streaming Gemma 4 model output.""" + has_start = TOOL_CALL_START in current_text + + if not has_start: + return {"content": delta_text} + + if TOOL_CALL_END in delta_text: + result = self.extract_tool_calls(current_text) + if result.tools_called: + return { + "tool_calls": [ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + } + + return None diff --git a/vllm_mlx/tool_parsers/minimax_tool_parser.py b/vllm_mlx/tool_parsers/minimax_tool_parser.py new file mode 100644 index 000000000..7459fe97f --- /dev/null +++ b/vllm_mlx/tool_parsers/minimax_tool_parser.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MiniMax tool call parser for vllm-mlx. + +Parses the MiniMax-M2 native XML tool call format: + + +param-value + + +""" + +import json +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + + +def generate_tool_id() -> str: + return f"call_{uuid.uuid4().hex[:8]}" + + +@ToolParserManager.register_module(["minimax", "minimax_m2"]) +class MiniMaxToolParser(ToolParser): + """ + Parser for MiniMax-M2 tool call format. + + Format: + + + value + + + """ + + TOOL_CALL_BLOCK = re.compile( + r"(.*?)", re.DOTALL + ) + INVOKE_PATTERN = re.compile(r'(.*?)', re.DOTALL) + PARAM_PATTERN = re.compile( + r'(.*?)', re.DOTALL + ) + THINK_PATTERN = re.compile(r".*?", re.DOTALL) + + def _extract_invokes(self, text: str) -> list[dict[str, Any]]: + """Extract tool calls from invoke elements, with or without wrapper.""" + tool_calls: list[dict[str, Any]] = [] + invokes = self.INVOKE_PATTERN.findall(text) + for func_name, params_block in invokes: + params = self.PARAM_PATTERN.findall(params_block) + # Skip bare tags without parameters (hallucinated junk) + if not params: + continue + arguments = {} + for p_name, p_value in params: + p_value = p_value.strip() + try: + arguments[p_name] = json.loads(p_value) + except (json.JSONDecodeError, ValueError): + arguments[p_name] = p_value + + tool_calls.append( + { + "id": generate_tool_id(), + "name": func_name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + return tool_calls + + def extract_tool_calls( + self, model_output: str, request: dict[str, Any] | None = None + ) -> ExtractedToolCallInformation: + # Try wrapped format first: ...... + blocks = self.TOOL_CALL_BLOCK.findall(model_output) + if blocks: + tool_calls: list[dict[str, Any]] = [] + for block in blocks: + tool_calls.extend(self._extract_invokes(block)) + + cleaned = self.TOOL_CALL_BLOCK.sub("", model_output).strip() + cleaned = self.THINK_PATTERN.sub("", cleaned).strip() + cleaned = re.sub(r"\[e~\[.*$", "", cleaned).strip() + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=cleaned if cleaned else None, + ) + + # Fallback: bare without wrapper + # (model sometimes emits tool calls inside without wrapper) + tool_calls = self._extract_invokes(model_output) + if tool_calls: + # Strip matched invoke blocks and thinking from content + cleaned = self.INVOKE_PATTERN.sub("", model_output).strip() + cleaned = self.THINK_PATTERN.sub("", cleaned).strip() + cleaned = re.sub(r"\[e~\[.*$", "", cleaned).strip() + # Remove leftover closing tags + cleaned = cleaned.replace("", "").strip() + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=cleaned if cleaned else None, + ) + + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def _has_tool_start(self, text: str) -> bool: + """Check if text contains the start of a tool call block.""" + return "" in text or ( + '" in current: + return ( + "" in current + and "" not in previous + ) + # Bare invoke: just appeared + if "" in current and "" not in previous: + return True + return False + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int] | None = None, + current_token_ids: Sequence[int] | None = None, + delta_token_ids: Sequence[int] | None = None, + request: dict[str, Any] | None = None, + ) -> dict[str, Any] | None: + # Not inside a tool call block yet — pass content through + if not self._has_tool_start(current_text): + return {"content": delta_text} + + # Tool call block just completed + if self._has_tool_end(current_text, previous_text): + result = self.extract_tool_calls(current_text) + if result.tools_called: + return { + "tool_calls": [ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + } + + # Inside tool call block but not yet complete — suppress output + return None diff --git a/vllm_mlx/tool_parsers/qwen_tool_parser.py b/vllm_mlx/tool_parsers/qwen_tool_parser.py index fd69b96c0..e235a3c7d 100644 --- a/vllm_mlx/tool_parsers/qwen_tool_parser.py +++ b/vllm_mlx/tool_parsers/qwen_tool_parser.py @@ -5,8 +5,10 @@ Handles Qwen's tool calling formats: - XML style: {"name": "func", "arguments": {...}} - Bracket style: [Calling tool: func_name({"arg": "value"})] +- Function style: value """ +import ast import json import re import uuid @@ -20,6 +22,24 @@ ) +def _parse_param_value(val: str) -> Any: + """Parse a parameter value, handling JSON literals and plain strings.""" + try: + return json.loads(val) + except (json.JSONDecodeError, ValueError): + pass + try: + python_val = ast.literal_eval(val) + if isinstance(python_val, set): + python_val = sorted(python_val, key=str) + if isinstance(python_val, (complex, bytes)): + return val + json.dumps(python_val) + return python_val + except (ValueError, SyntaxError, TypeError): + return val + + def generate_tool_id() -> str: """Generate a unique tool call ID.""" return f"call_{uuid.uuid4().hex[:8]}" @@ -33,6 +53,7 @@ class QwenToolParser(ToolParser): Supports multiple Qwen tool call formats: - XML: {"name": "func", "arguments": {...}} - Bracket: [Calling tool: func_name({"arg": "value"})] + - Function: value Used when --enable-auto-tool-choice --tool-call-parser qwen are set. """ @@ -43,6 +64,12 @@ class QwenToolParser(ToolParser): # Pattern for bracket-style: [Calling tool: func_name({...})] BRACKET_PATTERN = re.compile(r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]", re.DOTALL) + # Pattern for function-style: ... + FUNCTION_PATTERN = re.compile(r"]+)>(.*?)", re.DOTALL) + + # Pattern for parameter extraction: value + PARAM_PATTERN = re.compile(r"]+)>\s*(.*?)\s*", re.DOTALL) + def extract_tool_calls( self, model_output: str, request: dict[str, Any] | None = None ) -> ExtractedToolCallInformation: @@ -101,6 +128,41 @@ def extract_tool_calls( if xml_matches: cleaned_text = self.XML_PATTERN.sub("", cleaned_text).strip() + # Try function-style: value + # Qwen3.5 generates this format natively. + if not tool_calls: + func_matches = self.FUNCTION_PATTERN.findall(cleaned_text) + for name, params_block in func_matches: + # Try JSON arguments first (e.g. {"key": "val"}) + params_block_stripped = params_block.strip() + if params_block_stripped.startswith("{"): + try: + arguments = json.loads(params_block_stripped) + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + continue + except json.JSONDecodeError: + pass + # Parse value tags + params = self.PARAM_PATTERN.findall(params_block) + arguments = {} + for p_name, p_value in params: + arguments[p_name.strip()] = _parse_param_value(p_value.strip()) + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + if func_matches: + cleaned_text = self.FUNCTION_PATTERN.sub("", cleaned_text).strip() + if tool_calls: return ExtractedToolCallInformation( tools_called=True, @@ -112,6 +174,30 @@ def extract_tool_calls( tools_called=False, tool_calls=[], content=model_output ) + # Partial marker prefixes — when current_text ends with one of these, + # we suppress output until the next token confirms or denies a tool call. + # These are long enough to avoid false positives on normal text. + _PARTIAL_MARKERS = (" bool: + """Check if text ends with an incomplete tool call marker prefix.""" + return self._get_partial_marker_len(text) > 0 + + def _get_partial_marker_len(self, text: str) -> int: + """Return the length of a partial tool call marker suffix at end of text.""" + tail = text[-20:] + best = 0 + for marker in self._PARTIAL_MARKERS: + for length in range(len(marker), 0, -1): + if tail.endswith(marker[:length]) and length > best: + best = length + break + return best + + def _was_buffering(self, previous_text: str) -> bool: + """Check if the previous call was buffering a partial marker.""" + return self._has_partial_marker(previous_text) + def extract_tool_calls_streaming( self, previous_text: str, @@ -125,14 +211,67 @@ def extract_tool_calls_streaming( """ Extract tool calls from streaming Qwen model output. """ - # Check for tool call markers + # Check for complete tool call markers has_tool_marker = ( - "" in current_text or "[Calling tool:" in current_text + "" in current_text + or "[Calling tool:" in current_text + or "... (Qwen3.5 native format) + if "") + prev_func_close = previous_text.count("") + + if current_text.count(" func_close_count: + # Inside an incomplete function block, suppress output + return None + + if func_close_count > prev_func_close: + # New function block(s) completed + result = self.extract_tool_calls(current_text) + if result.tools_called: + new_calls = result.tool_calls[prev_func_close:] + if new_calls: + return { + "tool_calls": [ + { + "index": prev_func_close + i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(new_calls) + ] + } + + return None + # If we're in a tool call, accumulate and parse at the end # For simplicity, return None during accumulation if "" in delta_text or ")]" in delta_text: diff --git a/vllm_mlx/utils/__init__.py b/vllm_mlx/utils/__init__.py index e808515ad..14d5de5c8 100644 --- a/vllm_mlx/utils/__init__.py +++ b/vllm_mlx/utils/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """Utility modules for vllm-mlx.""" +from .download import DownloadConfig, ensure_model_downloaded from .tokenizer import load_model_with_fallback -__all__ = ["load_model_with_fallback"] +__all__ = ["DownloadConfig", "ensure_model_downloaded", "load_model_with_fallback"] diff --git a/vllm_mlx/utils/download.py b/vllm_mlx/utils/download.py new file mode 100644 index 000000000..39941c7af --- /dev/null +++ b/vllm_mlx/utils/download.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Resumable model download with retry/timeout support. + +Pre-downloads models via huggingface_hub.snapshot_download() with +configurable timeout and retry logic before passing to mlx-lm/mlx-vlm. +""" + +import logging +import os +import time +from dataclasses import dataclass +from pathlib import Path + +from huggingface_hub import snapshot_download + +logger = logging.getLogger(__name__) + +# Mirrors mlx_lm.utils._download() default allow_patterns +LLM_ALLOW_PATTERNS = [ + "*.json", + "model*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "tiktoken.model", + "*.txt", + "*.jsonl", + "*.jinja", +] + +# Mirrors mlx_vlm.utils.get_model_path() allow_patterns +MLLM_ALLOW_PATTERNS = [ + "*.json", + "*.safetensors", + "*.py", + "*.model", + "*.tiktoken", + "*.txt", + "*.jinja", +] + + +@dataclass +class DownloadConfig: + """Configuration for model download behavior.""" + + download_timeout: int = 300 + max_retries: int = 3 + retry_backoff_base: float = 2.0 + offline: bool = False + + +def ensure_model_downloaded( + model_name: str, + config: DownloadConfig | None = None, + is_mllm: bool = False, +) -> Path: + """ + Ensure a model is available locally, downloading with retry if needed. + + Args: + model_name: HuggingFace model name or local path. + config: Download configuration. Uses defaults if None. + is_mllm: If True, use MLLM download patterns (broader file set). + + Returns: + Path to the local model directory. + + Raises: + RuntimeError: If download fails after all retries. + KeyboardInterrupt: Propagated immediately without retry. + """ + if config is None: + config = DownloadConfig() + + model_path = Path(model_name) + if model_path.exists(): + logger.info(f"Model found at local path: {model_path}") + return model_path + + if config.offline: + logger.info(f"Offline mode: looking for cached {model_name}") + try: + result = Path(snapshot_download(model_name, local_files_only=True)) + logger.info(f"Found cached model at {result}") + return result + except Exception as e: + raise RuntimeError( + f"Model '{model_name}' not found in local cache. " + f"Download it first without --offline flag." + ) from e + + allow_patterns = MLLM_ALLOW_PATTERNS if is_mllm else LLM_ALLOW_PATTERNS + + # Set HF download timeout via environment variable + old_timeout = os.environ.get("HF_HUB_DOWNLOAD_TIMEOUT") + os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = str(config.download_timeout) + + last_error = None + try: + for attempt in range(1, config.max_retries + 1): + try: + logger.info( + f"Downloading model {model_name} " + f"(attempt {attempt}/{config.max_retries}, " + f"timeout={config.download_timeout}s)" + ) + result = Path( + snapshot_download( + model_name, + allow_patterns=allow_patterns, + ) + ) + logger.info(f"Model downloaded successfully to {result}") + return result + except KeyboardInterrupt: + logger.warning("Download interrupted by user.") + raise + except Exception as e: + last_error = e + if attempt < config.max_retries: + wait = config.retry_backoff_base**attempt + logger.warning( + f"Download attempt {attempt} failed: {e}. " + f"Retrying in {wait:.0f}s..." + ) + time.sleep(wait) + else: + logger.error( + f"Download failed after {config.max_retries} attempts." + ) + + raise RuntimeError( + f"Failed to download '{model_name}' after {config.max_retries} " + f"attempts. Last error: {last_error}\n" + f"Run the same command again to resume the download." + ) + finally: + # Restore original env var + if old_timeout is None: + os.environ.pop("HF_HUB_DOWNLOAD_TIMEOUT", None) + else: + os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = old_timeout diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index a50883951..9d200ab9f 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -28,6 +28,27 @@ def _needs_tokenizer_fallback(model_name: str) -> bool: return any(pattern.lower() in model_lower for pattern in FALLBACK_MODELS) +def _needs_strict_false(model_name: str) -> bool: + """Check if model needs strict=False loading (VLM models with extra weights). + + VLM models (e.g., Qwen3.5) have vision_tower weights that don't match + the text-only model class. Loading with strict=True fails and wastes + memory by loading all weights (~100 GB) before raising ValueError. + Detect these models up-front to avoid the double-load penalty. + """ + from mlx_lm.utils import _download, load_config + + try: + model_path = _download(model_name) + config = load_config(model_path) + except Exception: + return False + # VLM models have vision_config or text_config with a separate model_type + if "vision_config" in config and "text_config" in config: + return True + return False + + def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): """ Load model and tokenizer with fallback for non-standard tokenizers. @@ -50,6 +71,15 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): ) return _load_with_tokenizer_fallback(model_name) + # VLM models (e.g., Qwen3.5) have extra vision weights that cause + # strict=True to fail. Skip the first load attempt to avoid loading + # ~100 GB of weights twice (which can cause OOM on 256 GB systems). + if _needs_strict_false(model_name): + logger.info( + f"Model {model_name} detected as VLM, loading directly with strict=False" + ) + return _load_strict_false(model_name, tokenizer_config) + try: model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) except ValueError as e: @@ -59,42 +89,89 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): return _load_with_tokenizer_fallback(model_name) # Fallback for models with extra weights (e.g., vision tower, MTP layers). # Retry with strict=False to discard extra weights. - if "parameters not in model" in str(e): + elif "parameters not in model" in str(e): logger.warning( f"Extra parameters found (e.g., vision tower / MTP weights), " f"retrying with strict=False: {e}" ) + # Clear traceback references to free memory from the failed first load. + # Without this, large models (200GB+) cause OOM during retry because + # the traceback holds references to the first load's weight tensors. + e.__traceback__ = None + del e + import gc + + gc.collect() return _load_strict_false(model_name, tokenizer_config) - raise + else: + raise + # After successful load, check if MTP weights exist but were stripped by sanitize() + _try_inject_mtp_post_load(model, model_name) + return model, tokenizer -def _load_strict_false(model_name: str, tokenizer_config: dict = None): - """Load model with strict=False to discard extra weights (e.g., vision tower, MTP).""" - from mlx_lm.utils import load_model, load_tokenizer - local_path = Path(model_name) - if local_path.is_dir(): - model_path = local_path - else: - from huggingface_hub import snapshot_download +def _load_strict_false(model_name: str, tokenizer_config: dict = None): + """Load model with strict=False to discard extra weights. - model_path = Path(snapshot_download(model_name)) + Handles models with extra parameters that the text-only model class + doesn't define (e.g., vision tower weights in VLM models like Qwen3.5, + or MTP layers). The model's own sanitize() handles key remapping + (e.g., language_model.* prefix), and strict=False silently drops + unmatched keys. + """ + import mlx.core as mx + from mlx_lm.utils import _download, load_model, load_tokenizer + model_path = _download(model_name) model, config = load_model(model_path, strict=False) + + # Verify weights loaded correctly + from mlx.utils import tree_flatten + + params = tree_flatten(model.parameters()) + total_params = len(params) + zero_params = sum(1 for _, v in params if mx.all(v == 0).item()) + logger.info( + f"[strict=False] Loaded {total_params} parameters, " + f"{zero_params} all-zero tensors" + ) + # Spot-check embedding weights + if hasattr(model, "language_model"): + emb = model.language_model.model.embed_tokens.weight + logger.info( + f"[strict=False] embed_tokens: shape={emb.shape}, " + f"dtype={emb.dtype}, mean={mx.mean(emb.astype(mx.float32)).item():.4f}" + ) + tokenizer = load_tokenizer( model_path, tokenizer_config or {}, eos_token_ids=config.get("eos_token_id", None), ) - # Inject MTP support if model has MTP config + weights _try_inject_mtp(model, model_path, config) return model, tokenizer def _try_inject_mtp(model, model_path, config): """Inject MTP support if model has MTP config + weights.""" + # Qwen3-Next: flat num_nextn_predict_layers if config.get("num_nextn_predict_layers", 0) > 0: - from ..patches.qwen3_next_mtp import inject_mtp_support + # Detect Qwen3.5 vs Qwen3-Next by checking text_config or model_type + text_config = config.get("text_config", config) + model_type = text_config.get("model_type", config.get("model_type", "")) + if "qwen3_5" in model_type: + from ..patches.qwen3_5_mtp import inject_mtp_support + else: + from ..patches.qwen3_next_mtp import inject_mtp_support + inject_mtp_support(model, model_path, config) + return + + # Qwen3.5: mtp_num_hidden_layers in text_config + text_config = config.get("text_config", config) + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + if num_mtp > 0: + from ..patches.qwen3_5_mtp import inject_mtp_support inject_mtp_support(model, model_path, config) @@ -111,13 +188,21 @@ def _try_inject_mtp_post_load(model, model_name): return with open(config_path) as f: config = json.load(f) - # Also check text_config for nested configs + # Check for MTP in flat config and nested text_config + text_config = config.get("text_config", {}) num_mtp = config.get("num_nextn_predict_layers", 0) if num_mtp == 0: - text_config = config.get("text_config", {}) num_mtp = text_config.get("num_nextn_predict_layers", 0) - if num_mtp > 0 and getattr(model, "mtp", None) is None: - mtp_file = Path(model_path) / "model-mtp.safetensors" + if num_mtp == 0: + num_mtp = text_config.get("mtp_num_hidden_layers", 0) + # Also check mtp attribute on language_model for VLM wrappers + check_model = model + if hasattr(model, "language_model"): + check_model = model.language_model + if num_mtp > 0 and getattr(check_model, "mtp", None) is None: + mtp_file = Path(model_path) / "mtp" / "weights.safetensors" + if not mtp_file.exists(): + mtp_file = Path(model_path) / "model-mtp.safetensors" if mtp_file.exists(): logger.info( f"[MTP] Found MTP config (layers={num_mtp}) and weights, injecting..." @@ -126,7 +211,7 @@ def _try_inject_mtp_post_load(model, model_name): else: logger.info( f"[MTP] Config has num_nextn_predict_layers={num_mtp} " - "but model-mtp.safetensors not found, skipping MTP." + "but MTP weights not found, skipping MTP." ) @@ -134,16 +219,12 @@ def _load_with_tokenizer_fallback(model_name: str): """Load model with fallback tokenizer for non-standard models like Nemotron.""" from mlx_lm.utils import load_model - logger.info("Loading with tokenizer fallback...") + from .download import ensure_model_downloaded - # Get model path - use local path if it exists, otherwise download from Hub - local_path = Path(model_name) - if local_path.is_dir(): - model_path = local_path - else: - from huggingface_hub import snapshot_download + logger.info("Loading with tokenizer fallback...") - model_path = Path(snapshot_download(model_name)) + # Get model path (with retry/timeout support) + model_path = ensure_model_downloaded(model_name, is_mllm=False) # Load model model, _ = load_model(model_path)