diff --git a/tests/test_anthropic_stream_scrubber.py b/tests/test_anthropic_stream_scrubber.py new file mode 100644 index 000000000..466673867 --- /dev/null +++ b/tests/test_anthropic_stream_scrubber.py @@ -0,0 +1,1504 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for _AnthropicStreamScrubber – stateful tag stripping for Anthropic streaming. + +Tests the scrubber introduced in commit 6805baf which strips , , +, and markup from streamed text deltas on the +Anthropic /v1/messages endpoint. + +These are pure logic tests with no MLX dependency. +""" + +import pytest + +from vllm_mlx.server import _AnthropicStreamScrubber + +# ============================================================================= +# Basic Construction / Initial State +# ============================================================================= + + +class TestScrubberInitialState: + """Test scrubber creation and initial state.""" + + def test_initial_mode_is_text(self): + scrubber = _AnthropicStreamScrubber() + assert scrubber.mode == "TEXT" + + def test_initial_carry_is_empty(self): + scrubber = _AnthropicStreamScrubber() + assert scrubber.carry == "" + + def test_class_constants_exist(self): + """Verify key class-level constants are defined.""" + assert _AnthropicStreamScrubber.THINK_OPEN == "" + assert _AnthropicStreamScrubber.THINK_CLOSE == "" + assert _AnthropicStreamScrubber.TOOL_OPEN == "" + assert _AnthropicStreamScrubber.TOOL_CLOSE == "" + assert _AnthropicStreamScrubber.FUNC_CLOSE == "" + assert _AnthropicStreamScrubber.PARAM_CLOSE == "" + assert _AnthropicStreamScrubber.FUNC_PREFIX == "= _AnthropicStreamScrubber.MAX_TAG - 1 + + +# ============================================================================= +# Plain Text (no tags) – passthrough +# ============================================================================= + + +class TestScrubberPlainText: + """Scrubber should pass through normal text unchanged.""" + + def test_empty_string(self): + scrubber = _AnthropicStreamScrubber() + assert scrubber.feed("") == "" + + def test_none_delta(self): + """feed(None) should not crash.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed(None) + assert result == "" + + def test_short_text(self): + scrubber = _AnthropicStreamScrubber() + # Short text with no '<' should emit immediately (zero carry) + result = scrubber.feed("Hi") + assert result == "Hi" + assert scrubber.carry == "" + + def test_long_plain_text(self): + """Text longer than CARRY_N should emit most of it immediately.""" + scrubber = _AnthropicStreamScrubber() + text = ( + "Hello, this is a long sentence with no markup at all, just ordinary text." + ) + result = scrubber.feed(text) + flushed = scrubber.flush() + assert result + flushed == text + + def test_multiple_plain_deltas(self): + """Multiple consecutive plain-text deltas should reconstruct fully.""" + scrubber = _AnthropicStreamScrubber() + parts = ["Hello ", "world, ", "how ", "are ", "you?"] + collected = "" + for p in parts: + collected += scrubber.feed(p) + collected += scrubber.flush() + assert collected == "".join(parts) + + def test_flush_in_text_mode_returns_carry(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("abc") + flushed = scrubber.flush() + # After flush, carry should be empty + assert scrubber.carry == "" + + +# ============================================================================= +# ... suppression +# ============================================================================= + + +class TestScrubberThinkTags: + """Test suppression of ... blocks.""" + + def test_think_block_in_single_delta(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("Hello internal reasoning world") + result += scrubber.flush() + assert "" not in result + assert "internal reasoning" not in result + assert "" not in result + assert "Hello " in result + assert " world" in result + + def test_think_block_removes_content(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("some reasoningAfter thought") + result += scrubber.flush() + assert "some reasoning" not in result + assert "After thought" in result + + def test_think_block_split_across_deltas(self): + """Tag split across multiple feed() calls.""" + scrubber = _AnthropicStreamScrubber() + collected = "" + collected += scrubber.feed("Before secret reasoning here") + collected += scrubber.feed(" After") + collected += scrubber.flush() + assert "secret reasoning" not in collected + assert "" not in collected + assert "" not in collected + assert "Before" in collected + assert "After" in collected + + def test_think_block_close_tag_split(self): + """Closing tag split across deltas.""" + scrubber = _AnthropicStreamScrubber() + collected = "" + collected += scrubber.feed("reasoningvisible text") + collected += scrubber.flush() + assert "reasoning" not in collected + assert "visible text" in collected + + def test_multiple_think_blocks(self): + scrubber = _AnthropicStreamScrubber() + text = "Ar1Br2C" + result = scrubber.feed(text) + result += scrubber.flush() + assert "r1" not in result + assert "r2" not in result + assert "A" in result + assert "B" in result + assert "C" in result + + def test_think_with_newlines(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("Hi\nStep 1\nStep 2\n Done") + result += scrubber.flush() + assert "Step 1" not in result + assert "Step 2" not in result + assert "Hi" in result + assert "Done" in result + + def test_think_at_end_of_stream_flushed_away(self): + """If stream ends inside a block, flush discards carry.""" + scrubber = _AnthropicStreamScrubber() + collected = "" + collected += scrubber.feed("Hello unclosed reasoning") + flushed = scrubber.flush() + # In suppression mode, flush returns "" + assert "unclosed reasoning" not in collected + flushed + assert "Hello" in collected + flushed + + +# ============================================================================= +# ... suppression +# ============================================================================= + + +class TestScrubberToolCallTags: + """Test suppression of ... blocks.""" + + def test_tool_call_in_single_delta(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed('Before {"name":"fn"} After') + result += scrubber.flush() + assert "" not in result + assert '{"name":"fn"}' not in result + assert "" not in result + assert "Before" in result + assert "After" in result + + def test_tool_call_split_across_deltas(self): + scrubber = _AnthropicStreamScrubber() + collected = "" + collected += scrubber.feed("Text {"name":"search","args":{}} rest') + collected += scrubber.flush() + assert '{"name":"search"' not in collected + assert "" not in collected + assert "Text" in collected + assert "rest" in collected + + def test_tool_call_close_split(self): + scrubber = _AnthropicStreamScrubber() + collected = "" + collected += scrubber.feed("datavisible") + collected += scrubber.flush() + assert "data" not in collected + assert "visible" in collected + + def test_multiple_tool_calls(self): + scrubber = _AnthropicStreamScrubber() + text = "Acall1Bcall2C" + result = scrubber.feed(text) + result += scrubber.flush() + assert "call1" not in result + assert "call2" not in result + assert "A" in result + assert "B" in result + assert "C" in result + + def test_tool_call_at_end_of_stream(self): + """Unclosed tool_call at end of stream – suppressed by flush.""" + scrubber = _AnthropicStreamScrubber() + collected = "" + collected += scrubber.feed("Prefix unclosed") + flushed = scrubber.flush() + assert "unclosed" not in collected + flushed + assert "Prefix" in collected + flushed + + +# ============================================================================= +# ... suppression +# ============================================================================= + + +class TestScrubberFunctionTags: + """Test suppression of ... (Llama-style).""" + + def test_function_tag_single_delta(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed('Hello {"q":"test"} world') + result += scrubber.flush() + assert "" not in result + assert "Hello" in result + assert "world" in result + + def test_function_tag_split_across_deltas(self): + scrubber = _AnthropicStreamScrubber() + collected = "" + collected += scrubber.feed("Text {"city":"NYC"} done') + collected += scrubber.flush() + assert "get_weather" not in collected + assert '{"city":"NYC"}' not in collected + assert "Text" in collected + assert "done" in collected + + def test_function_tag_with_complex_name(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("bodyafter") + result += scrubber.flush() + assert "my_long_function_name" not in result + assert "body" not in result + assert "after" in result + + def test_stray_function_close_suppressed(self): + """A stray outside a function block should be suppressed.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("textmore") + result += scrubber.flush() + assert "" not in result + assert "text" in result + assert "more" in result + + +# ============================================================================= +# ... suppression +# ============================================================================= + + +class TestScrubberParameterTags: + """Test suppression of ... (Llama-style).""" + + def test_parameter_tag_single_delta(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("Before NYC After") + # + result += scrubber.flush() + assert " tag should be stripped.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("textmore") + result += scrubber.flush() + assert "" not in result + assert "text" in result + assert "more" in result + + def test_parameter_inside_function_block(self): + """Parameter tags typically appear inside function blocks.""" + scrubber = _AnthropicStreamScrubber() + text = "testdone" + result = scrubber.feed(text) + result += scrubber.flush() + # Everything inside ... should be suppressed + assert "query" not in result + assert "test" not in result + assert "done" in result + + +# ============================================================================= +# Stray Closing Tags +# ============================================================================= + + +class TestScrubberStrayClosingTags: + """Test that stray closing tags outside their context are consumed.""" + + def test_stray_think_close(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("helloworld") + result += scrubber.flush() + assert "" not in result + assert "hello" in result + assert "world" in result + + def test_stray_tool_call_close(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("helloworld") + result += scrubber.flush() + assert "" not in result + assert "hello" in result + assert "world" in result + + def test_stray_function_close(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("helloworld") + result += scrubber.flush() + assert "" not in result + assert "hello" in result + assert "world" in result + + def test_stray_parameter_close(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("helloworld") + result += scrubber.flush() + assert "" not in result + assert "hello" in result + assert "world" in result + + def test_multiple_stray_closing_tags(self): + scrubber = _AnthropicStreamScrubber() + text = "abcde" + result = scrubber.feed(text) + result += scrubber.flush() + assert "" not in result + assert "" not in result + assert "" not in result + assert "" not in result + assert "a" in result + assert "b" in result + assert "c" in result + assert "d" in result + assert "e" in result + + +# ============================================================================= +# Mixed Scenarios +# ============================================================================= + + +class TestScrubberMixedContent: + """Test combinations of tags and text.""" + + def test_think_then_tool_call(self): + scrubber = _AnthropicStreamScrubber() + text = ( + 'BeforereasoningMiddle{"fn":"x"}After' + ) + result = scrubber.feed(text) + result += scrubber.flush() + assert "reasoning" not in result + assert '{"fn":"x"}' not in result + assert "Before" in result + assert "Middle" in result + assert "After" in result + + def test_tool_call_then_think(self): + scrubber = _AnthropicStreamScrubber() + text = "datatextthoughtend" + result = scrubber.feed(text) + result += scrubber.flush() + assert "data" not in result + assert "thought" not in result + assert "text" in result + assert "end" in result + + def test_think_with_function_inside_tool_call(self): + """Nested-looking tags – only outer suppression matters.""" + scrubber = _AnthropicStreamScrubber() + text = "outernestedvisible" + result = scrubber.feed(text) + result += scrubber.flush() + assert "outer" not in result + assert "nested" not in result + assert "visible" in result + + def test_interleaved_text_and_tags(self): + scrubber = _AnthropicStreamScrubber() + parts = [ + "Hello ", + "", + "Let me think about this...", + "", + " Here's my answer.", + ] + collected = "" + for p in parts: + collected += scrubber.feed(p) + collected += scrubber.flush() + assert "Let me think about this" not in collected + assert "Hello" in collected + assert "Here's my answer." in collected + + def test_realistic_streaming_scenario(self): + """Simulate a realistic token-by-token streaming scenario.""" + scrubber = _AnthropicStreamScrubber() + # Model outputs: "Let me checkThe weather is sunny." + # Split into small token-like deltas + tokens = [ + "<", + "think", + ">", + "Let", + " me", + " check", + "", + "The", + " weather", + " is", + " sunny", + ".", + ] + collected = "" + for tok in tokens: + collected += scrubber.feed(tok) + collected += scrubber.flush() + assert "Let me check" not in collected + assert "" not in collected + assert "" not in collected + assert "The weather is sunny." in collected + + def test_realistic_tool_call_streaming(self): + """Simulate tool-call markup arriving token by token.""" + scrubber = _AnthropicStreamScrubber() + tokens = [ + "I'll ", + "search", + " for ", + "that.", + "", + '{"name', + '":"', + "search", + '","', + "arguments", + '":{"', + "q", + '":"', + "weather", + '"}}', + "", + ] + collected = "" + for tok in tokens: + collected += scrubber.feed(tok) + collected += scrubber.flush() + assert "" not in collected + assert "" not in collected + assert '"name"' not in collected + assert "I'll search for that." in collected + + +# ============================================================================= +# Tag Split Across Boundaries (carry buffer tests) +# ============================================================================= + + +class TestScrubberCarryBuffer: + """Test the carry buffer behavior for tags split across deltas.""" + + def test_tag_split_at_every_character(self): + """Split one char at a time.""" + scrubber = _AnthropicStreamScrubber() + collected = "" + for ch in "beforehiddenafter": + collected += scrubber.feed(ch) + collected += scrubber.flush() + assert "hidden" not in collected + assert "before" in collected + assert "after" in collected + + def test_close_tag_split_at_every_character(self): + """Split one char at a time.""" + scrubber = _AnthropicStreamScrubber() + collected = "" + for ch in "suppressedvisible": + collected += scrubber.feed(ch) + collected += scrubber.flush() + assert "suppressed" not in collected + assert "visible" in collected + + def test_tool_call_tag_split_at_every_character(self): + """Split one char at a time.""" + scrubber = _AnthropicStreamScrubber() + collected = "" + for ch in "prebodypost": + collected += scrubber.feed(ch) + collected += scrubber.flush() + assert "body" not in collected + assert "pre" in collected + assert "post" in collected + + def test_carry_cleared_after_flush(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("some text") + scrubber.flush() + assert scrubber.carry == "" + + def test_carry_cleared_after_full_consumption(self): + """When all content is consumed, carry should be empty.""" + scrubber = _AnthropicStreamScrubber() + # Feed a complete tag that consumes everything + scrubber.feed("x") + result = scrubber.flush() + assert scrubber.carry == "" + + +# ============================================================================= +# flush() Behavior +# ============================================================================= + + +class TestScrubberFlush: + """Test flush() method at end of stream.""" + + def test_flush_emits_remaining_text(self): + scrubber = _AnthropicStreamScrubber() + # "hi" has no '<' so is emitted immediately by feed(). + result = scrubber.feed("hi") + assert result == "hi" + # flush() should return empty since carry is empty. + flushed = scrubber.flush() + assert flushed == "" + + def test_flush_emits_carry_with_angle_bracket(self): + """Text ending with '<' is held in carry; flush emits it.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("text<") + assert scrubber.carry == "<" + flushed = scrubber.flush() + # '<' alone is not a valid tag, flush strips nothing extra + assert flushed == "<" + + def test_flush_in_think_mode_discards(self): + """If stream ends while inside , flush returns empty.""" + scrubber = _AnthropicStreamScrubber() + scrubber.feed("unfinished") + assert scrubber.mode == "IN_THINK" + flushed = scrubber.flush() + assert flushed == "" + assert scrubber.carry == "" + + def test_flush_in_toolcall_mode_discards(self): + """If stream ends while inside , flush returns empty.""" + scrubber = _AnthropicStreamScrubber() + scrubber.feed("unfinished") + assert scrubber.mode == "IN_TOOLCALL" + flushed = scrubber.flush() + assert flushed == "" + + def test_flush_in_function_mode_discards(self): + """If stream ends while inside , flush returns empty.""" + scrubber = _AnthropicStreamScrubber() + scrubber.feed("unfinished") + assert scrubber.mode == "IN_FUNCTION" + flushed = scrubber.flush() + assert flushed == "" + + def test_flush_strips_residual_exact_tags(self): + """flush() in TEXT mode strips any leftover exact tags from carry.""" + scrubber = _AnthropicStreamScrubber() + # Manually set carry to simulate leftover tag fragment + scrubber.carry = "textleftover" + scrubber.mode = "TEXT" + flushed = scrubber.flush() + assert "" not in flushed + assert "text" in flushed + + def test_flush_strips_residual_function_tags(self): + """flush() strips residual from carry.""" + scrubber = _AnthropicStreamScrubber() + scrubber.carry = "beforeafter" + scrubber.mode = "TEXT" + flushed = scrubber.flush() + assert " from carry.""" + scrubber = _AnthropicStreamScrubber() + scrubber.carry = "beforeafter" + scrubber.mode = "TEXT" + flushed = scrubber.flush() + assert "") + assert scrubber.mode == "IN_THINK" + + def test_think_back_to_text(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("content") + # After consuming , should be back in TEXT + assert scrubber.mode == "TEXT" + + def test_text_to_toolcall(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("text") + assert scrubber.mode == "IN_TOOLCALL" + + def test_toolcall_back_to_text(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("data") + assert scrubber.mode == "TEXT" + + def test_text_to_function(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("text") + assert scrubber.mode == "IN_FUNCTION" + + def test_function_back_to_text(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("body") + assert scrubber.mode == "TEXT" + + def test_text_to_parameter_enters_parameter_mode(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("text") + assert scrubber.mode == "IN_PARAMETER" + + def test_parameter_back_to_text(self): + scrubber = _AnthropicStreamScrubber() + scrubber.feed("body") + assert scrubber.mode == "TEXT" + + def test_stray_close_stays_in_text(self): + """Stray closing tags should not change mode.""" + scrubber = _AnthropicStreamScrubber() + scrubber.feed("textmore") + assert scrubber.mode == "TEXT" + + scrubber2 = _AnthropicStreamScrubber() + scrubber2.feed("textmore") + assert scrubber2.mode == "TEXT" + + +# ============================================================================= +# Edge Cases +# ============================================================================= + + +class TestScrubberEdgeCases: + """Edge cases and boundary conditions.""" + + def test_empty_think_block(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("beforeafter") + result += scrubber.flush() + assert "before" in result + assert "after" in result + assert "" not in result + + def test_empty_tool_call_block(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("beforeafter") + result += scrubber.flush() + assert "before" in result + assert "after" in result + + def test_angle_brackets_in_plain_text(self): + """Plain < and > that aren't tags should eventually be emitted.""" + scrubber = _AnthropicStreamScrubber() + # These don't form valid tags so should pass through + result = scrubber.feed("x < y and a > b are normal math expressions here") + result += scrubber.flush() + assert "x < y" in result or ("x" in result and "< y" in result) + assert "a > b" in result or ("a" in result and "> b" in result) + + def test_partial_tag_that_is_not_a_tag(self): + """Something like '' shouldn't be treated as .""" + scrubber = _AnthropicStreamScrubber() + # "" is not the same as ">" + # Actually is an exact match so has inside it + # The scrubber will find at the start... let's see + # Actually "" contains "" exactly + # Let me re-check: "" – scanning for "" won't match + # because the 7th char 'a' != '>' + result = scrubber.feed("atest") + result += scrubber.flush() + # is not a recognized tag, should pass through + assert "test" in result + + def test_only_tags_no_text(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("hidden") + result += scrubber.flush() + assert result == "" + + def test_consecutive_think_blocks_no_gap(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("aaabbbvisible") + result += scrubber.flush() + assert "aaa" not in result # suppressed + assert "bbb" not in result # suppressed + assert "visible" in result + + def test_very_long_suppressed_content(self): + """Test with large content inside tags.""" + scrubber = _AnthropicStreamScrubber() + long_content = "x" * 10000 + text = f"before{long_content}after" + result = scrubber.feed(text) + result += scrubber.flush() + assert long_content not in result + assert "before" in result + assert "after" in result + + def test_unicode_text_preserved(self): + """Unicode text outside tags should be preserved.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("Héllo wörld 你好 🌍secret done") + result += scrubber.flush() + assert "secret" not in result + assert "Héllo" in result + assert "done" in result + + def test_newlines_between_tags(self): + scrubber = _AnthropicStreamScrubber() + text = "line1\nhidden\n\nline2" + result = scrubber.feed(text) + result += scrubber.flush() + assert "hidden" not in result + assert "line1" in result + assert "line2" in result + + def test_back_to_back_different_tags(self): + scrubber = _AnthropicStreamScrubber() + text = "ttcfcend" + result = scrubber.feed(text) + result += scrubber.flush() + assert "end" in result + # All tagged content suppressed + for s in [ + "", + "", + "", + "", + "", + ]: + assert s not in result + + +# ============================================================================= +# _find_earliest_marker Internal Method +# ============================================================================= + + +class TestFindEarliestMarker: + """Test the _find_earliest_marker helper directly.""" + + def test_no_markers(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber._find_earliest_marker("hello world", 0) + assert result is None + + def test_finds_think_open(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber._find_earliest_marker("textmore", 0) + assert result is not None + pos, marker, consume = result + assert pos == 4 + assert marker == "" + assert consume == len("") + + def test_finds_earliest_of_multiple(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber._find_earliest_marker("text", 0) + assert result is not None + pos, marker, _ = result + assert pos == 0 + assert marker == "" + + def test_finds_function_prefix(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber._find_earliest_marker("textmore", 0) + assert result is not None + pos, marker, consume = result + assert pos == 4 + assert marker == "") + + def test_function_prefix_missing_close_angle(self): + """If '>' is missing for a prefix tag, consume should be -1.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber._find_earliest_marker("textmore", 7) + assert result is not None + pos, _, _ = result + assert pos == 11 # second + + def test_parameter_prefix(self): + scrubber = _AnthropicStreamScrubber() + result = scrubber._find_earliest_marker("val", 0) + assert result is not None + pos, marker, consume = result + assert pos == 0 + assert marker == "") + + +# ============================================================================= +# Integration with Scrubber in Streaming Context +# ============================================================================= + + +class TestScrubberStreamingIntegration: + """Simulate real streaming patterns to ensure correctness end-to-end.""" + + def _stream_through(self, scrubber, deltas): + """Feed a list of deltas through the scrubber, return collected output.""" + collected = "" + for d in deltas: + collected += scrubber.feed(d) + collected += scrubber.flush() + return collected + + def test_clean_text_passthrough(self): + """Normal text with no tags should come through unchanged.""" + scrubber = _AnthropicStreamScrubber() + text = "The weather today is sunny and warm." + words = text.split(" ") + deltas = [w + " " for w in words[:-1]] + [words[-1]] + result = self._stream_through(scrubber, deltas) + assert result == text + + def test_think_then_answer_streaming(self): + """Model thinks, then answers.""" + scrubber = _AnthropicStreamScrubber() + deltas = [ + "", + "Let me reason about this...\n", + "The user wants weather info.\n", + "", + "The weather ", + "is sunny ", + "today.", + ] + result = self._stream_through(scrubber, deltas) + assert "reason" not in result + assert "The weather is sunny today." in result + + def test_tool_call_json_streaming(self): + """Model emits tool call JSON in small chunks.""" + scrubber = _AnthropicStreamScrubber() + deltas = [ + "Let me look that up.", + "", + '{"', + 'name": "', + "get_weather", + '", "arguments', + '": {"city": "', + "San Francisco", + '"}}', + "", + ] + result = self._stream_through(scrubber, deltas) + assert "Let me look that up." in result + assert "get_weather" not in result + assert "San Francisco" not in result + + def test_think_then_tool_call_streaming(self): + """Model reasons then makes a tool call.""" + scrubber = _AnthropicStreamScrubber() + deltas = [ + "", + "I need to ", + "search for this.", + "", + "I'll help with that.", + "", + '{"name":"search"}', + "", + ] + result = self._stream_through(scrubber, deltas) + assert "I need to" not in result + assert "I'll help with that." in result + assert "search" not in result or result == "I'll help with that." + + def test_scrubber_reuse_not_recommended(self): + """After flush, feeding more data should still work (even if atypical).""" + scrubber = _AnthropicStreamScrubber() + r1 = scrubber.feed("first") + r1 += scrubber.flush() + # Reuse + r2 = scrubber.feed("second") + r2 += scrubber.flush() + assert "first" in r1 + assert "second" in r2 + + +# ============================================================================= +# Zero-Latency Carry: plain text should not be held back +# ============================================================================= + + +class TestScrubberZeroLatencyCarry: + """Verify that the conditional carry buffer doesn't stall plain text.""" + + def test_plain_text_emits_immediately(self): + """No '<' in text → carry should be empty, all text emitted.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("Hello world") + assert result == "Hello world" + assert scrubber.carry == "" + + def test_plain_deltas_no_carry(self): + """Multiple plain deltas should each emit fully.""" + scrubber = _AnthropicStreamScrubber() + for word in ["The ", "quick ", "brown ", "fox."]: + result = scrubber.feed(word) + assert result == word + assert scrubber.carry == "" + + def test_angle_bracket_at_end_triggers_carry(self): + """A '<' near the end should be held in carry (could be tag start).""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("text<") + assert "<" not in result + assert scrubber.carry == "<" + + def test_angle_bracket_resolved_next_delta(self): + """Carry '<' is resolved when next delta shows it's not a tag.""" + scrubber = _AnthropicStreamScrubber() + r1 = scrubber.feed("value < ") + r2 = scrubber.feed("other") + r2 += scrubber.flush() + full = r1 + r2 + assert "value < other" in full + + def test_angle_bracket_resolved_as_tag(self): + """Carry '<' is resolved when next delta completes a tag.""" + scrubber = _AnthropicStreamScrubber() + r1 = scrubber.feed("before<") + r2 = scrubber.feed("think>hiddenafter") + r2 += scrubber.flush() + full = r1 + r2 + assert "before" in full + assert "after" in full + assert "hidden" not in full + + def test_first_token_emits_immediately(self): + """The very first token should not be stalled by carry buffer.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("Hi") + assert result == "Hi" + + def test_long_plain_text_no_carry(self): + """Long text with no '<' should all be emitted, carry empty.""" + scrubber = _AnthropicStreamScrubber() + text = "A" * 500 + result = scrubber.feed(text) + assert result == text + assert scrubber.carry == "" + + +# ============================================================================= +# Carry Buffer Cap (unbounded growth prevention) +# ============================================================================= + + +class TestScrubberCarryCap: + """Test that the carry buffer is capped to prevent unbounded growth.""" + + def test_max_carry_constant_exists(self): + """MAX_CARRY constant should be defined and reasonable.""" + assert hasattr(_AnthropicStreamScrubber, "MAX_CARRY") + assert _AnthropicStreamScrubber.MAX_CARRY > _AnthropicStreamScrubber.CARRY_N + + def test_carry_cap_emits_as_literal(self): + """If a prefix tag never closes (>), carry cap emits content as literal text.""" + scrubber = _AnthropicStreamScrubber() + collected = "" + collected += scrubber.feed("text '.""" + + def test_flush_strips_incomplete_function_prefix(self): + """flush() strips incomplete .""" + scrubber = _AnthropicStreamScrubber() + scrubber.carry = "text.""" + scrubber = _AnthropicStreamScrubber() + scrubber.carry = "text, flush returns empty.""" + scrubber = _AnthropicStreamScrubber() + scrubber.feed("unfinished") + assert scrubber.mode == "IN_PARAMETER" + flushed = scrubber.flush() + assert flushed == "" + + +# ============================================================================= +# IN_PARAMETER State (separate from IN_FUNCTION) +# ============================================================================= + + +class TestScrubberParameterState: + """Test the IN_PARAMETER state closes on (not ).""" + + def test_parameter_closes_on_parameter_tag(self): + """... correctly suppresses and closes.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed("beforeNYCafter") + result += scrubber.flush() + assert "NYC" not in result + assert "before" in result + assert "after" in result + assert scrubber.mode == "TEXT" + + def test_parameter_does_not_close_on_function_tag(self): + """IN_PARAMETER mode should NOT close on .""" + scrubber = _AnthropicStreamScrubber() + scrubber.feed("contentstill_suppressed") + # Should still be in IN_PARAMETER since doesn't close it + assert scrubber.mode == "IN_PARAMETER" + + def test_standalone_parameter_with_text_after(self): + """Standalone value followed by text.""" + scrubber = _AnthropicStreamScrubber() + result = scrubber.feed( + "search term Here is my response." + ) + result += scrubber.flush() + assert "search term" not in result + assert "" not in result + assert "Here is my response." in result + + +# ============================================================================= +# Router Inherits Scrubber Behavior +# ============================================================================= + + +class TestRouterInheritsScrubber: + """Test that the Router (subclass) inherits all Scrubber capabilities.""" + + def test_router_is_subclass(self): + from vllm_mlx.server import _AnthropicStreamRouter + + assert issubclass(_AnthropicStreamRouter, _AnthropicStreamScrubber) + + def test_router_has_carry_cap(self): + from vllm_mlx.server import _AnthropicStreamRouter + + router = _AnthropicStreamRouter() + assert hasattr(router, "MAX_CARRY") + assert router.MAX_CARRY == _AnthropicStreamScrubber.MAX_CARRY + + def test_router_parameter_mode(self): + """Router should handle IN_PARAMETER the same as scrubber.""" + from vllm_mlx.server import _AnthropicStreamRouter + + router = _AnthropicStreamRouter() + pieces = router.feed("textsuppressedafter") + pieces += router.flush() + text = "".join(t for k, t in pieces if k == "text") + assert "suppressed" not in text + assert "text" in text + assert "after" in text + + +# ============================================================================= +# _AnthropicStreamRouter Tests +# ============================================================================= + +from vllm_mlx.server import _AnthropicStreamRouter, _is_thinking_enabled + + +class TestRouterInitialState: + """Test router creation and initial state.""" + + def test_initial_mode_is_text(self): + router = _AnthropicStreamRouter() + assert router.mode == "TEXT" + + def test_initial_carry_is_empty(self): + router = _AnthropicStreamRouter() + assert router.carry == "" + + +class TestRouterPlainText: + """Router should pass through plain text as ('text', ...) pieces.""" + + def test_plain_text_emits_text_piece(self): + router = _AnthropicStreamRouter() + pieces = router.feed("Hello world") + assert len(pieces) == 1 + assert pieces[0] == ("text", "Hello world") + + def test_empty_string(self): + router = _AnthropicStreamRouter() + pieces = router.feed("") + assert pieces == [] + + def test_multiple_plain_deltas(self): + router = _AnthropicStreamRouter() + all_text = "" + for word in ["The ", "quick ", "brown ", "fox."]: + for kind, text in router.feed(word): + assert kind == "text" + all_text += text + for kind, text in router.flush(): + if kind == "text": + all_text += text + assert all_text == "The quick brown fox." + + +class TestRouterThinkingBlocks: + """Router should emit thinking_start/thinking/thinking_stop for blocks.""" + + def test_think_block_single_delta(self): + router = _AnthropicStreamRouter() + pieces = router.feed("Hello reasoning world") + pieces += router.flush() + + kinds = [k for k, _ in pieces] + assert "thinking_start" in kinds + assert "thinking" in kinds + assert "thinking_stop" in kinds + + # Collect text and thinking separately + text = "".join(t for k, t in pieces if k == "text") + thinking = "".join(t for k, t in pieces if k == "thinking") + assert "Hello" in text + assert "world" in text + assert "reasoning" in thinking + + def test_think_block_split_across_deltas(self): + router = _AnthropicStreamRouter() + all_pieces = [] + for delta in ["Before secret After"]: + all_pieces.extend(router.feed(delta)) + all_pieces.extend(router.flush()) + + text = "".join(t for k, t in all_pieces if k == "text") + thinking = "".join(t for k, t in all_pieces if k == "thinking") + assert "Before" in text + assert "After" in text + assert "secret" in thinking + + def test_think_then_text_streaming(self): + """Simulate realistic think-then-answer streaming.""" + router = _AnthropicStreamRouter() + all_pieces = [] + for delta in ["", "Let me reason.", "", "The answer."]: + all_pieces.extend(router.feed(delta)) + all_pieces.extend(router.flush()) + + kinds = [k for k, _ in all_pieces] + text = "".join(t for k, t in all_pieces if k == "text") + thinking = "".join(t for k, t in all_pieces if k == "thinking") + + assert "thinking_start" in kinds + assert "thinking_stop" in kinds + assert "Let me reason." in thinking + assert "The answer." in text + + def test_multiple_think_blocks(self): + router = _AnthropicStreamRouter() + pieces = router.feed("Ar1Br2C") + pieces += router.flush() + + text = "".join(t for k, t in pieces if k == "text") + thinking = "".join(t for k, t in pieces if k == "thinking") + starts = sum(1 for k, _ in pieces if k == "thinking_start") + stops = sum(1 for k, _ in pieces if k == "thinking_stop") + + assert "A" in text + assert "B" in text + assert "C" in text + assert "r1" in thinking + assert "r2" in thinking + assert starts == 2 + assert stops == 2 + + def test_unclosed_think_at_end(self): + """Unclosed at end should flush remaining as thinking.""" + router = _AnthropicStreamRouter() + pieces = router.feed("unfinished") + pieces += router.flush() + + kinds = [k for k, _ in pieces] + thinking = "".join(t for k, t in pieces if k == "thinking") + assert "thinking_start" in kinds + assert "thinking_stop" in kinds # flush closes it + assert "unfinished" in thinking + + +class TestRouterToolCallSuppression: + """Router should suppress tool_call/function/parameter content (no pieces).""" + + def test_tool_call_suppressed(self): + router = _AnthropicStreamRouter() + pieces = router.feed('Before {"fn":"x"} After') + pieces += router.flush() + + text = "".join(t for k, t in pieces if k == "text") + all_content = "".join(t for _, t in pieces) + assert "Before" in text + assert "After" in text + assert '{"fn":"x"}' not in all_content + + def test_function_tag_suppressed(self): + router = _AnthropicStreamRouter() + pieces = router.feed("bodyafter") + pieces += router.flush() + + text = "".join(t for k, t in pieces if k == "text") + assert "after" in text + assert "body" not in "".join(t for _, t in pieces) + + +class TestRouterMixedContent: + """Test router with think + tool_call combined.""" + + def test_think_then_tool_call(self): + router = _AnthropicStreamRouter() + text_input = "reasoningvisibledataend" + pieces = router.feed(text_input) + pieces += router.flush() + + text = "".join(t for k, t in pieces if k == "text") + thinking = "".join(t for k, t in pieces if k == "thinking") + + assert "reasoning" in thinking + assert "visible" in text + assert "end" in text + assert "data" not in text + assert "data" not in thinking + + def test_realistic_streaming(self): + """Token-by-token streaming with thinking.""" + router = _AnthropicStreamRouter() + tokens = [ + "<", + "think", + ">", + "Let", + " me", + " check", + "", + "The", + " answer", + " is", + " 42", + ".", + ] + all_pieces = [] + for tok in tokens: + all_pieces.extend(router.feed(tok)) + all_pieces.extend(router.flush()) + + text = "".join(t for k, t in all_pieces if k == "text") + thinking = "".join(t for k, t in all_pieces if k == "thinking") + + assert "Let me check" in thinking + assert "The answer is 42." in text + + +class TestRouterFlush: + """Test router flush() behavior.""" + + def test_flush_text_mode(self): + router = _AnthropicStreamRouter() + router.feed("text<") # '<' held in carry + pieces = router.flush() + # Should emit the '<' as text + text = "".join(t for k, t in pieces if k == "text") + assert "<" in text + + def test_flush_in_think_mode(self): + router = _AnthropicStreamRouter() + router.feed("leftover") + pieces = router.flush() + kinds = [k for k, _ in pieces] + assert "thinking" in kinds + assert "thinking_stop" in kinds + + def test_flush_in_toolcall_mode(self): + router = _AnthropicStreamRouter() + router.feed("stuff") + pieces = router.flush() + # Should discard (tool_call content suppressed) + assert pieces == [] + + +# ============================================================================= +# _is_thinking_enabled Helper +# ============================================================================= + + +class TestIsThinkingEnabled: + """Test the _is_thinking_enabled helper function.""" + + def test_none_thinking(self): + from vllm_mlx.api.anthropic_models import AnthropicRequest, AnthropicMessage + + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + thinking=None, + ) + assert _is_thinking_enabled(req) is False + + def test_no_thinking_field(self): + from vllm_mlx.api.anthropic_models import AnthropicRequest, AnthropicMessage + + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + ) + assert _is_thinking_enabled(req) is False + + def test_thinking_enabled_dict(self): + from vllm_mlx.api.anthropic_models import AnthropicRequest, AnthropicMessage + + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + thinking={"type": "enabled", "budget_tokens": 5000}, + ) + assert _is_thinking_enabled(req) is True + + def test_thinking_disabled_dict(self): + from vllm_mlx.api.anthropic_models import AnthropicRequest, AnthropicMessage + + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + thinking={"type": "disabled"}, + ) + assert _is_thinking_enabled(req) is False + + def test_thinking_enabled_model(self): + from vllm_mlx.api.anthropic_models import ( + AnthropicRequest, + AnthropicMessage, + AnthropicThinkingConfig, + ) + + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + thinking=AnthropicThinkingConfig(type="enabled", budget_tokens=8000), + ) + assert _is_thinking_enabled(req) is True + + def test_thinking_disabled_model(self): + from vllm_mlx.api.anthropic_models import ( + AnthropicRequest, + AnthropicMessage, + AnthropicThinkingConfig, + ) + + req = AnthropicRequest( + model="test", + messages=[AnthropicMessage(role="user", content="hi")], + max_tokens=100, + thinking=AnthropicThinkingConfig(type="disabled"), + ) + assert _is_thinking_enabled(req) is False diff --git a/vllm_mlx/api/anthropic_models.py b/vllm_mlx/api/anthropic_models.py index a5bc6f776..0831916bf 100644 --- a/vllm_mlx/api/anthropic_models.py +++ b/vllm_mlx/api/anthropic_models.py @@ -50,6 +50,13 @@ class AnthropicToolDef(BaseModel): input_schema: dict | None = None +class AnthropicThinkingConfig(BaseModel): + """Configuration for extended thinking (Anthropic streaming).""" + + type: str = "enabled" # "enabled" or "disabled" + budget_tokens: int | None = None + + class AnthropicRequest(BaseModel): """Request for Anthropic Messages API.""" @@ -65,6 +72,7 @@ class AnthropicRequest(BaseModel): tool_choice: dict | None = None metadata: dict | None = None top_k: int | None = None + thinking: AnthropicThinkingConfig | dict | None = None # ============================================================================= diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f0328d4e6..e176834ed 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -42,6 +42,7 @@ import json import logging import os +import re import secrets import tempfile import threading @@ -1666,6 +1667,333 @@ async def count_anthropic_tokens(request: Request): return {"input_tokens": total_tokens} +class _AnthropicStreamScrubber: + """Stateful scrubber that strips model tool-call and reasoning markup + from streamed text deltas on the Anthropic ``/v1/messages`` endpoint. + + Suppressed patterns: + + * ``...`` – internal reasoning blocks + * ``...`` – Qwen/Hermes-style tool calls + * ``...`` – Llama-style tool calls + * ``...`` – Llama-style parameters + * Stray closing tags (````, ````, ````, + ````) appearing outside their expected context + + Handles tags that may be split across multiple token boundaries by + maintaining a small carry buffer. The carry buffer always retains + the last ``CARRY_N`` characters so that a tag split across two + consecutive deltas can still be detected. A cap (``MAX_CARRY``) + prevents unbounded growth when a prefix tag like ```` never arrives. + + The scrubber operates as a simple state machine: + + * **TEXT** – emit characters; scan for opening/stray-closing tags. + * **IN_THINK** – suppress until ````. + * **IN_TOOLCALL** – suppress until ````. + * **IN_FUNCTION** – suppress until ````. + * **IN_PARAMETER** – suppress until ````. + """ + + # --- Fixed (exact-match) tags ---------------------------------------- + THINK_OPEN = "" + THINK_CLOSE = "" + TOOL_OPEN = "" + TOOL_CLOSE = "" + FUNC_CLOSE = "" + PARAM_CLOSE = "" + + # Exact tags to scan for in TEXT mode. Order doesn't matter – we + # always pick the earliest match. + _EXACT_TAGS = [ + THINK_OPEN, + THINK_CLOSE, + TOOL_OPEN, + TOOL_CLOSE, + FUNC_CLOSE, + PARAM_CLOSE, + ] + + # --- Prefix (variable-length) opening tags --------------------------- + # These look like ```` or ```` where + # the name varies. We detect the prefix then scan forward for ``>``. + FUNC_PREFIX = "`` never arrives, carry keeps accumulating each + # new delta. Cap it and emit as literal text to prevent unbounded + # growth. + MAX_CARRY = CARRY_N + 256 + + # Map from opening signal → suppression mode + _MODE_MAP = { + THINK_OPEN: "IN_THINK", + TOOL_OPEN: "IN_TOOLCALL", + FUNC_PREFIX: "IN_FUNCTION", + PARAM_PREFIX: "IN_PARAMETER", + } + + # Map from suppression mode → closing tag + _CLOSE_MAP = { + "IN_THINK": THINK_CLOSE, + "IN_TOOLCALL": TOOL_CLOSE, + "IN_FUNCTION": FUNC_CLOSE, + "IN_PARAMETER": PARAM_CLOSE, + } + + def __init__(self) -> None: + self.mode: str = "TEXT" + self.carry: str = "" + + # ----------------------------------------------------------------- + # Internal helpers + # ----------------------------------------------------------------- + + def _find_earliest_marker(self, s: str, start: int) -> tuple[int, str, int] | None: + """Find the earliest opening or stray-closing tag in *s* from *start*. + + Returns ``(position, marker, consume_length)`` or ``None``. + *consume_length* is how many characters to skip past the marker + (for exact tags this equals ``len(marker)``; for prefix tags it + extends to the closing ``>``). + """ + best: tuple[int, str, int] | None = None + + # Check exact tags. + for tag in self._EXACT_TAGS: + pos = s.find(tag, start) + if pos != -1 and (best is None or pos < best[0]): + best = (pos, tag, len(tag)) + + # Check prefix tags (e.g. ````). + for prefix in self._PREFIX_TAGS: + pos = s.find(prefix, start) + if pos != -1 and (best is None or pos < best[0]): + # Need to find the closing '>' to know full tag length. + gt = s.find(">", pos + len(prefix)) + if gt != -1: + consume = gt + 1 - pos # e.g. len("") + best = (pos, prefix, consume) + else: + # '>' not yet in buffer – treat as a partial tag. + # consume = -1 signals "truncated". + best = (pos, prefix, -1) + + return best + + # ----------------------------------------------------------------- + # Core processing – returns typed pieces + # ----------------------------------------------------------------- + + def _feed_pieces(self, delta: str) -> list[tuple[str, str]]: + """Core processing: returns ``(kind, text)`` pieces. + + *kind* is one of ``"text"``, ``"thinking_start"``, + ``"thinking"``, ``"thinking_stop"``. Tool / function / + parameter content is silently dropped (no pieces emitted). + + Used directly by :class:`_AnthropicStreamRouter`; the + scrubber's public :meth:`feed` extracts only ``"text"`` pieces. + """ + s = self.carry + (delta or "") + pieces: list[tuple[str, str]] = [] + slen = len(s) + i = 0 + + while i < slen: + if self.mode == "TEXT": + hit = self._find_earliest_marker(s, i) + + if hit is None: + # No marker anywhere. Only retain a carry suffix if + # there is a '<' near the tail that could be the start + # of a split tag. Otherwise emit everything immediately + # so plain text streams with zero latency. + tail = s[max(i, slen - self.CARRY_N) :] + lt_pos = tail.rfind("<") + if lt_pos != -1: + # Keep from the '<' onward as carry. + carry_start = max(i, slen - self.CARRY_N) + lt_pos + if carry_start > i: + pieces.append(("text", s[i:carry_start])) + self.carry = s[carry_start:] + else: + # No '<' in tail – emit everything. + if slen > i: + pieces.append(("text", s[i:])) + self.carry = "" + return pieces + + pos, marker, consume = hit + + if consume < 0: + # Prefix tag found but closing '>' missing – truncated. + if pos > i: + pieces.append(("text", s[i:pos])) + # Cap carry growth: if the partial prefix tag region + # exceeds MAX_CARRY, treat it as plain text. + if slen - pos > self.MAX_CARRY: + pieces.append(("text", s[pos:])) + self.carry = "" + else: + self.carry = s[pos:] + return pieces + + tag_end = pos + consume + if tag_end > slen: + # Full tag not in buffer yet. + if pos > i: + pieces.append(("text", s[i:pos])) + self.carry = s[pos:] + return pieces + + # Emit text before the tag. + if pos > i: + pieces.append(("text", s[i:pos])) + + # Consume the tag. + i = tag_end + + # Determine new mode (if any). + new_mode = self._MODE_MAP.get(marker) + if new_mode: + self.mode = new_mode + if new_mode == "IN_THINK": + pieces.append(("thinking_start", "")) + # else: stray closing tag – consumed and suppressed, stay TEXT. + + elif self.mode == "IN_THINK": + # Find closing . + close_pos = s.find(self.THINK_CLOSE, i) + if close_pos == -1: + # Emit thinking content up to carry boundary. + safe_end = max(i, slen - self.CARRY_N) + if safe_end > i: + pieces.append(("thinking", s[i:safe_end])) + self.carry = s[safe_end:] + return pieces + # Emit thinking content before closing tag. + if close_pos > i: + pieces.append(("thinking", s[i:close_pos])) + pieces.append(("thinking_stop", "")) + i = close_pos + len(self.THINK_CLOSE) + self.mode = "TEXT" + + else: + # IN_TOOLCALL, IN_FUNCTION, or IN_PARAMETER – suppress. + close_tag = self._CLOSE_MAP[self.mode] + close_pos = s.find(close_tag, i) + if close_pos == -1: + # Closing tag not yet in buffer. + self.carry = s[max(i, slen - self.CARRY_N) :] + return pieces + i = close_pos + len(close_tag) + self.mode = "TEXT" + + # Entire buffer consumed. + self.carry = "" + return pieces + + def _flush_pieces(self) -> list[tuple[str, str]]: + """Core flush processing: returns ``(kind, text)`` pieces.""" + pieces: list[tuple[str, str]] = [] + if self.mode == "IN_THINK": + # Emit any remaining thinking content. + if self.carry: + pieces.append(("thinking", self.carry)) + pieces.append(("thinking_stop", "")) + elif self.mode == "TEXT" and self.carry: + result = self.carry + # Strip residual exact tags. + for tag in self._EXACT_TAGS: + result = result.replace(tag, "") + # Strip complete prefix tags (e.g. ````). + result = re.sub(r"]*>", "", result) + result = re.sub(r"]*>", "", result) + # Strip incomplete prefix tags without closing ``>``. + result = re.sub(r"]*$", "", result) + result = re.sub(r"]*$", "", result) + if result: + pieces.append(("text", result)) + # IN_TOOLCALL / IN_FUNCTION / IN_PARAMETER – discard. + self.carry = "" + self.mode = "TEXT" + return pieces + + # ----------------------------------------------------------------- + # Public API + # ----------------------------------------------------------------- + + def feed(self, delta: str) -> str: + """Process a new text delta and return only the safe-to-emit portion. + + Think content is suppressed (dropped). + """ + pieces = self._feed_pieces(delta) + return "".join(text for kind, text in pieces if kind == "text") + + def flush(self) -> str: + """Flush remaining carry buffer at end of stream. + + Emits leftover text only in TEXT mode (stripping any stray tags); + discards carry if inside a suppressed region. + """ + pieces = self._flush_pieces() + return "".join(text for kind, text in pieces if kind == "text") + + +class _AnthropicStreamRouter(_AnthropicStreamScrubber): + """Stream router that translates ```` regions into Anthropic + ``thinking_delta`` events while still suppressing tool-call markup. + + Subclasses :class:`_AnthropicStreamScrubber` to reuse tag detection, + carry-buffer management, and the core state-machine loop. Instead + of dropping ```` content, the router exposes the typed + ``(kind, text)`` pieces from ``_feed_pieces`` / ``_flush_pieces`` + so they can be emitted on separate content-block channels. + + ``feed()`` returns a list of ``(kind, text)`` tuples: + + * ``("text", "...")`` – normal text for ``text_delta`` + * ``("thinking_start", "")`` – signals start of a thinking block + * ``("thinking", "...")`` – thinking content for ``thinking_delta`` + * ``("thinking_stop", "")`` – signals end of a thinking block + * Tool-call / function / parameter content is silently suppressed. + """ + + def __init__(self, start_in_thinking: bool = False) -> None: + super().__init__() + if start_in_thinking: + self.mode = "IN_THINK" + + def feed(self, delta: str) -> list[tuple[str, str]]: + """Process a delta and return a list of ``(kind, text)`` pieces.""" + return self._feed_pieces(delta) + + def flush(self) -> list[tuple[str, str]]: + """Flush at end of stream.""" + return self._flush_pieces() + + +def _is_thinking_enabled(anthropic_request: AnthropicRequest) -> bool: + """Check if the client has requested extended thinking.""" + thinking = getattr(anthropic_request, "thinking", None) + if thinking is None: + return False + if isinstance(thinking, dict): + return thinking.get("type") == "enabled" + return getattr(thinking, "type", None) == "enabled" + + async def _stream_anthropic_messages( engine: BaseEngine, openai_request: ChatCompletionRequest, @@ -1677,6 +2005,10 @@ async def _stream_anthropic_messages( Converts OpenAI streaming chunks to Anthropic event format: message_start -> content_block_start -> content_block_delta* -> content_block_stop -> message_delta -> message_stop + + A streaming scrubber always filters out ... and + ... markup that the model may emit, so + clients only see clean text and structured tool_use blocks. """ msg_id = f"msg_{uuid.uuid4().hex[:24]}" start_time = time.perf_counter() @@ -1715,13 +2047,64 @@ async def _stream_anthropic_messages( } yield f"event: message_start\ndata: {json.dumps(message_start)}\n\n" - # Emit content_block_start for text - content_block_start = { - "type": "content_block_start", - "index": 0, - "content_block": {"type": "text", "text": ""}, - } - yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n" + # Determine whether the client requested extended thinking. + thinking_enabled = _is_thinking_enabled(anthropic_request) + + # Content block index tracking. When thinking is enabled the + # thinking block is emitted first (index 0) and the text block + # follows (index 1). Otherwise only the text block exists (index 0). + # These values are updated dynamically as blocks are opened. + next_block_index = 0 + thinking_block_index: int | None = None + thinking_block_open = False + text_block_index: int | None = None + text_block_open = False + + if thinking_enabled: + # Use the stream router which yields typed (kind, text) pieces + # that separate thinking content from user-facing text. + router: _AnthropicStreamRouter | None = _AnthropicStreamRouter( + start_in_thinking=False + ) + scrubber: _AnthropicStreamScrubber | None = None + + # Open both content blocks upfront so clients know the layout: + # index 0 = thinking block + # index 1 = text block + thinking_block_index = next_block_index + next_block_index += 1 + thinking_block_open = True + ev = { + "type": "content_block_start", + "index": thinking_block_index, + "content_block": {"type": "thinking", "thinking": ""}, + } + yield f"event: content_block_start\ndata: {json.dumps(ev)}\n\n" + + text_block_index = next_block_index + next_block_index += 1 + text_block_open = True + ev = { + "type": "content_block_start", + "index": text_block_index, + "content_block": {"type": "text", "text": ""}, + } + yield f"event: content_block_start\ndata: {json.dumps(ev)}\n\n" + else: + # Use the scrubber which simply strips all content. + router = None + scrubber = _AnthropicStreamScrubber() + + # Only text block (index 0). + text_block_index = next_block_index + next_block_index += 1 + text_block_open = True + content_block_start = { + "type": "content_block_start", + "index": text_block_index, + "content_block": {"type": "text", "text": ""}, + } + yield f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n" # Stream content deltas accumulated_text = "" @@ -1739,24 +2122,122 @@ async def _stream_anthropic_messages( content = SPECIAL_TOKENS_PATTERN.sub("", delta_text) if content: + # Always accumulate the raw (unfiltered) content for tool-call + # parsing at the end of the stream. accumulated_text += content - delta_event = { + + if router is not None: + # ---- Thinking-enabled path (stream router) ---- + # Both blocks are opened upfront (thinking=0, text=1). + for kind, text in router.feed(content): + if kind == "thinking_start": + # Block already opened upfront – nothing to do. + pass + + elif kind == "thinking" and text: + ev = { + "type": "content_block_delta", + "index": thinking_block_index, + "delta": {"type": "thinking_delta", "thinking": text}, + } + yield f"event: content_block_delta\ndata: {json.dumps(ev)}\n\n" + + elif kind == "thinking_stop": + if thinking_block_open: + ev = { + "type": "content_block_stop", + "index": thinking_block_index, + } + yield f"event: content_block_stop\ndata: {json.dumps(ev)}\n\n" + thinking_block_open = False + + elif kind == "text" and text: + ev = { + "type": "content_block_delta", + "index": text_block_index, + "delta": {"type": "text_delta", "text": text}, + } + yield f"event: content_block_delta\ndata: {json.dumps(ev)}\n\n" + + elif scrubber is not None: + # ---- Scrubber path (thinking suppressed) ---- + content = scrubber.feed(content) + if content: + delta_event = { + "type": "content_block_delta", + "index": text_block_index, + "delta": {"type": "text_delta", "text": content}, + } + yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" + + # Flush remaining carry buffer + if router is not None: + for kind, text in router.flush(): + if kind == "thinking" and text: + if thinking_block_open: + ev = { + "type": "content_block_delta", + "index": thinking_block_index, + "delta": {"type": "thinking_delta", "thinking": text}, + } + yield f"event: content_block_delta\ndata: {json.dumps(ev)}\n\n" + elif kind == "thinking_stop": + if thinking_block_open: + ev = {"type": "content_block_stop", "index": thinking_block_index} + yield f"event: content_block_stop\ndata: {json.dumps(ev)}\n\n" + thinking_block_open = False + elif kind == "text" and text: + if text_block_index is None: + text_block_index = next_block_index + next_block_index += 1 + text_block_open = True + ev = { + "type": "content_block_start", + "index": text_block_index, + "content_block": {"type": "text", "text": ""}, + } + yield f"event: content_block_start\ndata: {json.dumps(ev)}\n\n" + ev = { "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": content}, + "index": text_block_index, + "delta": {"type": "text_delta", "text": text}, } - yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" + yield f"event: content_block_delta\ndata: {json.dumps(ev)}\n\n" + elif scrubber is not None: + flushed = scrubber.flush() + if flushed: + delta_event = { + "type": "content_block_delta", + "index": text_block_index, + "delta": {"type": "text_delta", "text": flushed}, + } + yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n" + + # Ensure text block was opened (even if model produced no text) + if text_block_index is None: + text_block_index = next_block_index + next_block_index += 1 + text_block_open = True + ev = { + "type": "content_block_start", + "index": text_block_index, + "content_block": {"type": "text", "text": ""}, + } + yield f"event: content_block_start\ndata: {json.dumps(ev)}\n\n" # Check for tool calls in accumulated text _, tool_calls = _parse_tool_calls_with_parser(accumulated_text, openai_request) - # Emit content_block_stop for text block - yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0})}\n\n" + # Close any remaining open blocks + if thinking_block_open: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': thinking_block_index})}\n\n" + if text_block_open: + yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': text_block_index})}\n\n" # If there are tool calls, emit tool_use blocks if tool_calls: for i, tc in enumerate(tool_calls): - tool_index = i + 1 + tool_index = next_block_index + i try: tool_input = json.loads(tc.function.arguments) except (json.JSONDecodeError, AttributeError):