Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/reference/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun
| Mistral / Devstral | 7B, Mixtral 8x7B | 4-bit, 8-bit |
| Qwen2/Qwen3 | 0.5B to 72B | Various |
| DeepSeek V3, R1 | 7B, 33B, 67B | 4-bit |
| Gemma 2, 3 | 2B, 9B, 27B | 4-bit |
| Gemma 2, 3, 4 | 2B, 9B, 27B | 4-bit |
| GLM-4.7 | Flash, Base | 4-bit, 8-bit |
| Kimi K2 | Various | 4-bit |
| Phi-3 | 3.8B, 14B | 4-bit |
Expand All @@ -35,6 +35,7 @@ Browse thousands of pre-optimized models at: **https://huggingface.co/mlx-commun
| **Qwen-VL** | `Qwen3-VL-4B-Instruct-3bit`, `Qwen3-VL-8B-Instruct-4bit`, `Qwen2-VL-2B/7B-Instruct-4bit` |
| **LLaVA** | `llava-1.5-7b-4bit`, `llava-v1.6-mistral-7b-4bit`, `llava-llama-3-8b-v1_1-4bit` |
| **Idefics** | `Idefics3-8B-Llama3-4bit`, `idefics2-8b-4bit` |
| **Gemma 4** | `gemma-4-e2b-it-mxfp4` (vision + audio) |
| **PaliGemma** | `paligemma2-3b-mix-224-4bit`, `paligemma-3b-mix-224-8bit` |
| **Pixtral** | `pixtral-12b-4bit`, `pixtral-12b-8bit` |
| **Molmo** | `Molmo-7B-D-0924-4bit`, `Molmo-7B-D-0924-8bit` |
Expand Down Expand Up @@ -72,7 +73,7 @@ vllm-mlx auto-detects multimodal models by name patterns:
- Contains "VL", "Vision", "vision"
- Contains "llava", "idefics", "paligemma"
- Contains "pixtral", "molmo", "deepseek-vl"
- Contains "MedGemma", "Gemma-3" (vision variants)
- Contains "MedGemma", "Gemma-3", "Gemma-4" (multimodal variants)

