Skip to content
129 changes: 129 additions & 0 deletions tests/test_mllm_mtp_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for MLLM + MTP per-request routing."""


def test_has_media_content_text_only():
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([{"role": "user", "content": "Hello"}]) is False


def test_has_media_content_with_image():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's this?"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,..."},
},
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_with_video():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "video_url", "video_url": {"url": "file:///tmp/v.mp4"}}
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_empty():
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([]) is False


def test_has_media_content_string_content():
"""String content (not list) should return False."""
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([{"role": "user", "content": "Just text"}]) is False


def test_has_media_content_audio():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,..."}}
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_multi_turn():
"""Media in earlier turns should still be detected."""
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Look at this"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,..."},
},
],
},
{"role": "assistant", "content": "I see an image."},
{"role": "user", "content": "Tell me more about it."},
]
assert _has_media_content(messages) is True


def test_has_media_content_text_list():
"""List content with only text parts should return False."""
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
],
}
]
assert _has_media_content(messages) is False


# --- MLXMultimodalLM extraction method tests ---

from unittest.mock import MagicMock


def test_get_language_model():
from vllm_mlx.models.mllm import MLXMultimodalLM

mllm = MagicMock(spec=MLXMultimodalLM)
inner_lm = MagicMock()
mllm.model = MagicMock()
mllm.model.language_model = inner_lm
assert MLXMultimodalLM.get_language_model(mllm) is inner_lm


def test_get_tokenizer():
from vllm_mlx.models.mllm import MLXMultimodalLM

mllm = MagicMock(spec=MLXMultimodalLM)
inner_tok = MagicMock()
mllm.processor = MagicMock()
mllm.processor.tokenizer = inner_tok
assert MLXMultimodalLM.get_tokenizer(mllm) is inner_tok
140 changes: 140 additions & 0 deletions tests/test_text_model_from_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for building mlx_lm TextModel from mlx_vlm-loaded weights."""

import json
from pathlib import Path

import pytest

from vllm_mlx.text_model_from_vlm import build_text_model

# VLM+MTP model (created by merging mlx-community VLM + our MTP weights)
VLM_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-VLM-MTP-8bit"

# Text-only MTP model (no vision tower — can't test VLM loading)
TEXT_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-8bit"


def test_build_text_model_no_config():
"""Returns None when model path has no config.json."""
result = build_text_model(None, "/nonexistent/path")
assert result is None


def test_build_text_model_none_vlm():
"""Returns None when vlm_model is None."""
result = build_text_model(None, TEXT_MTP_MODEL)
assert result is None


@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk")
def test_build_text_model_moe():
"""build_text_model creates a TextModel with shared weights and MTP (MoE)."""
import runtime_patches

runtime_patches.apply()

from mlx_vlm import load as vlm_load

vlm_model, processor = vlm_load(str(VLM_MTP_MODEL))
text_model = build_text_model(vlm_model, VLM_MTP_MODEL)

assert text_model is not None, "build_text_model returned None"

# TextModel should have MTP (config has mtp_num_hidden_layers=1)
assert hasattr(text_model, "mtp"), "TextModel missing .mtp attribute"
assert text_model.mtp is not None, "TextModel.mtp is None"
assert hasattr(text_model, "mtp_forward"), "TextModel missing mtp_forward method"
assert hasattr(
text_model, "make_mtp_cache"
), "TextModel missing make_mtp_cache method"

# Verify MoE layer exists in MTP
mtp_layer = text_model.mtp.layers[0]
assert hasattr(mtp_layer, "mlp"), "MTP layer missing mlp"


@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk")
def test_text_model_mtp_forward():
"""TextModel.mtp_forward returns logits of correct vocab_size shape."""
import mlx.core as mx
import runtime_patches

runtime_patches.apply()

from mlx_vlm import load as vlm_load

vlm_model, _ = vlm_load(str(VLM_MTP_MODEL))
text_model = build_text_model(vlm_model, VLM_MTP_MODEL)

config = json.loads((VLM_MTP_MODEL / "config.json").read_text())
text_config = config.get("text_config", config)

mtp_cache = text_model.make_mtp_cache()
assert len(mtp_cache) > 0

hidden = mx.zeros((1, 1, text_config["hidden_size"]))
next_ids = mx.array([[0]])
logits = text_model.mtp_forward(hidden, next_ids, mtp_cache)

assert (
logits.shape[-1] == text_config["vocab_size"]
), f"Expected vocab_size={text_config['vocab_size']}, got {logits.shape[-1]}"


@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk")
def test_text_model_return_hidden():
"""TextModel supports return_hidden=True (required by mtp_generate_step)."""
import mlx.core as mx
import runtime_patches

runtime_patches.apply()

from mlx_vlm import load as vlm_load

vlm_model, _ = vlm_load(str(VLM_MTP_MODEL))
text_model = build_text_model(vlm_model, VLM_MTP_MODEL)

