Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions benchmarks/bench_reasoning_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Benchmark: reasoning parser streaming performance.

Measures per-token overhead of extract_reasoning_streaming() at various
output lengths. Demonstrates the difference between O(N²) accumulated
text scanning and O(1) state-machine tracking.

Usage:
python benchmarks/bench_reasoning_parser.py
"""

import time

from vllm_mlx.reasoning.qwen3_parser import Qwen3ReasoningParser


def bench_streaming(parser, n_tokens: int, label: str) -> float:
"""Simulate n_tokens of streaming through the parser. Returns total ms."""
parser.reset_state()

# Simulate: <think> + N reasoning tokens + </think> + 10 content tokens
tokens = ["<think>"]
tokens += [f"word{i} " for i in range(n_tokens)]
tokens += ["</think>"]
tokens += [f"answer{i} " for i in range(10)]

accumulated = ""
start = time.perf_counter()
for tok in tokens:
prev = accumulated
accumulated += tok
parser.extract_reasoning_streaming(prev, accumulated, tok)
elapsed = (time.perf_counter() - start) * 1000

print(f" {label}: {n_tokens:>6} tokens -> {elapsed:>8.2f}ms "
f"({elapsed / (n_tokens + 11):.3f}ms/tok)")
return elapsed


def main():
parser = Qwen3ReasoningParser()

print("Reasoning parser streaming benchmark")
print("=" * 60)
print()

for n in [50, 100, 200, 500, 1000, 2000, 5000]:
bench_streaming(parser, n, f"{n} tokens")

print()
print("At 50 tok/s, per-token budget is 20ms.")
print("Parser overhead should be <0.1ms/tok to be negligible.")


if __name__ == "__main__":
main()
209 changes: 99 additions & 110 deletions vllm_mlx/reasoning/think_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
1. Both tags in output: <think>reasoning</think>content
2. Only closing tag (think injected in prompt): reasoning</think>content
3. No tags: pure content

Performance: The streaming parser uses a simple state machine to track the
current phase (pre-think / thinking / content). Tag completion is detected
against the accumulated text for correctness when `<think>` / `</think>` are
split across delta boundaries, but phase tracking still avoids the old
whole-output rescanning behavior.
"""

from abc import abstractmethod
Expand All @@ -27,8 +33,12 @@ class BaseThinkingReasoningParser(ReasoningParser):
and only </think> appears in the model output. This is common with AI agents
like OpenCode that force models to reason by injecting thinking tags.

The parser tracks state during streaming to correctly separate reasoning
from content as tokens arrive incrementally.
The streaming parser uses a state machine with three phases:

pre_think -> thinking -> content

Transitions are tracked by parser state. Accumulated text is consulted only
to detect when a start/end tag has completed across delta boundaries.
"""

@property
Expand All @@ -43,6 +53,12 @@ def end_token(self) -> str:

def __init__(self, tokenizer=None):
super().__init__(tokenizer)
# Streaming state — reset per request via reset_state()
self._phase: str = "pre_think" # "pre_think" | "thinking" | "content"

def reset_state(self):
"""Reset state machine for a new streaming request."""
self._phase = "pre_think"

def extract_reasoning(
self,
Expand All @@ -66,14 +82,11 @@ def extract_reasoning(

# Case 1: Both tags present (normal case)
if self.start_token in text and self.end_token in text:
# Get everything after start token
_, _, after_start = text.partition(self.start_token)
# Split on end token
reasoning, _, content = after_start.partition(self.end_token)
return reasoning.strip() or None, content.strip() or None

# Case 2: Only closing tag (think was injected in prompt)
# Everything before </think> is reasoning
if self.end_token in text:
reasoning, _, content = text.partition(self.end_token)
return reasoning.strip() or None, content.strip() or None
Expand All @@ -83,7 +96,7 @@ def extract_reasoning(
_, _, reasoning = text.partition(self.start_token)
return reasoning.strip() or None, None

# Case 4: No tags at all - pure content
# Case 4: No tags at all pure content
return None, model_output

def extract_reasoning_streaming(
Expand All @@ -93,123 +106,99 @@ def extract_reasoning_streaming(
delta_text: str,
) -> DeltaMessage | None:
"""
Extract reasoning from streaming delta using text-based detection.
Extract reasoning from a streaming delta using state-machine tracking.

Instead of rescanning the full accumulated text on every token, this
method tracks the current phase (pre_think / thinking / content) and
only consults accumulated text to detect completed start/end tags that
were split across delta boundaries.

Handles implicit reasoning mode where <think> was in the prompt
and only </think> appears in the output.
Handles three scenarios:
1. Explicit <think>...</think> in model output
2. Implicit mode (<think> in prompt, only </think> in output)
3. No tags at all (pure content after first token with no reasoning)

Args:
previous_text: Text accumulated before this delta.
current_text: Text including this delta.
delta_text: Just the new text.
delta_text: Just the new text in this chunk.