## Using Models

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ classifiers = [
dependencies = [
"mlx>=0.29.0",
"mlx-lm>=0.31.0", # 0.31+ required for ArraysCache native batching (hybrid models)
"mlx-vlm>=0.1.0", # VLM support
"mlx-vlm>=0.4.3", # 0.4.3+ required for Gemma 4 support
"transformers>=5.0.0", # mlx-lm 0.30.5+ requires transformers 5.0 (rc3 bug fixed in stable)
"tokenizers>=0.19.0",
"huggingface-hub>=0.23.0",
Expand Down
47 changes: 47 additions & 0 deletions tests/test_mllm_continuous_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,53 @@ def test_finished_response(self):

assert resp.finish_reason == "stop"

def test_error_response_skips_decoding(self):
"""Error responses must not decode token=0 as content."""
from unittest.mock import MagicMock, PropertyMock

from vllm_mlx.mllm_batch_generator import MLLMBatchResponse
from vllm_mlx.mllm_scheduler import MLLMScheduler
from vllm_mlx.request import RequestStatus

# Build a minimal scheduler with mocked internals
scheduler = MLLMScheduler.__new__(MLLMScheduler)
scheduler._detokenizer_pool = {}
scheduler.uid_to_request_id = {0: "req-err"}
scheduler.total_completion_tokens = 0
scheduler.num_requests_processed = 0

mock_tokenizer = MagicMock()
mock_tokenizer.decode.return_value = ""
mock_processor = MagicMock()
mock_processor.tokenizer = mock_tokenizer
scheduler.processor = mock_processor

# Create a running request
mock_request = MagicMock()
mock_request.request_id = "req-err"
mock_request.output_tokens = []
mock_request.num_output_tokens = 0
mock_request.num_prompt_tokens = 10
mock_request.status = RequestStatus.RUNNING
scheduler.running = {"req-err": mock_request}

error_resp = MLLMBatchResponse(
uid=0,
request_id="req-err",
token=0,
logprobs=mx.array([0.0]),
finish_reason="error",
)

outputs, finished = scheduler._process_batch_responses([error_resp])

assert "req-err" in finished
assert mock_request.status == RequestStatus.FINISHED_ABORTED
# token=0 should not have been decoded through a detokenizer
assert "req-err" not in scheduler._detokenizer_pool
assert len(outputs) == 1
assert outputs[0].new_text == ""


class TestMLLMBatch:
"""Tests for MLLMBatch class."""
Expand Down
266 changes: 266 additions & 0 deletions tests/test_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- Parser registry (registration, lookup, listing)
- Qwen3 parser (non-streaming and streaming)
- DeepSeek-R1 parser (non-streaming and streaming)
- Gemma 4 parser (channel protocol, streaming, channel name stripping)
- Edge cases (no tags, partial tags, etc.)
"""

Expand All @@ -28,6 +29,7 @@ def test_list_parsers_includes_builtin(self):
parsers = list_parsers()
assert "qwen3" in parsers
assert "deepseek_r1" in parsers
assert "gemma4" in parsers

def test_get_parser_qwen3(self):
"""Should be able to get Qwen3 parser."""
Expand Down Expand Up @@ -920,3 +922,267 @@ def test_constrain_tokens_stripped(self, parser):
reasoning, content = parser.extract_reasoning(output)
assert "<|constrain|>" not in (content or "")
assert "<|channel|>" not in (content or "")


class TestGemma4Parser:
"""Tests for the Gemma 4 reasoning parser (channel-based protocol)."""

@pytest.fixture
def parser(self):
"""Create a fresh Gemma 4 parser for each test."""
return get_parser("gemma4")()

# --- Non-streaming tests ---

def test_extract_standard_format(self, parser):
"""Standard format: <|channel>thought...<channel|>response."""
output = (
"<|channel>thought\nLet me think step by step.\n<channel|>The answer is 42."
)
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "Let me think step by step."
assert content == "The answer is 42."

def test_extract_alternative_format(self, parser):
"""Alternative format: <|channel>thought...<|channel>response..."""
output = "<|channel>thought\nAnalyzing the problem.\n<|channel>response\nThe result is 7."
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "Analyzing the problem."
assert content == "The result is 7."

def test_extract_strips_thought_prefix(self, parser):
"""Channel name 'thought' should be stripped from reasoning."""
output = "<|channel>thought\nActual reasoning here<channel|>Content"
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "Actual reasoning here"
assert "thought" not in reasoning

def test_extract_no_tags_pure_content(self, parser):
"""No channel tags at all should return pure content."""
output = "Just a regular response without thinking."
reasoning, content = parser.extract_reasoning(output)
assert reasoning is None
assert content == output

def test_extract_only_start_tag(self, parser):
"""Only start tag means incomplete reasoning (no content yet)."""
output = "<|channel>thought\nStill thinking..."
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "Still thinking..."
assert content is None

def test_extract_only_end_tag(self, parser):
"""Only end tag (think injected in prompt)."""
output = "thought\nImplicit reasoning<channel|>The answer"
reasoning, content = parser.extract_reasoning(output)
assert reasoning == "Implicit reasoning"
assert content == "The answer"

def test_extract_empty_reasoning(self, parser):
"""Empty reasoning should return None."""
output = "<|channel>thought\n<channel|>Only content here."
reasoning, content = parser.extract_reasoning(output)
assert reasoning is None
assert content == "Only content here."

def test_extract_multiline_reasoning(self, parser):
"""Should preserve multiline reasoning content."""
output = (
"<|channel>thought\n"
"Step 1: Understand the question.\n"
"Step 2: Analyze the data.\n"
"Step 3: Form conclusion.\n"
"<channel|>The conclusion is clear."
)
reasoning, content = parser.extract_reasoning(output)
assert "Step 1" in reasoning
assert "Step 2" in reasoning
assert "Step 3" in reasoning
assert content == "The conclusion is clear."

def test_extract_unicode_reasoning(self, parser):
"""Should handle Unicode in reasoning."""
output = "<|channel>thought\n日本語テスト 🤔\n<channel|>答えは42"
reasoning, content = parser.extract_reasoning(output)
assert "日本語テスト" in reasoning
assert "🤔" in reasoning
assert "42" in content

def test_registry_includes_gemma4(self):
"""gemma4 should be in the parser registry."""
assert "gemma4" in list_parsers()

# --- Streaming tests ---

def test_streaming_no_tags_plain_content(self, parser):
"""Streaming without any channel tags should return content."""
parser.reset_state()
result = parser.extract_reasoning_streaming("", "Hello", "Hello")
assert result is not None
assert result.content == "Hello"
assert result.reasoning is None

def test_streaming_standard_format(self, parser):
"""Test streaming through <|channel>thought...<channel|>content flow."""
parser.reset_state()

tokens = [
"<|channel>",
"thought",
"\n",
"Let me ",
"think.",
"<channel|>",
"The ",
"answer.",
]

accumulated = ""
reasoning_parts = []
content_parts = []

for token in tokens:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result:
if result.reasoning:
reasoning_parts.append(result.reasoning)
if result.content:
content_parts.append(result.content)

full_reasoning = "".join(reasoning_parts)
full_content = "".join(content_parts)

# "thought\n" prefix should be stripped
assert "thought" not in full_reasoning or "thought" in "Let me think."
assert "Let me think." in full_reasoning
assert "The answer." in full_content

def test_streaming_alternative_format(self, parser):
"""Test streaming with <|channel>response transition."""
parser.reset_state()

tokens = [
"<|channel>",
"thought",
"\n",
"Analyzing.",
"<|channel>response",
"\n",
"Result: ",
"42",
]

accumulated = ""
reasoning_parts = []
content_parts = []

for token in tokens:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result:
if result.reasoning:
reasoning_parts.append(result.reasoning)
if result.content:
content_parts.append(result.content)

full_content = "".join(content_parts)
assert "Result: 42" in full_content

def test_streaming_suppresses_channel_names(self, parser):
"""Channel names 'thought' and 'response' should not appear in output."""
parser.reset_state()

# Simulate realistic Gemma 4 output
tokens = [
"<|channel>",
"thought",
"\n",
"Real ",
"reasoning.",
"<channel|>",
"Real ",
"content.",
]

accumulated = ""
all_output = []

for token in tokens:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result:
if result.reasoning:
all_output.append(("r", result.reasoning))
if result.content:
all_output.append(("c", result.content))

# Verify no raw "thought" token leaked as reasoning
reasoning_text = "".join(t for tag, t in all_output if tag == "r")
content_text = "".join(t for tag, t in all_output if tag == "c")

assert "Real reasoning." in reasoning_text
assert "Real content." in content_text

def test_streaming_token_by_token(self, parser):
"""Test character-by-character streaming (worst case)."""
parser.reset_state()

output = "<|channel>thought\nStep 1: Think\nStep 2: Analyze\n<channel|>Final answer: 42."

accumulated = ""
reasoning_parts = []
content_parts = []

for char in output:
prev = accumulated
accumulated += char
result = parser.extract_reasoning_streaming(prev, accumulated, char)
if result:
if result.reasoning:
reasoning_parts.append(result.reasoning)
if result.content:
content_parts.append(result.content)

full_reasoning = "".join(reasoning_parts)
full_content = "".join(content_parts)

assert "Step 1: Think" in full_reasoning
assert "Step 2: Analyze" in full_reasoning
assert "Final answer: 42." in full_content

def test_streaming_long_thinking_no_end_tag(self, parser):
"""When model generates long thinking without end tag, all goes to reasoning."""
parser.reset_state()

# Simulate model that hits max_tokens before <channel|>
tokens = [
"<|channel>",
"thought",
"\n",
"This is a very long ",
"reasoning process ",
"that continues ",
"without ending.",
]

accumulated = ""
reasoning_parts = []
content_parts = []

for token in tokens:
prev = accumulated
accumulated += token
result = parser.extract_reasoning_streaming(prev, accumulated, token)
if result:
if result.reasoning:
reasoning_parts.append(result.reasoning)
if result.content:
content_parts.append(result.content)

full_reasoning = "".join(reasoning_parts)
assert "very long reasoning process" in full_reasoning
assert len(content_parts) == 0
2 changes: 2 additions & 0 deletions vllm_mlx/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ def flush(self) -> list[tuple[str, str]]:
"PaliGemma", # PaliGemma
"gemma-3",
"gemma3", # Gemma 3 (multimodal)
"gemma-4",
"gemma4", # Gemma 4 (multimodal: vision + audio)
"medgemma",
"MedGemma", # MedGemma (medical multimodal with SigLIP vision encoder)
"pixtral",
Expand Down
Loading