diff --git a/tests/test_streaming_pipeline_integration.py b/tests/test_streaming_pipeline_integration.py new file mode 100644 index 000000000..7ddcf39f8 --- /dev/null +++ b/tests/test_streaming_pipeline_integration.py @@ -0,0 +1,240 @@ +"""Integration test for the Anthropic streaming pipeline. + +Tests the full flow: raw model output → StreamingToolCallFilter → StreamingThinkRouter +→ Anthropic SSE events, verifying block transitions, tool call extraction, and +prompt_tokens tracking work together correctly. +""" + +import json +import unittest + +from vllm_mlx.api.utils import StreamingToolCallFilter, StreamingThinkRouter +from vllm_mlx.server import _emit_content_pieces + + +class TestEmitContentPieces(unittest.TestCase): + """Test the refactored _emit_content_pieces helper.""" + + def test_single_text_block(self): + events, block_type, index = _emit_content_pieces([("text", "hello")], None, 0) + assert len(events) == 2 # block_start + delta + assert block_type == "text" + assert index == 0 + # Verify block_start + start_data = json.loads(events[0].split("data: ")[1]) + assert start_data["type"] == "content_block_start" + assert start_data["content_block"]["type"] == "text" + # Verify delta + delta_data = json.loads(events[1].split("data: ")[1]) + assert delta_data["delta"]["text"] == "hello" + + def test_single_thinking_block(self): + events, block_type, index = _emit_content_pieces( + [("thinking", "reasoning")], None, 0 + ) + assert block_type == "thinking" + delta_data = json.loads(events[1].split("data: ")[1]) + assert delta_data["delta"]["thinking"] == "reasoning" + + def test_transition_thinking_to_text(self): + events, block_type, index = _emit_content_pieces( + [("thinking", "reason"), ("text", "answer")], None, 0 + ) + assert block_type == "text" + assert index == 1 # incremented on block transition + # Should have: start_thinking, delta_thinking, stop_thinking, start_text, delta_text + assert len(events) == 5 + stop_data = json.loads(events[2].split("data: ")[1]) + assert stop_data["type"] == "content_block_stop" + + def test_continues_existing_block(self): + """If current_block_type matches, no start/stop emitted.""" + events, block_type, index = _emit_content_pieces([("text", "more")], "text", 0) + assert len(events) == 1 # just delta, no start + assert block_type == "text" + + def test_empty_pieces(self): + events, block_type, index = _emit_content_pieces([], None, 0) + assert events == [] + assert block_type is None + assert index == 0 + + +class TestStreamingPipelineIntegration(unittest.TestCase): + """Integration test for the full streaming pipeline.""" + + def _run_pipeline(self, deltas, start_in_thinking=False): + """Run deltas through tool_filter → think_router → emit, return events.""" + tool_filter = StreamingToolCallFilter() + think_router = StreamingThinkRouter(start_in_thinking=start_in_thinking) + current_block_type = None + block_index = 0 + all_events = [] + accumulated_text = "" + + for delta in deltas: + accumulated_text += delta + filtered = tool_filter.process(delta) + if not filtered: + continue + pieces = think_router.process(filtered) + events, current_block_type, block_index = _emit_content_pieces( + pieces, current_block_type, block_index + ) + all_events.extend(events) + + # Flush + remaining = tool_filter.flush() + if remaining: + pieces = think_router.process(remaining) + events, current_block_type, block_index = _emit_content_pieces( + pieces, current_block_type, block_index + ) + all_events.extend(events) + + flush_pieces = think_router.flush() + if flush_pieces: + events, current_block_type, block_index = _emit_content_pieces( + flush_pieces, current_block_type, block_index + ) + all_events.extend(events) + + # Close final block + if current_block_type is not None: + all_events.append( + f"event: content_block_stop\ndata: " + f"{json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" + ) + block_index += 1 + + return all_events, accumulated_text, block_index + + def _parse_events(self, events): + """Parse SSE events into structured data.""" + parsed = [] + for event in events: + data_line = event.split("data: ", 1)[1].split("\n")[0] + parsed.append(json.loads(data_line)) + return parsed + + def test_pure_text_response(self): + """Simple text response - one text block.""" + events, _, block_index = self._run_pipeline(["Hello ", "world!"]) + parsed = self._parse_events(events) + + # block_start, 2 deltas, block_stop + types = [p["type"] for p in parsed] + assert types[0] == "content_block_start" + assert parsed[0]["content_block"]["type"] == "text" + assert types[-1] == "content_block_stop" + assert block_index == 1 + + def test_thinking_then_text(self): + """Model thinks then responds.""" + events, _, block_index = self._run_pipeline( + ["Let me think", " about this", "The answer is 42"] + ) + parsed = self._parse_events(events) + + block_starts = [p for p in parsed if p["type"] == "content_block_start"] + assert len(block_starts) == 2 + assert block_starts[0]["content_block"]["type"] == "thinking" + assert block_starts[1]["content_block"]["type"] == "text" + assert block_index == 2 + + def test_start_in_thinking_then_text(self): + """Model starts in thinking mode (template injects ).""" + events, _, _ = self._run_pipeline( + ["reasoning here", "", "The answer"], + start_in_thinking=True, + ) + parsed = self._parse_events(events) + + block_starts = [p for p in parsed if p["type"] == "content_block_start"] + assert len(block_starts) == 2 + assert block_starts[0]["content_block"]["type"] == "thinking" + assert block_starts[1]["content_block"]["type"] == "text" + + def test_text_then_tool_call(self): + """Text followed by tool call - tool markup suppressed from text.""" + events, accumulated, _ = self._run_pipeline( + [ + "I'll search for that. ", + "", + '', + 'ls /tmp', + "", + "", + ] + ) + parsed = self._parse_events(events) + + # Only text block should appear (tool call is suppressed from streaming) + text_deltas = [ + p + for p in parsed + if p["type"] == "content_block_delta" + and p["delta"].get("type") == "text_delta" + ] + text_content = "".join(d["delta"]["text"] for d in text_deltas) + assert "I'll search for that." in text_content + assert "" not in text_content + + # But accumulated text has the full tool call for parsing + assert "" in accumulated + + def test_thinking_then_tool_call(self): + """Thinking followed by tool call - both properly routed.""" + events, accumulated, _ = self._run_pipeline( + [ + "I need to search", + "", + '', + 'test', + "", + "", + ] + ) + parsed = self._parse_events(events) + + block_starts = [p for p in parsed if p["type"] == "content_block_start"] + # Only thinking block (tool call is suppressed) + assert len(block_starts) == 1 + assert block_starts[0]["content_block"]["type"] == "thinking" + + def test_mixed_thinking_text_and_tool_call(self): + """Full scenario: thinking → text → tool call.""" + events, accumulated, block_index = self._run_pipeline( + [ + "analyzing request", + "Let me help. ", + "", + 'echo hi', + "", + ] + ) + parsed = self._parse_events(events) + + block_starts = [p for p in parsed if p["type"] == "content_block_start"] + # thinking block + text block (tool call suppressed) + assert len(block_starts) == 2 + assert block_starts[0]["content_block"]["type"] == "thinking" + assert block_starts[1]["content_block"]["type"] == "text" + + # Accumulated has everything for post-stream tool parsing + assert "" in accumulated + + def test_block_index_increments_correctly(self): + """Block indices should increment on each transition.""" + events, _, final_index = self._run_pipeline( + ["t1textt2end"] + ) + parsed = self._parse_events(events) + + starts = [p for p in parsed if p["type"] == "content_block_start"] + assert [s["index"] for s in starts] == [0, 1, 2, 3] + assert final_index == 4 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_streaming_think_router.py b/tests/test_streaming_think_router.py new file mode 100644 index 000000000..d5272cdca --- /dev/null +++ b/tests/test_streaming_think_router.py @@ -0,0 +1,168 @@ +"""Tests for StreamingThinkRouter - routes blocks to Anthropic thinking content blocks.""" + +import unittest + +from vllm_mlx.api.utils import StreamingThinkRouter + + +class TestStreamingThinkRouter(unittest.TestCase): + """Unit tests for StreamingThinkRouter.""" + + # --- Basic routing --- + + def test_plain_text_routes_as_text(self): + r = StreamingThinkRouter() + assert r.process("Hello world") == [("text", "Hello world")] + + def test_think_block_routes_as_thinking(self): + r = StreamingThinkRouter() + assert r.process("reasoning") == [("thinking", "reasoning")] + + def test_text_then_think_then_text(self): + r = StreamingThinkRouter() + result = r.process("beforemiddleafter") + assert result == [("text", "before"), ("thinking", "middle"), ("text", "after")] + + # --- start_in_thinking mode --- + + def test_start_in_thinking_mode(self): + """When model injects into prompt, output starts in thinking mode.""" + r = StreamingThinkRouter(start_in_thinking=True) + result = r.process("reasoning here") + assert result == [("thinking", "reasoning here")] + + def test_start_in_thinking_then_close(self): + """Thinking closes with , then text follows.""" + r = StreamingThinkRouter(start_in_thinking=True) + result = r.process("reasoninganswer") + assert result == [("thinking", "reasoning"), ("text", "answer")] + + def test_start_in_thinking_close_across_deltas(self): + """ split across multiple deltas.""" + r = StreamingThinkRouter(start_in_thinking=True) + p1 = r.process("thinking stuffnow text") + # First delta should hold back partial + assert ("thinking", "thinking stuff") in p1 + # Second delta should transition + all_pieces = p1 + p2 + types = [t for t, _ in all_pieces] + assert "text" in types + + # --- Partial tag handling --- + + def test_partial_open_tag_held_back(self): + """Partial reasoning") + # p1 should emit "Hello " but hold back "answer") + # p1 should emit thinking content but hold back partial + assert ("thinking", "deep thought") in p1 + # p2 should transition to text + assert ("text", "answer") in p2 + + def test_partial_tag_false_alarm(self): + """Partial match that turns out not to be a tag.""" + r = StreamingThinkRouter() + p1 = r.process("Hello ") + # After p2, the held-back "" should emit as text + all_text = "".join(t for bt, t in p1 + p2 if bt == "text") + assert "Hello " == all_text + + # --- Multiple think blocks --- + + def test_multiple_think_blocks(self): + r = StreamingThinkRouter() + result = r.process("firstmiddlesecondend") + assert result == [ + ("thinking", "first"), + ("text", "middle"), + ("thinking", "second"), + ("text", "end"), + ] + + # --- Streaming across deltas --- + + def test_streaming_token_by_token(self): + """Simulate character-by-character streaming.""" + r = StreamingThinkRouter() + text = "abcxyz" + all_pieces = [] + for ch in text: + all_pieces.extend(r.process(ch)) + all_pieces.extend(r.flush()) + thinking = "".join(t for bt, t in all_pieces if bt == "thinking") + text_out = "".join(t for bt, t in all_pieces if bt == "text") + assert thinking == "abc" + assert text_out == "xyz" + + def test_streaming_with_start_in_thinking(self): + """Token-by-token with start_in_thinking.""" + r = StreamingThinkRouter(start_in_thinking=True) + text = "reasoningthe answer" + all_pieces = [] + for ch in text: + all_pieces.extend(r.process(ch)) + all_pieces.extend(r.flush()) + thinking = "".join(t for bt, t in all_pieces if bt == "thinking") + text_out = "".join(t for bt, t in all_pieces if bt == "text") + assert thinking == "reasoning" + assert text_out == "the answer" + + # --- Flush behavior --- + + def test_flush_emits_remaining_text(self): + """Text without partial tags is emitted by process(), flush() is empty.""" + r = StreamingThinkRouter() + pieces = r.process("partial text") + assert pieces == [("text", "partial text")] + assert r.flush() == [] + + def test_flush_emits_remaining_thinking(self): + """Thinking without partial tags is emitted by process(), flush() is empty.""" + r = StreamingThinkRouter(start_in_thinking=True) + pieces = r.process("unfinished thought") + assert pieces == [("thinking", "unfinished thought")] + assert r.flush() == [] + + def test_flush_with_held_back_partial(self): + """Flush should emit held-back partial tag as content.""" + r = StreamingThinkRouter() + r.process("text") + f.process('') + f.process('/tmp/test.txt') + f.process("") + result = f.process("") + assert result == "" + + def test_text_after_tool_call_emits(self): + f = StreamingToolCallFilter() + f.process("content") + assert f.process("After") == "After" + + def test_text_before_and_after_same_delta(self): + f = StreamingToolCallFilter() + result = f.process("Before insideAfter") + assert result == "Before After" + + def test_split_across_deltas(self): + f = StreamingToolCallFilter() + r1 = f.process("Before insideAfter") + assert r1 + r2 == "Before After" + + def test_qwen_format_suppressed(self): + f = StreamingToolCallFilter() + result = f.process('Text {"name":"fn"} more') + assert result == "Text more" + + def test_multiple_tool_calls(self): + f = StreamingToolCallFilter() + result = f.process( + "A x" + " B y C" + ) + assert result == "A B C" + + def test_flush_partial_tag_emits(self): + f = StreamingToolCallFilter() + r = f.process("text partial content") + assert f.flush() == "" + + def test_large_tool_call_content(self): + """Simulates a Read tool returning a large file.""" + f = StreamingToolCallFilter() + big = "x" * 10000 + result = f.process(f"Before {big}After") + assert result == "Before After" + + def test_think_tags_not_filtered(self): + f = StreamingToolCallFilter() + result = f.process("reasoning hereanswer") + assert "" in result + assert "reasoning here" in result + + def test_mixed_think_and_tool_call(self): + f = StreamingToolCallFilter() + result = f.process( + "thinking" + "tool stuff" + "final answer" + ) + assert "thinking" in result + assert "tool stuff" not in result + assert "final answer" in result + + def test_gradual_token_by_token(self): + """Simulate token-by-token streaming.""" + f = StreamingToolCallFilter() + parts = [ + "Hello ", + "<", + "mini", + "max:", + "tool_call", + ">", + '', + "", + "", + " world", + ] + result = "" + for part in parts: + result += f.process(part) + result += f.flush() + assert result == "Hello world", f"Got: {result!r}" + + def test_empty_deltas(self): + f = StreamingToolCallFilter() + assert f.process("") == "" + assert f.process("text") == "text" + assert f.process("") == "" + + def test_calling_tool_bracket_suppressed(self): + """Qwen3 bracket-style: [Calling tool: func({...})]\n""" + f = StreamingToolCallFilter() + result = f.process('[Calling tool: search({"q": "test"})]\n') + assert result == "" + + def test_calling_tool_multiline_json(self): + """Multi-line JSON args in bracket-style tool call.""" + f = StreamingToolCallFilter() + r1 = f.process('[Calling tool: search({"q": "test",') + r2 = f.process(' "limit": 5})]\n') + r3 = f.process("After") + assert r1 + r2 + r3 == "After" + + def test_buffer_cap_on_unclosed_block(self): + """Buffer should be capped if tool call block never closes.""" + from vllm_mlx.api.utils import _MAX_TOOL_BUFFER_BYTES + + f = StreamingToolCallFilter() + f.process("") + # Feed data exceeding the cap + chunk = "x" * 10000 + for _ in range(_MAX_TOOL_BUFFER_BYTES // 10000 + 2): + f.process(chunk) + # After cap, filter should have exited the block + assert not f._in_block + # New text should pass through + assert f.process("after") == "after" + + +if __name__ == "__main__": + unittest.main() diff --git a/vllm_mlx/api/__init__.py b/vllm_mlx/api/__init__.py index cfb62f452..62dcb6919 100644 --- a/vllm_mlx/api/__init__.py +++ b/vllm_mlx/api/__init__.py @@ -61,6 +61,8 @@ extract_multimodal_content, MLLM_PATTERNS, SPECIAL_TOKENS_PATTERN, + StreamingToolCallFilter, + StreamingThinkRouter, ) from .tool_calling import ( @@ -118,6 +120,8 @@ "extract_multimodal_content", "MLLM_PATTERNS", "SPECIAL_TOKENS_PATTERN", + "StreamingToolCallFilter", + "StreamingThinkRouter", # Tool calling "parse_tool_calls", "convert_tools_for_template", diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 39cf58391..9fdbfef13 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -3,10 +3,13 @@ Utility functions for text processing and model detection. """ +import logging import re from .models import Message +logger = logging.getLogger(__name__) + # ============================================================================= # Special Token Patterns # ============================================================================= @@ -101,6 +104,224 @@ def clean_output_text(text: str) -> str: return text +# ============================================================================= +# Streaming Tool Call Filter +# ============================================================================= + +# Safety cap for tool call buffer (bytes). If a tool call block never closes, +# the buffer is capped to prevent unbounded memory growth. In practice, the +# buffer is bounded by max_tokens (~100KB at 32768 tokens), but this cap +# protects against pathological cases. +_MAX_TOOL_BUFFER_BYTES = 1_048_576 # 1 MB + +# Tags that delimit tool call blocks in streaming output. +# Content inside these tags should be suppressed during streaming because +# it will be re-emitted as structured tool_use blocks after parsing. +_TOOL_CALL_TAGS = [ + ("", ""), + ("", ""), + (""), + ("[TOOL_CALL]", "[/TOOL_CALL]"), + ("[Calling tool", "]\n"), # Qwen3 bracket-style: [Calling tool: func({...})]\n +] + + +class StreamingToolCallFilter: + """Buffer streaming text to suppress tool call markup. + + Tool call XML (e.g. ...) arrives + split across multiple streaming deltas. This filter detects entry into a + tool call block, suppresses all output until the block closes, and emits + only non-tool-call text. + + The full unfiltered text is still accumulated separately for tool call + parsing at stream end. + """ + + def __init__(self): + self._buffer = "" + self._in_block = False + self._close_tag = "" + # Longest open tag - used to determine how much buffer to hold back + self._max_open_len = max(len(t[0]) for t in _TOOL_CALL_TAGS) + + def process(self, delta: str) -> str: + """Process a streaming delta. Returns text to emit (may be empty).""" + self._buffer += delta + + if self._in_block: + return self._consume_block() + else: + return self._scan_for_open() + + def _scan_for_open(self) -> str: + """Scan buffer for tool call open tags. Emit safe text.""" + # Check for complete open tags + for open_tag, close_tag in _TOOL_CALL_TAGS: + idx = self._buffer.find(open_tag) + if idx >= 0: + # Found an open tag - emit text before it, enter block mode + emit = self._buffer[:idx] + self._buffer = self._buffer[idx + len(open_tag) :] + self._in_block = True + self._close_tag = close_tag + # Process remainder in case close tag is already in buffer + after = self._consume_block() + return emit + after + + # No complete open tag found. Check if buffer ends with a partial + # match of any open tag - hold that back to avoid emitting a fragment. + hold_back = 0 + for open_tag, _ in _TOOL_CALL_TAGS: + for prefix_len in range(min(len(open_tag), len(self._buffer)), 0, -1): + if self._buffer.endswith(open_tag[:prefix_len]): + hold_back = max(hold_back, prefix_len) + break + + if hold_back > 0: + emit = self._buffer[:-hold_back] + self._buffer = self._buffer[-hold_back:] + return emit + + # No partial match - safe to emit everything + emit = self._buffer + self._buffer = "" + return emit + + def _consume_block(self) -> str: + """Consume content inside a tool call block. Returns empty string + unless the block closes and there's text after it.""" + idx = self._buffer.find(self._close_tag) + if idx >= 0: + # Block closed - discard content up to and including close tag + self._buffer = self._buffer[idx + len(self._close_tag) :] + self._in_block = False + self._close_tag = "" + # Process remainder - might have more text or another tool call + if self._buffer: + return self._scan_for_open() + return "" + # Still inside block - suppress everything but cap buffer size + if len(self._buffer) > _MAX_TOOL_BUFFER_BYTES: + logger.warning( + f"Tool call buffer exceeded {_MAX_TOOL_BUFFER_BYTES} bytes, " + f"discarding and exiting block" + ) + self._buffer = "" + self._in_block = False + self._close_tag = "" + return "" + + def flush(self) -> str: + """Flush remaining buffer at end of stream.""" + if self._in_block: + # Unterminated tool call block - discard + self._buffer = "" + self._in_block = False + return "" + emit = self._buffer + self._buffer = "" + return emit + + +# ============================================================================= +# Streaming Think Block Router +# ============================================================================= + + +class StreamingThinkRouter: + """Route ... content to separate Anthropic thinking blocks. + + Instead of emitting thinking content as plain text (where it's + indistinguishable from the response), this router yields tagged + pieces that the streaming handler can emit as proper Anthropic + content block types. + + Each call to process() returns a list of (block_type, text) tuples: + - ("thinking", text) for content inside ... + - ("text", text) for content outside think blocks + + Args: + start_in_thinking: If True, assume the model starts in thinking + mode (e.g. MiniMax adds to the generation prompt, + so the tag never appears in the output stream). + """ + + def __init__(self, start_in_thinking: bool = False): + self._buffer = "" + self._in_think = start_in_thinking + + def process(self, delta: str) -> list[tuple[str, str]]: + """Process a delta. Returns list of (block_type, text) pieces.""" + self._buffer += delta + pieces = [] + self._extract_pieces(pieces) + return pieces + + def _extract_pieces(self, pieces: list[tuple[str, str]]) -> None: + """Extract all complete pieces from the buffer.""" + while True: + if self._in_think: + idx = self._buffer.find("") + if idx >= 0: + # Emit thinking content, exit think mode + thinking = self._buffer[:idx] + self._buffer = self._buffer[idx + len("") :] + self._in_think = False + if thinking: + pieces.append(("thinking", thinking)) + continue # Process remainder + else: + # Check for partial close tag at end + for plen in range(min(len(""), len(self._buffer)), 0, -1): + if self._buffer.endswith(""[:plen]): + # Hold back partial match + emit = self._buffer[:-plen] + self._buffer = self._buffer[-plen:] + if emit: + pieces.append(("thinking", emit)) + return + # No partial match - emit all as thinking + if self._buffer: + pieces.append(("thinking", self._buffer)) + self._buffer = "" + return + else: + idx = self._buffer.find("") + if idx >= 0: + # Emit text before tag, enter think mode + before = self._buffer[:idx] + self._buffer = self._buffer[idx + len("") :] + self._in_think = True + if before: + pieces.append(("text", before)) + continue # Process remainder + else: + # Check for partial open tag at end + for plen in range(min(len(""), len(self._buffer)), 0, -1): + if self._buffer.endswith(""[:plen]): + emit = self._buffer[:-plen] + self._buffer = self._buffer[-plen:] + if emit: + pieces.append(("text", emit)) + return + # No partial match - emit all as text + if self._buffer: + pieces.append(("text", self._buffer)) + self._buffer = "" + return + + def flush(self) -> list[tuple[str, str]]: + """Flush remaining buffer at end of stream.""" + pieces = [] + if self._buffer: + block_type = "thinking" if self._in_think else "text" + pieces.append((block_type, self._buffer)) + self._buffer = "" + self._in_think = False + return pieces + + # ============================================================================= # Model Detection # ============================================================================= diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index a0038d5f8..af10e7341 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -98,6 +98,8 @@ ) from .api.utils import ( SPECIAL_TOKENS_PATTERN, + StreamingThinkRouter, + StreamingToolCallFilter, clean_output_text, extract_multimodal_content, is_mllm_model, # noqa: F401 @@ -1730,6 +1732,59 @@ async def count_anthropic_tokens(request: Request): return {"input_tokens": total_tokens} +def _emit_content_pieces( + pieces: list[tuple[str, str]], + current_block_type: str | None, + block_index: int, +) -> tuple[list[str], str | None, int]: + """Emit Anthropic SSE events for content pieces from the think router. + + Handles block type transitions (thinking <-> text), emitting + content_block_start/stop/delta events as needed. + + Args: + pieces: List of (block_type, text) from StreamingThinkRouter + current_block_type: Current open block type, or None + block_index: Current block index + + Returns: + Tuple of (events, updated_block_type, updated_block_index) + """ + events = [] + for block_type, text in pieces: + if block_type != current_block_type: + # Close previous block if open + if current_block_type is not None: + events.append( + f"event: content_block_stop\ndata: " + f"{json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" + ) + block_index += 1 + # Start new block + current_block_type = block_type + content_block = ( + {"type": block_type, "text": ""} + if block_type == "text" + else {"type": block_type, "thinking": ""} + ) + events.append( + f"event: content_block_start\ndata: " + f"{json.dumps({'type': 'content_block_start', 'index': block_index, 'content_block': content_block})}\n\n" + ) + # Emit delta + delta_key = "thinking" if block_type == "thinking" else "text" + delta_type = "thinking_delta" if block_type == "thinking" else "text_delta" + delta_event = { + "type": "content_block_delta", + "index": block_index, + "delta": {"type": delta_type, delta_key: text}, + } + events.append( + f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" + ) + return events, current_block_type, block_index + + async def _stream_anthropic_messages( engine: BaseEngine, openai_request: ChatCompletionRequest, @@ -1779,48 +1834,87 @@ async def _stream_anthropic_messages( } yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" - # Emit content_block_start for text - content_block_start = { - "type": "content_block_start", - "index": 0, - "content_block": {"type": "text", "text": ""}, - } - yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n" - - # Stream content deltas + # Stream pipeline: raw text → tool call filter → think router → emit + # - Tool call filter strips tool call markup (emitted as structured blocks later) + # - Think router separates content into Anthropic thinking blocks accumulated_text = "" + tool_filter = StreamingToolCallFilter() + # Detect if the model's chat template injects into the + # generation prompt. If so, the model starts in thinking mode and + # the opening tag never appears in the output stream. + _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None + _chat_template = "" + if _tokenizer and hasattr(_tokenizer, "chat_template"): + _chat_template = _tokenizer.chat_template or "" + _starts_thinking = ( + "" in _chat_template and "add_generation_prompt" in _chat_template + ) + think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking) + prompt_tokens = 0 completion_tokens = 0 + # Track which content blocks we've started + current_block_type = None # "thinking" or "text" + block_index = 0 + async for output in engine.stream_chat(messages=messages, **chat_kwargs): delta_text = output.new_text # Track token counts + if hasattr(output, "prompt_tokens") and output.prompt_tokens: + prompt_tokens = output.prompt_tokens if hasattr(output, "completion_tokens") and output.completion_tokens: completion_tokens = output.completion_tokens if delta_text: - # Filter special tokens + # Accumulate raw text BEFORE special token cleaning for tool parsing + accumulated_text += delta_text + + # Filter special tokens for display content = SPECIAL_TOKENS_PATTERN.sub("", delta_text) if content: - accumulated_text += content - delta_event = { - "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": content}, - } - yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" + # Stage 1: strip tool call markup + filtered = tool_filter.process(content) + if not filtered: + continue + # Stage 2: route thinking vs text + pieces = think_router.process(filtered) + events, current_block_type, block_index = _emit_content_pieces( + pieces, current_block_type, block_index + ) + for event in events: + yield event + + # Flush remaining from both filters + remaining = tool_filter.flush() + if remaining: + events, current_block_type, block_index = _emit_content_pieces( + think_router.process(remaining), current_block_type, block_index + ) + for event in events: + yield event + + flush_pieces = think_router.flush() + if flush_pieces: + events, current_block_type, block_index = _emit_content_pieces( + flush_pieces, current_block_type, block_index + ) + for event in events: + yield event + + # Close final content block + if current_block_type is not None: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': block_index})}\n\n" + block_index += 1 # Check for tool calls in accumulated text _, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request) - # Emit content_block_stop for text block - yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" - # If there are tool calls, emit tool_use blocks if tool_calls: for i, tc in enumerate(tool_calls): - tool_index = i + 1 + tool_index = block_index + i try: tool_input = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError): @@ -1858,7 +1952,7 @@ async def _stream_anthropic_messages( message_delta = { "type": "message_delta", "delta": {"stop_reason": stop_reason, "stop_sequence": None}, - "usage": {"output_tokens": completion_tokens}, + "usage": {"input_tokens": prompt_tokens, "output_tokens": completion_tokens}, } yield f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n" @@ -1866,7 +1960,7 @@ async def _stream_anthropic_messages( elapsed = time.perf_counter() - start_time tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 logger.info( - f"Anthropic messages (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + f"Anthropic messages (stream): prompt={prompt_tokens} + completion={completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" ) # Emit message_stop