Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions tests/test_mllm_mtp_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for MLLM + MTP per-request routing."""


def test_has_media_content_text_only():
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([{"role": "user", "content": "Hello"}]) is False


def test_has_media_content_with_image():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's this?"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,..."},
},
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_with_video():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "video_url", "video_url": {"url": "file:///tmp/v.mp4"}}
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_empty():
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([]) is False


def test_has_media_content_string_content():
"""String content (not list) should return False."""
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([{"role": "user", "content": "Just text"}]) is False


def test_has_media_content_audio():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,..."}}
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_multi_turn():
"""Media in earlier turns should still be detected."""
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Look at this"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,..."},
},
],
},
{"role": "assistant", "content": "I see an image."},
{"role": "user", "content": "Tell me more about it."},
]
assert _has_media_content(messages) is True


def test_has_media_content_text_list():
"""List content with only text parts should return False."""
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
],
}
]
assert _has_media_content(messages) is False


# --- MLXMultimodalLM extraction method tests ---

from unittest.mock import MagicMock


def test_get_language_model():
from vllm_mlx.models.mllm import MLXMultimodalLM

mllm = MagicMock(spec=MLXMultimodalLM)
inner_lm = MagicMock()
mllm.model = MagicMock()
mllm.model.language_model = inner_lm
assert MLXMultimodalLM.get_language_model(mllm) is inner_lm


def test_get_tokenizer():
from vllm_mlx.models.mllm import MLXMultimodalLM

mllm = MagicMock(spec=MLXMultimodalLM)
inner_tok = MagicMock()
mllm.processor = MagicMock()
mllm.processor.tokenizer = inner_tok
assert MLXMultimodalLM.get_tokenizer(mllm) is inner_tok
140 changes: 140 additions & 0 deletions tests/test_text_model_from_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for building mlx_lm TextModel from mlx_vlm-loaded weights."""

import json
from pathlib import Path

import pytest

from vllm_mlx.text_model_from_vlm import build_text_model

# VLM+MTP model (created by merging mlx-community VLM + our MTP weights)
VLM_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-VLM-MTP-8bit"

# Text-only MTP model (no vision tower — can't test VLM loading)
TEXT_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-8bit"


def test_build_text_model_no_config():
"""Returns None when model path has no config.json."""
result = build_text_model(None, "/nonexistent/path")
assert result is None


def test_build_text_model_none_vlm():
"""Returns None when vlm_model is None."""
result = build_text_model(None, TEXT_MTP_MODEL)
assert result is None


@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk")
def test_build_text_model_moe():
"""build_text_model creates a TextModel with shared weights and MTP (MoE)."""
import runtime_patches

runtime_patches.apply()

from mlx_vlm import load as vlm_load

vlm_model, processor = vlm_load(str(VLM_MTP_MODEL))
text_model = build_text_model(vlm_model, VLM_MTP_MODEL)

assert text_model is not None, "build_text_model returned None"

# TextModel should have MTP (config has mtp_num_hidden_layers=1)
assert hasattr(text_model, "mtp"), "TextModel missing .mtp attribute"
assert text_model.mtp is not None, "TextModel.mtp is None"
assert hasattr(text_model, "mtp_forward"), "TextModel missing mtp_forward method"
assert hasattr(
text_model, "make_mtp_cache"
), "TextModel missing make_mtp_cache method"

# Verify MoE layer exists in MTP
mtp_layer = text_model.mtp.layers[0]
assert hasattr(mtp_layer, "mlp"), "MTP layer missing mlp"


@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk")
def test_text_model_mtp_forward():
"""TextModel.mtp_forward returns logits of correct vocab_size shape."""
import mlx.core as mx
import runtime_patches

runtime_patches.apply()

from mlx_vlm import load as vlm_load

vlm_model, _ = vlm_load(str(VLM_MTP_MODEL))
text_model = build_text_model(vlm_model, VLM_MTP_MODEL)

config = json.loads((VLM_MTP_MODEL / "config.json").read_text())
text_config = config.get("text_config", config)

mtp_cache = text_model.make_mtp_cache()
assert len(mtp_cache) > 0

hidden = mx.zeros((1, 1, text_config["hidden_size"]))
next_ids = mx.array([[0]])
logits = text_model.mtp_forward(hidden, next_ids, mtp_cache)

assert (
logits.shape[-1] == text_config["vocab_size"]
), f"Expected vocab_size={text_config['vocab_size']}, got {logits.shape[-1]}"


@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk")
def test_text_model_return_hidden():
"""TextModel supports return_hidden=True (required by mtp_generate_step)."""
import mlx.core as mx
import runtime_patches

runtime_patches.apply()

from mlx_vlm import load as vlm_load

vlm_model, _ = vlm_load(str(VLM_MTP_MODEL))
text_model = build_text_model(vlm_model, VLM_MTP_MODEL)

config = json.loads((VLM_MTP_MODEL / "config.json").read_text())
text_config = config.get("text_config", config)

cache = text_model.make_cache()
tokens = mx.array([[1, 2, 3]]) # Dummy token IDs

# return_hidden=True should return (logits, hidden_states)
result = text_model(tokens, cache=cache, return_hidden=True)

# Should be a tuple of (logits, hidden)
assert isinstance(result, tuple), f"Expected tuple, got {type(result)}"
logits, hidden = result
assert logits.shape[-1] == text_config["vocab_size"]
assert hidden.shape[-1] == text_config["hidden_size"]


@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk")
def test_weight_sharing():
"""Backbone weights are shared (zero-copy) between vlm and TextModel."""
import mlx.core as mx
import runtime_patches

runtime_patches.apply()

from mlx_vlm import load as vlm_load

vlm_model, _ = vlm_load(str(VLM_MTP_MODEL))
text_model = build_text_model(vlm_model, VLM_MTP_MODEL)

# Compare a backbone weight reference.
# Layer 0 may be linear_attn (GatedDeltaNet) on MoE models, so find a layer
# with self_attn (full attention layers are at indices 11, 15, 19, 23, 27).
for i in range(len(vlm_model.language_model.model.layers)):
layer = vlm_model.language_model.model.layers[i]
if hasattr(layer, "self_attn"):
vlm_weight = layer.self_attn.q_proj.weight
tm_weight = text_model.model.layers[i].self_attn.q_proj.weight
assert mx.array_equal(
vlm_weight, tm_weight
), f"Weights at layer {i} should be identical"
break
else:
pytest.fail("No layer with self_attn found")
5 changes: 5 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def serve_command(args):
print(f"Prefix cache: max_entries={args.prefix_cache_size}")
else:
print("Mode: Simple (maximum throughput)")
if args.enable_mtp:
print("MTP: enabled (native speculative decoding)")
if args.enable_mtp and getattr(args, "mllm", False):
print("MTP + MLLM: per-request routing (text-only → MTP, media → MLLM)")

# Load model with unified server
load_model(
Expand All @@ -187,6 +191,7 @@ def serve_command(args):
stream_interval=args.stream_interval if args.continuous_batching else 1,
max_tokens=args.max_tokens,
force_mllm=args.mllm,
mtp=args.enable_mtp,
)

# Start server
Expand Down
Loading
Loading