diff --git a/tests/test_streaming_pipeline_integration.py b/tests/test_streaming_pipeline_integration.py
new file mode 100644
index 000000000..7ddcf39f8
--- /dev/null
+++ b/tests/test_streaming_pipeline_integration.py
@@ -0,0 +1,240 @@
+"""Integration test for the Anthropic streaming pipeline.
+
+Tests the full flow: raw model output → StreamingToolCallFilter → StreamingThinkRouter
+→ Anthropic SSE events, verifying block transitions, tool call extraction, and
+prompt_tokens tracking work together correctly.
+"""
+
+import json
+import unittest
+
+from vllm_mlx.api.utils import StreamingToolCallFilter, StreamingThinkRouter
+from vllm_mlx.server import _emit_content_pieces
+
+
+class TestEmitContentPieces(unittest.TestCase):
+ """Test the refactored _emit_content_pieces helper."""
+
+ def test_single_text_block(self):
+ events, block_type, index = _emit_content_pieces([("text", "hello")], None, 0)
+ assert len(events) == 2 # block_start + delta
+ assert block_type == "text"
+ assert index == 0
+ # Verify block_start
+ start_data = json.loads(events[0].split("data: ")[1])
+ assert start_data["type"] == "content_block_start"
+ assert start_data["content_block"]["type"] == "text"
+ # Verify delta
+ delta_data = json.loads(events[1].split("data: ")[1])
+ assert delta_data["delta"]["text"] == "hello"
+
+ def test_single_thinking_block(self):
+ events, block_type, index = _emit_content_pieces(
+ [("thinking", "reasoning")], None, 0
+ )
+ assert block_type == "thinking"
+ delta_data = json.loads(events[1].split("data: ")[1])
+ assert delta_data["delta"]["thinking"] == "reasoning"
+
+ def test_transition_thinking_to_text(self):
+ events, block_type, index = _emit_content_pieces(
+ [("thinking", "reason"), ("text", "answer")], None, 0
+ )
+ assert block_type == "text"
+ assert index == 1 # incremented on block transition
+ # Should have: start_thinking, delta_thinking, stop_thinking, start_text, delta_text
+ assert len(events) == 5
+ stop_data = json.loads(events[2].split("data: ")[1])
+ assert stop_data["type"] == "content_block_stop"
+
+ def test_continues_existing_block(self):
+ """If current_block_type matches, no start/stop emitted."""
+ events, block_type, index = _emit_content_pieces([("text", "more")], "text", 0)
+ assert len(events) == 1 # just delta, no start
+ assert block_type == "text"
+
+ def test_empty_pieces(self):
+ events, block_type, index = _emit_content_pieces([], None, 0)
+ assert events == []
+ assert block_type is None
+ assert index == 0
+
+
+class TestStreamingPipelineIntegration(unittest.TestCase):
+ """Integration test for the full streaming pipeline."""
+
+ def _run_pipeline(self, deltas, start_in_thinking=False):
+ """Run deltas through tool_filter → think_router → emit, return events."""
+ tool_filter = StreamingToolCallFilter()
+ think_router = StreamingThinkRouter(start_in_thinking=start_in_thinking)
+ current_block_type = None
+ block_index = 0
+ all_events = []
+ accumulated_text = ""
+
+ for delta in deltas:
+ accumulated_text += delta
+ filtered = tool_filter.process(delta)
+ if not filtered:
+ continue
+ pieces = think_router.process(filtered)
+ events, current_block_type, block_index = _emit_content_pieces(
+ pieces, current_block_type, block_index
+ )
+ all_events.extend(events)
+
+ # Flush
+ remaining = tool_filter.flush()
+ if remaining:
+ pieces = think_router.process(remaining)
+ events, current_block_type, block_index = _emit_content_pieces(
+ pieces, current_block_type, block_index
+ )
+ all_events.extend(events)
+
+ flush_pieces = think_router.flush()
+ if flush_pieces:
+ events, current_block_type, block_index = _emit_content_pieces(
+ flush_pieces, current_block_type, block_index
+ )
+ all_events.extend(events)
+
+ # Close final block
+ if current_block_type is not None:
+ all_events.append(
+ f"event: content_block_stop\ndata: "
+ f"{json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n"
+ )
+ block_index += 1
+
+ return all_events, accumulated_text, block_index
+
+ def _parse_events(self, events):
+ """Parse SSE events into structured data."""
+ parsed = []
+ for event in events:
+ data_line = event.split("data: ", 1)[1].split("\n")[0]
+ parsed.append(json.loads(data_line))
+ return parsed
+
+ def test_pure_text_response(self):
+ """Simple text response - one text block."""
+ events, _, block_index = self._run_pipeline(["Hello ", "world!"])
+ parsed = self._parse_events(events)
+
+ # block_start, 2 deltas, block_stop
+ types = [p["type"] for p in parsed]
+ assert types[0] == "content_block_start"
+ assert parsed[0]["content_block"]["type"] == "text"
+ assert types[-1] == "content_block_stop"
+ assert block_index == 1
+
+ def test_thinking_then_text(self):
+ """Model thinks then responds."""
+ events, _, block_index = self._run_pipeline(
+ ["Let me think", " about this", "The answer is 42"]
+ )
+ parsed = self._parse_events(events)
+
+ block_starts = [p for p in parsed if p["type"] == "content_block_start"]
+ assert len(block_starts) == 2
+ assert block_starts[0]["content_block"]["type"] == "thinking"
+ assert block_starts[1]["content_block"]["type"] == "text"
+ assert block_index == 2
+
+ def test_start_in_thinking_then_text(self):
+ """Model starts in thinking mode (template injects )."""
+ events, _, _ = self._run_pipeline(
+ ["reasoning here", "", "The answer"],
+ start_in_thinking=True,
+ )
+ parsed = self._parse_events(events)
+
+ block_starts = [p for p in parsed if p["type"] == "content_block_start"]
+ assert len(block_starts) == 2
+ assert block_starts[0]["content_block"]["type"] == "thinking"
+ assert block_starts[1]["content_block"]["type"] == "text"
+
+ def test_text_then_tool_call(self):
+ """Text followed by tool call - tool markup suppressed from text."""
+ events, accumulated, _ = self._run_pipeline(
+ [
+ "I'll search for that. ",
+ "",
+ '',
+ 'ls /tmp',
+ "",
+ "",
+ ]
+ )
+ parsed = self._parse_events(events)
+
+ # Only text block should appear (tool call is suppressed from streaming)
+ text_deltas = [
+ p
+ for p in parsed
+ if p["type"] == "content_block_delta"
+ and p["delta"].get("type") == "text_delta"
+ ]
+ text_content = "".join(d["delta"]["text"] for d in text_deltas)
+ assert "I'll search for that." in text_content
+ assert "" not in text_content
+
+ # But accumulated text has the full tool call for parsing
+ assert "" in accumulated
+
+ def test_thinking_then_tool_call(self):
+ """Thinking followed by tool call - both properly routed."""
+ events, accumulated, _ = self._run_pipeline(
+ [
+ "I need to search",
+ "",
+ '',
+ 'test',
+ "",
+ "",
+ ]
+ )
+ parsed = self._parse_events(events)
+
+ block_starts = [p for p in parsed if p["type"] == "content_block_start"]
+ # Only thinking block (tool call is suppressed)
+ assert len(block_starts) == 1
+ assert block_starts[0]["content_block"]["type"] == "thinking"
+
+ def test_mixed_thinking_text_and_tool_call(self):
+ """Full scenario: thinking → text → tool call."""
+ events, accumulated, block_index = self._run_pipeline(
+ [
+ "analyzing request",
+ "Let me help. ",
+ "",
+ 'echo hi',
+ "",
+ ]
+ )
+ parsed = self._parse_events(events)
+
+ block_starts = [p for p in parsed if p["type"] == "content_block_start"]
+ # thinking block + text block (tool call suppressed)
+ assert len(block_starts) == 2
+ assert block_starts[0]["content_block"]["type"] == "thinking"
+ assert block_starts[1]["content_block"]["type"] == "text"
+
+ # Accumulated has everything for post-stream tool parsing
+ assert "" in accumulated
+
+ def test_block_index_increments_correctly(self):
+ """Block indices should increment on each transition."""
+ events, _, final_index = self._run_pipeline(
+ ["t1textt2end"]
+ )
+ parsed = self._parse_events(events)
+
+ starts = [p for p in parsed if p["type"] == "content_block_start"]
+ assert [s["index"] for s in starts] == [0, 1, 2, 3]
+ assert final_index == 4
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_streaming_think_router.py b/tests/test_streaming_think_router.py
new file mode 100644
index 000000000..d5272cdca
--- /dev/null
+++ b/tests/test_streaming_think_router.py
@@ -0,0 +1,168 @@
+"""Tests for StreamingThinkRouter - routes blocks to Anthropic thinking content blocks."""
+
+import unittest
+
+from vllm_mlx.api.utils import StreamingThinkRouter
+
+
+class TestStreamingThinkRouter(unittest.TestCase):
+ """Unit tests for StreamingThinkRouter."""
+
+ # --- Basic routing ---
+
+ def test_plain_text_routes_as_text(self):
+ r = StreamingThinkRouter()
+ assert r.process("Hello world") == [("text", "Hello world")]
+
+ def test_think_block_routes_as_thinking(self):
+ r = StreamingThinkRouter()
+ assert r.process("reasoning") == [("thinking", "reasoning")]
+
+ def test_text_then_think_then_text(self):
+ r = StreamingThinkRouter()
+ result = r.process("beforemiddleafter")
+ assert result == [("text", "before"), ("thinking", "middle"), ("text", "after")]
+
+ # --- start_in_thinking mode ---
+
+ def test_start_in_thinking_mode(self):
+ """When model injects into prompt, output starts in thinking mode."""
+ r = StreamingThinkRouter(start_in_thinking=True)
+ result = r.process("reasoning here")
+ assert result == [("thinking", "reasoning here")]
+
+ def test_start_in_thinking_then_close(self):
+ """Thinking closes with , then text follows."""
+ r = StreamingThinkRouter(start_in_thinking=True)
+ result = r.process("reasoninganswer")
+ assert result == [("thinking", "reasoning"), ("text", "answer")]
+
+ def test_start_in_thinking_close_across_deltas(self):
+ """ split across multiple deltas."""
+ r = StreamingThinkRouter(start_in_thinking=True)
+ p1 = r.process("thinking stuffnow text")
+ # First delta should hold back partial
+ assert ("thinking", "thinking stuff") in p1
+ # Second delta should transition
+ all_pieces = p1 + p2
+ types = [t for t, _ in all_pieces]
+ assert "text" in types
+
+ # --- Partial tag handling ---
+
+ def test_partial_open_tag_held_back(self):
+ """Partial reasoning")
+ # p1 should emit "Hello " but hold back "answer")
+ # p1 should emit thinking content but hold back partial
+ assert ("thinking", "deep thought") in p1
+ # p2 should transition to text
+ assert ("text", "answer") in p2
+
+ def test_partial_tag_false_alarm(self):
+ """Partial match that turns out not to be a tag."""
+ r = StreamingThinkRouter()
+ p1 = r.process("Hello ")
+ # After p2, the held-back "" should emit as text
+ all_text = "".join(t for bt, t in p1 + p2 if bt == "text")
+ assert "Hello " == all_text
+
+ # --- Multiple think blocks ---
+
+ def test_multiple_think_blocks(self):
+ r = StreamingThinkRouter()
+ result = r.process("firstmiddlesecondend")
+ assert result == [
+ ("thinking", "first"),
+ ("text", "middle"),
+ ("thinking", "second"),
+ ("text", "end"),
+ ]
+
+ # --- Streaming across deltas ---
+
+ def test_streaming_token_by_token(self):
+ """Simulate character-by-character streaming."""
+ r = StreamingThinkRouter()
+ text = "abcxyz"
+ all_pieces = []
+ for ch in text:
+ all_pieces.extend(r.process(ch))
+ all_pieces.extend(r.flush())
+ thinking = "".join(t for bt, t in all_pieces if bt == "thinking")
+ text_out = "".join(t for bt, t in all_pieces if bt == "text")
+ assert thinking == "abc"
+ assert text_out == "xyz"
+
+ def test_streaming_with_start_in_thinking(self):
+ """Token-by-token with start_in_thinking."""
+ r = StreamingThinkRouter(start_in_thinking=True)
+ text = "reasoningthe answer"
+ all_pieces = []
+ for ch in text:
+ all_pieces.extend(r.process(ch))
+ all_pieces.extend(r.flush())
+ thinking = "".join(t for bt, t in all_pieces if bt == "thinking")
+ text_out = "".join(t for bt, t in all_pieces if bt == "text")
+ assert thinking == "reasoning"
+ assert text_out == "the answer"
+
+ # --- Flush behavior ---
+
+ def test_flush_emits_remaining_text(self):
+ """Text without partial tags is emitted by process(), flush() is empty."""
+ r = StreamingThinkRouter()
+ pieces = r.process("partial text")
+ assert pieces == [("text", "partial text")]
+ assert r.flush() == []
+
+ def test_flush_emits_remaining_thinking(self):
+ """Thinking without partial tags is emitted by process(), flush() is empty."""
+ r = StreamingThinkRouter(start_in_thinking=True)
+ pieces = r.process("unfinished thought")
+ assert pieces == [("thinking", "unfinished thought")]
+ assert r.flush() == []
+
+ def test_flush_with_held_back_partial(self):
+ """Flush should emit held-back partial tag as content."""
+ r = StreamingThinkRouter()
+ r.process("text")
+ f.process('')
+ f.process('/tmp/test.txt')
+ f.process("")
+ result = f.process("")
+ assert result == ""
+
+ def test_text_after_tool_call_emits(self):
+ f = StreamingToolCallFilter()
+ f.process("content")
+ assert f.process("After") == "After"
+
+ def test_text_before_and_after_same_delta(self):
+ f = StreamingToolCallFilter()
+ result = f.process("Before insideAfter")
+ assert result == "Before After"
+
+ def test_split_across_deltas(self):
+ f = StreamingToolCallFilter()
+ r1 = f.process("Before insideAfter")
+ assert r1 + r2 == "Before After"
+
+ def test_qwen_format_suppressed(self):
+ f = StreamingToolCallFilter()
+ result = f.process('Text {"name":"fn"} more')
+ assert result == "Text more"
+
+ def test_multiple_tool_calls(self):
+ f = StreamingToolCallFilter()
+ result = f.process(
+ "A x"
+ " B y C"
+ )
+ assert result == "A B C"
+
+ def test_flush_partial_tag_emits(self):
+ f = StreamingToolCallFilter()
+ r = f.process("text partial content")
+ assert f.flush() == ""
+
+ def test_large_tool_call_content(self):
+ """Simulates a Read tool returning a large file."""
+ f = StreamingToolCallFilter()
+ big = "x" * 10000
+ result = f.process(f"Before {big}After")
+ assert result == "Before After"
+
+ def test_think_tags_not_filtered(self):
+ f = StreamingToolCallFilter()
+ result = f.process("reasoning hereanswer")
+ assert "" in result
+ assert "reasoning here" in result
+
+ def test_mixed_think_and_tool_call(self):
+ f = StreamingToolCallFilter()
+ result = f.process(
+ "thinking"
+ "tool stuff"
+ "final answer"
+ )
+ assert "thinking" in result
+ assert "tool stuff" not in result
+ assert "final answer" in result
+
+ def test_gradual_token_by_token(self):
+ """Simulate token-by-token streaming."""
+ f = StreamingToolCallFilter()
+ parts = [
+ "Hello ",
+ "<",
+ "mini",
+ "max:",
+ "tool_call",
+ ">",
+ '',
+ "",
+ "",
+ " world",
+ ]
+ result = ""
+ for part in parts:
+ result += f.process(part)
+ result += f.flush()
+ assert result == "Hello world", f"Got: {result!r}"
+
+ def test_empty_deltas(self):
+ f = StreamingToolCallFilter()
+ assert f.process("") == ""
+ assert f.process("text") == "text"
+ assert f.process("") == ""
+
+ def test_calling_tool_bracket_suppressed(self):
+ """Qwen3 bracket-style: [Calling tool: func({...})]\n"""
+ f = StreamingToolCallFilter()
+ result = f.process('[Calling tool: search({"q": "test"})]\n')
+ assert result == ""
+
+ def test_calling_tool_multiline_json(self):
+ """Multi-line JSON args in bracket-style tool call."""
+ f = StreamingToolCallFilter()
+ r1 = f.process('[Calling tool: search({"q": "test",')
+ r2 = f.process(' "limit": 5})]\n')
+ r3 = f.process("After")
+ assert r1 + r2 + r3 == "After"
+
+ def test_buffer_cap_on_unclosed_block(self):
+ """Buffer should be capped if tool call block never closes."""
+ from vllm_mlx.api.utils import _MAX_TOOL_BUFFER_BYTES
+
+ f = StreamingToolCallFilter()
+ f.process("")
+ # Feed data exceeding the cap
+ chunk = "x" * 10000
+ for _ in range(_MAX_TOOL_BUFFER_BYTES // 10000 + 2):
+ f.process(chunk)
+ # After cap, filter should have exited the block
+ assert not f._in_block
+ # New text should pass through
+ assert f.process("after") == "after"
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/vllm_mlx/api/__init__.py b/vllm_mlx/api/__init__.py
index cfb62f452..62dcb6919 100644
--- a/vllm_mlx/api/__init__.py
+++ b/vllm_mlx/api/__init__.py
@@ -61,6 +61,8 @@
extract_multimodal_content,
MLLM_PATTERNS,
SPECIAL_TOKENS_PATTERN,
+ StreamingToolCallFilter,
+ StreamingThinkRouter,
)
from .tool_calling import (
@@ -118,6 +120,8 @@
"extract_multimodal_content",
"MLLM_PATTERNS",
"SPECIAL_TOKENS_PATTERN",
+ "StreamingToolCallFilter",
+ "StreamingThinkRouter",
# Tool calling
"parse_tool_calls",
"convert_tools_for_template",
diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py
index 39cf58391..9fdbfef13 100644
--- a/vllm_mlx/api/utils.py
+++ b/vllm_mlx/api/utils.py
@@ -3,10 +3,13 @@
Utility functions for text processing and model detection.
"""
+import logging
import re
from .models import Message
+logger = logging.getLogger(__name__)
+
# =============================================================================
# Special Token Patterns
# =============================================================================
@@ -101,6 +104,224 @@ def clean_output_text(text: str) -> str:
return text
+# =============================================================================
+# Streaming Tool Call Filter
+# =============================================================================
+
+# Safety cap for tool call buffer (bytes). If a tool call block never closes,
+# the buffer is capped to prevent unbounded memory growth. In practice, the
+# buffer is bounded by max_tokens (~100KB at 32768 tokens), but this cap
+# protects against pathological cases.
+_MAX_TOOL_BUFFER_BYTES = 1_048_576 # 1 MB
+
+# Tags that delimit tool call blocks in streaming output.
+# Content inside these tags should be suppressed during streaming because
+# it will be re-emitted as structured tool_use blocks after parsing.
+_TOOL_CALL_TAGS = [
+ ("", ""),
+ ("", ""),
+ (""),
+ ("[TOOL_CALL]", "[/TOOL_CALL]"),
+ ("[Calling tool", "]\n"), # Qwen3 bracket-style: [Calling tool: func({...})]\n
+]
+
+
+class StreamingToolCallFilter:
+ """Buffer streaming text to suppress tool call markup.
+
+ Tool call XML (e.g. ...) arrives
+ split across multiple streaming deltas. This filter detects entry into a
+ tool call block, suppresses all output until the block closes, and emits
+ only non-tool-call text.
+
+ The full unfiltered text is still accumulated separately for tool call
+ parsing at stream end.
+ """
+
+ def __init__(self):
+ self._buffer = ""
+ self._in_block = False
+ self._close_tag = ""
+ # Longest open tag - used to determine how much buffer to hold back
+ self._max_open_len = max(len(t[0]) for t in _TOOL_CALL_TAGS)
+
+ def process(self, delta: str) -> str:
+ """Process a streaming delta. Returns text to emit (may be empty)."""
+ self._buffer += delta
+
+ if self._in_block:
+ return self._consume_block()
+ else:
+ return self._scan_for_open()
+
+ def _scan_for_open(self) -> str:
+ """Scan buffer for tool call open tags. Emit safe text."""
+ # Check for complete open tags
+ for open_tag, close_tag in _TOOL_CALL_TAGS:
+ idx = self._buffer.find(open_tag)
+ if idx >= 0:
+ # Found an open tag - emit text before it, enter block mode
+ emit = self._buffer[:idx]
+ self._buffer = self._buffer[idx + len(open_tag) :]
+ self._in_block = True
+ self._close_tag = close_tag
+ # Process remainder in case close tag is already in buffer
+ after = self._consume_block()
+ return emit + after
+
+ # No complete open tag found. Check if buffer ends with a partial
+ # match of any open tag - hold that back to avoid emitting a fragment.
+ hold_back = 0
+ for open_tag, _ in _TOOL_CALL_TAGS:
+ for prefix_len in range(min(len(open_tag), len(self._buffer)), 0, -1):
+ if self._buffer.endswith(open_tag[:prefix_len]):
+ hold_back = max(hold_back, prefix_len)
+ break
+
+ if hold_back > 0:
+ emit = self._buffer[:-hold_back]
+ self._buffer = self._buffer[-hold_back:]
+ return emit
+
+ # No partial match - safe to emit everything
+ emit = self._buffer
+ self._buffer = ""
+ return emit
+
+ def _consume_block(self) -> str:
+ """Consume content inside a tool call block. Returns empty string
+ unless the block closes and there's text after it."""
+ idx = self._buffer.find(self._close_tag)
+ if idx >= 0:
+ # Block closed - discard content up to and including close tag
+ self._buffer = self._buffer[idx + len(self._close_tag) :]
+ self._in_block = False
+ self._close_tag = ""
+ # Process remainder - might have more text or another tool call
+ if self._buffer:
+ return self._scan_for_open()
+ return ""
+ # Still inside block - suppress everything but cap buffer size
+ if len(self._buffer) > _MAX_TOOL_BUFFER_BYTES:
+ logger.warning(
+ f"Tool call buffer exceeded {_MAX_TOOL_BUFFER_BYTES} bytes, "
+ f"discarding and exiting block"
+ )
+ self._buffer = ""
+ self._in_block = False
+ self._close_tag = ""
+ return ""
+
+ def flush(self) -> str:
+ """Flush remaining buffer at end of stream."""
+ if self._in_block:
+ # Unterminated tool call block - discard
+ self._buffer = ""
+ self._in_block = False
+ return ""
+ emit = self._buffer
+ self._buffer = ""
+ return emit
+
+
+# =============================================================================
+# Streaming Think Block Router
+# =============================================================================
+
+
+class StreamingThinkRouter:
+ """Route ... content to separate Anthropic thinking blocks.
+
+ Instead of emitting thinking content as plain text (where it's
+ indistinguishable from the response), this router yields tagged
+ pieces that the streaming handler can emit as proper Anthropic
+ content block types.
+
+ Each call to process() returns a list of (block_type, text) tuples:
+ - ("thinking", text) for content inside ...
+ - ("text", text) for content outside think blocks
+
+ Args:
+ start_in_thinking: If True, assume the model starts in thinking
+ mode (e.g. MiniMax adds to the generation prompt,
+ so the tag never appears in the output stream).
+ """
+
+ def __init__(self, start_in_thinking: bool = False):
+ self._buffer = ""
+ self._in_think = start_in_thinking
+
+ def process(self, delta: str) -> list[tuple[str, str]]:
+ """Process a delta. Returns list of (block_type, text) pieces."""
+ self._buffer += delta
+ pieces = []
+ self._extract_pieces(pieces)
+ return pieces
+
+ def _extract_pieces(self, pieces: list[tuple[str, str]]) -> None:
+ """Extract all complete pieces from the buffer."""
+ while True:
+ if self._in_think:
+ idx = self._buffer.find("")
+ if idx >= 0:
+ # Emit thinking content, exit think mode
+ thinking = self._buffer[:idx]
+ self._buffer = self._buffer[idx + len("") :]
+ self._in_think = False
+ if thinking:
+ pieces.append(("thinking", thinking))
+ continue # Process remainder
+ else:
+ # Check for partial close tag at end
+ for plen in range(min(len(""), len(self._buffer)), 0, -1):
+ if self._buffer.endswith(""[:plen]):
+ # Hold back partial match
+ emit = self._buffer[:-plen]
+ self._buffer = self._buffer[-plen:]
+ if emit:
+ pieces.append(("thinking", emit))
+ return
+ # No partial match - emit all as thinking
+ if self._buffer:
+ pieces.append(("thinking", self._buffer))
+ self._buffer = ""
+ return
+ else:
+ idx = self._buffer.find("")
+ if idx >= 0:
+ # Emit text before tag, enter think mode
+ before = self._buffer[:idx]
+ self._buffer = self._buffer[idx + len("") :]
+ self._in_think = True
+ if before:
+ pieces.append(("text", before))
+ continue # Process remainder
+ else:
+ # Check for partial open tag at end
+ for plen in range(min(len(""), len(self._buffer)), 0, -1):
+ if self._buffer.endswith(""[:plen]):
+ emit = self._buffer[:-plen]
+ self._buffer = self._buffer[-plen:]
+ if emit:
+ pieces.append(("text", emit))
+ return
+ # No partial match - emit all as text
+ if self._buffer:
+ pieces.append(("text", self._buffer))
+ self._buffer = ""
+ return
+
+ def flush(self) -> list[tuple[str, str]]:
+ """Flush remaining buffer at end of stream."""
+ pieces = []
+ if self._buffer:
+ block_type = "thinking" if self._in_think else "text"
+ pieces.append((block_type, self._buffer))
+ self._buffer = ""
+ self._in_think = False
+ return pieces
+
+
# =============================================================================
# Model Detection
# =============================================================================
diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py
index a0038d5f8..af10e7341 100644
--- a/vllm_mlx/server.py
+++ b/vllm_mlx/server.py
@@ -98,6 +98,8 @@
)
from .api.utils import (
SPECIAL_TOKENS_PATTERN,
+ StreamingThinkRouter,
+ StreamingToolCallFilter,
clean_output_text,
extract_multimodal_content,
is_mllm_model, # noqa: F401
@@ -1730,6 +1732,59 @@ async def count_anthropic_tokens(request: Request):
return {"input_tokens": total_tokens}
+def _emit_content_pieces(
+ pieces: list[tuple[str, str]],
+ current_block_type: str | None,
+ block_index: int,
+) -> tuple[list[str], str | None, int]:
+ """Emit Anthropic SSE events for content pieces from the think router.
+
+ Handles block type transitions (thinking <-> text), emitting
+ content_block_start/stop/delta events as needed.
+
+ Args:
+ pieces: List of (block_type, text) from StreamingThinkRouter
+ current_block_type: Current open block type, or None
+ block_index: Current block index
+
+ Returns:
+ Tuple of (events, updated_block_type, updated_block_index)
+ """
+ events = []
+ for block_type, text in pieces:
+ if block_type != current_block_type:
+ # Close previous block if open
+ if current_block_type is not None:
+ events.append(
+ f"event: content_block_stop\ndata: "
+ f"{json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n"
+ )
+ block_index += 1
+ # Start new block
+ current_block_type = block_type
+ content_block = (
+ {"type": block_type, "text": ""}
+ if block_type == "text"
+ else {"type": block_type, "thinking": ""}
+ )
+ events.append(
+ f"event: content_block_start\ndata: "
+ f"{json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': content_block})}\n\n"
+ )
+ # Emit delta
+ delta_key = "thinking" if block_type == "thinking" else "text"
+ delta_type = "thinking_delta" if block_type == "thinking" else "text_delta"
+ delta_event = {
+ "type": "content_block_delta",
+ "index": block_index,
+ "delta": {"type": delta_type, delta_key: text},
+ }
+ events.append(
+ f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n"
+ )
+ return events, current_block_type, block_index
+
+
async def _stream_anthropic_messages(
engine: BaseEngine,
openai_request: ChatCompletionRequest,
@@ -1779,48 +1834,87 @@ async def _stream_anthropic_messages(
}
yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n"
- # Emit content_block_start for text
- content_block_start = {
- "type": "content_block_start",
- "index": 0,
- "content_block": {"type": "text", "text": ""},
- }
- yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n"
-
- # Stream content deltas
+ # Stream pipeline: raw text → tool call filter → think router → emit
+ # - Tool call filter strips tool call markup (emitted as structured blocks later)
+ # - Think router separates content into Anthropic thinking blocks
accumulated_text = ""
+ tool_filter = StreamingToolCallFilter()
+ # Detect if the model's chat template injects into the
+ # generation prompt. If so, the model starts in thinking mode and
+ # the opening tag never appears in the output stream.
+ _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None
+ _chat_template = ""
+ if _tokenizer and hasattr(_tokenizer, "chat_template"):
+ _chat_template = _tokenizer.chat_template or ""
+ _starts_thinking = (
+ "" in _chat_template and "add_generation_prompt" in _chat_template
+ )
+ think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking)
+ prompt_tokens = 0
completion_tokens = 0
+ # Track which content blocks we've started
+ current_block_type = None # "thinking" or "text"
+ block_index = 0
+
async for output in engine.stream_chat(messages=messages, **chat_kwargs):
delta_text = output.new_text
# Track token counts
+ if hasattr(output, "prompt_tokens") and output.prompt_tokens:
+ prompt_tokens = output.prompt_tokens
if hasattr(output, "completion_tokens") and output.completion_tokens:
completion_tokens = output.completion_tokens
if delta_text:
- # Filter special tokens
+ # Accumulate raw text BEFORE special token cleaning for tool parsing
+ accumulated_text += delta_text
+
+ # Filter special tokens for display
content = SPECIAL_TOKENS_PATTERN.sub("", delta_text)
if content:
- accumulated_text += content
- delta_event = {
- "type": "content_block_delta",
- "index": 0,
- "delta": {"type": "text_delta", "text": content},
- }
- yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n"
+ # Stage 1: strip tool call markup
+ filtered = tool_filter.process(content)
+ if not filtered:
+ continue
+ # Stage 2: route thinking vs text
+ pieces = think_router.process(filtered)
+ events, current_block_type, block_index = _emit_content_pieces(
+ pieces, current_block_type, block_index
+ )
+ for event in events:
+ yield event
+
+ # Flush remaining from both filters
+ remaining = tool_filter.flush()
+ if remaining:
+ events, current_block_type, block_index = _emit_content_pieces(
+ think_router.process(remaining), current_block_type, block_index
+ )
+ for event in events:
+ yield event
+
+ flush_pieces = think_router.flush()
+ if flush_pieces:
+ events, current_block_type, block_index = _emit_content_pieces(
+ flush_pieces, current_block_type, block_index
+ )
+ for event in events:
+ yield event
+
+ # Close final content block
+ if current_block_type is not None:
+ yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n"
+ block_index += 1
# Check for tool calls in accumulated text
_, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request)
- # Emit content_block_stop for text block
- yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n"
-
# If there are tool calls, emit tool_use blocks
if tool_calls:
for i, tc in enumerate(tool_calls):
- tool_index = i + 1
+ tool_index = block_index + i
try:
tool_input = json.loads(tc.function.arguments)
except (json.JSONDecodeError, AttributeError):
@@ -1858,7 +1952,7 @@ async def _stream_anthropic_messages(
message_delta = {
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
- "usage": {"output_tokens": completion_tokens},
+ "usage": {"input_tokens": prompt_tokens, "output_tokens": completion_tokens},
}
yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n"
@@ -1866,7 +1960,7 @@ async def _stream_anthropic_messages(
elapsed = time.perf_counter() - start_time
tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0
logger.info(
- f"Anthropic messages (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
+ f"Anthropic messages (stream): prompt={prompt_tokens} + completion={completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)"
)
# Emit message_stop