diff --git a/tests/test_batched_mtp_routing.py b/tests/test_batched_mtp_routing.py new file mode 100644 index 000000000..61f7ecacc --- /dev/null +++ b/tests/test_batched_mtp_routing.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for BatchedEngine MLLM + MTP per-request routing.""" + + +def test_has_media_content_text_only(): + from vllm_mlx.engine.batched import _has_media_content + + assert _has_media_content([{"role": "user", "content": "Hello"}]) is False + + +def test_has_media_content_with_image(): + from vllm_mlx.engine.batched import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's this?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,..."}, + }, + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_with_video(): + from vllm_mlx.engine.batched import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": "file:///tmp/v.mp4"}} + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_empty(): + from vllm_mlx.engine.batched import _has_media_content + + assert _has_media_content([]) is False + + +def test_has_media_content_string_content(): + """String content (not list) should return False.""" + from vllm_mlx.engine.batched import _has_media_content + + assert _has_media_content([{"role": "user", "content": "Just text"}]) is False + + +def test_has_media_content_audio(): + from vllm_mlx.engine.batched import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + { + "type": "audio_url", + "audio_url": {"url": "data:audio/wav;base64,..."}, + } + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_multi_turn(): + """Media in earlier turns should still be detected.""" + from vllm_mlx.engine.batched import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Look at this"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,..."}, + }, + ], + }, + {"role": "assistant", "content": "I see an image."}, + {"role": "user", "content": "Tell me more about it."}, + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_text_list(): + """List content with only text parts should return False.""" + from vllm_mlx.engine.batched import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "World"}, + ], + } + ] + assert _has_media_content(messages) is False diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628e..4eaad3165 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -11,6 +11,7 @@ LLM engine), so text-only requests must also be routed through it. """ +import asyncio import logging from collections.abc import AsyncIterator from typing import Any @@ -21,6 +22,28 @@ logger = logging.getLogger(__name__) +_MEDIA_TYPES = frozenset( + { + "image_url", + "video_url", + "audio_url", + "image", + "video", + "audio", + } +) + + +def _has_media_content(messages: list) -> bool: + """Check if any message contains media content (images, video, audio).""" + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") in _MEDIA_TYPES: + return True + return False + def _extract_media_from_messages(messages: list[dict[str, Any]]) -> tuple: """ @@ -137,6 +160,8 @@ def __init__( scheduler_config: Any | None = None, stream_interval: int = 1, force_mllm: bool = False, + mtp: bool = False, + prefill_step_size: int | None = None, ): """ Initialize the batched engine. @@ -147,12 +172,16 @@ def __init__( scheduler_config: Optional scheduler configuration stream_interval: Tokens to batch before streaming (1=every token) force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable MTP per-request routing (text-only → TextModel, media → MLLM) + prefill_step_size: Chunk size for prompt prefill (default 2048) """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._scheduler_config = scheduler_config self._stream_interval = stream_interval self._is_mllm = force_mllm or is_mllm_model(model_name) + self._mtp = mtp + self._prefill_step_size = prefill_step_size or 2048 self._model = None self._processor = None # For MLLM @@ -162,6 +191,11 @@ def __init__( self._mllm_instance = None # MLXMultimodalLM instance self._loaded = False + # Per-request MTP routing state (MLLM+MTP mode) + self._text_model = None + self._text_tokenizer = None + self._text_generation_lock = asyncio.Lock() + @property def model_name(self) -> str: """Get the model name.""" @@ -241,6 +275,43 @@ async def _start_mllm(self) -> None: f"completion_batch={completion_batch_size}" ) + # Build TextModel for MTP per-request routing (text-only → MTP, media → MLLM) + if self._mtp: + try: + from ..text_model_from_vlm import build_text_model + + self._text_model = build_text_model( + self._mllm_instance.model, self._model_name + ) + if ( + self._text_model is not None + and hasattr(self._text_model, "mtp") + and self._text_model.mtp is not None + ): + self._text_tokenizer = self._mllm_instance.get_tokenizer() + + # Apply Qwen3.5 eos_token fix (matches SimpleEngine pattern) + if "qwen3" in self._model_name.lower(): + self._text_tokenizer.eos_token = "<|im_end|>" + self._text_tokenizer.eos_token_id = ( + self._text_tokenizer.convert_tokens_to_ids("<|im_end|>") + ) + + logger.info( + "BatchedEngine MLLM+MTP routing: " + "text-only → TextModel (MTP), media → MLLM" + ) + else: + logger.warning( + "TextModel built but no MTP — " + "text-only requests won't use MTP" + ) + self._text_model = None + except Exception as e: + logger.error("MTP TextModel build failed: %s", e) + self._text_model = None + self._text_tokenizer = None + async def _start_llm(self) -> None: """Start the LLM engine with AsyncEngineCore.""" from ..engine_core import AsyncEngineCore, EngineConfig @@ -327,6 +398,8 @@ async def stop(self) -> None: self._tokenizer = None self._processor = None self._mllm_instance = None + self._text_model = None + self._text_tokenizer = None self._loaded = False logger.info("BatchedEngine stopped") @@ -612,6 +685,17 @@ async def chat( if not self._loaded: await self.start() + # Per-request MTP routing: text-only → TextModel, media → MLLM + if self._text_model is not None and not _has_media_content(messages): + return await self._chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ) + # Extract images/videos from messages (OpenAI multimodal format) # Note: We only use extracted media here, messages are already processed by server _, extracted_images, extracted_videos = extract_multimodal_content(messages) @@ -723,6 +807,19 @@ async def stream_chat( if not self._loaded: await self.start() + # Per-request MTP routing: text-only → TextModel, media → MLLM + if self._text_model is not None and not _has_media_content(messages): + async for output in self._stream_chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ): + yield output + return + # Extract images/videos from messages (OpenAI multimodal format) # Note: We only use extracted media here, messages are already processed by server _, extracted_images, extracted_videos = extract_multimodal_content(messages) @@ -755,6 +852,128 @@ async def stream_chat( ): yield output + async def _chat_text_model( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + **kwargs, + ) -> GenerationOutput: + """Non-streaming text-only generation via mlx_lm TextModel with MTP.""" + logger.info("Text-only request → TextModel (MTP) [non-streaming]") + accumulated_text = "" + last_chunk = None + async for chunk in self._stream_chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ): + accumulated_text = chunk.text + last_chunk = chunk + if last_chunk is not None: + return GenerationOutput( + text=accumulated_text, + prompt_tokens=last_chunk.prompt_tokens, + completion_tokens=last_chunk.completion_tokens, + finish_reason=last_chunk.finish_reason, + ) + return GenerationOutput(text="", finish_reason="stop") + + async def _stream_chat_text_model( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Streaming text-only generation via mlx_lm TextModel with MTP. + + Used when MLLM+MTP routing is active and the request has no media. + Runs generation under a lock to serialize Metal operations. + """ + import os + + import mlx.core as mx + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.sample_utils import make_sampler + + # Read enable_thinking from env (set by runtime_patches) + enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") + enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") + + # Convert tools for template + template_tools = convert_tools_for_template(tools) if tools else None + + # Apply chat template + template_kwargs = { + "tokenize": False, + "add_generation_prompt": True, + "enable_thinking": enable_thinking, + } + if template_tools: + template_kwargs["tools"] = template_tools + + try: + prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + except TypeError: + template_kwargs.pop("tools", None) + template_kwargs.pop("enable_thinking", None) + prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + + sampler = make_sampler(temp=temperature, top_p=top_p) + max_tokens = max_tokens or 4096 + + async with self._text_generation_lock: + + def _run_generation(): + model = self._text_model + + # Let mlx_lm handle cache creation (MTP-aware in feat/mtp-native) + results = [] + for response in mlx_stream_generate( + model=model, + tokenizer=self._text_tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + prefill_step_size=self._prefill_step_size, + ): + results.append(response) + + mx.clear_cache() + return results + + loop = asyncio.get_event_loop() + results = await loop.run_in_executor(None, _run_generation) + + # Yield results + accumulated_text = "" + prompt_tokens = 0 + for i, response in enumerate(results): + token_text = response.text + accumulated_text += token_text + prompt_tokens = getattr(response, "prompt_tokens", 0) + + yield GenerationOutput( + text=accumulated_text, + new_text=token_text, + prompt_tokens=prompt_tokens, + completion_tokens=i + 1, + finished=i == len(results) - 1, + finish_reason="stop" if i == len(results) - 1 else None, + ) + def get_stats(self) -> dict[str, Any]: """Get engine statistics.""" stats = {