diff --git a/tests/test_mllm_mtp_routing.py b/tests/test_mllm_mtp_routing.py new file mode 100644 index 00000000..e2394cf6 --- /dev/null +++ b/tests/test_mllm_mtp_routing.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for MLLM + MTP per-request routing.""" + + +def test_has_media_content_text_only(): + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([{"role": "user", "content": "Hello"}]) is False + + +def test_has_media_content_with_image(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's this?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,..."}, + }, + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_with_video(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": "file:///tmp/v.mp4"}} + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_empty(): + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([]) is False + + +def test_has_media_content_string_content(): + """String content (not list) should return False.""" + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([{"role": "user", "content": "Just text"}]) is False + + +def test_has_media_content_audio(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,..."}} + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_multi_turn(): + """Media in earlier turns should still be detected.""" + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Look at this"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,..."}, + }, + ], + }, + {"role": "assistant", "content": "I see an image."}, + {"role": "user", "content": "Tell me more about it."}, + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_text_list(): + """List content with only text parts should return False.""" + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "World"}, + ], + } + ] + assert _has_media_content(messages) is False + + +# --- MLXMultimodalLM extraction method tests --- + +from unittest.mock import MagicMock + + +def test_get_language_model(): + from vllm_mlx.models.mllm import MLXMultimodalLM + + mllm = MagicMock(spec=MLXMultimodalLM) + inner_lm = MagicMock() + mllm.model = MagicMock() + mllm.model.language_model = inner_lm + assert MLXMultimodalLM.get_language_model(mllm) is inner_lm + + +def test_get_tokenizer(): + from vllm_mlx.models.mllm import MLXMultimodalLM + + mllm = MagicMock(spec=MLXMultimodalLM) + inner_tok = MagicMock() + mllm.processor = MagicMock() + mllm.processor.tokenizer = inner_tok + assert MLXMultimodalLM.get_tokenizer(mllm) is inner_tok diff --git a/tests/test_text_model_from_vlm.py b/tests/test_text_model_from_vlm.py new file mode 100644 index 00000000..037ff810 --- /dev/null +++ b/tests/test_text_model_from_vlm.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for building mlx_lm TextModel from mlx_vlm-loaded weights.""" + +import json +from pathlib import Path + +import pytest + +from vllm_mlx.text_model_from_vlm import build_text_model + +# VLM+MTP model (created by merging mlx-community VLM + our MTP weights) +VLM_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-VLM-MTP-8bit" + +# Text-only MTP model (no vision tower — can't test VLM loading) +TEXT_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-8bit" + + +def test_build_text_model_no_config(): + """Returns None when model path has no config.json.""" + result = build_text_model(None, "/nonexistent/path") + assert result is None + + +def test_build_text_model_none_vlm(): + """Returns None when vlm_model is None.""" + result = build_text_model(None, TEXT_MTP_MODEL) + assert result is None + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_build_text_model_moe(): + """build_text_model creates a TextModel with shared weights and MTP (MoE).""" + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, processor = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + assert text_model is not None, "build_text_model returned None" + + # TextModel should have MTP (config has mtp_num_hidden_layers=1) + assert hasattr(text_model, "mtp"), "TextModel missing .mtp attribute" + assert text_model.mtp is not None, "TextModel.mtp is None" + assert hasattr(text_model, "mtp_forward"), "TextModel missing mtp_forward method" + assert hasattr( + text_model, "make_mtp_cache" + ), "TextModel missing make_mtp_cache method" + + # Verify MoE layer exists in MTP + mtp_layer = text_model.mtp.layers[0] + assert hasattr(mtp_layer, "mlp"), "MTP layer missing mlp" + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_text_model_mtp_forward(): + """TextModel.mtp_forward returns logits of correct vocab_size shape.""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + config = json.loads((VLM_MTP_MODEL / "config.json").read_text()) + text_config = config.get("text_config", config) + + mtp_cache = text_model.make_mtp_cache() + assert len(mtp_cache) > 0 + + hidden = mx.zeros((1, 1, text_config["hidden_size"])) + next_ids = mx.array([[0]]) + logits = text_model.mtp_forward(hidden, next_ids, mtp_cache) + + assert ( + logits.shape[-1] == text_config["vocab_size"] + ), f"Expected vocab_size={text_config['vocab_size']}, got {logits.shape[-1]}" + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_text_model_return_hidden(): + """TextModel supports return_hidden=True (required by mtp_generate_step).""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + config = json.loads((VLM_MTP_MODEL / "config.json").read_text()) + text_config = config.get("text_config", config) + + cache = text_model.make_cache() + tokens = mx.array([[1, 2, 3]]) # Dummy token IDs + + # return_hidden=True should return (logits, hidden_states) + result = text_model(tokens, cache=cache, return_hidden=True) + + # Should be a tuple of (logits, hidden) + assert isinstance(result, tuple), f"Expected tuple, got {type(result)}" + logits, hidden = result + assert logits.shape[-1] == text_config["vocab_size"] + assert hidden.shape[-1] == text_config["hidden_size"] + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_weight_sharing(): + """Backbone weights are shared (zero-copy) between vlm and TextModel.""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + # Compare a backbone weight reference. + # Layer 0 may be linear_attn (GatedDeltaNet) on MoE models, so find a layer + # with self_attn (full attention layers are at indices 11, 15, 19, 23, 27). + for i in range(len(vlm_model.language_model.model.layers)): + layer = vlm_model.language_model.model.layers[i] + if hasattr(layer, "self_attn"): + vlm_weight = layer.self_attn.q_proj.weight + tm_weight = text_model.model.layers[i].self_attn.q_proj.weight + assert mx.array_equal( + vlm_weight, tm_weight + ), f"Weights at layer {i} should be identical" + break + else: + pytest.fail("No layer with self_attn found") diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index dcbee8ac..4051d579 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -178,6 +178,10 @@ def serve_command(args): print(f"Prefix cache: max_entries={args.prefix_cache_size}") else: print("Mode: Simple (maximum throughput)") + if args.enable_mtp: + print("MTP: enabled (native speculative decoding)") + if args.enable_mtp and getattr(args, "mllm", False): + print("MTP + MLLM: per-request routing (text-only → MTP, media → MLLM)") # Load model with unified server load_model( @@ -187,6 +191,7 @@ def serve_command(args): stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, force_mllm=args.mllm, + mtp=args.enable_mtp, ) # Start server diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 3a7e14e2..89bbd9c3 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -18,6 +18,29 @@ logger = logging.getLogger(__name__) +_MEDIA_TYPES = frozenset( + { + "image_url", + "video_url", + "audio_url", + "image", + "video", + "audio", + } +) + + +def _has_media_content(messages: list) -> bool: + """Check if any message contains media content (images, video, audio).""" + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") in _MEDIA_TYPES: + return True + return False + + class SimpleEngine(BaseEngine): """ Simple engine for direct model calls. @@ -32,6 +55,7 @@ def __init__( trust_remote_code: bool = True, enable_cache: bool = True, force_mllm: bool = False, + mtp: bool = False, ): """ Initialize the simple engine. @@ -41,18 +65,29 @@ def __init__( trust_remote_code: Whether to trust remote code enable_cache: Enable VLM cache for multimodal models force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable native MTP speculative decoding (model must have MTP head) """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._enable_cache = enable_cache self._is_mllm = force_mllm or is_mllm_model(model_name) + self._mtp = mtp self._model = None self._loaded = False + # Per-request routing state (MLLM+MTP mode) + self._text_model = None + self._text_tokenizer = None + # Lock to serialize MLX operations (prevents Metal command buffer conflicts) self._generation_lock = asyncio.Lock() + # System prompt KV cache (reduces repeated prefill across requests) + self._system_kv_snapshot = None # List of (keys, values) per backbone layer + self._system_kv_hash = None # Hash of system prefix text + self._system_kv_token_count = 0 # Tokens in cached prefix + @property def model_name(self) -> str: """Get the model name.""" @@ -91,16 +126,62 @@ async def start(self) -> None: self._model = MLXLanguageModel( self._model_name, trust_remote_code=self._trust_remote_code, + mtp=self._mtp, ) self._model.load() self._loaded = True - logger.info(f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm})") + + # Build parallel mlx_lm TextModel for text-only MTP routing + if self._is_mllm and self._mtp: + try: + from ..text_model_from_vlm import build_text_model + + self._text_model = build_text_model(self._model.model, self._model_name) + + if ( + self._text_model is not None + and hasattr(self._text_model, "mtp") + and self._text_model.mtp is not None + ): + self._text_tokenizer = self._model.get_tokenizer() + + # Apply Qwen3.5 eos_token fix (matches MLXLanguageModel.load) + if "qwen3" in self._model_name.lower(): + self._text_tokenizer.eos_token = "<|im_end|>" + self._text_tokenizer.eos_token_id = ( + self._text_tokenizer.convert_tokens_to_ids("<|im_end|>") + ) + + logger.info( + "MLLM+MTP routing: text-only → mlx_lm TextModel (MTP=True), " + "media → mlx_vlm" + ) + else: + logger.warning( + "TextModel built but no MTP — text-only requests won't use MTP" + ) + self._text_model = None + + except Exception as e: + logger.error("MLLM+MTP routing setup failed: %s", e) + self._text_model = None + self._text_tokenizer = None + + mtp_info = f", MTP={self._mtp}" if self._mtp else "" + routing = ", routing=per-request" if self._text_model is not None else "" + logger.info( + f"SimpleEngine loaded: {self._model_name} " + f"(MLLM={self._is_mllm}{mtp_info}{routing})" + ) async def stop(self) -> None: """Stop the engine and cleanup resources.""" self._model = None self._loaded = False + self._system_kv_snapshot = None + self._system_kv_hash = None + self._system_kv_token_count = 0 logger.info("SimpleEngine stopped") async def generate( @@ -339,44 +420,67 @@ async def stream_chat( # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + # Per-request routing: text-only through mlx_lm with MTP + if ( + self._is_mllm + and self._text_model is not None + and not _has_media_content(messages) + ): + logger.info("Text-only request → LLM path (MTP=True)") + async for chunk in self._stream_generate_text( + messages, + max_tokens, + temperature, + top_p, + tools=template_tools, + **kwargs, + ): + yield chunk + return + # Build prompt using tokenizer if self._is_mllm: - # For MLLM, use stream_chat which yields tokens incrementally - accumulated_text = "" - token_count = 0 - - # Run stream_chat in thread pool since it's synchronous - def run_stream(): - return list( - self._model.stream_chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, + if self._text_model is not None: + logger.info("Media request → MLLM path") + # For MLLM, use stream_chat which yields tokens incrementally. + # Must hold _generation_lock to prevent concurrent Metal access + # (e.g. OpenCode sends title + main request simultaneously). + async with self._generation_lock: + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + return list( + self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, + ) ) - ) - chunks = await asyncio.to_thread(run_stream) + chunks = await asyncio.to_thread(run_stream) - for chunk in chunks: - token_count += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text - finished = chunk.finish_reason is not None + finished = chunk.finish_reason is not None - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=getattr(chunk, "prompt_tokens", 0), - completion_tokens=token_count, - finished=finished, - finish_reason=chunk.finish_reason if finished else None, - ) + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, "prompt_tokens", 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) - if finished: - break + if finished: + break return # For LLM, apply chat template and stream @@ -415,6 +519,272 @@ def run_stream(): ): yield output + async def _stream_generate_text( + self, + messages: list[dict[str, Any]], + max_tokens: int, + temperature: float, + top_p: float, + tools: list | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Text-only generation via mlx_lm TextModel with MTP. + + Used when MLLM+MTP routing is active and the request has no media. + Runs the full generation in a single thread to maintain Metal safety. + + System prompt KV caching: on the first request, prefills system tokens + and snapshots backbone KV state. Subsequent requests with the same + system prompt restore the snapshot and only prefill the suffix tokens. + """ + import hashlib + import os + + import mlx.core as mx + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + + # Read enable_thinking from env (set by runtime_patches, consistent with MLLM path) + enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") + enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") + + # Apply chat template for full prompt + template_kwargs = { + "tokenize": False, + "add_generation_prompt": True, + "enable_thinking": enable_thinking, + } + if tools: + template_kwargs["tools"] = tools + + try: + full_prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + except TypeError: + # Template doesn't accept tools= or enable_thinking= + template_kwargs.pop("tools", None) + template_kwargs.pop("enable_thinking", None) + full_prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + + # Build sampler + sampler = make_sampler(temp=temperature, top_p=top_p) + max_tokens = max_tokens or 4096 + + # --- System prompt KV caching --- + prompt_cache = None + prompt_to_send = full_prompt # Default: send full prompt text + cache_hit = False + system_token_count = 0 + full_token_count = 0 + system_hash = None + system_tokens = None + suffix_tokens = None + + # Extract system messages for caching + has_system = any(m.get("role") == "system" for m in messages) + + if has_system and self._text_model is not None: + # Find system prefix boundary in full prompt text. + # ChatML format: system section ends where first non-system message begins. + # Works with tools (rendered inside system section by Qwen templates). + system_prefix_end = -1 + for marker in ("<|im_start|>user\n", "<|im_start|>assistant\n"): + idx = full_prompt.find(marker) + if idx > 0: + system_prefix_end = idx + break + + if system_prefix_end > 0: + system_prefix_text = full_prompt[:system_prefix_end] + system_hash = hashlib.sha256(system_prefix_text.encode()).hexdigest()[ + :16 + ] + + # Tokenize both (matching stream_generate's tokenization logic) + tokenizer = self._text_tokenizer + add_special = tokenizer.bos_token is None or not full_prompt.startswith( + tokenizer.bos_token + ) + full_tokens_list = tokenizer.encode( + full_prompt, add_special_tokens=add_special + ) + full_token_count = len(full_tokens_list) + + system_tokens_list = tokenizer.encode( + system_prefix_text, add_special_tokens=add_special + ) + system_token_count = len(system_tokens_list) + + # Verify system tokens are a proper prefix of full tokens + prefix_valid = ( + len(full_tokens_list) > system_token_count + and full_tokens_list[:system_token_count] == system_tokens_list + ) + + if prefix_valid: + system_tokens = system_tokens_list + suffix_tokens = full_tokens_list[system_token_count:] + + if ( + system_hash == self._system_kv_hash + and self._system_kv_snapshot is not None + and system_token_count == self._system_kv_token_count + ): + # Cache HIT — restore KV state into fresh cache objects + model_cache = make_prompt_cache(self._text_model) + for i, saved_state in enumerate(self._system_kv_snapshot): + model_cache[i].state = saved_state + + # Fresh MTP cache (not populated during prefill) + if hasattr(self._text_model, "make_mtp_cache"): + mtp_cache = self._text_model.make_mtp_cache() + prompt_cache = model_cache + mtp_cache + else: + prompt_cache = model_cache + + prompt_to_send = mx.array(suffix_tokens) + cache_hit = True + logger.info( + "System KV cache HIT: reusing %d cached tokens, " + "prefilling %d new tokens (hash=%s)", + system_token_count, + len(suffix_tokens), + system_hash, + ) + else: + # Cache MISS — will prefill system tokens and snapshot + logger.info( + "System KV cache MISS: will prefill %d system tokens, " + "%d suffix tokens (hash=%s)", + system_token_count, + len(suffix_tokens), + system_hash, + ) + else: + logger.debug( + "System KV cache: prefix token validation failed, " + "using full prompt (%d tokens)", + len(full_tokens_list), + ) + system_token_count = 0 + + # Run under generation lock, all Metal ops in single thread + async with self._generation_lock: + + def _run_all(): + nonlocal prompt_cache, prompt_to_send + + # Cache MISS with valid prefix: prefill system tokens and snapshot + if ( + not cache_hit + and system_token_count > 0 + and system_tokens is not None + and suffix_tokens is not None + ): + model = self._text_model + mc = make_prompt_cache(model) + sys_arr = mx.array(system_tokens) + + # Prefill system tokens in chunks (matching generate_step) + step = ( + self._prefill_step_size + if hasattr(self, "_prefill_step_size") + else 2048 + ) + while sys_arr.size > step: + model(sys_arr[:step][None], cache=mc) + mx.eval([c.state for c in mc]) + sys_arr = sys_arr[step:] + mx.clear_cache() + if sys_arr.size > 0: + model(sys_arr[None], cache=mc) + mx.eval([c.state for c in mc]) + + # Snapshot backbone cache (immutable mx.arrays, safe to reuse) + snapshot = [c.state for c in mc] + mx.eval([s for pair in snapshot for s in pair]) + + self._system_kv_snapshot = snapshot + self._system_kv_hash = system_hash + self._system_kv_token_count = system_token_count + + # Build prompt_cache with MTP + if hasattr(model, "make_mtp_cache"): + mtp_cache = model.make_mtp_cache() + prompt_cache = mc + mtp_cache + else: + prompt_cache = mc + + prompt_to_send = mx.array(suffix_tokens) + logger.info( + "System KV cache: stored %d-token snapshot (%.1f MB), " + "prefilling %d remaining", + system_token_count, + sum(c.nbytes for c in mc) / 1e6, + len(suffix_tokens), + ) + + # Generate + results = [] + gen_kwargs = dict( + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + ) + if hasattr(self, "_prefill_step_size"): + gen_kwargs["prefill_step_size"] = self._prefill_step_size + if prompt_cache is not None: + gen_kwargs["prompt_cache"] = prompt_cache + + for resp in mlx_stream_generate( + self._text_model, + self._text_tokenizer, + prompt=prompt_to_send, + **gen_kwargs, + ): + results.append(resp) + return results + + all_resps = await asyncio.to_thread(_run_all) + + # Yield results as GenerationOutput + accumulated_text = "" + token_count = 0 + finished = False + for i, resp in enumerate(all_resps): + token_count += 1 + new_text = resp.text if hasattr(resp, "text") else str(resp) + accumulated_text += new_text + + is_last = i == len(all_resps) - 1 + finished = is_last or token_count >= max_tokens + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=full_token_count or 0, + completion_tokens=token_count, + finished=finished, + finish_reason="stop" if finished else None, + ) + + if finished: + break + + if not finished: + yield GenerationOutput( + text=accumulated_text, + new_text="", + prompt_tokens=full_token_count or 0, + completion_tokens=token_count, + finished=True, + finish_reason="length", + ) + def get_stats(self) -> dict[str, Any]: """Get engine statistics.""" stats = { @@ -424,6 +794,15 @@ def get_stats(self) -> dict[str, Any]: "loaded": self._loaded, } + # System KV cache stats + if self._system_kv_snapshot is not None: + cache_bytes = sum(k.nbytes + v.nbytes for k, v in self._system_kv_snapshot) + stats["system_kv_cache"] = { + "tokens": self._system_kv_token_count, + "hash": self._system_kv_hash, + "memory_mb": round(cache_bytes / 1e6, 1), + } + # Include Metal memory stats try: import mlx.core as mx diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 092c060e..72182037 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -50,6 +50,7 @@ def __init__( model_name: str, tokenizer_name: str | None = None, trust_remote_code: bool = False, + mtp: bool = False, ): """ Initialize the MLX language model. @@ -58,10 +59,12 @@ def __init__( model_name: HuggingFace model name or local path tokenizer_name: Optional separate tokenizer name trust_remote_code: Whether to trust remote code + mtp: Enable native MTP speculative decoding (model must have MTP head) """ self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name self.trust_remote_code = trust_remote_code + self._mtp = mtp self.model = None self.tokenizer = None @@ -203,12 +206,17 @@ def stream_generate( token_count = 0 accumulated_text = "" + mtp_kwargs = {} + if self._mtp: + mtp_kwargs["mtp"] = True + for response in stream_generate( self.model, self.tokenizer, prompt=prompt, max_tokens=max_tokens, sampler=sampler, + **mtp_kwargs, ): token_count += 1 # response.text is the new token text (not accumulated) diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 22b36963..3a9090b1 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -740,6 +740,14 @@ def load(self) -> None: logger.error(f"Failed to load MLLM: {e}") raise + def get_language_model(self): + """Extract the underlying language model for mlx_lm TextModel construction.""" + return self.model.language_model + + def get_tokenizer(self): + """Get the text tokenizer (not the multimodal processor).""" + return self.processor.tokenizer + def _prepare_images(self, images: list) -> list[str]: """Process image inputs and return local file paths.""" processed = [] diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f0328d4e..53b5ac6e 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -467,6 +467,7 @@ def load_model( stream_interval: int = 1, max_tokens: int = 32768, force_mllm: bool = False, + mtp: bool = False, ): """ Load a model (auto-detects MLLM vs LLM). @@ -478,6 +479,7 @@ def load_model( stream_interval: Tokens to batch before streaming (batched mode only) max_tokens: Default max tokens for generation force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable native MTP speculative decoding (SimpleEngine only) """ global _engine, _model_name, _default_max_tokens, _tool_parser_instance @@ -502,7 +504,7 @@ def load_model( logger.info(f"Model loaded (batched mode): {model_name}") else: logger.info(f"Loading model with SimpleEngine: {model_name}") - _engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm) + _engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm, mtp=mtp) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) loop = asyncio.new_event_loop() diff --git a/vllm_mlx/text_model_from_vlm.py b/vllm_mlx/text_model_from_vlm.py new file mode 100644 index 00000000..b1130fdc --- /dev/null +++ b/vllm_mlx/text_model_from_vlm.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Construct an mlx_lm TextModel from mlx_vlm-loaded model weights. + +When mlx_vlm loads a model, it strips MTP weights in sanitize(). +This module builds a parallel mlx_lm TextModel that: +1. Shares backbone + lm_head weights with the vlm model (zero-copy) +2. Loads MTP weights from safetensors on disk +3. Provides full mlx_lm API: return_hidden, n_confirmed, mtp_forward, make_mtp_cache +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + +import mlx.core as mx +import mlx.nn as nn +import mlx.utils + +logger = logging.getLogger(__name__) + + +def build_text_model(vlm_model: Any, model_path: str | Path) -> Any | None: + """Build an mlx_lm TextModel from a vlm-loaded model's weights. + + Args: + vlm_model: The mlx_vlm-loaded model (has .language_model attribute) + model_path: Path to the model directory (contains config.json + safetensors) + + Returns: + mlx_lm TextModel with MTP support, or None on failure. + """ + if vlm_model is None: + return None + + model_path = Path(model_path) if model_path else None + if model_path is None or not (model_path / "config.json").exists(): + return None + + try: + config = json.loads((model_path / "config.json").read_text()) + text_config = config.get("text_config", config) + + # Always import from qwen3_5 — TextModel and TextModelArgs handle both + # dense and MoE natively (MTPDecoderLayer auto-selects SparseMoeBlock + # when args.num_experts > 0). qwen3_5_moe.py does NOT export these. + from mlx_lm.models.qwen3_5 import TextModel, TextModelArgs + + # Build args with proper __post_init__ (handles partial_rotary_factor, + # rope_scaling, head_dim derivation) + args = TextModelArgs.from_dict(text_config) + text_model = TextModel(args) + + # Collect all weights first: backbone from vlm + MTP from safetensors + vlm_lm = vlm_model.language_model + vlm_weights = mlx.utils.tree_flatten(vlm_lm.parameters()) + mtp_weights = _load_mtp_weights(model_path) + + all_weight_names = set(name for name, _ in vlm_weights) + all_weight_names.update(name for name, _ in mtp_weights) + + # Quantize the TextModel skeleton to match source weights. + # Use a predicate that only quantizes layers that have .scales in source. + # This prevents quantizing layers like mtp.fc which are BF16. + quantization = text_config.get("quantization", config.get("quantization", None)) + if quantization is not None: + + def _class_predicate(path, module): + if not hasattr(module, "to_quantized"): + return False + return f"{path}.scales" in all_weight_names + + nn.quantize( + text_model, + group_size=quantization.get("group_size", 64), + bits=quantization.get("bits", 8), + class_predicate=_class_predicate, + ) + + # Transfer backbone + lm_head weights from vlm language_model (zero-copy). + # strict=False because TextModel has MTP params that vlm doesn't have yet. + text_model.load_weights(vlm_weights, strict=False) + + logger.info( + "Transferred %d weight arrays from vlm language_model", len(vlm_weights) + ) + + # Load MTP weights from safetensors + if mtp_weights: + text_model.load_weights(mtp_weights, strict=False) + logger.info("Loaded %d MTP weights from safetensors", len(mtp_weights)) + else: + logger.warning("No MTP weights found in %s", model_path.name) + + # Verify MTP is functional + if hasattr(text_model, "mtp") and text_model.mtp is not None: + mx.eval(text_model.mtp.parameters()) + logger.info( + "TextModel built with MTP support (%d layers)", + args.mtp_num_hidden_layers, + ) + else: + logger.info("TextModel built without MTP (mtp_num_hidden_layers=0)") + + return text_model + + except ImportError as e: + logger.error("Cannot import mlx_lm TextModel (need PR #990): %s", e) + return None + except Exception as e: + logger.error("Failed to build TextModel from vlm: %s", e) + return None + + +def _load_mtp_weights(model_path: Path) -> list[tuple[str, mx.array]]: + """Load MTP weights from safetensors, stripping the language_model. prefix. + + mlx_vlm's sanitize() strips mtp.* keys during model loading, + but the weights are still on disk in the safetensors files. + """ + index_file = model_path / "model.safetensors.index.json" + if not index_file.exists(): + return [] + + index = json.loads(index_file.read_text()) + weight_map = index.get("weight_map", {}) + + # Find MTP keys and their shard files + mtp_keys: dict[str, tuple[str, str]] = {} + for key, shard in weight_map.items(): + if ".mtp." in key: + # Strip "language_model." prefix to match mlx_lm namespace + clean = ( + key.replace("language_model.", "", 1) + if key.startswith("language_model.") + else key + ) + mtp_keys[key] = (clean, shard) + + if not mtp_keys: + return [] + + # Group by shard to minimize I/O + shards: dict[str, list[tuple[str, str]]] = {} + for orig, (clean, shard) in mtp_keys.items(): + shards.setdefault(shard, []).append((orig, clean)) + + weights = [] + for shard_file, key_pairs in shards.items(): + shard_path = model_path / shard_file + if not shard_path.exists(): + logger.warning("MTP shard not found: %s", shard_file) + continue + shard_data = mx.load(str(shard_path)) + for orig, clean in key_pairs: + if orig in shard_data: + weights.append((clean, shard_data[orig])) + + return weights