diff --git a/README.md b/README.md index de6924c..9cd1f8b 100644 --- a/README.md +++ b/README.md @@ -299,11 +299,14 @@ All 17 parsers include automatic recovery — if a quantized model outputs broke | **GPT-OSS 20B** | **127** tok/s · 100% tools | 79 (mlx-lm serve) | **1.6x** | | **Qwen3.5-9B** | **108** tok/s | 46 (Ollama) | **2.3x** | | **Kimi-Linear-48B** | **94** tok/s · 100% tools | — (only engine) | — | +| 🆕 **Gemma 4 26B-A4B** | **94** tok/s · 100% tools | — (day-0, only engine) | — | +| 🆕 **Gemma 4 E4B** | **83** tok/s · 100% tools | — (day-0, only engine) | — | | **Qwen3.5-35B-A3B** | **83** tok/s · 100% tools | 75 (oMLX) | **1.1x** | | **Qwen3-Coder 80B** | **74** tok/s · 100% tools | 69 (mlx-lm serve) | **1.1x** | | **Qwen3.5-122B** | **44** tok/s · 100% tools | 43 (mlx-lm serve) | ~1.0x | +| 🆕 **Gemma 4 31B** | **31** tok/s · 100% tools | 10.9 (mlx-vlm bf16) | **2.8x** | -*Full benchmark data with all 18 models, TTFT tables, DeltaNet snapshots, and engine comparison below.* +*Full benchmark data with all models, TTFT tables, DeltaNet snapshots, and engine comparison below.*
TTFT — Prompt Cache Advantage @@ -325,6 +328,9 @@ Prompt cache keeps multi-turn conversations fast. For standard transformers, KV | Qwen3-Coder-Next 80B | **0.16s** | 0.27s | 1.7x | | GPT-OSS 20B | **0.16s** | 0.27s | 1.7x | | Qwen3.5-9B | **0.22s** | 0.26s | 1.2x | +| 🆕 Gemma 4 E4B | **0.25s** | — (day-0) | — | +| 🆕 Gemma 4 26B-A4B | **0.25s** | — (day-0) | — | +| 🆕 Gemma 4 31B | **0.34s** | 0.57s (mlx-vlm bf16) | **1.7x** | **DeltaNet state snapshots (hybrid RNN + attention):** @@ -368,7 +374,7 @@ Qwen3.5 uses Gated DeltaNet (75% RNN) + full attention (25% KV). Other engines r | **DeltaNet state snapshots** | Deep-copy RNN state at prefix boundary, restore in ~0.1ms | Qwen3.5 (4B, 9B, 27B, 35B, 122B), Qwen3-Coder-Next | | **Hybrid cache sync** | Keep trimmable KV + non-trimmable RNN layers in sync | Qwen3.5 (Gated DeltaNet + attention) | | **Tool logits bias** | Jump-forward decoding — bias logits toward structured tokens | All models with `--enable-tool-logits-bias` | -| **Auto tool recovery** | Detect broken text-format tool calls, convert to structured | All 17 parser formats | +| **Auto tool recovery** | Detect broken text-format tool calls, convert to structured | All 18 parser formats (incl. Gemma 4) | | **Speculative decoding** | Draft model generates candidates, main model verifies | Any model + `--draft-model` | | **KV quantization** | 4/8-bit KV cache for longer contexts in less memory | All models with `--kv-bits` | | **Prefill chunking** | Configurable step size for large-prompt throughput | All models | @@ -379,10 +385,13 @@ Qwen3.5 uses Gated DeltaNet (75% RNN) + full attention (25% KV). Other engines r
Eval benchmarks (17 models, 4 suites) -17 models across tool calling (30 scenarios), coding (HumanEval+), reasoning (MATH-500), and general knowledge (MMLU-Pro). All with `enable_thinking: false` on M3 Ultra. +19 models across tool calling (30 scenarios), coding (HumanEval+), reasoning (MATH-500), and general knowledge (MMLU-Pro). All with `enable_thinking: false` on M3 Ultra. 🆕 = Gemma 4 (day-0 support). | Model | Quant | RAM | Decode | Tools | Code | Reason | General | Avg | |-------|-------|-----|--------|-------|------|--------|---------|-----| +| 🆕 Gemma 4 26B-A4B | 4bit | 14.4 GB | 94 t/s | **100%** | — | — | — | — | +| 🆕 Gemma 4 E4B | 4bit | 6.4 GB | 83 t/s | **100%** | — | — | — | — | +| 🆕 Gemma 4 31B | 4bit | 17.0 GB | 31 t/s | **100%** | — | — | — | — | | Qwen3.5-122B-A10B | 8bit | 129.8 GB | 44 t/s | 87% | **90%** | **90%** | **90%** | **89%** | | Qwen3.5-122B-A10B | mxfp4 | 65.0 GB | 57 t/s | **90%** | **90%** | 80% | **90%** | 88% | | Qwen3.5-35B-A3B | 8bit | 36.9 GB | 83 t/s | **90%** | **90%** | 80% | 80% | 85% | diff --git a/pyproject.toml b/pyproject.toml index 53434ee..e6ccdf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "rapid-mlx" -version = "0.2.7" +version = "0.4.0" description = "Rapid-MLX — AI inference for Apple Silicon. Drop-in OpenAI API, 2-4x faster than Ollama." readme = "README.md" license = {text = "Apache-2.0"} @@ -31,7 +31,7 @@ dependencies = [ # Core — these are all you need for `rapid-mlx serve ` "mlx>=0.29.0", "mlx-lm>=0.31.0", # 0.31+ required for ArraysCache native batching (hybrid models) - "mlx-vlm>=0.1.0", # VLM support + "mlx-vlm>=0.4.4", # 0.4.4+ required for Gemma 4 support "transformers>=5.0.0", # mlx-lm 0.30.5+ requires transformers 5.0 (rc3 bug fixed in stable) "tokenizers>=0.19.0", "huggingface-hub>=0.23.0", diff --git a/reports/benchmarks/gemma4-26b-a4b-4bit.json b/reports/benchmarks/gemma4-26b-a4b-4bit.json new file mode 100644 index 0000000..58935ea --- /dev/null +++ b/reports/benchmarks/gemma4-26b-a4b-4bit.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/mlx-models/gemma-4-26b-a4b-it-4bit", + "short_decode_tps": { + "mean": 91.99284441095028, + "median": 91.9997184259893, + "min": 91.90088449538918, + "max": 92.07793031147239 + }, + "short_prefill_tps": { + "median": 104.6567741055699 + }, + "long_decode_tps": { + "mean": 90.20737344249109, + "median": 90.15377328174793, + "min": 90.10046743781253, + "max": 90.36787960791281 + }, + "long_prefill_tps": { + "median": 404.29261432785694 + }, + "ttft_cold_s": 0.6939874159870669, + "ttft_cached_s": 0.25729072900139727, + "multi_turn_ttft_cold_s": 0.3698056659777649, + "multi_turn_ttft_cached_s": 0.257844791514799, + "peak_ram_mb": 14697.5, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/gemma4-31b-4bit.json b/reports/benchmarks/gemma4-31b-4bit.json new file mode 100644 index 0000000..8af36bc --- /dev/null +++ b/reports/benchmarks/gemma4-31b-4bit.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/mlx-models/gemma-4-31b-it-4bit-local", + "short_decode_tps": { + "mean": 30.658969382957626, + "median": 30.650768843603235, + "min": 30.636982911649486, + "max": 30.68915639362016 + }, + "short_prefill_tps": { + "median": 77.82051641509116 + }, + "long_decode_tps": { + "mean": 29.81875147521854, + "median": 29.834616923274258, + "min": 29.772306347596366, + "max": 29.849331154785002 + }, + "long_prefill_tps": { + "median": 318.28190673479276 + }, + "ttft_cold_s": 9.772502000036184, + "ttft_cached_s": 0.34381089551607147, + "multi_turn_ttft_cold_s": 0.7450880000251345, + "multi_turn_ttft_cached_s": 0.34492891700938344, + "peak_ram_mb": 17363.453125, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/gemma4-31b-bf16-mllm.json b/reports/benchmarks/gemma4-31b-bf16-mllm.json new file mode 100644 index 0000000..d2c8654 --- /dev/null +++ b/reports/benchmarks/gemma4-31b-bf16-mllm.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/mlx-models/gemma-4-31b-it-bf16", + "short_decode_tps": { + "mean": 7.684495219859486, + "median": 7.685015108337882, + "min": 7.683350416504045, + "max": 7.685120134736532 + }, + "short_prefill_tps": { + "median": 49.61073354493354 + }, + "long_decode_tps": { + "mean": 6.150148014216069, + "median": 6.149420465554755, + "min": 6.148029410337342, + "max": 6.152994166756111 + }, + "long_prefill_tps": { + "median": 130.33556741428563 + }, + "ttft_cold_s": 0.8671420829778071, + "ttft_cached_s": 0.503123354021227, + "multi_turn_ttft_cold_s": 0.878063625015784, + "multi_turn_ttft_cached_s": 0.8742528125003446, + "peak_ram_mb": 60796.328125, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": false, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/gemma4-31b-bf16.json b/reports/benchmarks/gemma4-31b-bf16.json new file mode 100644 index 0000000..95705f8 --- /dev/null +++ b/reports/benchmarks/gemma4-31b-bf16.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/mlx-models/gemma-4-31b-it-bf16", + "short_decode_tps": { + "mean": 10.877661903575952, + "median": 10.881409537294747, + "min": 10.8682413908779, + "max": 10.883334782555206 + }, + "short_prefill_tps": { + "median": 46.99511568078357 + }, + "long_decode_tps": { + "mean": 10.730247908489643, + "median": 10.733271737703564, + "min": 10.722421680460178, + "max": 10.735050307305189 + }, + "long_prefill_tps": { + "median": 186.58741680810584 + }, + "ttft_cold_s": 76.44581962499069, + "ttft_cached_s": 0.5739832909894176, + "multi_turn_ttft_cold_s": 1.105832208006177, + "multi_turn_ttft_cached_s": 0.5784412914945278, + "peak_ram_mb": 59444.0625, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/reports/benchmarks/gemma4-e4b-4bit.json b/reports/benchmarks/gemma4-e4b-4bit.json new file mode 100644 index 0000000..e7f49e0 --- /dev/null +++ b/reports/benchmarks/gemma4-e4b-4bit.json @@ -0,0 +1,34 @@ +[ + { + "engine": "Rapid-MLX", + "model": "/Volumes/Extreme SSD/mlx-models/gemma-4-e4b-it-4bit-local", + "short_decode_tps": { + "mean": 82.22621304400961, + "median": 82.2157561516956, + "min": 82.17578086740563, + "max": 82.28710211292758 + }, + "short_prefill_tps": { + "median": 101.84400488869173 + }, + "long_decode_tps": { + "mean": 79.74346950172897, + "median": 80.09642988741999, + "min": 78.86214671758383, + "max": 80.27183190018309 + }, + "long_prefill_tps": { + "median": 349.3339133508353 + }, + "ttft_cold_s": 2.396504874981474, + "ttft_cached_s": 0.2615705000353046, + "multi_turn_ttft_cold_s": 0.3181427090312354, + "multi_turn_ttft_cached_s": 0.25800837500719354, + "peak_ram_mb": 6519.265625, + "tool_call_rate": 1.0, + "recovery_rate": 1.0, + "leak_rate": 0.0, + "vision": true, + "audio": false + } +] \ No newline at end of file diff --git a/tests/test_output_router.py b/tests/test_output_router.py new file mode 100644 index 0000000..da2b0c3 --- /dev/null +++ b/tests/test_output_router.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for the token-level OutputRouter. + +Uses real Gemma 4 token IDs (from tokenizer vocabulary) to verify +routing correctness without any text-level matching. +""" + +import pytest + +from vllm_mlx.output_router import Channel, OutputRouter, RouterState, TokenMap + + +# === Gemma 4 Token IDs (from tokenizer) === +GEMMA4_MAP = TokenMap( + channel_start=100, # <|channel> + channel_end=101, # + thought_word=45518, # "thought" + content_word=3955, # "content" + final_word=10218, # "final" + turn_start=105, # <|turn> + turn_end=106, # + tool_call_start=48, # <|tool_call> + tool_call_end=49, # + tool_quote=52, # <|"|> + tool_start=46, # <|tool> + tool_end=47, # + tool_response_start=50, # <|tool_response> + tool_response_end=51, # + bos=2, + eos=1, + pad=0, +) + + +class FakeTokenizer: + """Minimal tokenizer that maps token IDs to text.""" + + def __init__(self, vocab: dict[str, int]): + self._id_to_text = {v: k for k, v in vocab.items()} + self._vocab = vocab + + def decode(self, ids: list[int]) -> str: + return "".join(self._id_to_text.get(i, f"") for i in ids) + + def get_vocab(self) -> dict[str, int]: + return self._vocab + + +# Gemma 4 vocabulary (subset for testing) +VOCAB = { + "": 0, "": 1, "": 2, + "<|tool>": 46, "": 47, + "<|tool_call>": 48, "": 49, + "<|tool_response>": 50, "": 51, + '<|"|>': 52, + "<|channel>": 100, "": 101, + "<|turn>": 105, "": 106, + "\n": 107, + "thought": 45518, "content": 3955, "final": 10218, + "Hello": 9259, "Four": 73440, "call": 6639, ":": 236787, + "get": 828, "_": 236779, "weather": 19323, + "{": 236782, "}": 236783, "city": 13319, + "Tokyo": 89265, " ": 235248, + "The": 651, "user": 2364, "wants": 10388, +} + +TOKENIZER = FakeTokenizer(VOCAB) + + +@pytest.fixture +def router(): + r = OutputRouter(GEMMA4_MAP, TOKENIZER) + r.reset() + return r + + +class TestBasicRouting: + """Test fundamental token routing.""" + + def test_content_passthrough(self, router): + """Plain content tokens go to CONTENT channel.""" + event = router.feed(9259) # "Hello" + assert event is not None + assert event.channel == Channel.CONTENT + assert event.text == "Hello" + + def test_bos_eos_pad_suppressed(self, router): + """Control tokens are suppressed.""" + assert router.feed(0) is None # pad + assert router.feed(1) is None # eos + assert router.feed(2) is None # bos + + def test_turn_tokens_suppressed(self, router): + """Turn markers are suppressed.""" + assert router.feed(105) is None # <|turn> + assert router.feed(106) is None # + + +class TestThinkingChannel: + """Test thought channel routing.""" + + def test_thought_channel_detected(self, router): + """<|channel> + thought → THINKING state, tokens go to REASONING.""" + assert router.feed(100) is None # <|channel> suppressed + assert router.feed(45518) is None # "thought" suppressed + assert router.state == RouterState.THINKING + + event = router.feed(651) # "The" + assert event is not None + assert event.channel == Channel.REASONING + + def test_thought_ends_at_channel_close(self, router): + """ ends thinking, switches to CONTENT.""" + router.feed(100) # <|channel> + router.feed(45518) # thought + router.feed(651) # "The" (reasoning) + + assert router.feed(101) is None # suppressed + assert router.state == RouterState.CONTENT + + event = router.feed(73440) # "Four" + assert event is not None + assert event.channel == Channel.CONTENT + + def test_thought_then_content_channel(self, router): + """Full cycle: thought → → content channel → answer.""" + # Thought channel + router.feed(100) # <|channel> + router.feed(45518) # thought + e1 = router.feed(651) # "The" + assert e1.channel == Channel.REASONING + + # End thought + router.feed(101) # + + # Content channel + router.feed(100) # <|channel> + router.feed(3955) # content + e2 = router.feed(73440) # "Four" + assert e2.channel == Channel.CONTENT + + def test_implicit_content_after_thought(self, router): + """After , if no new <|channel>, tokens are content.""" + router.feed(100) # <|channel> + router.feed(45518) # thought + router.feed(651) # reasoning + router.feed(101) # + + # No explicit content channel — should still be content + e = router.feed(73440) # "Four" + assert e.channel == Channel.CONTENT + + +class TestToolCallRouting: + """Test tool call accumulation.""" + + def test_tool_call_accumulated(self, router): + """Tokens between <|tool_call> and are accumulated.""" + assert router.feed(48) is None # <|tool_call> → accumulate + assert router.feed(6639) is None # "call" → accumulate + assert router.feed(236787) is None # ":" → accumulate + + def test_tool_call_emitted_on_close(self, router): + """Complete tool call emitted as TOOL_CALL event.""" + router.feed(48) # <|tool_call> + router.feed(6639) # call + router.feed(236787) # : + router.feed(828) # get + router.feed(236779) # _ + router.feed(19323) # weather + router.feed(236782) # { + router.feed(13319) # city + router.feed(236787) # : + router.feed(52) # <|"|> + router.feed(89265) # Tokyo + router.feed(52) # <|"|> + router.feed(236783) # } + + event = router.feed(49) # + assert event is not None + assert event.channel == Channel.TOOL_CALL + assert "get_weather" in event.text + assert "Tokyo" in event.text + + def test_content_after_tool_call(self, router): + """After tool call completes, back to content mode.""" + router.feed(48) # <|tool_call> + router.feed(6639) # call + router.feed(49) # + + event = router.feed(9259) # "Hello" + assert event.channel == Channel.CONTENT + + +class TestOrphanTokens: + """Test handling of orphaned/leaked special tokens.""" + + def test_orphan_tool_call_end_suppressed(self, router): + """ without <|tool_call> should be suppressed.""" + assert router.feed(49) is None # orphan + + def test_orphan_tool_response_suppressed(self, router): + """Tool response markers should always be suppressed.""" + assert router.feed(50) is None # <|tool_response> + assert router.feed(51) is None # + + def test_orphan_tool_markers_suppressed(self, router): + """<|tool> and should be suppressed.""" + assert router.feed(46) is None # <|tool> + assert router.feed(47) is None # + + def test_content_after_orphan_tokens(self, router): + """Content after orphan tokens routes correctly.""" + router.feed(49) # orphan + router.feed(51) # orphan + event = router.feed(9259) # "Hello" + assert event is not None + assert event.channel == Channel.CONTENT + + +class TestFeedSequence: + """Test batch processing of token sequences.""" + + def test_thought_then_content(self, router): + """Process a full thought→content sequence.""" + tokens = [ + 100, 45518, 107, # <|channel> thought \n + 651, 2364, # "The" "user" (reasoning) + 101, # + 100, 3955, 107, # <|channel> content \n + 73440, # "Four" (content) + 101, # + ] + result = router.feed_sequence(tokens) + assert result["content"] == "Four" + assert "The" in result["reasoning"] + assert result["tool_calls"] is None + + def test_tool_call_sequence(self, router): + """Process a tool call sequence.""" + tokens = [ + 48, # <|tool_call> + 6639, 236787, # call : + 828, 236779, 19323, # get _ weather + 236782, # { + 13319, 236787, # city : + 52, 89265, 52, # <|"|> Tokyo <|"|> + 236783, # } + 49, # + ] + result = router.feed_sequence(tokens) + assert result["tool_calls"] is not None + assert len(result["tool_calls"]) == 1 + assert "Tokyo" in result["tool_calls"][0] + + def test_plain_content(self, router): + """No special tokens → all content.""" + tokens = [9259, 235248, 73440] # Hello Four + result = router.feed_sequence(tokens) + assert result["content"] is not None + assert result["reasoning"] is None + assert result["tool_calls"] is None + + +class TestFromTokenizer: + """Test auto-detection from tokenizer.""" + + def test_gemma4_detected(self): + """Gemma 4 tokenizer auto-detected.""" + router = OutputRouter.from_tokenizer(TOKENIZER) + assert router is not None + assert router.map.channel_start == 100 + + def test_unknown_tokenizer(self): + """Non-Gemma tokenizer returns None.""" + plain_vocab = {"hello": 1, "world": 2} + plain_tok = FakeTokenizer(plain_vocab) + router = OutputRouter.from_tokenizer(plain_tok) + assert router is None + + +class TestStateReset: + """Test state management.""" + + def test_reset_clears_state(self, router): + """Reset returns to INIT state.""" + router.feed(100) # enter channel + router.feed(45518) # thinking + assert router.state == RouterState.THINKING + + router.reset() + assert router.state == RouterState.INIT + assert router._tool_tokens == [] + + def test_multiple_requests(self, router): + """Router works correctly across multiple reset cycles.""" + # Request 1 + router.feed(100); router.feed(45518) # thinking + e1 = router.feed(651) + assert e1.channel == Channel.REASONING + + # Request 2 + router.reset() + e2 = router.feed(9259) # "Hello" + assert e2.channel == Channel.CONTENT diff --git a/vllm_mlx/api/__init__.py b/vllm_mlx/api/__init__.py index 666962a..302d085 100644 --- a/vllm_mlx/api/__init__.py +++ b/vllm_mlx/api/__init__.py @@ -71,6 +71,7 @@ extract_multimodal_content, is_mllm_model, is_vlm_model, + sanitize_output, strip_special_tokens, ) @@ -119,6 +120,7 @@ "extract_multimodal_content", "MLLM_PATTERNS", "SPECIAL_TOKENS_PATTERN", + "sanitize_output", "strip_special_tokens", "StreamingToolCallFilter", "StreamingThinkRouter", diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 069e5ea..6bf5a74 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -21,6 +21,7 @@ r"<\|end\|>|<\|eot_id\|>|<\|eom_id\|>|<\|python_tag\|>|" r"<\|start_header_id\|>|<\|end_header_id\|>|" r"<\|channel\|>|<\|message\|>|<\|start\|>|<\|return\|>|<\|call\|>|<\|constrain\|>|" + r"<\|turn>||" r"|||\[PAD\]|\[SEP\]|\[CLS\]|" r"\[e~\[|\]~b\][a-z]*|\]~!b\[" ) @@ -43,6 +44,44 @@ def strip_special_tokens(text: str) -> str: return text +# ============================================================================= +# Final sanitizer — last-mile catch-all before content reaches the client. +# Catches ANY remaining markup that earlier layers missed, including: +# - All <|..> and <..|> asymmetric tokens (Gemma 4 style) +# - All <|..|> symmetric tokens (Qwen, GPT-OSS style) +# - [Calling tool:...] text-format tool calls +# - Stray , , etc. +# ============================================================================= + +_FINAL_SANITIZER = re.compile( + # Any <|...> or <...|> token (Gemma 4 asymmetric: <|channel>, , etc.) + r"<\|[a-z_\"]+>|<[a-z_\"]+\|>" + # Any <|...|> token (symmetric: <|im_end|>, <|channel|>, etc.) + r"|<\|[a-z_]+\|>" + # [Calling tool:...] or [Calling tool="..."] + r"|\[Calling\s+tool[=:][^\]]*\]" + # Stray closing tags + r"||" +) + + +def sanitize_output(text: str) -> str: + """Final catch-all sanitizer for client-facing content. + + This is the LAST defense against markup leakage. Runs after all + parsers and filters. Strips anything that looks like a special token + or internal markup pattern. + + Designed to be aggressive — better to over-strip than to leak. + """ + if not text: + return text + for ch in text: + if ch in _SPECIAL_TOKEN_CHARS: + return _FINAL_SANITIZER.sub("", text).strip() + return text + + # Regex for matching final channel marker with optional constrain token: # <|channel|>final<|message|> # <|channel|>final <|constrain|>JSON<|message|> @@ -123,8 +162,14 @@ def clean_output_text(text: str) -> str: return text -# Pattern to match ... blocks -THINK_PATTERN = re.compile(r"[\s\S]*?\s*", re.DOTALL) +# Pattern to match thinking blocks: +# - ... (Qwen, DeepSeek, etc.) +# - <|channel>thought\n... (Gemma 4) +THINK_PATTERN = re.compile( + r"[\s\S]*?\s*" + r"|<\|channel>thought\n[\s\S]*?\s*", + re.DOTALL, +) def strip_thinking_tags(text: str) -> str: diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index 328dc87..3fbb2cd 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -1063,12 +1063,13 @@ def main(): "minimax", "harmony", "gpt-oss", + "gemma4", ], help=( "Select the tool call parser for the model. Options: " "auto (auto-detect), mistral, qwen, qwen3_coder, llama, hermes, " "deepseek, kimi, granite, nemotron, xlam, functionary, glm47, minimax, " - "harmony/gpt-oss. " + "harmony/gpt-oss, gemma4. " "Required for --enable-auto-tool-choice." ), ) diff --git a/vllm_mlx/engine/base.py b/vllm_mlx/engine/base.py index d05bb93..22692fd 100644 --- a/vllm_mlx/engine/base.py +++ b/vllm_mlx/engine/base.py @@ -27,6 +27,8 @@ class GenerationOutput: finished: bool = True # Per-token logprobs (mx.array of shape [vocab_size] for current token) logprobs: Any = None + # Semantic channel: "content", "reasoning", "tool_call", or None + channel: str | None = None class BaseEngine(ABC): diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 5d10e94..0a7044a 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -460,6 +460,7 @@ async def stream_generate( finished=finished, finish_reason=finish_reason, logprobs=getattr(chunk, "logprobs", None), + channel=getattr(chunk, "channel", None), ) # Yield to event loop periodically so the server can @@ -555,16 +556,20 @@ async def chat( if self._is_mllm: # For MLLM with media, use the chat method which handles images/videos # Run in thread pool to allow asyncio timeout to work - output = await asyncio.to_thread( - self._model.chat, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - stop=stop, - tools=template_tools, - **kwargs, - ) + try: + output = await asyncio.to_thread( + self._model.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stop=stop, + tools=template_tools, + **kwargs, + ) + except Exception as e: + logger.error("MLLM chat() failed: %s", e, exc_info=True) + raise return GenerationOutput( text=output.text, prompt_tokens=output.prompt_tokens, @@ -744,6 +749,18 @@ async def stream_chat( chunk = await asyncio.to_thread(next, sync_gen) except StopIteration: break + except Exception as e: + # Some VLM models (e.g. Gemma 4) raise during + # generator cleanup after generation completes. + # If we already have output, treat as finished. + if token_count > 0: + logger.warning( + "MLLM stream_chat error after %d tokens " + "(likely post-generation cleanup): %s", + token_count, e, + ) + break + raise token_count += 1 new_text = chunk.text if hasattr(chunk, "text") else str(chunk) diff --git a/vllm_mlx/model_auto_config.py b/vllm_mlx/model_auto_config.py index bd87333..b28475a 100644 --- a/vllm_mlx/model_auto_config.py +++ b/vllm_mlx/model_auto_config.py @@ -72,7 +72,12 @@ class ModelConfig: tool_call_parser="hermes", reasoning_parser=None, )), - # Gemma + # Gemma 4 (native tool format) + (re.compile(r"gemma[-_]?4", re.IGNORECASE), ModelConfig( + tool_call_parser="gemma4", + reasoning_parser="gemma4", + )), + # Gemma 2/3 (hermes format) (re.compile(r"gemma", re.IGNORECASE), ModelConfig( tool_call_parser="hermes", reasoning_parser=None, diff --git a/vllm_mlx/models/gemma4_text.py b/vllm_mlx/models/gemma4_text.py new file mode 100644 index 0000000..b7290a5 --- /dev/null +++ b/vllm_mlx/models/gemma4_text.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Gemma 4 text-only model loader for the LLM path. + +mlx-lm doesn't support gemma4 yet, but mlx-vlm does. This module loads +just the language model portion from mlx-vlm and wraps it to be compatible +with mlx-lm's generate_step() interface, enabling: +- Prompt cache (KV reuse across requests) +- DeltaNet state snapshots (if applicable) +- All LLM-path optimizations + +The wrapper is thin: it just ensures model(input_ids, cache=cache) returns +a raw logits tensor instead of LanguageModelOutput. + +TODO: Remove once mlx-lm adds native gemma4 support. +""" + +import json +import logging +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn + +logger = logging.getLogger(__name__) + + +def is_gemma4_model(model_path: str | Path) -> bool: + """Check if the model at the given path is a Gemma 4 model.""" + p = Path(model_path) + config_path = p / "config.json" if p.is_dir() else None + if config_path is None or not config_path.exists(): + # Try HF cache + try: + from huggingface_hub import snapshot_download + p = Path(snapshot_download(str(model_path))) + config_path = p / "config.json" + except Exception: + return False + if not config_path.exists(): + return False + try: + config = json.loads(config_path.read_text()) + model_type = config.get("model_type", "") + return "gemma4" in model_type + except Exception: + return False + + +class Gemma4TextWrapper(nn.Module): + """Wraps mlx-vlm's Gemma4 LanguageModel for mlx-lm compatibility. + + mlx-lm's generate_step() expects model(input_ids, cache=cache) -> logits. + mlx-vlm's LanguageModel returns LanguageModelOutput(logits=...). + This wrapper extracts .logits so the interface matches. + """ + + def __init__(self, language_model): + super().__init__() + self.language_model = language_model + # Expose config for mlx-lm compatibility + self.config = language_model.config + self.model = language_model.model + self.model_type = getattr(language_model, "model_type", "gemma4") + + def __call__(self, input_ids, cache=None, **kwargs): + out = self.language_model(input_ids, cache=cache, **kwargs) + # LanguageModelOutput -> raw logits tensor + return out.logits if hasattr(out, "logits") else out + + def sanitize(self, weights): + """Strip language_model. prefix from VLM-format weights.""" + sanitized = {} + for k, v in weights.items(): + new_key = k + # Strip top-level "model." wrapper + if new_key.startswith("model."): + new_key = new_key[len("model."):] + # Strip "language_model." to get bare model weights, + # then re-add "language_model." for our wrapper structure + if new_key.startswith("language_model."): + pass # keep as-is — our wrapper has .language_model attribute + elif not any(new_key.startswith(p) for p in + ["vision_tower", "audio_tower", "embed_vision", "embed_audio"]): + new_key = "language_model." + new_key + else: + continue # skip vision/audio weights + # Skip rotary embeddings (computed dynamically) + if "rotary_emb" in new_key: + continue + # Skip clipping params (vision-only) + if any(s in new_key for s in ["input_max", "input_min", "output_max", "output_min"]): + continue + sanitized[new_key] = v + return sanitized + + def make_cache(self): + """Delegate to LanguageModel for proper sliding window + full attention cache.""" + return self.language_model.make_cache() + + @property + def layers(self): + return self.language_model.layers + + @property + def head_dim(self): + return self.language_model.head_dim + + @property + def n_kv_heads(self): + return self.language_model.n_kv_heads + + +def load_gemma4_text(model_path: str | Path, tokenizer_config: dict = None): + """Load Gemma 4 as a text-only model via the LLM path. + + Returns (model, tokenizer) compatible with mlx-lm's generate_step(). + """ + from mlx_lm.utils import load_tokenizer + + p = Path(model_path) + if not p.is_dir(): + from huggingface_hub import snapshot_download + p = Path(snapshot_download(str(model_path))) + + config = json.loads((p / "config.json").read_text()) + text_config = config.get("text_config", config) + + # Build the language model from mlx-vlm + from mlx_vlm.models.gemma4.config import TextConfig + from mlx_vlm.models.gemma4.language import LanguageModel + + tc = TextConfig.from_dict(text_config) + language_model = LanguageModel(tc) + + # Wrap for mlx-lm compatibility + model = Gemma4TextWrapper(language_model) + + # Apply quantization config if present (converts Linear → QuantizedLinear) + quant_config = config.get("quantization", config.get("quantization_config")) + if quant_config: + default_bits = quant_config.get("bits", 4) + default_gs = quant_config.get("group_size", 64) + + # Build per-layer override map from config (mixed quantization) + # Keys like "language_model.model.layers.0.mlp.gate_proj" → {bits:8, group_size:64} + overrides = {} + for k, v in quant_config.items(): + if isinstance(v, dict) and "bits" in v: + overrides[k] = { + kk: vv for kk, vv in v.items() + if kk in ("bits", "group_size", "mode") + } + + if overrides: + logger.info( + "[gemma4] Mixed quantization: %d-bit default, %d overrides (8-bit MLP)", + default_bits, len(overrides), + ) + + def _class_predicate(path, module): + if not hasattr(module, "to_quantized"): + return False + # Check per-layer overrides + # Override keys use "language_model.model.layers..." but nn.quantize + # sees "model.layers..." (relative to wrapper). Match by suffix. + for override_path, override_cfg in overrides.items(): + # Strip common prefixes for matching + suffix = override_path.split("language_model.model.")[-1] + if path.endswith(suffix): + return override_cfg + return {"bits": default_bits, "group_size": default_gs} + + nn.quantize(model, class_predicate=_class_predicate) + else: + logger.info("[gemma4] Applying %d-bit quantization (group_size=%d)", default_bits, default_gs) + nn.quantize(model, bits=default_bits, group_size=default_gs) + + # Load weights + weight_files = sorted( + f for f in p.glob("*.safetensors") + if not f.name.startswith("._") + ) + if not weight_files: + raise FileNotFoundError(f"No .safetensors files in {p}") + raw_weights = {} + for wf in weight_files: + raw_weights.update(mx.load(str(wf))) + + # Sanitize and load + sanitized = model.sanitize(raw_weights) + model.load_weights(list(sanitized.items()), strict=False) + + # Verify weights loaded + test_param = model.language_model.model.embed_tokens + if hasattr(test_param, "scales") and mx.all(test_param.scales == 0).item(): + logger.warning("[gemma4] Embedding scales are zero — quantized model may have issues") + + # Load tokenizer + tokenizer_config = tokenizer_config or {} + eos_token_ids = config.get("eos_token_id", text_config.get("eos_token_id")) + tokenizer = load_tokenizer(p, tokenizer_config, eos_token_ids=eos_token_ids) + + logger.info("[gemma4] Loaded text-only model via LLM path (%d layers)", len(model.layers)) + return model, tokenizer diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 632d57d..ea9f026 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -35,6 +35,7 @@ class StreamingOutput: token: int finished: bool = False finish_reason: str | None = None + channel: str | None = None # "content", "reasoning", "tool_call", or None (unrouted) logprobs: Any = None # mx.array of shape [vocab_size] from mlx-lm prompt_tokens: int = 0 @@ -98,6 +99,9 @@ def __init__( self._cached_token_ids: list[int] = [] self._cache_lock = False # Simple guard against concurrent use + # Token-level output router (set in load() if model supports it) + self._output_router = None + # DeltaNet/hybrid cache snapshot for prefix reuse self._rnn_state_snapshot: list | None = None # deep-copied ArraysCache states self._snapshot_prefix_ids: list[int] = [] # token IDs at snapshot time @@ -152,6 +156,13 @@ def load(self) -> None: self._loaded = True logger.info(f"Model loaded successfully: {self.model_name}") + # Initialize token-level output router (if model supports it) + from ..output_router import OutputRouter + + self._output_router = OutputRouter.from_tokenizer(self.tokenizer) + if self._output_router: + logger.info("Token-level OutputRouter enabled") + except ImportError: raise ImportError( "mlx-lm is required for LLM inference. Install with: pip install mlx-lm" @@ -594,6 +605,9 @@ def stream_generate( token_count = 0 accumulated_text = "" + # Reset output router for new request + if self._output_router: + self._output_router.reset() # Use IncrementalDecoder with skip_special_tokens=False to preserve # control tokens (e.g. Harmony's <|channel|>, <|call|>) that tool # parsers need. Also handles multi-byte chars (emoji, CJK) safely. @@ -670,6 +684,29 @@ def _make_generator(): raise for response in itertools.chain([first_response], gen): + token_id = response.token if hasattr(response, "token") else 0 + + # Token-level routing (if router available) + channel = None + if self._output_router: + try: + event = self._output_router.feed(token_id) + except Exception as router_err: + logger.warning("OutputRouter.feed failed (%s), disabling", router_err) + self._output_router = None + event = None + if self._output_router and event is None: + # Control token — suppress entirely, don't count + continue + if event: + new_text = event.text + channel = event.channel.name.lower() + else: + new_text = decoder.add_token(token_id) + else: + new_text = decoder.add_token(token_id) + + # Count only visible (non-suppressed) tokens token_count += 1 if token_count == 1: t_first_token = _time.perf_counter() @@ -680,8 +717,7 @@ def _make_generator(): f"(prompt={len(full_token_ids)} tokens, " f"prefilled={len(prompt_to_send)} tokens)" ) - token_id = response.token if hasattr(response, "token") else 0 - new_text = decoder.add_token(token_id) + accumulated_text += new_text # Check for stop sequences — truncate at the stop point @@ -693,7 +729,6 @@ def _make_generator(): idx = accumulated_text.find(stop_seq) if idx != -1: should_stop = True - # Truncate new_text so accumulated ends just before the stop seq stop_truncate_text = new_text[: len(new_text) - (len(accumulated_text) - idx)] accumulated_text = accumulated_text[:idx] break @@ -710,10 +745,6 @@ def _make_generator(): finish_reason = getattr(response, "finish_reason", "stop") else: finish_reason = "length" - # Save cache BEFORE yielding the finished chunk. - # The caller may break/abandon this generator after - # receiving the finished chunk, so code after yield - # would never execute. self._save_cache_snapshot(full_token_ids) cache_saved = True @@ -724,6 +755,7 @@ def _make_generator(): finish_reason=finish_reason, logprobs=getattr(response, "logprobs", None), prompt_tokens=len(full_token_ids), + channel=channel, ) if finished: diff --git a/vllm_mlx/output_router.py b/vllm_mlx/output_router.py new file mode 100644 index 0000000..9accb1e --- /dev/null +++ b/vllm_mlx/output_router.py @@ -0,0 +1,260 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Token-level output router for LLM generation. + +Routes model output tokens into semantic channels (thinking, content, tool_calls) +based on special token IDs read from the tokenizer. No regex, no text matching. + +Architecture: + 1. Read special token IDs from tokenizer vocabulary (config-driven) + 2. As tokens stream in, a state machine routes each to the correct channel + 3. Text is decoded only AFTER routing, so partial-token issues are impossible + +Usage: + router = OutputRouter.from_tokenizer(tokenizer) + for token_id in generation: + event = router.feed(token_id) + if event.channel == "content": + yield event.text + elif event.channel == "reasoning": + yield_reasoning(event.text) + elif event.channel == "tool_call": + accumulate_tool_call(event.text) + +Designed to replace the fragile regex-based strip_special_tokens + +reasoning_parser + tool_call_parser chain with a single unified router. + +Currently implements Gemma 4 format. Other models can be added by defining +their token mappings in MODEL_TOKEN_MAPS. +""" + +import logging +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Any + +logger = logging.getLogger(__name__) + + +class Channel(Enum): + """Output channel for a token.""" + CONTENT = auto() + REASONING = auto() + TOOL_CALL = auto() + CONTROL = auto() # special tokens that should be suppressed + + +@dataclass +class RouterEvent: + """A single routed token.""" + channel: Channel + token_id: int + text: str # decoded text for this token + + +@dataclass +class TokenMap: + """Special token ID mappings for a model family.""" + # Channel control (Gemma 4 style) + channel_start: int | None = None # <|channel> = 100 + channel_end: int | None = None # = 101 + thought_word: int | None = None # "thought" = 45518 + content_word: int | None = None # "content" = 3955 + final_word: int | None = None # "final" = 10218 + + # Turn control + turn_start: int | None = None # <|turn> = 105 + turn_end: int | None = None # = 106 + + # Tool call (Gemma 4 style) + tool_call_start: int | None = None # <|tool_call> = 48 + tool_call_end: int | None = None # = 49 + tool_quote: int | None = None # <|"|> = 52 + tool_start: int | None = None # <|tool> = 46 + tool_end: int | None = None # = 47 + tool_response_start: int | None = None # <|tool_response> = 50 + tool_response_end: int | None = None # = 51 + + # Think tags (Qwen/DeepSeek style) — for future migration + think_start: int | None = None # token ID + think_end: int | None = None # token ID + + # Standard control + bos: int | None = None + eos: int | None = None + pad: int | None = None + + + +class RouterState(Enum): + """State machine states.""" + INIT = auto() + THINKING = auto() # inside thought channel + CONTENT = auto() # inside content/final channel + TOOL_CALL = auto() # inside tool call + AWAITING_CHANNEL_TYPE = auto() # saw <|channel>, waiting for thought/content/final + + +class OutputRouter: + """ + Token-level output router with state machine. + + Processes token IDs one at a time, routing each to the appropriate + semantic channel without any text-level regex matching. + """ + + def __init__(self, token_map: TokenMap, tokenizer: Any): + self.map = token_map + self.tokenizer = tokenizer + self.state = RouterState.INIT + self._tool_tokens: list[int] = [] # accumulated tool call token IDs + + def reset(self): + """Reset state for a new request.""" + self.state = RouterState.INIT + self._tool_tokens = [] + + def feed(self, token_id: int) -> RouterEvent | None: + """ + Feed a single token and get the routing decision. + + Returns RouterEvent with the channel assignment, or None if the + token should be suppressed entirely (control tokens). + """ + m = self.map + + # === Control tokens: always suppress (no decode needed) === + if token_id in (m.bos, m.eos, m.pad): + return None + if token_id == m.turn_start or token_id == m.turn_end: + return None + # Suppress tool-related markers that may appear without proper nesting + if token_id in (m.tool_response_start, m.tool_response_end, + m.tool_start, m.tool_end): + return None + + # === Channel start: transition to AWAITING_CHANNEL_TYPE === + if token_id == m.channel_start: + self.state = RouterState.AWAITING_CHANNEL_TYPE + return None # suppress <|channel> + + # === Channel type word: set state based on which channel === + if self.state == RouterState.AWAITING_CHANNEL_TYPE: + if token_id == m.thought_word: + self.state = RouterState.THINKING + return None # suppress "thought" + elif token_id == m.content_word or token_id == m.final_word: + self.state = RouterState.CONTENT + return None # suppress "content" / "final" + else: + # Unknown channel type — treat as content + self.state = RouterState.CONTENT + text = self.tokenizer.decode([token_id]) + return RouterEvent(Channel.CONTENT, token_id, text) + + # === Channel end: transition back === + if token_id == m.channel_end: + if self.state == RouterState.THINKING: + self.state = RouterState.CONTENT + return None # suppress + + # === Orphan tool call end (no matching start): suppress === + if token_id == m.tool_call_end and self.state != RouterState.TOOL_CALL: + return None + + # === Tool call start === + if token_id == m.tool_call_start: + self.state = RouterState.TOOL_CALL + self._tool_tokens = [token_id] + return None + + # === Inside tool call: accumulate (no per-token decode) === + if self.state == RouterState.TOOL_CALL: + self._tool_tokens.append(token_id) + if token_id == m.tool_call_end: + full_text = self.tokenizer.decode(self._tool_tokens) + self.state = RouterState.CONTENT + self._tool_tokens = [] + return RouterEvent(Channel.TOOL_CALL, token_id, full_text) + return None + + # === Default: decode and route based on current state === + text = self.tokenizer.decode([token_id]) + if self.state == RouterState.THINKING: + return RouterEvent(Channel.REASONING, token_id, text) + else: + return RouterEvent(Channel.CONTENT, token_id, text) + + def feed_sequence(self, token_ids: list[int]) -> dict[str, str]: + """ + Feed a complete token sequence and return separated channels. + + Returns: + {"content": "...", "reasoning": "...", "tool_calls": [...]} + """ + content = "" + reasoning = "" + tool_calls = [] + + for tid in token_ids: + event = self.feed(tid) + if event is None: + continue + if event.channel == Channel.CONTENT: + content += event.text + elif event.channel == Channel.REASONING: + reasoning += event.text + elif event.channel == Channel.TOOL_CALL: + tool_calls.append(event.text) + + return { + "content": content.strip() or None, + "reasoning": reasoning.strip() or None, + "tool_calls": tool_calls or None, + } + + @classmethod + def from_tokenizer(cls, tokenizer: Any) -> "OutputRouter | None": + """ + Create an OutputRouter from a tokenizer by reading its vocabulary. + + Returns None if the tokenizer doesn't have the expected special tokens + (i.e., the model doesn't use a supported format). + """ + vocab = tokenizer.get_vocab() + + # Gemma 4 detection: look for <|channel> and <|tool_call> + if "<|channel>" in vocab and "<|tool_call>" in vocab: + token_map = TokenMap( + channel_start=vocab.get("<|channel>"), + channel_end=vocab.get(""), + thought_word=vocab.get("thought"), + content_word=vocab.get("content"), + final_word=vocab.get("final"), + turn_start=vocab.get("<|turn>"), + turn_end=vocab.get(""), + tool_call_start=vocab.get("<|tool_call>"), + tool_call_end=vocab.get(""), + tool_quote=vocab.get('<|"|>'), + tool_start=vocab.get("<|tool>"), + tool_end=vocab.get(""), + tool_response_start=vocab.get("<|tool_response>"), + tool_response_end=vocab.get(""), + bos=vocab.get(""), + eos=vocab.get(""), + pad=vocab.get(""), + ) + logger.info( + "[OutputRouter] Gemma 4 format detected: " + "channel=%d/%d, tool=%d/%d", + token_map.channel_start, token_map.channel_end, + token_map.tool_call_start, token_map.tool_call_end, + ) + return cls(token_map, tokenizer) + + # Qwen/DeepSeek detection: look for and + # TODO: implement when migrating existing parsers + # if "" in vocab and "" in vocab: + # ... + + return None # unsupported model format diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py index 069dcc9..e8dfb90 100644 --- a/vllm_mlx/reasoning/__init__.py +++ b/vllm_mlx/reasoning/__init__.py @@ -78,9 +78,11 @@ def _register_builtin_parsers(): from .deepseek_r1_parser import DeepSeekR1ReasoningParser from .gpt_oss_parser import GptOssReasoningParser from .harmony_parser import HarmonyReasoningParser + from .gemma4_parser import Gemma4ReasoningParser from .minimax_parser import MiniMaxReasoningParser from .qwen3_parser import Qwen3ReasoningParser + register_parser("gemma4", Gemma4ReasoningParser) register_parser("qwen3", Qwen3ReasoningParser) register_parser("deepseek_r1", DeepSeekR1ReasoningParser) register_parser("gpt_oss", GptOssReasoningParser) diff --git a/vllm_mlx/reasoning/gemma4_parser.py b/vllm_mlx/reasoning/gemma4_parser.py new file mode 100644 index 0000000..6253a55 --- /dev/null +++ b/vllm_mlx/reasoning/gemma4_parser.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Gemma 4 reasoning parser. + +Gemma 4 uses channel tokens for thinking: + <|channel>thought\n...reasoning... + <|channel>content\n...answer... + +The parser separates thinking from content by tracking the active channel. +""" + +import re + +from .base import DeltaMessage, ReasoningParser + +# Match full thought blocks in complete text +_THOUGHT_BLOCK = re.compile( + r"<\|channel>thought\n[\s\S]*?\s*", re.DOTALL +) +# Match content channel markers +_CONTENT_START = re.compile(r"<\|channel>(?:content|final)\n?") +_CHANNEL_END = re.compile(r"") +_TURN_END = re.compile(r"") + + +class Gemma4ReasoningParser(ReasoningParser): + """Parser for Gemma 4's channel-based thinking format.""" + + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + self._in_thought = False + self._in_content = False + self._saw_any_channel = False + + def reset_state(self): + super().reset_state() + self._in_thought = False + self._in_content = False + self._saw_any_channel = False + + def extract_reasoning(self, model_output: str) -> tuple[str | None, str | None]: + """Extract reasoning from complete output.""" + if not model_output: + return None, model_output + + # Extract thought blocks as reasoning + thought_blocks = _THOUGHT_BLOCK.findall(model_output) + if not thought_blocks: + # No thinking tags — all content + cleaned = _CONTENT_START.sub("", model_output) + cleaned = _CHANNEL_END.sub("", cleaned) + cleaned = _TURN_END.sub("", cleaned).strip() + return None, cleaned + + # Reasoning = thought block contents (strip markers) + reasoning = "" + for block in thought_blocks: + inner = block.replace("<|channel>thought\n", "").replace("", "").strip() + reasoning += inner + + # Content = everything after thought blocks, strip markers + content = _THOUGHT_BLOCK.sub("", model_output) + content = _CONTENT_START.sub("", content) + content = _CHANNEL_END.sub("", content) + content = _TURN_END.sub("", content).strip() + + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, previous_text: str, current_text: str, delta_text: str + ) -> DeltaMessage | None: + """Extract reasoning from streaming delta.""" + if not delta_text: + return None + + # Track channel state based on accumulated text + # Check if we just entered thought channel + if "<|channel>thought" in current_text and not self._in_content: + self._in_thought = True + self._saw_any_channel = True + + # Check if we just entered content channel + if "<|channel>content" in current_text or "<|channel>final" in current_text: + self._in_thought = False + self._in_content = True + + # Check if thought ended (first after thought start) + if self._in_thought and "" in current_text: + thought_starts = current_text.count("<|channel>thought") + channel_ends = current_text.count("") + if channel_ends >= thought_starts: + self._in_thought = False + # If no explicit content channel follows, switch to content mode + if "<|channel>content" not in current_text and "<|channel>final" not in current_text: + self._in_content = True + + # Filter out channel markers from delta + clean = delta_text + for marker in ["<|channel>", "", "<|turn>", "", + "thought\n", "content\n", "final\n"]: + clean = clean.replace(marker, "") + + if not clean: + return None # pure marker token, skip + + if self._in_thought: + return DeltaMessage(reasoning=clean) + elif self._in_content: + return DeltaMessage(content=clean) + elif not self._saw_any_channel: + # No channel tokens seen — plain content (no thinking) + return DeltaMessage(content=clean) + else: + # Between channels — treat as reasoning + return DeltaMessage(reasoning=clean) + + def finalize_streaming(self, accumulated_text: str) -> DeltaMessage | None: + """Handle end of stream — emit any remaining content.""" + return None diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 820e955..3e6eb2e 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -111,6 +111,7 @@ extract_json_from_response, extract_multimodal_content, is_mllm_model, # noqa: F401 + sanitize_output, strip_special_tokens, strip_thinking_tags, ) @@ -402,6 +403,18 @@ def configure_cors(origins: list[str]) -> None: security = HTTPBearer(auto_error=False) +@app.exception_handler(Exception) +async def _global_exception_handler(request: Request, exc: Exception): + """Catch unhandled exceptions so they return JSON 500 instead of killing + the connection. This keeps the server alive for subsequent requests.""" + logger.error("Unhandled exception on %s %s: %s", request.method, request.url.path, exc, exc_info=True) + from starlette.responses import JSONResponse + return JSONResponse( + status_code=500, + content={"error": {"message": str(exc), "type": type(exc).__name__}}, + ) + + class RateLimiter: """Simple in-memory rate limiter using sliding window.""" @@ -2075,7 +2088,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re if request.stream: # Validate chat template eagerly so template errors return 400, # not a broken SSE stream. build_prompt is cheap (no generation). - if hasattr(engine, "build_prompt"): + if hasattr(engine, "build_prompt") and not engine.is_mllm: try: engine.build_prompt( messages, @@ -2910,6 +2923,11 @@ def _fast_sse_chunk( no finish_reason), this avoids constructing Pydantic objects and calling model_dump_json(). """ + # Final sanitizer: last-mile defense against markup leakage + if field == "content": + text = sanitize_output(text) + if not text: + return "" # suppressed entirely escaped = json.dumps(text) # Handles escaping return f'{_sse_prefix}"{field}":{escaped}{_sse_suffix}' @@ -2984,7 +3002,110 @@ def _fast_sse_chunk( if hasattr(output, "completion_tokens") and output.completion_tokens: completion_tokens = output.completion_tokens - # Use reasoning parser if enabled + # Token-level routing (OutputRouter): if channel is set, use it directly + # instead of text-based reasoning parser. This is the new architecture + # for models like Gemma 4 that have token-level channel markers. + if output.channel and delta_text: + if output.channel == "reasoning": + content = None + reasoning = delta_text + elif output.channel == "tool_call": + # Tool call handled by tool parser below + content = delta_text + reasoning = None + else: # "content" + content = delta_text + reasoning = None + + # Tool call parsing on content + if tool_parser and content: + if not tool_markup_possible and "<" not in content and "[" not in content: + tool_accumulated_text += content + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += content + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, content + ) + if tool_result is None: + continue + if "tool_calls" in tool_result: + tool_calls_detected = True + chunk = ChatCompletionChunk( + id=response_id, + model=_model_name or request.model, + choices=[ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=tool_result["tool_calls"] + ), + finish_reason="tool_calls" if output.finished else None, + )], + usage=get_usage(output) if output.finished else None, + ) + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + continue + content = tool_result.get("content", "") + + if tool_calls_detected: + if output.finished: + chunk = ChatCompletionChunk( + id=response_id, + model=_model_name or request.model, + choices=[ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta(), + finish_reason="tool_calls", + )], + usage=get_usage(output), + ) + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + continue + + # Filter special tokens + if content: + content = strip_special_tokens(content) + if reasoning: + reasoning = strip_special_tokens(reasoning) + + # Skip empty deltas + finish_reason = ( + "tool_calls" if (output.finished and tool_calls_detected) + else (output.finish_reason if output.finished else None) + ) + if not content and not reasoning and not finish_reason: + continue + + # Fast SSE path + if not finish_reason and not want_logprobs and not output.finished: + if content and not reasoning: + _sse = _fast_sse_chunk(content, "content") + if _sse: + yield _sse + continue + if reasoning and not content: + yield _fast_sse_chunk(reasoning, "reasoning") + continue + + chunk = ChatCompletionChunk( + id=response_id, + model=_model_name or request.model, + choices=[ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + content=sanitize_output(content) if content else None, + reasoning=reasoning if reasoning else None, + ), + finish_reason=finish_reason, + )], + usage=get_usage(output) if output.finished else None, + ) + sse_line = f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"[SSE] {sse_line.strip()[:300]}") + yield sse_line + continue + + # Legacy text-based reasoning parser (for non-router models) if _reasoning_parser and delta_text: previous_text = accumulated_text accumulated_text += delta_text @@ -3019,7 +3140,7 @@ def _fast_sse_chunk( # Tool call parsing on content portion if tool_parser and content: - if not tool_markup_possible and "<" not in content: + if not tool_markup_possible and "<" not in content and "[" not in content: tool_accumulated_text += content else: if not tool_markup_possible: @@ -3105,7 +3226,9 @@ def _fast_sse_chunk( and not output.finished ): if content and not reasoning: - yield _fast_sse_chunk(content, "content") + _sse = _fast_sse_chunk(content, "content") + if _sse: + yield _sse continue if reasoning and not content: yield _fast_sse_chunk(reasoning, "reasoning") @@ -3117,7 +3240,7 @@ def _fast_sse_chunk( choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( - content=content if content else None, + content=sanitize_output(content) if content else None, reasoning=reasoning if reasoning else None, ), finish_reason=finish_reason, @@ -3148,7 +3271,7 @@ def _fast_sse_chunk( # Fast path: skip full parsing until '<' is seen in the stream, # which could start tool markup (e.g. ). This avoids # per-token string scanning on the growing accumulated text. - if not tool_markup_possible and "<" not in delta_text: + if not tool_markup_possible and "<" not in delta_text and "[" not in delta_text: tool_accumulated_text += delta_text # No tool markup yet, fall through to normal chunk emission else: @@ -3239,7 +3362,9 @@ def _fast_sse_chunk( and not want_logprobs and not output.finished ): - yield _fast_sse_chunk(content, "content") + _sse = _fast_sse_chunk(content, "content") + if _sse: + yield _sse continue chunk = ChatCompletionChunk( @@ -3248,7 +3373,7 @@ def _fast_sse_chunk( choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( - content=content if content else None + content=sanitize_output(content) if content else None ), finish_reason=finish_reason, logprobs=_build_chunk_logprobs(output), diff --git a/vllm_mlx/tool_parsers/__init__.py b/vllm_mlx/tool_parsers/__init__.py index d2c83e6..4aa48ab 100644 --- a/vllm_mlx/tool_parsers/__init__.py +++ b/vllm_mlx/tool_parsers/__init__.py @@ -51,6 +51,7 @@ from .deepseek_tool_parser import DeepSeekToolParser from .deepseekv31_tool_parser import DeepSeekV31ToolParser from .functionary_tool_parser import FunctionaryToolParser +from .gemma4_tool_parser import Gemma4ToolParser from .glm47_tool_parser import Glm47ToolParser from .granite_tool_parser import GraniteToolParser from .harmony_tool_parser import HarmonyToolParser @@ -82,6 +83,7 @@ "NemotronToolParser", "xLAMToolParser", "FunctionaryToolParser", + "Gemma4ToolParser", "Glm47ToolParser", "HarmonyToolParser", "MiniMaxToolParser", diff --git a/vllm_mlx/tool_parsers/gemma4_tool_parser.py b/vllm_mlx/tool_parsers/gemma4_tool_parser.py new file mode 100644 index 0000000..3293997 --- /dev/null +++ b/vllm_mlx/tool_parsers/gemma4_tool_parser.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Gemma 4 tool call parser for vllm-mlx. + +Handles Gemma 4's native tool calling format: + <|tool_call>call:FUNC_NAME{key:<|"|>value<|"|>,...} +""" + +import json +import re +import uuid +from collections.abc import Sequence +from typing import Any + +from .abstract_tool_parser import ( + ExtractedToolCallInformation, + ToolParser, + ToolParserManager, +) + +# Match: <|tool_call>call:name{...} +GEMMA4_TOOL_PATTERN = re.compile( + r"<\|tool_call>call:(\w+)\{(.*?)\}", re.DOTALL +) + +# Match key-value pairs inside {}: key:<|"|>value<|"|> +GEMMA4_KV_PATTERN = re.compile( + r'(\w+):<\|"\|>(.*?)<\|"\|>', re.DOTALL +) + + +def _parse_gemma4_args(args_str: str) -> dict[str, Any]: + """Parse Gemma 4's key:<|"|>value<|"|> format into a dict.""" + result = {} + for match in GEMMA4_KV_PATTERN.finditer(args_str): + key = match.group(1) + value = match.group(2) + # Try to parse as JSON value (number, bool, etc.) + try: + result[key] = json.loads(value) + except (json.JSONDecodeError, ValueError): + result[key] = value + return result + + +def _generate_tool_id() -> str: + return f"call_{uuid.uuid4().hex[:8]}" + + +@ToolParserManager.register_module(["gemma4", "gemma_4"]) +class Gemma4ToolParser(ToolParser): + """ + Tool call parser for Gemma 4 models. + + Format: <|tool_call>call:func_name{key:<|"|>value<|"|>} + """ + + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + self._emitted_tool_count = 0 + + def reset(self): + """Reset state for a new request.""" + super().reset() + self._emitted_tool_count = 0 + + def extract_tool_calls( + self, model_output: str, request: Any = None + ) -> ExtractedToolCallInformation: + matches = list(GEMMA4_TOOL_PATTERN.finditer(model_output)) + + if not matches: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + tool_calls = [] + for match in matches: + func_name = match.group(1) + args_str = match.group(2) + args = _parse_gemma4_args(args_str) + + tool_calls.append( + { + "id": _generate_tool_id(), + "name": func_name, + "arguments": json.dumps(args), + } + ) + + # Content is everything outside the tool calls + content = GEMMA4_TOOL_PATTERN.sub("", model_output).strip() or None + + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence = (), + current_token_ids: Sequence = (), + delta_token_ids: Sequence = (), + ) -> dict | None: + # Check if we're inside a tool call + if "<|tool_call>" in current_text: + # Count completed tool calls so far + completed = current_text.count("") + open_count = current_text.count("<|tool_call>") + + # Still accumulating an incomplete tool call + if completed < open_count: + return None # suppress output while inside tool markup + + # Only emit newly completed tool calls (dedup) + if completed <= self._emitted_tool_count: + return None + + result = self.extract_tool_calls(current_text) + if result.tools_called: + # Only emit tool calls we haven't sent yet + new_calls = result.tool_calls[self._emitted_tool_count:] + self._emitted_tool_count = len(result.tool_calls) + + if new_calls: + return { + "tool_calls": [ + { + "index": self._emitted_tool_count + - len(new_calls) + + i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(new_calls) + ] + } + + # Text-format tool call recovery: catch [Calling tool: name({...})] + # Models degrade to this format after multiple tool rounds at low quant + from .abstract_tool_parser import TEXT_TOOL_CALL_FN_PATTERN, TEXT_TOOL_CALL_ANY + + if TEXT_TOOL_CALL_ANY.search(current_text): + # Check if we have a complete text tool call + matches = list(TEXT_TOOL_CALL_FN_PATTERN.finditer(current_text)) + new_matches = matches[self._emitted_tool_count:] + if new_matches: + self._emitted_tool_count = len(matches) + return { + "tool_calls": [ + { + "index": self._emitted_tool_count - len(new_matches) + i, + "id": _generate_tool_id(), + "type": "function", + "function": { + "name": m.group(1), + "arguments": m.group(2), + }, + } + for i, m in enumerate(new_matches) + ] + } + # Already emitted or partial — suppress + return None + + # No tool call markup — pass through as content + return {"content": delta_text} diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index 4cf433a..46dcc26 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -50,6 +50,15 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): ) return _load_with_tokenizer_fallback(model_name) + # Gemma 4: mlx-lm doesn't support it yet, load via our text-only wrapper + from ..models.gemma4_text import is_gemma4_model + + if is_gemma4_model(model_name): + from ..models.gemma4_text import load_gemma4_text + + logger.info("Gemma 4 detected — loading as text-only via LLM path") + return load_gemma4_text(model_name, tokenizer_config) + try: model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) return model, tokenizer