config = json.loads((VLM_MTP_MODEL / "config.json").read_text())
text_config = config.get("text_config", config)

cache = text_model.make_cache()
tokens = mx.array([[1, 2, 3]]) # Dummy token IDs

# return_hidden=True should return (logits, hidden_states)
result = text_model(tokens, cache=cache, return_hidden=True)

# Should be a tuple of (logits, hidden)
assert isinstance(result, tuple), f"Expected tuple, got {type(result)}"
logits, hidden = result
assert logits.shape[-1] == text_config["vocab_size"]
assert hidden.shape[-1] == text_config["hidden_size"]


@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk")
def test_weight_sharing():
"""Backbone weights are shared (zero-copy) between vlm and TextModel."""
import mlx.core as mx
import runtime_patches

runtime_patches.apply()

from mlx_vlm import load as vlm_load

vlm_model, _ = vlm_load(str(VLM_MTP_MODEL))
text_model = build_text_model(vlm_model, VLM_MTP_MODEL)

# Compare a backbone weight reference.
# Layer 0 may be linear_attn (GatedDeltaNet) on MoE models, so find a layer
# with self_attn (full attention layers are at indices 11, 15, 19, 23, 27).
for i in range(len(vlm_model.language_model.model.layers)):
layer = vlm_model.language_model.model.layers[i]
if hasattr(layer, "self_attn"):
vlm_weight = layer.self_attn.q_proj.weight
tm_weight = text_model.model.layers[i].self_attn.q_proj.weight
assert mx.array_equal(
vlm_weight, tm_weight
), f"Weights at layer {i} should be identical"
break
else:
pytest.fail("No layer with self_attn found")
4 changes: 4 additions & 0 deletions vllm_mlx/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ class ChatCompletionRequest(BaseModel):
video_max_frames: int | None = None
# Request timeout in seconds (None = use server default)
timeout: float | None = None
# SpecPrefill: per-request enable/disable (None = server decides)
specprefill: bool | None = None
# SpecPrefill: per-request keep percentage (0.0-1.0, None = use server default)
specprefill_keep_pct: float | None = None


class AssistantMessage(BaseModel):
Expand Down
54 changes: 54 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,16 @@ def serve_command(args):
print(f"Prefix cache: max_entries={args.prefix_cache_size}")
else:
print("Mode: Simple (maximum throughput)")
if args.enable_mtp:
print("MTP: enabled (native speculative decoding)")
if args.enable_mtp and getattr(args, "mllm", False):
print("MTP + MLLM: per-request routing (text-only → MTP, media → MLLM)")
if args.specprefill and args.specprefill_draft_model:
print(
f"SpecPrefill: enabled (draft={args.specprefill_draft_model}, "
f"threshold={args.specprefill_threshold}, "
f"keep={args.specprefill_keep_pct*100:.0f}%)"
)

# Load model with unified server
load_model(
Expand All @@ -187,6 +197,12 @@ def serve_command(args):
stream_interval=args.stream_interval if args.continuous_batching else 1,
max_tokens=args.max_tokens,
force_mllm=args.mllm,
mtp=args.enable_mtp,
prefill_step_size=args.prefill_step_size,
specprefill_enabled=args.specprefill,
specprefill_threshold=args.specprefill_threshold,
specprefill_keep_pct=args.specprefill_keep_pct,
specprefill_draft_model=args.specprefill_draft_model,
)

# Start server
Expand Down Expand Up @@ -728,6 +744,44 @@ def main():
help="Skip MTP acceptance check for maximum speed. "
"~5-10%% wrong tokens. Best for chat, not for code.",
)
# Prefill step size
serve_parser.add_argument(
"--prefill-step-size",
type=int,
default=2048,
help="Chunk size for prompt prefill processing. Larger values use more memory "
"but can improve prefill throughput. (default: 2048)",
)
# SpecPrefill (attention-based sparse prefill using draft model)
serve_parser.add_argument(
"--specprefill",
action="store_true",
default=False,
help="Enable SpecPrefill: use a small draft model to score token importance, "
"then sparse-prefill only the important tokens on the target model. "
"Reduces TTFT on long prompts. Requires --specprefill-draft-model.",
)
serve_parser.add_argument(
"--specprefill-threshold",
type=int,
default=8192,
help="Minimum suffix tokens to trigger SpecPrefill (default: 8192). "
"Shorter prompts use full prefill (scoring overhead > savings).",
)
serve_parser.add_argument(
"--specprefill-keep-pct",
type=float,
default=0.3,
help="Fraction of tokens to keep during sparse prefill (default: 0.3). "
"Lower = faster prefill but more quality loss.",
)
serve_parser.add_argument(
"--specprefill-draft-model",
type=str,
default=None,
help="Path to small draft model for SpecPrefill importance scoring. "
"Must share the same tokenizer as the target model.",
)
# MCP options
serve_parser.add_argument(
"--mcp-config",
Expand Down
Loading
Loading