Returns:
DeltaMessage with reasoning/content, or None to skip.
DeltaMessage with reasoning and/or content, or None to skip.
"""
# Skip if delta is just the special tokens themselves
stripped_delta = delta_text.strip()
if stripped_delta == self.start_token:
return None
if stripped_delta == self.end_token:
if not delta_text:
return None

# Check token positions in text (stateless text-based detection)
start_in_prev = self.start_token in previous_text
start_in_current = self.start_token in current_text
end_in_prev = self.end_token in previous_text
end_in_delta = self.end_token in delta_text

# Case 1: Explicit <think> found in text - standard behavior
if start_in_current:
return self._handle_explicit_think(
previous_text, delta_text, start_in_prev, end_in_prev, end_in_delta
)

# Case 2: No <think> but </think> found - implicit reasoning mode
# This handles when <think> was injected in the prompt
if self.end_token in current_text:
return self._handle_implicit_think(delta_text, end_in_prev, end_in_delta)

# Case 3: No think tags seen yet
# We can't know if <think> was in the prompt, so we must make a choice:
# - Treat as content (safe, but loses reasoning if think was in prompt)
# - Treat as reasoning (risky, wrong if no thinking at all)
# We choose to treat as reasoning IF we haven't seen </think> yet,
# because if think was in prompt, we want to capture the reasoning.
# This will be corrected once </think> is seen.
return DeltaMessage(reasoning=delta_text)

def _handle_explicit_think(
self,
previous_text: str,
delta_text: str,
start_in_prev: bool,
end_in_prev: bool,
end_in_delta: bool,
) -> DeltaMessage | None:
"""Handle case where <think> tag is explicitly in the output."""
start_in_delta = self.start_token in delta_text

if start_in_prev:
# We're after the start token
if end_in_delta:
# Transition: end token in this delta
idx = delta_text.find(self.end_token)
reasoning_part = delta_text[:idx]
content_part = delta_text[idx + len(self.end_token) :]
start_tok = self.start_token
end_tok = self.end_token

# ── Phase: pre_think ──────────────────────────────────────
# Haven't seen a completed tag yet. Could be:
# - About to see <think> (explicit reasoning)
# - Already inside implicit reasoning (think was in prompt)
# - No reasoning at all (pure content model)
if self._phase == "pre_think":
if start_tok in current_text:
self._phase = "thinking"
idx = delta_text.find(start_tok)
after = delta_text[idx + len(start_tok) :] if idx >= 0 else delta_text

if end_tok in after:
self._phase = "content"
eidx = after.find(end_tok)
reasoning = after[:eidx]
content = after[eidx + len(end_tok) :]
if not reasoning and not content:
return None
return DeltaMessage(
reasoning=reasoning or None,
content=content or None,
)
return DeltaMessage(reasoning=after) if after else None

# Implicit mode: </think> completed without an explicit <think>.
if end_tok in current_text:
self._phase = "content"
idx = delta_text.find(end_tok)
if idx >= 0:
reasoning = delta_text[:idx]
content = delta_text[idx + len(end_tok) :]
else:
reasoning = None
content = delta_text
if not reasoning and not content:
return None
return DeltaMessage(
reasoning=reasoning_part if reasoning_part else None,
content=content_part if content_part else None,
reasoning=reasoning or None,
content=content or None,
)
elif end_in_prev:
# Already past reasoning phase - pure content
return DeltaMessage(content=delta_text)
else:
# Still in reasoning phase
return DeltaMessage(reasoning=delta_text)

elif start_in_delta:
# Start token is in this delta
start_idx = delta_text.find(self.start_token)

if end_in_delta:
# Both tokens in this delta
end_idx = delta_text.find(self.end_token)
reasoning_part = delta_text[start_idx + len(self.start_token) : end_idx]
content_part = delta_text[end_idx + len(self.end_token) :]
return DeltaMessage(
reasoning=reasoning_part if reasoning_part else None,
content=content_part if content_part else None,
)
else:
# Only start token - beginning of reasoning
reasoning_part = delta_text[start_idx + len(self.start_token) :]

# No tags — default to reasoning (implicit mode assumption).
# If the model doesn't use thinking at all, the server's
# non-parser path handles it. This path only activates when
# a reasoning parser is explicitly configured.
return DeltaMessage(reasoning=delta_text)

# ── Phase: thinking ───────────────────────────────────────
# Inside a reasoning block, waiting for end tag.
if self._phase == "thinking":
if end_tok in current_text and end_tok not in previous_text:
self._phase = "content"
idx = delta_text.find(end_tok)
if idx >= 0:
reasoning = delta_text[:idx]
content = delta_text[idx + len(end_tok) :]
else:
reasoning = delta_text
content = None
if not reasoning and not content:
return None
return DeltaMessage(
reasoning=reasoning_part if reasoning_part else None
reasoning=reasoning or None,
content=content or None,
)
return DeltaMessage(reasoning=delta_text)

# Fallback - treat as content
# ── Phase: content ────────────────────────────────────────
# Past the reasoning block — everything is content.
return DeltaMessage(content=delta_text)

def _handle_implicit_think(
self,
delta_text: str,
end_in_prev: bool,
end_in_delta: bool,
) -> DeltaMessage | None:
"""Handle case where <think> was in prompt (only </think> in output)."""
if end_in_delta:
# Transition: end token in this delta
idx = delta_text.find(self.end_token)
reasoning_part = delta_text[:idx]
content_part = delta_text[idx + len(self.end_token) :]
return DeltaMessage(
reasoning=reasoning_part if reasoning_part else None,
content=content_part if content_part else None,
)
elif end_in_prev:
# Already past reasoning phase - pure content
return DeltaMessage(content=delta_text)
else:
# Still in implicit reasoning phase
return DeltaMessage(reasoning=delta_text)
Loading