Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions tests/test_output_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,47 @@ def get_vocab(self) -> dict[str, int]:
TOKENIZER = FakeTokenizer(VOCAB)


# === Qwen3 Token IDs (representative) ===
QWEN3_VOCAB = {
"<|endoftext|>": 151643,
"<|im_end|>": 151645,
"<think>": 151667,
"</think>": 151668,
"Let": 5733,
"me": 2734,
"analyze": 28541,
"The": 785,
"answer": 10234,
"is": 374,
"42": 2983,
".": 13,
}

QWEN3_MAP = TokenMap(
think_start=151667, # <think>
think_end=151668, # </think>
bos=151643, # <|endoftext|>
eos=151645, # <|im_end|>
pad=151643, # <|endoftext|>
)

QWEN3_TOKENIZER = FakeTokenizer(QWEN3_VOCAB)


@pytest.fixture
def router():
r = OutputRouter(GEMMA4_MAP, TOKENIZER)
r.reset()
return r


@pytest.fixture
def qwen3_router():
r = OutputRouter(QWEN3_MAP, QWEN3_TOKENIZER)
r.reset()
return r


class TestBasicRouting:
"""Test fundamental token routing."""

Expand Down Expand Up @@ -272,6 +306,13 @@ def test_gemma4_detected(self):
assert router is not None
assert router.map.channel_start == 100

def test_qwen3_detected(self):
"""Qwen3 tokenizer auto-detected via <think>/</think> tokens."""
router = OutputRouter.from_tokenizer(QWEN3_TOKENIZER)
assert router is not None
assert router.map.think_start == 151667
assert router.map.think_end == 151668

def test_unknown_tokenizer(self):
"""Non-Gemma tokenizer returns None."""
plain_vocab = {"hello": 1, "world": 2}
Expand Down Expand Up @@ -304,3 +345,94 @@ def test_multiple_requests(self, router):
router.reset()
e2 = router.feed(9259) # "Hello"
assert e2.channel == Channel.CONTENT


# === Qwen3 Think-Tag Tests ===


class TestQwen3ThinkRouting:
"""Test Qwen3 <think>/</think> tag routing."""

def test_think_start_enters_thinking(self, qwen3_router):
"""<think> token enters THINKING state."""
assert qwen3_router.feed(151667) is None # <think> suppressed
assert qwen3_router.state == RouterState.THINKING

def test_think_end_enters_content(self, qwen3_router):
"""</think> token switches to CONTENT state."""
qwen3_router.feed(151667) # <think>
assert qwen3_router.feed(151668) is None # </think> suppressed
assert qwen3_router.state == RouterState.CONTENT

def test_thinking_tokens_routed_to_reasoning(self, qwen3_router):
"""Tokens between <think> and </think> go to REASONING channel."""
qwen3_router.feed(151667) # <think>
e1 = qwen3_router.feed(5733) # "Let"
e2 = qwen3_router.feed(2734) # "me"
e3 = qwen3_router.feed(28541) # "analyze"
assert e1.channel == Channel.REASONING
assert e2.channel == Channel.REASONING
assert e3.channel == Channel.REASONING

def test_content_after_think_end(self, qwen3_router):
"""Tokens after </think> go to CONTENT channel."""
qwen3_router.feed(151667) # <think>
qwen3_router.feed(5733) # "Let" (reasoning)
qwen3_router.feed(151668) # </think>

e = qwen3_router.feed(785) # "The"
assert e.channel == Channel.CONTENT

def test_full_think_content_sequence(self, qwen3_router):
"""Full <think>reasoning</think>content sequence."""
tokens = [
151667, # <think>
5733, 2734, # "Let" "me" (reasoning)
151668, # </think>
785, 10234, 374, 2983, 13, # "The" "answer" "is" "42" "."
]
result = qwen3_router.feed_sequence(tokens)
assert result["reasoning"] == "Letmeanalyze" or "Let" in result["reasoning"]
assert result["content"] is not None
assert "42" in result["content"]
assert result["tool_calls"] is None

def test_implicit_think_only_end_tag(self, qwen3_router):
"""Implicit thinking: no <think>, only </think> in output.

When <think> was injected in the prompt, the model output starts
in INIT state (content). Tokens before </think> are content.
After </think>, tokens are also content. This matches the
expected behavior: without <think>, router stays in INIT/CONTENT.
"""
# No <think> token — router starts in INIT
e1 = qwen3_router.feed(5733) # "Let" — INIT defaults to CONTENT
assert e1.channel == Channel.CONTENT

qwen3_router.feed(151668) # </think> — switches to CONTENT (already content)

e2 = qwen3_router.feed(785) # "The"
assert e2.channel == Channel.CONTENT

def test_no_tags_pure_content(self, qwen3_router):
"""No think tags at all: everything is content."""
tokens = [785, 10234, 374, 2983, 13] # "The" "answer" "is" "42" "."
result = qwen3_router.feed_sequence(tokens)
assert result["content"] is not None
assert result["reasoning"] is None

def test_bos_eos_suppressed(self, qwen3_router):
"""Qwen3 control tokens are suppressed."""
assert qwen3_router.feed(151643) is None # <|endoftext|> (bos/pad)
assert qwen3_router.feed(151645) is None # <|im_end|> (eos)

def test_reset_after_thinking(self, qwen3_router):
"""Reset clears thinking state for Qwen3."""
qwen3_router.feed(151667) # <think>
assert qwen3_router.state == RouterState.THINKING

qwen3_router.reset()
assert qwen3_router.state == RouterState.INIT

e = qwen3_router.feed(785) # "The"
assert e.channel == Channel.CONTENT
25 changes: 22 additions & 3 deletions vllm_mlx/output_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ def feed(self, token_id: int) -> RouterEvent | None:
m.tool_start, m.tool_end):
return None

# === Think tags (Qwen3/DeepSeek style): <think> / </think> ===
if token_id == m.think_start:
self.state = RouterState.THINKING
return None # suppress <think>
if token_id == m.think_end:
self.state = RouterState.CONTENT
return None # suppress </think>

# === Channel start: transition to AWAITING_CHANNEL_TYPE ===
if token_id == m.channel_start:
self.state = RouterState.AWAITING_CHANNEL_TYPE
Expand Down Expand Up @@ -253,8 +261,19 @@ def from_tokenizer(cls, tokenizer: Any) -> "OutputRouter | None":
return cls(token_map, tokenizer)

# Qwen/DeepSeek detection: look for <think> and </think>
# TODO: implement when migrating existing parsers
# if "<think>" in vocab and "</think>" in vocab:
# ...
if "<think>" in vocab and "</think>" in vocab:
token_map = TokenMap(
think_start=vocab["<think>"],
think_end=vocab["</think>"],
bos=vocab.get("<|endoftext|>") or vocab.get("<bos>"),
eos=vocab.get("<|im_end|>") or vocab.get("<|endoftext|>") or vocab.get("<eos>"),
pad=vocab.get("<|endoftext|>") or vocab.get("<pad>"),
)
logger.info(
"[OutputRouter] Qwen3/think-tag format detected: "
"think_start=%d, think_end=%d",
token_map.think_start, token_map.think_end,
)
return cls(token_map, tokenizer)

return None # unsupported model format