diff --git a/docs/reference/models.md b/docs/reference/models.md index a45550e4d..d378de003 100644 --- a/docs/reference/models.md +++ b/docs/reference/models.md @@ -12,7 +12,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun | Mistral / Devstral | 7B, Mixtral 8x7B | 4-bit, 8-bit | | Qwen2/Qwen3 | 0.5B to 72B | Various | | DeepSeek V3, R1 | 7B, 33B, 67B | 4-bit | -| Gemma 2, 3 | 2B, 9B, 27B | 4-bit | +| Gemma 2, 3, 4 | 2B, 9B, 27B | 4-bit | | GLM-4.7 | Flash, Base | 4-bit, 8-bit | | Kimi K2 | Various | 4-bit | | Phi-3 | 3.8B, 14B | 4-bit | @@ -35,6 +35,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun | **Qwen-VL** | `Qwen3-VL-4B-Instruct-3bit`, `Qwen3-VL-8B-Instruct-4bit`, `Qwen2-VL-2B/7B-Instruct-4bit` | | **LLaVA** | `llava-1.5-7b-4bit`, `llava-v1.6-mistral-7b-4bit`, `llava-llama-3-8b-v1_1-4bit` | | **Idefics** | `Idefics3-8B-Llama3-4bit`, `idefics2-8b-4bit` | +| **Gemma 4** | `gemma-4-e2b-it-mxfp4` (vision + audio) | | **PaliGemma** | `paligemma2-3b-mix-224-4bit`, `paligemma-3b-mix-224-8bit` | | **Pixtral** | `pixtral-12b-4bit`, `pixtral-12b-8bit` | | **Molmo** | `Molmo-7B-D-0924-4bit`, `Molmo-7B-D-0924-8bit` | @@ -72,7 +73,7 @@ vllm-mlx auto-detects multimodal models by name patterns: - Contains "VL", "Vision", "vision" - Contains "llava", "idefics", "paligemma" - Contains "pixtral", "molmo", "deepseek-vl" -- Contains "MedGemma", "Gemma-3" (vision variants) +- Contains "MedGemma", "Gemma-3", "Gemma-4" (multimodal variants) ## Using Models diff --git a/pyproject.toml b/pyproject.toml index 6ccc45282..87b1974df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ classifiers = [ dependencies = [ "mlx>=0.29.0", "mlx-lm>=0.31.0", # 0.31+ required for ArraysCache native batching (hybrid models) - "mlx-vlm>=0.1.0", # VLM support + "mlx-vlm>=0.4.3", # 0.4.3+ required for Gemma 4 support "transformers>=5.0.0", # mlx-lm 0.30.5+ requires transformers 5.0 (rc3 bug fixed in stable) "tokenizers>=0.19.0", "huggingface-hub>=0.23.0", diff --git a/tests/test_mllm_continuous_batching.py b/tests/test_mllm_continuous_batching.py index 28b26b219..3ff525dc6 100644 --- a/tests/test_mllm_continuous_batching.py +++ b/tests/test_mllm_continuous_batching.py @@ -129,6 +129,53 @@ def test_finished_response(self): assert resp.finish_reason == "stop" + def test_error_response_skips_decoding(self): + """Error responses must not decode token=0 as content.""" + from unittest.mock import MagicMock, PropertyMock + + from vllm_mlx.mllm_batch_generator import MLLMBatchResponse + from vllm_mlx.mllm_scheduler import MLLMScheduler + from vllm_mlx.request import RequestStatus + + # Build a minimal scheduler with mocked internals + scheduler = MLLMScheduler.__new__(MLLMScheduler) + scheduler._detokenizer_pool = {} + scheduler.uid_to_request_id = {0: "req-err"} + scheduler.total_completion_tokens = 0 + scheduler.num_requests_processed = 0 + + mock_tokenizer = MagicMock() + mock_tokenizer.decode.return_value = "" + mock_processor = MagicMock() + mock_processor.tokenizer = mock_tokenizer + scheduler.processor = mock_processor + + # Create a running request + mock_request = MagicMock() + mock_request.request_id = "req-err" + mock_request.output_tokens = [] + mock_request.num_output_tokens = 0 + mock_request.num_prompt_tokens = 10 + mock_request.status = RequestStatus.RUNNING + scheduler.running = {"req-err": mock_request} + + error_resp = MLLMBatchResponse( + uid=0, + request_id="req-err", + token=0, + logprobs=mx.array([0.0]), + finish_reason="error", + ) + + outputs, finished = scheduler._process_batch_responses([error_resp]) + + assert "req-err" in finished + assert mock_request.status == RequestStatus.FINISHED_ABORTED + # token=0 should not have been decoded through a detokenizer + assert "req-err" not in scheduler._detokenizer_pool + assert len(outputs) == 1 + assert outputs[0].new_text == "" + class TestMLLMBatch: """Tests for MLLMBatch class.""" diff --git a/tests/test_reasoning_parser.py b/tests/test_reasoning_parser.py index e2d0184e7..4bcb5ab3f 100644 --- a/tests/test_reasoning_parser.py +++ b/tests/test_reasoning_parser.py @@ -6,6 +6,7 @@ - Parser registry (registration, lookup, listing) - Qwen3 parser (non-streaming and streaming) - DeepSeek-R1 parser (non-streaming and streaming) +- Gemma 4 parser (channel protocol, streaming, channel name stripping) - Edge cases (no tags, partial tags, etc.) """ @@ -28,6 +29,7 @@ def test_list_parsers_includes_builtin(self): parsers = list_parsers() assert "qwen3" in parsers assert "deepseek_r1" in parsers + assert "gemma4" in parsers def test_get_parser_qwen3(self): """Should be able to get Qwen3 parser.""" @@ -920,3 +922,267 @@ def test_constrain_tokens_stripped(self, parser): reasoning, content = parser.extract_reasoning(output) assert "<|constrain|>" not in (content or "") assert "<|channel|>" not in (content or "") + + +class TestGemma4Parser: + """Tests for the Gemma 4 reasoning parser (channel-based protocol).""" + + @pytest.fixture + def parser(self): + """Create a fresh Gemma 4 parser for each test.""" + return get_parser("gemma4")() + + # --- Non-streaming tests --- + + def test_extract_standard_format(self, parser): + """Standard format: <|channel>thought...response.""" + output = ( + "<|channel>thought\nLet me think step by step.\nThe answer is 42." + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Let me think step by step." + assert content == "The answer is 42." + + def test_extract_alternative_format(self, parser): + """Alternative format: <|channel>thought...<|channel>response...""" + output = "<|channel>thought\nAnalyzing the problem.\n<|channel>response\nThe result is 7." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Analyzing the problem." + assert content == "The result is 7." + + def test_extract_strips_thought_prefix(self, parser): + """Channel name 'thought' should be stripped from reasoning.""" + output = "<|channel>thought\nActual reasoning hereContent" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Actual reasoning here" + assert "thought" not in reasoning + + def test_extract_no_tags_pure_content(self, parser): + """No channel tags at all should return pure content.""" + output = "Just a regular response without thinking." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == output + + def test_extract_only_start_tag(self, parser): + """Only start tag means incomplete reasoning (no content yet).""" + output = "<|channel>thought\nStill thinking..." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Still thinking..." + assert content is None + + def test_extract_only_end_tag(self, parser): + """Only end tag (think injected in prompt).""" + output = "thought\nImplicit reasoningThe answer" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Implicit reasoning" + assert content == "The answer" + + def test_extract_empty_reasoning(self, parser): + """Empty reasoning should return None.""" + output = "<|channel>thought\nOnly content here." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == "Only content here." + + def test_extract_multiline_reasoning(self, parser): + """Should preserve multiline reasoning content.""" + output = ( + "<|channel>thought\n" + "Step 1: Understand the question.\n" + "Step 2: Analyze the data.\n" + "Step 3: Form conclusion.\n" + "The conclusion is clear." + ) + reasoning, content = parser.extract_reasoning(output) + assert "Step 1" in reasoning + assert "Step 2" in reasoning + assert "Step 3" in reasoning + assert content == "The conclusion is clear." + + def test_extract_unicode_reasoning(self, parser): + """Should handle Unicode in reasoning.""" + output = "<|channel>thought\n日本語テスト 🤔\n答えは42" + reasoning, content = parser.extract_reasoning(output) + assert "日本語テスト" in reasoning + assert "🤔" in reasoning + assert "42" in content + + def test_registry_includes_gemma4(self): + """gemma4 should be in the parser registry.""" + assert "gemma4" in list_parsers() + + # --- Streaming tests --- + + def test_streaming_no_tags_plain_content(self, parser): + """Streaming without any channel tags should return content.""" + parser.reset_state() + result = parser.extract_reasoning_streaming("", "Hello", "Hello") + assert result is not None + assert result.content == "Hello" + assert result.reasoning is None + + def test_streaming_standard_format(self, parser): + """Test streaming through <|channel>thought...content flow.""" + parser.reset_state() + + tokens = [ + "<|channel>", + "thought", + "\n", + "Let me ", + "think.", + "", + "The ", + "answer.", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + # "thought\n" prefix should be stripped + assert "thought" not in full_reasoning or "thought" in "Let me think." + assert "Let me think." in full_reasoning + assert "The answer." in full_content + + def test_streaming_alternative_format(self, parser): + """Test streaming with <|channel>response transition.""" + parser.reset_state() + + tokens = [ + "<|channel>", + "thought", + "\n", + "Analyzing.", + "<|channel>response", + "\n", + "Result: ", + "42", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_content = "".join(content_parts) + assert "Result: 42" in full_content + + def test_streaming_suppresses_channel_names(self, parser): + """Channel names 'thought' and 'response' should not appear in output.""" + parser.reset_state() + + # Simulate realistic Gemma 4 output + tokens = [ + "<|channel>", + "thought", + "\n", + "Real ", + "reasoning.", + "", + "Real ", + "content.", + ] + + accumulated = "" + all_output = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + all_output.append(("r", result.reasoning)) + if result.content: + all_output.append(("c", result.content)) + + # Verify no raw "thought" token leaked as reasoning + reasoning_text = "".join(t for tag, t in all_output if tag == "r") + content_text = "".join(t for tag, t in all_output if tag == "c") + + assert "Real reasoning." in reasoning_text + assert "Real content." in content_text + + def test_streaming_token_by_token(self, parser): + """Test character-by-character streaming (worst case).""" + parser.reset_state() + + output = "<|channel>thought\nStep 1: Think\nStep 2: Analyze\nFinal answer: 42." + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for char in output: + prev = accumulated + accumulated += char + result = parser.extract_reasoning_streaming(prev, accumulated, char) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + assert "Step 1: Think" in full_reasoning + assert "Step 2: Analyze" in full_reasoning + assert "Final answer: 42." in full_content + + def test_streaming_long_thinking_no_end_tag(self, parser): + """When model generates long thinking without end tag, all goes to reasoning.""" + parser.reset_state() + + # Simulate model that hits max_tokens before + tokens = [ + "<|channel>", + "thought", + "\n", + "This is a very long ", + "reasoning process ", + "that continues ", + "without ending.", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + assert "very long reasoning process" in full_reasoning + assert len(content_parts) == 0 diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 9fdbfef13..6dea67150 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -339,6 +339,8 @@ def flush(self) -> list[tuple[str, str]]: "PaliGemma", # PaliGemma "gemma-3", "gemma3", # Gemma 3 (multimodal) + "gemma-4", + "gemma4", # Gemma 4 (multimodal: vision + audio) "medgemma", "MedGemma", # MedGemma (medical multimodal with SigLIP vision encoder) "pixtral", diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 3ac52b4b0..e47cd4fc6 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -89,23 +89,23 @@ class MLLMModelWrapper: but MLLM models return LanguageModelOutput objects. This wrapper extracts the logits from the output. - Also handles Gemma 3's required pixel_values argument by injecting None + Also handles Gemma 3/4's required pixel_values argument by injecting None for text-only requests. """ def __init__(self, model): self._model = model - # Detect if this is a Gemma 3 model (requires pixel_values as positional arg) - self._is_gemma3 = ( - hasattr(model, "model_type") - and "gemma3" in str(getattr(model, "model_type", "")).lower() + # Detect if this is a Gemma 3/4 model (requires pixel_values as positional arg) + model_type = str(getattr(model, "model_type", "")).lower() + self._is_gemma_multimodal = hasattr(model, "model_type") and ( + "gemma3" in model_type or "gemma4" in model_type ) def __call__(self, *args, **kwargs): """Call the model and extract logits from LanguageModelOutput.""" - # Gemma 3 requires pixel_values as a positional argument, unlike Qwen + # Gemma 3/4 requires pixel_values as a positional argument, unlike Qwen # which makes it optional. Inject pixel_values=None for text-only requests. - if self._is_gemma3 and "pixel_values" not in kwargs: + if self._is_gemma_multimodal and "pixel_values" not in kwargs: kwargs["pixel_values"] = None output = self._model(*args, **kwargs) diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index ee8d8da7b..a8845c5e8 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -324,6 +324,11 @@ def __init__( "MLLMBatchGenerator: Model does not have language_model, using model directly" ) + # Patch attention for BatchKVCache compatibility + from .patches.gemma4_mllm import patch_gemma4_attention_for_batching + + patch_gemma4_attention_for_batching() + self.max_tokens = max_tokens self.stop_tokens = stop_tokens or set() self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) @@ -340,6 +345,9 @@ def __init__( # Statistics self._stats = MLLMBatchStats() + # Error responses for requests that failed during preprocessing + self._pending_error_responses: List[MLLMBatchResponse] = [] + # Vision embedding cache for repeated images self.vision_cache = VisionEmbeddingCache( max_pixel_entries=vision_cache_size, @@ -666,16 +674,37 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: # KVCache.merge() creates a BatchKVCache with proper left-padding # alignment, so all requests share a single batched cache for # subsequent generation steps. - from mlx_lm.models.cache import KVCache + from mlx_lm.models.cache import KVCache, RotatingKVCache sample_cache = per_request_caches[0][0] - if not isinstance(sample_cache, KVCache): + if not isinstance(sample_cache, (KVCache, RotatingKVCache)): raise ValueError( - f"MLLM continuous batching requires standard KVCache but got " - f"{type(sample_cache).__name__}. Disable --kv-cache-quantization " - f"when using multimodal models with --continuous-batching." + f"MLLM continuous batching requires KVCache or RotatingKVCache " + f"but got {type(sample_cache).__name__}. Disable " + f"--kv-cache-quantization when using multimodal models with " + f"--continuous-batching." ) + # Fix: RotatingKVCache._update_concat does NOT trim on first call — + # if prompt length > max_size, the buffer grows beyond max_size. + # BatchRotatingKVCache.merge() then hits a shape mismatch when + # copying via _temporal_order (full buffer) into a max_size slice. + # Trim buffer to max_size before merging. + for rc in per_request_caches: + for layer_cache in rc: + if isinstance(layer_cache, RotatingKVCache): + if layer_cache.keys is not None: + buf_len = layer_cache.keys.shape[2] + if buf_len > layer_cache.max_size: + trim_size = buf_len - layer_cache.max_size + layer_cache.keys = layer_cache._trim( + trim_size, layer_cache.keys + ) + layer_cache.values = layer_cache._trim( + trim_size, layer_cache.values + ) + layer_cache._idx = layer_cache.max_size + try: batch_cache = [ per_request_caches[0][layer_idx].merge( @@ -764,15 +793,40 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None return [] - new_batch = self._process_prompts(requests) - self.unprocessed_requests = self.unprocessed_requests[len(requests) :] - self.active_batch = new_batch - prompt_processing = True + try: + new_batch = self._process_prompts(requests) + self.unprocessed_requests = self.unprocessed_requests[len(requests) :] + self.active_batch = new_batch + prompt_processing = True + except Exception as e: + logger.error( + f"Failed to process batch of {len(requests)} prompts: " + f"{type(e).__name__}: {e}", + exc_info=True, + ) + # Remove failed requests to avoid infinite retry loop + self.unprocessed_requests = self.unprocessed_requests[len(requests) :] + for req in requests: + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + + # Collect any pending error responses (from failed preprocessing) + error_responses = [] + if self._pending_error_responses: + error_responses = list(self._pending_error_responses) + self._pending_error_responses.clear() # Generate next token for active batch batch = self.active_batch if batch is None: - return [] + return error_responses y, logprobs = batch.y, batch.logprobs batch.y, batch.logprobs = self._step(y[:, None], batch.cache) @@ -841,7 +895,7 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None self._stats.generation_tokens += len(responses) - return responses + return error_responses + responses def next(self) -> List[MLLMBatchResponse]: """ diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 555b230f2..945992045 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -219,7 +219,7 @@ def __init__( self.total_completion_tokens = 0 def _get_stop_tokens(self) -> Set[int]: - """Get stop token IDs from tokenizer.""" + """Get stop token IDs from tokenizer and generation_config.json.""" stop_tokens = set() tokenizer = ( self.processor.tokenizer @@ -239,6 +239,25 @@ def _get_stop_tokens(self) -> Set[int]: else: stop_tokens.add(tokenizer.eos_token_ids) + # Also read generation_config.json which may have additional EOS tokens + # (e.g., Gemma 4 has =106, <|tool_response>=50 as EOS) + model_path = getattr(tokenizer, "name_or_path", None) + if model_path: + import json + from pathlib import Path + + gc_path = Path(model_path) / "generation_config.json" + if gc_path.exists(): + try: + gc = json.loads(gc_path.read_text()) + gc_eos = gc.get("eos_token_id") + if isinstance(gc_eos, list): + stop_tokens.update(gc_eos) + elif gc_eos is not None: + stop_tokens.add(gc_eos) + except Exception: + pass + return stop_tokens def _ensure_batch_generator(self) -> None: @@ -458,8 +477,8 @@ def _process_batch_responses( request.num_output_tokens = len(request.output_tokens) # Decode the new token using streaming detokenizer (UTF-8 safe). - # Skip stop tokens — they are not content. - if response.finish_reason == "stop": + # Skip stop tokens and error placeholders — they are not content. + if response.finish_reason in ("stop", "error"): new_text = "" else: if request_id not in self._detokenizer_pool: @@ -489,6 +508,8 @@ def _process_batch_responses( request.status = RequestStatus.FINISHED_STOPPED elif response.finish_reason == "length": request.status = RequestStatus.FINISHED_LENGTH_CAPPED + elif response.finish_reason == "error": + request.status = RequestStatus.FINISHED_ABORTED output.finished = True output.finish_reason = response.finish_reason diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index fcf3537f4..a6c67226e 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -2091,6 +2091,8 @@ def is_mllm_model(model_name: str) -> bool: "PaliGemma", "gemma-3", "gemma3", # Gemma 3 (multimodal) + "gemma-4", + "gemma4", # Gemma 4 (multimodal: vision + audio) "medgemma", "MedGemma", # MedGemma (medical multimodal) "pixtral", diff --git a/vllm_mlx/multimodal_processor.py b/vllm_mlx/multimodal_processor.py index a5c861216..2905e9abb 100644 --- a/vllm_mlx/multimodal_processor.py +++ b/vllm_mlx/multimodal_processor.py @@ -147,7 +147,7 @@ def process( logger.warning(f"Failed to process video: {e}") # Determine add_special_tokens based on model type - if self.config and self.config.model_type in ["gemma3", "gemma3n"]: + if self.config and self.config.model_type in ["gemma3", "gemma3n", "gemma4"]: add_special_tokens = not hasattr(self.processor, "chat_template") # Prepare inputs using mlx_vlm diff --git a/vllm_mlx/patches/gemma4_mllm.py b/vllm_mlx/patches/gemma4_mllm.py new file mode 100644 index 000000000..dc041cf31 --- /dev/null +++ b/vllm_mlx/patches/gemma4_mllm.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime patch for mlx-vlm's Gemma 4 attention to support BatchKVCache. + +Gemma 4 Attention reads cache.offset into a local variable before calling +update_and_fetch, then uses the same variable later for RoPE on queries: + + offset = cache.offset # reference to mx.array([22]) + keys = self.rope(keys, offset=offset) + keys, values = cache.update_and_fetch(keys, values) + # ^^^ self.offset += 1 mutates the SAME mx.array in-place! + queries = self.rope(queries, offset=offset) # offset is now 23! + +For KVCache, cache.offset is a Python int (immutable), so the local copy +is unaffected. For BatchKVCache, cache.offset is an mx.array and +mx.array.__iadd__ is *in-place*, so the local reference is silently +mutated by update_and_fetch, giving queries the wrong RoPE position. + +This patch replaces Gemma4 Attention.__call__ with a version that +snapshots cache.offset as a defensive copy before any mutation can occur. +The mx.array copy preserves per-sequence offsets needed for correct RoPE +in continuous batching (unlike int conversion which would lose this info). +""" + +import logging +from typing import Any, Optional + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +def _snapshot_cache_offset(cache): + """Snapshot cache offset, making a defensive copy if it's an mx.array. + + BatchKVCache stores offset as mx.array (per-batch-item). + mx.array.__iadd__ is in-place, so update_and_fetch mutates the original. + We return a copy to preserve the pre-update value for RoPE on queries. + """ + if cache is None: + return 0 + off = cache.offset + if isinstance(off, int): + return off + if isinstance(off, mx.array): + return off + 0 # defensive copy — new array, same values + return off + + +def patch_gemma4_attention_for_batching() -> bool: + """Monkey-patch Gemma4 Attention.__call__ to snapshot offset before update. + + Returns True if patch was applied, False if mlx-vlm is not installed + or Gemma 4 module not available. + """ + try: + from mlx_vlm.models.gemma4.language import Attention as Gemma4Attention + from mlx_vlm.models.base import scaled_dot_product_attention + except ImportError: + logger.debug("[Gemma4 patch] mlx-vlm Gemma4 module not available") + return False + + if getattr(Gemma4Attention, "_batch_patched", False): + logger.debug("[Gemma4 patch] Already patched") + return True + + _orig_call = Gemma4Attention.__call__ + + def _patched_call( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, _ = x.shape + + queries = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim) + queries = self.q_norm(queries) + + # Snapshot offset BEFORE update_and_fetch can mutate it in-place. + # Preserves per-sequence mx.array offsets for correct batched RoPE. + offset = _snapshot_cache_offset(cache) + + if self.is_kv_shared_layer and cache is not None: + state = cache.state + keys, values = state[0], state[1] + else: + keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + + if self.use_k_eq_v: + values = keys + else: + values = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + + keys = self.k_norm(keys) + values = self.v_norm(values) + values = values.transpose(0, 2, 1, 3) + + keys = keys.transpose(0, 2, 1, 3) + keys = self.rope(keys, offset=offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + queries = queries.transpose(0, 2, 1, 3) + queries = self.rope(queries, offset=offset) + + if mask is not None and isinstance(mask, mx.array): + if mask.shape[-1] != keys.shape[-2]: + mask = mask[..., -keys.shape[-2] :] + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + Gemma4Attention.__call__ = _patched_call + Gemma4Attention._batch_patched = True + logger.info("[Gemma4 patch] Attention patched for BatchKVCache support") + return True diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py index f138796ff..49d13a26b 100644 --- a/vllm_mlx/reasoning/__init__.py +++ b/vllm_mlx/reasoning/__init__.py @@ -76,6 +76,7 @@ def list_parsers() -> list[str]: def _register_builtin_parsers(): """Register built-in parsers.""" from .deepseek_r1_parser import DeepSeekR1ReasoningParser + from .gemma4_parser import Gemma4ReasoningParser from .gpt_oss_parser import GptOssReasoningParser from .harmony_parser import HarmonyReasoningParser from .qwen3_parser import Qwen3ReasoningParser @@ -84,6 +85,7 @@ def _register_builtin_parsers(): register_parser("deepseek_r1", DeepSeekR1ReasoningParser) register_parser("gpt_oss", GptOssReasoningParser) register_parser("harmony", HarmonyReasoningParser) + register_parser("gemma4", Gemma4ReasoningParser) # Register built-in parsers on module load diff --git a/vllm_mlx/reasoning/gemma4_parser.py b/vllm_mlx/reasoning/gemma4_parser.py new file mode 100644 index 000000000..8b6dd8149 --- /dev/null +++ b/vllm_mlx/reasoning/gemma4_parser.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for Gemma 4 models. + +Gemma 4 uses a channel-based protocol for reasoning: + + <|channel>thought + ...thinking content... + + ...response content... + +Where: + <|channel> = token 100 (channel switch marker) + = token 101 (end-of-channel marker) + +The channel names "thought" and "response" appear as text after the +special tokens and should be stripped from the output. + +Some model variants may use <|channel>response instead of +to transition from thinking to response mode. This parser handles both. + +When thinking is disabled or not triggered, output contains no tags. +""" + +from .base import DeltaMessage +from .think_parser import BaseThinkingReasoningParser + +# Channel names that follow <|channel> — stripped from output +_THOUGHT_PREFIX = "thought" +_RESPONSE_MARKER = "<|channel>response" + + +def _strip_channel_name(text: str, prefix: str) -> str: + """Strip channel name and leading whitespace/newline from text start.""" + if text.startswith(prefix): + text = text[len(prefix) :] + return text.lstrip("\n") + + +class Gemma4ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Gemma 4 models. + + Handles two transition formats: + 1. <|channel>thought...response (standard: token 100 + 101) + 2. <|channel>thought...<|channel>response (alternative: token 100 + 100) + + Channel names ("thought", "response") are stripped from output. + + Example: + Input: "<|channel>thought\\nLet me think...The answer is 42." + Output: reasoning="Let me think...", content="The answer is 42." + + When no tags are present, the entire output is treated as content. + """ + + @property + def start_token(self) -> str: + return "<|channel>" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from complete output. + + Handles both and <|channel>response as transition markers. + Strips channel names ("thought", "response") from output. + """ + text = model_output + + # Try standard format first: <|channel>thought...response + if self.start_token in text and self.end_token in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(self.end_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.strip() + return reasoning or None, content or None + + # Try alternative format: <|channel>thought...<|channel>response... + if text.count(self.start_token) >= 2 and _RESPONSE_MARKER in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(_RESPONSE_MARKER) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.lstrip("\n").strip() + return reasoning or None, content or None + + # Only closing tag (think injected in prompt) + if self.end_token in text: + reasoning, _, content = text.partition(self.end_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.strip() + return reasoning or None, content or None + + # Only start tag (incomplete reasoning, no end yet) + if self.start_token in text: + _, _, reasoning = text.partition(self.start_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + return reasoning or None, None + + # No tags at all — pure content + return None, model_output + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming delta. + + Handles: + - No tags: treat as content (Gemma 4 doesn't inject tags in prompt) + - <|channel>thought: enter reasoning mode, strip channel name + - or <|channel>response: transition to content mode + """ + # No channel tokens at all — plain content + if self.start_token not in current_text and self.end_token not in current_text: + return DeltaMessage(content=delta_text) + + # Check for alternative transition: <|channel>response + if _RESPONSE_MARKER in current_text: + if _RESPONSE_MARKER not in previous_text: + # Transition happening in this delta + # Find what (if any) content comes after the marker + marker_pos = current_text.find(_RESPONSE_MARKER) + after_marker = current_text[marker_pos + len(_RESPONSE_MARKER) :] + after_marker = after_marker.lstrip("\n") + if after_marker: + return DeltaMessage(content=after_marker) + return None # Suppress the marker itself + else: + # Already past transition — pure content + # But we need to only emit the NEW text (delta) + return DeltaMessage(content=delta_text) + + # Delegate to base class for standard <|channel>/ handling + result = super().extract_reasoning_streaming( + previous_text, current_text, delta_text + ) + + # Strip "thought" channel name from initial reasoning + if result is not None and result.reasoning is not None: + r = result.reasoning + # First reasoning delta after <|channel> will be "thought" or "thought\n" + if self.start_token in current_text: + # Check if this is the very first reasoning content + after_channel = current_text.split(self.start_token, 1)[1] + if after_channel.startswith(_THOUGHT_PREFIX): + # Remove "thought" prefix from the accumulated reasoning so far + clean = after_channel[len(_THOUGHT_PREFIX) :].lstrip("\n") + # Compute what portion of clean text is in this delta + prev_after = "" + if self.start_token in previous_text: + prev_after = previous_text.split(self.start_token, 1)[1] + if prev_after.startswith(_THOUGHT_PREFIX): + prev_after = prev_after[len(_THOUGHT_PREFIX) :].lstrip("\n") + # The new reasoning text is clean minus what was already emitted + new_reasoning = clean[len(prev_after) :] + if new_reasoning: + return DeltaMessage(reasoning=new_reasoning) + return None # Suppress channel name token + + return result diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index a50883951..0cb4b5d82 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -67,6 +67,8 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): return _load_strict_false(model_name, tokenizer_config) raise + return model, tokenizer + def _load_strict_false(model_name: str, tokenizer_config: dict = None): """Load model with strict=False to discard extra weights (e.g., vision tower, MTP)."""