diff --git a/tests/test_mllm_continuous_batching.py b/tests/test_mllm_continuous_batching.py index 28b26b219..13ce30a49 100644 --- a/tests/test_mllm_continuous_batching.py +++ b/tests/test_mllm_continuous_batching.py @@ -225,6 +225,53 @@ def test_tps_zero_time(self): assert stats.prompt_tps == 0 assert stats.generation_tps == 0 + def test_process_prompts_accepts_rotating_kv_cache(self, monkeypatch): + """MLLM prompt prefill should accept native sliding-window caches.""" + import mlx_lm.models.cache as cache_mod + from mlx_lm.models.cache import BatchRotatingKVCache, RotatingKVCache + from vllm_mlx.mllm_batch_generator import MLLMBatchGenerator, MLLMBatchRequest + + class DummyModel: + layers = [object()] + + def fake_preprocess(request): + request.input_ids = mx.array([1, 2, 3], dtype=mx.int32) + request.pixel_values = None + request.attention_mask = None + request.image_grid_thw = None + request.extra_kwargs = {} + + def fake_make_prompt_cache(_model): + cache = RotatingKVCache(max_size=16, keep=0) + cache.keys = mx.arange(8, dtype=mx.float32).reshape(1, 1, 4, 2) + cache.values = (mx.arange(8, dtype=mx.float32) + 1).reshape(1, 1, 4, 2) + cache.offset = 4 + cache._idx = 4 + return [cache] + + generator = MLLMBatchGenerator( + DummyModel(), + MagicMock(), + enable_vision_cache=False, + ) + + monkeypatch.setattr(generator, "_preprocess_request", fake_preprocess) + monkeypatch.setattr(cache_mod, "make_prompt_cache", fake_make_prompt_cache) + monkeypatch.setattr( + generator, + "_run_vision_encoding", + lambda request, cache=None: mx.zeros((1, 3, 8), dtype=mx.float32), + ) + + req = MLLMBatchRequest(uid=0, request_id="req-rot", prompt="Describe this") + batch = generator._process_prompts([req]) + + assert len(batch.cache) == 1 + assert isinstance(batch.cache[0], BatchRotatingKVCache) + assert batch.request_ids == ["req-rot"] + + generator.close() + class TestMLLMSchedulerConfig: """Tests for MLLMSchedulerConfig.""" diff --git a/tests/test_reasoning_parser.py b/tests/test_reasoning_parser.py index e2d0184e7..4bcb5ab3f 100644 --- a/tests/test_reasoning_parser.py +++ b/tests/test_reasoning_parser.py @@ -6,6 +6,7 @@ - Parser registry (registration, lookup, listing) - Qwen3 parser (non-streaming and streaming) - DeepSeek-R1 parser (non-streaming and streaming) +- Gemma 4 parser (channel protocol, streaming, channel name stripping) - Edge cases (no tags, partial tags, etc.) """ @@ -28,6 +29,7 @@ def test_list_parsers_includes_builtin(self): parsers = list_parsers() assert "qwen3" in parsers assert "deepseek_r1" in parsers + assert "gemma4" in parsers def test_get_parser_qwen3(self): """Should be able to get Qwen3 parser.""" @@ -920,3 +922,267 @@ def test_constrain_tokens_stripped(self, parser): reasoning, content = parser.extract_reasoning(output) assert "<|constrain|>" not in (content or "") assert "<|channel|>" not in (content or "") + + +class TestGemma4Parser: + """Tests for the Gemma 4 reasoning parser (channel-based protocol).""" + + @pytest.fixture + def parser(self): + """Create a fresh Gemma 4 parser for each test.""" + return get_parser("gemma4")() + + # --- Non-streaming tests --- + + def test_extract_standard_format(self, parser): + """Standard format: <|channel>thought...response.""" + output = ( + "<|channel>thought\nLet me think step by step.\nThe answer is 42." + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Let me think step by step." + assert content == "The answer is 42." + + def test_extract_alternative_format(self, parser): + """Alternative format: <|channel>thought...<|channel>response...""" + output = "<|channel>thought\nAnalyzing the problem.\n<|channel>response\nThe result is 7." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Analyzing the problem." + assert content == "The result is 7." + + def test_extract_strips_thought_prefix(self, parser): + """Channel name 'thought' should be stripped from reasoning.""" + output = "<|channel>thought\nActual reasoning hereContent" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Actual reasoning here" + assert "thought" not in reasoning + + def test_extract_no_tags_pure_content(self, parser): + """No channel tags at all should return pure content.""" + output = "Just a regular response without thinking." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == output + + def test_extract_only_start_tag(self, parser): + """Only start tag means incomplete reasoning (no content yet).""" + output = "<|channel>thought\nStill thinking..." + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Still thinking..." + assert content is None + + def test_extract_only_end_tag(self, parser): + """Only end tag (think injected in prompt).""" + output = "thought\nImplicit reasoningThe answer" + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Implicit reasoning" + assert content == "The answer" + + def test_extract_empty_reasoning(self, parser): + """Empty reasoning should return None.""" + output = "<|channel>thought\nOnly content here." + reasoning, content = parser.extract_reasoning(output) + assert reasoning is None + assert content == "Only content here." + + def test_extract_multiline_reasoning(self, parser): + """Should preserve multiline reasoning content.""" + output = ( + "<|channel>thought\n" + "Step 1: Understand the question.\n" + "Step 2: Analyze the data.\n" + "Step 3: Form conclusion.\n" + "The conclusion is clear." + ) + reasoning, content = parser.extract_reasoning(output) + assert "Step 1" in reasoning + assert "Step 2" in reasoning + assert "Step 3" in reasoning + assert content == "The conclusion is clear." + + def test_extract_unicode_reasoning(self, parser): + """Should handle Unicode in reasoning.""" + output = "<|channel>thought\n日本語テスト 🤔\n答えは42" + reasoning, content = parser.extract_reasoning(output) + assert "日本語テスト" in reasoning + assert "🤔" in reasoning + assert "42" in content + + def test_registry_includes_gemma4(self): + """gemma4 should be in the parser registry.""" + assert "gemma4" in list_parsers() + + # --- Streaming tests --- + + def test_streaming_no_tags_plain_content(self, parser): + """Streaming without any channel tags should return content.""" + parser.reset_state() + result = parser.extract_reasoning_streaming("", "Hello", "Hello") + assert result is not None + assert result.content == "Hello" + assert result.reasoning is None + + def test_streaming_standard_format(self, parser): + """Test streaming through <|channel>thought...content flow.""" + parser.reset_state() + + tokens = [ + "<|channel>", + "thought", + "\n", + "Let me ", + "think.", + "", + "The ", + "answer.", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + # "thought\n" prefix should be stripped + assert "thought" not in full_reasoning or "thought" in "Let me think." + assert "Let me think." in full_reasoning + assert "The answer." in full_content + + def test_streaming_alternative_format(self, parser): + """Test streaming with <|channel>response transition.""" + parser.reset_state() + + tokens = [ + "<|channel>", + "thought", + "\n", + "Analyzing.", + "<|channel>response", + "\n", + "Result: ", + "42", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_content = "".join(content_parts) + assert "Result: 42" in full_content + + def test_streaming_suppresses_channel_names(self, parser): + """Channel names 'thought' and 'response' should not appear in output.""" + parser.reset_state() + + # Simulate realistic Gemma 4 output + tokens = [ + "<|channel>", + "thought", + "\n", + "Real ", + "reasoning.", + "", + "Real ", + "content.", + ] + + accumulated = "" + all_output = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + all_output.append(("r", result.reasoning)) + if result.content: + all_output.append(("c", result.content)) + + # Verify no raw "thought" token leaked as reasoning + reasoning_text = "".join(t for tag, t in all_output if tag == "r") + content_text = "".join(t for tag, t in all_output if tag == "c") + + assert "Real reasoning." in reasoning_text + assert "Real content." in content_text + + def test_streaming_token_by_token(self, parser): + """Test character-by-character streaming (worst case).""" + parser.reset_state() + + output = "<|channel>thought\nStep 1: Think\nStep 2: Analyze\nFinal answer: 42." + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for char in output: + prev = accumulated + accumulated += char + result = parser.extract_reasoning_streaming(prev, accumulated, char) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + full_content = "".join(content_parts) + + assert "Step 1: Think" in full_reasoning + assert "Step 2: Analyze" in full_reasoning + assert "Final answer: 42." in full_content + + def test_streaming_long_thinking_no_end_tag(self, parser): + """When model generates long thinking without end tag, all goes to reasoning.""" + parser.reset_state() + + # Simulate model that hits max_tokens before + tokens = [ + "<|channel>", + "thought", + "\n", + "This is a very long ", + "reasoning process ", + "that continues ", + "without ending.", + ] + + accumulated = "" + reasoning_parts = [] + content_parts = [] + + for token in tokens: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result: + if result.reasoning: + reasoning_parts.append(result.reasoning) + if result.content: + content_parts.append(result.content) + + full_reasoning = "".join(reasoning_parts) + assert "very long reasoning process" in full_reasoning + assert len(content_parts) == 0 diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index ee8d8da7b..b26a89427 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -324,6 +324,11 @@ def __init__( "MLLMBatchGenerator: Model does not have language_model, using model directly" ) + # Patch attention for BatchKVCache compatibility + from .patches.gemma4_mllm import patch_gemma4_attention_for_batching + + patch_gemma4_attention_for_batching() + self.max_tokens = max_tokens self.stop_tokens = stop_tokens or set() self.sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) @@ -340,6 +345,9 @@ def __init__( # Statistics self._stats = MLLMBatchStats() + # Error responses for requests that failed during preprocessing + self._pending_error_responses: List[MLLMBatchResponse] = [] + # Vision embedding cache for repeated images self.vision_cache = VisionEmbeddingCache( max_pixel_entries=vision_cache_size, @@ -666,16 +674,37 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: # KVCache.merge() creates a BatchKVCache with proper left-padding # alignment, so all requests share a single batched cache for # subsequent generation steps. - from mlx_lm.models.cache import KVCache + from mlx_lm.models.cache import KVCache, RotatingKVCache sample_cache = per_request_caches[0][0] - if not isinstance(sample_cache, KVCache): + if not isinstance(sample_cache, (KVCache, RotatingKVCache)): raise ValueError( - f"MLLM continuous batching requires standard KVCache but got " - f"{type(sample_cache).__name__}. Disable --kv-cache-quantization " - f"when using multimodal models with --continuous-batching." + f"MLLM continuous batching requires standard KVCache or " + f"RotatingKVCache but got {type(sample_cache).__name__}. " + f"Disable --kv-cache-quantization when using multimodal " + f"models with --continuous-batching." ) + # Fix: RotatingKVCache._update_concat does NOT trim on first call — + # if prompt length > max_size, the buffer grows beyond max_size. + # BatchRotatingKVCache.merge() then hits a shape mismatch when + # copying via _temporal_order (full buffer) into a max_size slice. + # Trim buffer to max_size before merging. + for rc in per_request_caches: + for layer_cache in rc: + if isinstance(layer_cache, RotatingKVCache): + if layer_cache.keys is not None: + buf_len = layer_cache.keys.shape[2] + if buf_len > layer_cache.max_size: + trim_size = buf_len - layer_cache.max_size + layer_cache.keys = layer_cache._trim( + trim_size, layer_cache.keys + ) + layer_cache.values = layer_cache._trim( + trim_size, layer_cache.values + ) + layer_cache._idx = layer_cache.max_size + try: batch_cache = [ per_request_caches[0][layer_idx].merge( @@ -764,15 +793,40 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None return [] - new_batch = self._process_prompts(requests) - self.unprocessed_requests = self.unprocessed_requests[len(requests) :] - self.active_batch = new_batch - prompt_processing = True + try: + new_batch = self._process_prompts(requests) + self.unprocessed_requests = self.unprocessed_requests[len(requests) :] + self.active_batch = new_batch + prompt_processing = True + except Exception as e: + logger.error( + f"Failed to process batch of {len(requests)} prompts: " + f"{type(e).__name__}: {e}", + exc_info=True, + ) + # Remove failed requests to avoid infinite retry loop + self.unprocessed_requests = self.unprocessed_requests[len(requests) :] + for req in requests: + self._pending_error_responses.append( + MLLMBatchResponse( + uid=req.uid, + request_id=req.request_id, + token=0, + logprobs=mx.zeros(1), + finish_reason="error", + ) + ) + + # Collect any pending error responses (from failed preprocessing) + error_responses = [] + if self._pending_error_responses: + error_responses = list(self._pending_error_responses) + self._pending_error_responses.clear() # Generate next token for active batch batch = self.active_batch if batch is None: - return [] + return error_responses y, logprobs = batch.y, batch.logprobs batch.y, batch.logprobs = self._step(y[:, None], batch.cache) @@ -841,7 +895,7 @@ def _next(self) -> List[MLLMBatchResponse]: self.active_batch = None self._stats.generation_tokens += len(responses) - return responses + return error_responses + responses def next(self) -> List[MLLMBatchResponse]: """ diff --git a/vllm_mlx/mllm_scheduler.py b/vllm_mlx/mllm_scheduler.py index 555b230f2..9623ca27f 100644 --- a/vllm_mlx/mllm_scheduler.py +++ b/vllm_mlx/mllm_scheduler.py @@ -219,7 +219,7 @@ def __init__( self.total_completion_tokens = 0 def _get_stop_tokens(self) -> Set[int]: - """Get stop token IDs from tokenizer.""" + """Get stop token IDs from tokenizer and generation_config.json.""" stop_tokens = set() tokenizer = ( self.processor.tokenizer @@ -239,6 +239,25 @@ def _get_stop_tokens(self) -> Set[int]: else: stop_tokens.add(tokenizer.eos_token_ids) + # Also read generation_config.json which may have additional EOS tokens + # (e.g., Gemma 4 has =106, <|tool_response>=50 as EOS) + model_path = getattr(tokenizer, "name_or_path", None) + if model_path: + import json + from pathlib import Path + + gc_path = Path(model_path) / "generation_config.json" + if gc_path.exists(): + try: + gc = json.loads(gc_path.read_text()) + gc_eos = gc.get("eos_token_id") + if isinstance(gc_eos, list): + stop_tokens.update(gc_eos) + elif gc_eos is not None: + stop_tokens.add(gc_eos) + except Exception: + pass + return stop_tokens def _ensure_batch_generator(self) -> None: diff --git a/vllm_mlx/patches/gemma4_mllm.py b/vllm_mlx/patches/gemma4_mllm.py new file mode 100644 index 000000000..dc041cf31 --- /dev/null +++ b/vllm_mlx/patches/gemma4_mllm.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime patch for mlx-vlm's Gemma 4 attention to support BatchKVCache. + +Gemma 4 Attention reads cache.offset into a local variable before calling +update_and_fetch, then uses the same variable later for RoPE on queries: + + offset = cache.offset # reference to mx.array([22]) + keys = self.rope(keys, offset=offset) + keys, values = cache.update_and_fetch(keys, values) + # ^^^ self.offset += 1 mutates the SAME mx.array in-place! + queries = self.rope(queries, offset=offset) # offset is now 23! + +For KVCache, cache.offset is a Python int (immutable), so the local copy +is unaffected. For BatchKVCache, cache.offset is an mx.array and +mx.array.__iadd__ is *in-place*, so the local reference is silently +mutated by update_and_fetch, giving queries the wrong RoPE position. + +This patch replaces Gemma4 Attention.__call__ with a version that +snapshots cache.offset as a defensive copy before any mutation can occur. +The mx.array copy preserves per-sequence offsets needed for correct RoPE +in continuous batching (unlike int conversion which would lose this info). +""" + +import logging +from typing import Any, Optional + +import mlx.core as mx + +logger = logging.getLogger(__name__) + + +def _snapshot_cache_offset(cache): + """Snapshot cache offset, making a defensive copy if it's an mx.array. + + BatchKVCache stores offset as mx.array (per-batch-item). + mx.array.__iadd__ is in-place, so update_and_fetch mutates the original. + We return a copy to preserve the pre-update value for RoPE on queries. + """ + if cache is None: + return 0 + off = cache.offset + if isinstance(off, int): + return off + if isinstance(off, mx.array): + return off + 0 # defensive copy — new array, same values + return off + + +def patch_gemma4_attention_for_batching() -> bool: + """Monkey-patch Gemma4 Attention.__call__ to snapshot offset before update. + + Returns True if patch was applied, False if mlx-vlm is not installed + or Gemma 4 module not available. + """ + try: + from mlx_vlm.models.gemma4.language import Attention as Gemma4Attention + from mlx_vlm.models.base import scaled_dot_product_attention + except ImportError: + logger.debug("[Gemma4 patch] mlx-vlm Gemma4 module not available") + return False + + if getattr(Gemma4Attention, "_batch_patched", False): + logger.debug("[Gemma4 patch] Already patched") + return True + + _orig_call = Gemma4Attention.__call__ + + def _patched_call( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, _ = x.shape + + queries = self.q_proj(x).reshape(B, L, self.n_heads, self.head_dim) + queries = self.q_norm(queries) + + # Snapshot offset BEFORE update_and_fetch can mutate it in-place. + # Preserves per-sequence mx.array offsets for correct batched RoPE. + offset = _snapshot_cache_offset(cache) + + if self.is_kv_shared_layer and cache is not None: + state = cache.state + keys, values = state[0], state[1] + else: + keys = self.k_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + + if self.use_k_eq_v: + values = keys + else: + values = self.v_proj(x).reshape(B, L, self.n_kv_heads, self.head_dim) + + keys = self.k_norm(keys) + values = self.v_norm(values) + values = values.transpose(0, 2, 1, 3) + + keys = keys.transpose(0, 2, 1, 3) + keys = self.rope(keys, offset=offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + queries = queries.transpose(0, 2, 1, 3) + queries = self.rope(queries, offset=offset) + + if mask is not None and isinstance(mask, mx.array): + if mask.shape[-1] != keys.shape[-2]: + mask = mask[..., -keys.shape[-2] :] + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + Gemma4Attention.__call__ = _patched_call + Gemma4Attention._batch_patched = True + logger.info("[Gemma4 patch] Attention patched for BatchKVCache support") + return True diff --git a/vllm_mlx/reasoning/__init__.py b/vllm_mlx/reasoning/__init__.py index f138796ff..49d13a26b 100644 --- a/vllm_mlx/reasoning/__init__.py +++ b/vllm_mlx/reasoning/__init__.py @@ -76,6 +76,7 @@ def list_parsers() -> list[str]: def _register_builtin_parsers(): """Register built-in parsers.""" from .deepseek_r1_parser import DeepSeekR1ReasoningParser + from .gemma4_parser import Gemma4ReasoningParser from .gpt_oss_parser import GptOssReasoningParser from .harmony_parser import HarmonyReasoningParser from .qwen3_parser import Qwen3ReasoningParser @@ -84,6 +85,7 @@ def _register_builtin_parsers(): register_parser("deepseek_r1", DeepSeekR1ReasoningParser) register_parser("gpt_oss", GptOssReasoningParser) register_parser("harmony", HarmonyReasoningParser) + register_parser("gemma4", Gemma4ReasoningParser) # Register built-in parsers on module load diff --git a/vllm_mlx/reasoning/gemma4_parser.py b/vllm_mlx/reasoning/gemma4_parser.py new file mode 100644 index 000000000..8b6dd8149 --- /dev/null +++ b/vllm_mlx/reasoning/gemma4_parser.py @@ -0,0 +1,170 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Reasoning parser for Gemma 4 models. + +Gemma 4 uses a channel-based protocol for reasoning: + + <|channel>thought + ...thinking content... + + ...response content... + +Where: + <|channel> = token 100 (channel switch marker) + = token 101 (end-of-channel marker) + +The channel names "thought" and "response" appear as text after the +special tokens and should be stripped from the output. + +Some model variants may use <|channel>response instead of +to transition from thinking to response mode. This parser handles both. + +When thinking is disabled or not triggered, output contains no tags. +""" + +from .base import DeltaMessage +from .think_parser import BaseThinkingReasoningParser + +# Channel names that follow <|channel> — stripped from output +_THOUGHT_PREFIX = "thought" +_RESPONSE_MARKER = "<|channel>response" + + +def _strip_channel_name(text: str, prefix: str) -> str: + """Strip channel name and leading whitespace/newline from text start.""" + if text.startswith(prefix): + text = text[len(prefix) :] + return text.lstrip("\n") + + +class Gemma4ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Gemma 4 models. + + Handles two transition formats: + 1. <|channel>thought...response (standard: token 100 + 101) + 2. <|channel>thought...<|channel>response (alternative: token 100 + 100) + + Channel names ("thought", "response") are stripped from output. + + Example: + Input: "<|channel>thought\\nLet me think...The answer is 42." + Output: reasoning="Let me think...", content="The answer is 42." + + When no tags are present, the entire output is treated as content. + """ + + @property + def start_token(self) -> str: + return "<|channel>" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, + model_output: str, + ) -> tuple[str | None, str | None]: + """ + Extract reasoning from complete output. + + Handles both and <|channel>response as transition markers. + Strips channel names ("thought", "response") from output. + """ + text = model_output + + # Try standard format first: <|channel>thought...response + if self.start_token in text and self.end_token in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(self.end_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.strip() + return reasoning or None, content or None + + # Try alternative format: <|channel>thought...<|channel>response... + if text.count(self.start_token) >= 2 and _RESPONSE_MARKER in text: + _, _, after_start = text.partition(self.start_token) + reasoning, _, content = after_start.partition(_RESPONSE_MARKER) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.lstrip("\n").strip() + return reasoning or None, content or None + + # Only closing tag (think injected in prompt) + if self.end_token in text: + reasoning, _, content = text.partition(self.end_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + content = content.strip() + return reasoning or None, content or None + + # Only start tag (incomplete reasoning, no end yet) + if self.start_token in text: + _, _, reasoning = text.partition(self.start_token) + reasoning = _strip_channel_name(reasoning.strip(), _THOUGHT_PREFIX) + return reasoning or None, None + + # No tags at all — pure content + return None, model_output + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """ + Extract reasoning from streaming delta. + + Handles: + - No tags: treat as content (Gemma 4 doesn't inject tags in prompt) + - <|channel>thought: enter reasoning mode, strip channel name + - or <|channel>response: transition to content mode + """ + # No channel tokens at all — plain content + if self.start_token not in current_text and self.end_token not in current_text: + return DeltaMessage(content=delta_text) + + # Check for alternative transition: <|channel>response + if _RESPONSE_MARKER in current_text: + if _RESPONSE_MARKER not in previous_text: + # Transition happening in this delta + # Find what (if any) content comes after the marker + marker_pos = current_text.find(_RESPONSE_MARKER) + after_marker = current_text[marker_pos + len(_RESPONSE_MARKER) :] + after_marker = after_marker.lstrip("\n") + if after_marker: + return DeltaMessage(content=after_marker) + return None # Suppress the marker itself + else: + # Already past transition — pure content + # But we need to only emit the NEW text (delta) + return DeltaMessage(content=delta_text) + + # Delegate to base class for standard <|channel>/ handling + result = super().extract_reasoning_streaming( + previous_text, current_text, delta_text + ) + + # Strip "thought" channel name from initial reasoning + if result is not None and result.reasoning is not None: + r = result.reasoning + # First reasoning delta after <|channel> will be "thought" or "thought\n" + if self.start_token in current_text: + # Check if this is the very first reasoning content + after_channel = current_text.split(self.start_token, 1)[1] + if after_channel.startswith(_THOUGHT_PREFIX): + # Remove "thought" prefix from the accumulated reasoning so far + clean = after_channel[len(_THOUGHT_PREFIX) :].lstrip("\n") + # Compute what portion of clean text is in this delta + prev_after = "" + if self.start_token in previous_text: + prev_after = previous_text.split(self.start_token, 1)[1] + if prev_after.startswith(_THOUGHT_PREFIX): + prev_after = prev_after[len(_THOUGHT_PREFIX) :].lstrip("\n") + # The new reasoning text is clean minus what was already emitted + new_reasoning = clean[len(prev_after) :] + if new_reasoning: + return DeltaMessage(reasoning=new_reasoning) + return None # Suppress channel name token + + return result