diff --git a/tests/test_output_router.py b/tests/test_output_router.py index da2b0c3..10f22f9 100644 --- a/tests/test_output_router.py +++ b/tests/test_output_router.py @@ -68,6 +68,33 @@ def get_vocab(self) -> dict[str, int]: TOKENIZER = FakeTokenizer(VOCAB) +# === Qwen3 Token IDs (representative) === +QWEN3_VOCAB = { + "<|endoftext|>": 151643, + "<|im_end|>": 151645, + "": 151667, + "": 151668, + "Let": 5733, + "me": 2734, + "analyze": 28541, + "The": 785, + "answer": 10234, + "is": 374, + "42": 2983, + ".": 13, +} + +QWEN3_MAP = TokenMap( + think_start=151667, # + think_end=151668, # + bos=151643, # <|endoftext|> + eos=151645, # <|im_end|> + pad=151643, # <|endoftext|> +) + +QWEN3_TOKENIZER = FakeTokenizer(QWEN3_VOCAB) + + @pytest.fixture def router(): r = OutputRouter(GEMMA4_MAP, TOKENIZER) @@ -75,6 +102,13 @@ def router(): return r +@pytest.fixture +def qwen3_router(): + r = OutputRouter(QWEN3_MAP, QWEN3_TOKENIZER) + r.reset() + return r + + class TestBasicRouting: """Test fundamental token routing.""" @@ -272,6 +306,13 @@ def test_gemma4_detected(self): assert router is not None assert router.map.channel_start == 100 + def test_qwen3_detected(self): + """Qwen3 tokenizer auto-detected via / tokens.""" + router = OutputRouter.from_tokenizer(QWEN3_TOKENIZER) + assert router is not None + assert router.map.think_start == 151667 + assert router.map.think_end == 151668 + def test_unknown_tokenizer(self): """Non-Gemma tokenizer returns None.""" plain_vocab = {"hello": 1, "world": 2} @@ -304,3 +345,94 @@ def test_multiple_requests(self, router): router.reset() e2 = router.feed(9259) # "Hello" assert e2.channel == Channel.CONTENT + + +# === Qwen3 Think-Tag Tests === + + +class TestQwen3ThinkRouting: + """Test Qwen3 / tag routing.""" + + def test_think_start_enters_thinking(self, qwen3_router): + """ token enters THINKING state.""" + assert qwen3_router.feed(151667) is None # suppressed + assert qwen3_router.state == RouterState.THINKING + + def test_think_end_enters_content(self, qwen3_router): + """ token switches to CONTENT state.""" + qwen3_router.feed(151667) # + assert qwen3_router.feed(151668) is None # suppressed + assert qwen3_router.state == RouterState.CONTENT + + def test_thinking_tokens_routed_to_reasoning(self, qwen3_router): + """Tokens between and go to REASONING channel.""" + qwen3_router.feed(151667) # + e1 = qwen3_router.feed(5733) # "Let" + e2 = qwen3_router.feed(2734) # "me" + e3 = qwen3_router.feed(28541) # "analyze" + assert e1.channel == Channel.REASONING + assert e2.channel == Channel.REASONING + assert e3.channel == Channel.REASONING + + def test_content_after_think_end(self, qwen3_router): + """Tokens after go to CONTENT channel.""" + qwen3_router.feed(151667) # + qwen3_router.feed(5733) # "Let" (reasoning) + qwen3_router.feed(151668) # + + e = qwen3_router.feed(785) # "The" + assert e.channel == Channel.CONTENT + + def test_full_think_content_sequence(self, qwen3_router): + """Full reasoningcontent sequence.""" + tokens = [ + 151667, # + 5733, 2734, # "Let" "me" (reasoning) + 151668, # + 785, 10234, 374, 2983, 13, # "The" "answer" "is" "42" "." + ] + result = qwen3_router.feed_sequence(tokens) + assert result["reasoning"] == "Letmeanalyze" or "Let" in result["reasoning"] + assert result["content"] is not None + assert "42" in result["content"] + assert result["tool_calls"] is None + + def test_implicit_think_only_end_tag(self, qwen3_router): + """Implicit thinking: no , only in output. + + When was injected in the prompt, the model output starts + in INIT state (content). Tokens before are content. + After , tokens are also content. This matches the + expected behavior: without , router stays in INIT/CONTENT. + """ + # No token — router starts in INIT + e1 = qwen3_router.feed(5733) # "Let" — INIT defaults to CONTENT + assert e1.channel == Channel.CONTENT + + qwen3_router.feed(151668) # — switches to CONTENT (already content) + + e2 = qwen3_router.feed(785) # "The" + assert e2.channel == Channel.CONTENT + + def test_no_tags_pure_content(self, qwen3_router): + """No think tags at all: everything is content.""" + tokens = [785, 10234, 374, 2983, 13] # "The" "answer" "is" "42" "." + result = qwen3_router.feed_sequence(tokens) + assert result["content"] is not None + assert result["reasoning"] is None + + def test_bos_eos_suppressed(self, qwen3_router): + """Qwen3 control tokens are suppressed.""" + assert qwen3_router.feed(151643) is None # <|endoftext|> (bos/pad) + assert qwen3_router.feed(151645) is None # <|im_end|> (eos) + + def test_reset_after_thinking(self, qwen3_router): + """Reset clears thinking state for Qwen3.""" + qwen3_router.feed(151667) # + assert qwen3_router.state == RouterState.THINKING + + qwen3_router.reset() + assert qwen3_router.state == RouterState.INIT + + e = qwen3_router.feed(785) # "The" + assert e.channel == Channel.CONTENT diff --git a/vllm_mlx/output_router.py b/vllm_mlx/output_router.py index 9accb1e..bda04af 100644 --- a/vllm_mlx/output_router.py +++ b/vllm_mlx/output_router.py @@ -133,6 +133,14 @@ def feed(self, token_id: int) -> RouterEvent | None: m.tool_start, m.tool_end): return None + # === Think tags (Qwen3/DeepSeek style): / === + if token_id == m.think_start: + self.state = RouterState.THINKING + return None # suppress + if token_id == m.think_end: + self.state = RouterState.CONTENT + return None # suppress + # === Channel start: transition to AWAITING_CHANNEL_TYPE === if token_id == m.channel_start: self.state = RouterState.AWAITING_CHANNEL_TYPE @@ -253,8 +261,19 @@ def from_tokenizer(cls, tokenizer: Any) -> "OutputRouter | None": return cls(token_map, tokenizer) # Qwen/DeepSeek detection: look for and - # TODO: implement when migrating existing parsers - # if "" in vocab and "" in vocab: - # ... + if "" in vocab and "" in vocab: + token_map = TokenMap( + think_start=vocab[""], + think_end=vocab[""], + bos=vocab.get("<|endoftext|>") or vocab.get(""), + eos=vocab.get("<|im_end|>") or vocab.get("<|endoftext|>") or vocab.get(""), + pad=vocab.get("<|endoftext|>") or vocab.get(""), + ) + logger.info( + "[OutputRouter] Qwen3/think-tag format detected: " + "think_start=%d, think_end=%d", + token_map.think_start, token_map.think_end, + ) + return cls(token_map, tokenizer) return None # unsupported model format