diff --git a/tests/test_video.py b/tests/test_video.py new file mode 100644 index 00000000..b66efbd9 --- /dev/null +++ b/tests/test_video.py @@ -0,0 +1,359 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for video support in MLLM chat/stream_chat.""" + +from vllm_mlx.models.mllm import ( + FRAME_FACTOR, + MIN_FRAMES, + MLXMultimodalLM, + is_base64_video, + smart_nframes, +) + + +class TestSmartNframes: + """Verify frame count alignment and clamping.""" + + def test_basic_calculation(self): + # 300 frames at 30fps = 10s video, at 2fps target = 20 frames + result = smart_nframes(300, 30.0, target_fps=2.0) + assert result == 20 + assert result % FRAME_FACTOR == 0 + + def test_clamps_to_min(self): + # Very short video: 6 frames at 30fps + result = smart_nframes(6, 30.0, target_fps=2.0) + assert result >= MIN_FRAMES + assert result % FRAME_FACTOR == 0 + + def test_clamps_to_max(self): + # Very long video: 100000 frames + result = smart_nframes(100000, 30.0, target_fps=2.0, max_frames=64) + assert result <= 64 + assert result % FRAME_FACTOR == 0 + + def test_result_always_even(self): + for total in [5, 7, 11, 13, 100, 999]: + result = smart_nframes(total, 30.0) + assert ( + result % FRAME_FACTOR == 0 + ), f"Odd frame count {result} for total={total}" + + +class TestVideoUrlParsing: + """Verify video_url content type extraction from OpenAI messages.""" + + def _make_model(self): + """Create an unloaded model instance for testing.""" + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = False + model._video_native = False + return model + + def _extract_video_inputs(self, messages): + """Use the actual _collect_video_inputs helper.""" + model = self._make_model() + return model._collect_video_inputs(messages) + + def test_video_url_dict_format(self): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": {"url": "https://example.com/video.mp4"}, + }, + {"type": "text", "text": "Describe this video"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert 0 in result + assert result[0] == ["https://example.com/video.mp4"] + + def test_video_url_string_format(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": "https://example.com/video.mp4"}, + {"type": "text", "text": "Describe"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert result[0] == ["https://example.com/video.mp4"] + + def test_video_type(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": "/path/to/video.mp4"}, + {"type": "text", "text": "Describe"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert result[0] == ["/path/to/video.mp4"] + + def test_no_video(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert len(result) == 0 + + def test_mixed_media(self): + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "https://example.com/img.jpg"}, + }, + { + "type": "video_url", + "video_url": {"url": "https://example.com/vid.mp4"}, + }, + {"type": "text", "text": "Compare"}, + ], + } + ] + result = self._extract_video_inputs(messages) + # Only video extracted, not image + assert result[0] == ["https://example.com/vid.mp4"] + + def test_multi_message_videos(self): + """Videos in different messages should be keyed by message index.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": "/path/first.mp4"}, + {"type": "text", "text": "First"}, + ], + }, + {"role": "assistant", "content": "OK"}, + { + "role": "user", + "content": [ + {"type": "video", "video": "/path/second.mp4"}, + {"type": "text", "text": "Second"}, + ], + }, + ] + result = self._extract_video_inputs(messages) + assert result[0] == ["/path/first.mp4"] + assert result[2] == ["/path/second.mp4"] + assert 1 not in result + + def test_multiple_videos_single_message(self): + """Multiple videos in one message should produce a list at that index.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": "/path/a.mp4"}, + {"type": "video_url", "video_url": {"url": "/path/b.mp4"}}, + {"type": "text", "text": "Compare these"}, + ], + } + ] + result = self._extract_video_inputs(messages) + assert result[0] == ["/path/a.mp4", "/path/b.mp4"] + + +class TestTranslateMessages: + """Verify OpenAI format to process_vision_info format translation.""" + + def _make_model(self): + """Create an unloaded model instance for testing translation.""" + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = False + model._video_native = True + return model + + def test_text_only_passthrough(self): + model = self._make_model() + messages = [{"role": "user", "content": "Hello"}] + result = model._translate_messages_for_native_video(messages, 2.0, 128) + assert result[0]["content"] == "Hello" + + def test_video_url_translated(self): + import os + import tempfile + + # Create a temp file to act as a "video" + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(b"\x00" * 100) + video_path = f.name + + try: + model = self._make_model() + messages = [ + { + "role": "user", + "content": [ + {"type": "video", "video": video_path}, + {"type": "text", "text": "Describe"}, + ], + } + ] + result = model._translate_messages_for_native_video(messages, 2.0, 128) + content = result[0]["content"] + + # Should have video and text items + types = [item["type"] for item in content] + assert "video" in types + assert "text" in types + + # Video item should have fps and max_frames + video_item = next(i for i in content if i["type"] == "video") + assert video_item["fps"] == 2.0 + assert video_item["max_frames"] == 128 + finally: + os.unlink(video_path) + + def test_video_url_type_translated(self): + import os + import tempfile + + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + f.write(b"\x00" * 100) + video_path = f.name + + try: + model = self._make_model() + messages = [ + { + "role": "user", + "content": [ + { + "type": "video_url", + "video_url": {"url": video_path}, + }, + {"type": "text", "text": "Describe"}, + ], + } + ] + result = model._translate_messages_for_native_video(messages, 1.0, 64) + content = result[0]["content"] + + types = [item["type"] for item in content] + assert "video" in types + assert "text" in types + + video_item = next(i for i in content if i["type"] == "video") + assert video_item["fps"] == 1.0 + assert video_item["max_frames"] == 64 + finally: + os.unlink(video_path) + + +class TestCollectVideoInputsPydantic: + """Verify _collect_video_inputs handles Pydantic models correctly.""" + + def _make_model(self): + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = False + model._video_native = False + return model + + def test_pydantic_model_dump(self): + """Pydantic v2 objects with model_dump() are handled.""" + + class FakeContent: + def model_dump(self, exclude_none=False): + return {"type": "video", "video": "/path/to/video.mp4"} + + messages = [{"role": "user", "content": [FakeContent()]}] + result = self._make_model()._collect_video_inputs(messages) + assert result[0] == ["/path/to/video.mp4"] + + def test_pydantic_v1_dict(self): + """Pydantic v1 objects with dict() are handled.""" + + class FakeContent: + def dict(self): + return { + "type": "video_url", + "video_url": {"url": "https://example.com/v.mp4"}, + "image_url": None, + } + + messages = [{"role": "user", "content": [FakeContent()]}] + result = self._make_model()._collect_video_inputs(messages) + assert result[0] == ["https://example.com/v.mp4"] + + def test_empty_video_url_skipped(self): + """Empty video URL dicts are skipped.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": ""}}, + ], + } + ] + result = self._make_model()._collect_video_inputs(messages) + assert len(result) == 0 + + +class TestToolForwarding: + """Verify tools are popped from kwargs before native video path.""" + + def test_tools_not_in_kwargs_after_pop(self): + """Ensure tools don't leak into **kwargs for mlx_vlm.generate().""" + model = MLXMultimodalLM.__new__(MLXMultimodalLM) + model._loaded = False + model._video_native = True + + tools = [{"type": "function", "function": {"name": "test"}}] + kwargs = {"tools": tools, "video_fps": 2.0, "video_max_frames": 64} + + # Simulate what chat() does: pop tools before native video branch + video_fps = kwargs.pop("video_fps", 2.0) + video_max_frames = kwargs.pop("video_max_frames", 128) + popped_tools = kwargs.pop("tools", None) + + assert popped_tools == tools + assert "tools" not in kwargs + + def test_generate_native_video_accepts_tools_param(self): + """Verify _generate_native_video signature accepts tools kwarg.""" + import inspect + + sig = inspect.signature(MLXMultimodalLM._generate_native_video) + params = list(sig.parameters.keys()) + assert "tools" in params + + def test_prepare_native_video_inputs_accepts_tools(self): + """Verify preprocessing helper also accepts tools.""" + import inspect + + sig = inspect.signature(MLXMultimodalLM._prepare_native_video_inputs) + params = list(sig.parameters.keys()) + assert "tools" in params + + def test_generate_imports_from_video_generate(self): + """Verify _generate_native_video uses mlx_vlm.video_generate.generate.""" + import inspect + + source = inspect.getsource(MLXMultimodalLM._generate_native_video) + assert "from mlx_vlm.video_generate import generate" in source + + +class TestIsBase64Video: + def test_detects_base64_video(self): + assert is_base64_video("data:video/mp4;base64,AAAA") is True + + def test_rejects_non_video(self): + assert is_base64_video("data:image/jpeg;base64,AAAA") is False + assert is_base64_video("https://example.com/video.mp4") is False diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 3a9090b1..64f64676 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -20,9 +20,9 @@ import os import tempfile import threading +from collections.abc import Iterator from dataclasses import dataclass, field from pathlib import Path -from typing import Iterator from urllib.parse import urlparse import numpy as np @@ -708,6 +708,7 @@ def __init__( self.processor = None self.config = None self._loaded = False + self._video_native = False # Initialize MLLM prefix cache manager (with vision embedding caching) self._cache_manager: MLLMPrefixCacheManager | None = None @@ -729,7 +730,12 @@ def load(self) -> None: self.config = load_config(self.model_name) self._loaded = True + self._video_native = hasattr( + self.model.config, "video_token_id" + ) or hasattr(self.model.config, "video_token_index") logger.info(f"MLLM loaded successfully: {self.model_name}") + if self._video_native: + logger.info("Native video pipeline enabled (temporal 3D conv + M-RoPE)") except ImportError: raise ImportError( @@ -793,6 +799,259 @@ def _prepare_video( ) return save_frames_to_temp(frames) + def _collect_video_inputs(self, messages: list[dict]) -> dict[int, list]: + """Collect video inputs from messages, keyed by message index. + + Handles both 'video' and 'video_url' content types, including + Pydantic model conversion. + """ + video_inputs: dict[int, list] = {} + for msg_idx, msg in enumerate(messages): + content = msg.get("content", "") + if not isinstance(content, list): + continue + for item in content: + if hasattr(item, "model_dump"): + item = item.model_dump(exclude_none=True) + elif hasattr(item, "dict"): + item = {k: v for k, v in item.dict().items() if v is not None} + + if not isinstance(item, dict): + continue + item_type = item.get("type", "") + if item_type == "video": + video_inputs.setdefault(msg_idx, []).append( + item.get("video", item.get("url", "")) + ) + elif item_type == "video_url": + vid_url = item.get("video_url", {}) + if isinstance(vid_url, str): + video_inputs.setdefault(msg_idx, []).append(vid_url) + elif isinstance(vid_url, dict): + url = vid_url.get("url", "") + if url: + video_inputs.setdefault(msg_idx, []).append(url) + return video_inputs + + def _prepare_native_video_inputs( + self, + messages: list[dict], + video_fps: float = DEFAULT_FPS, + video_max_frames: int = MAX_FRAMES, + tools: list | None = None, + ) -> tuple[str, dict]: + """Preprocess messages into prompt + generation kwargs for native video. + + Mirrors the preprocessing in mlx_vlm.video_generate.main() so that + upstream improvements are easy to adopt. Returns the formatted prompt + text and a dict of kwargs ready to pass to video_generate.generate(). + + Currently Qwen-family-specific (video_token_id / video_token_index). + """ + import mlx.core as mx + + try: + from mlx_vlm.video_generate import process_vision_info + except ImportError: + raise ImportError( + "mlx_vlm.video_generate is required for native video support. " + "Upgrade with: pip install --upgrade mlx-vlm" + ) + + # Translate OpenAI API messages into process_vision_info format + native_messages = self._translate_messages_for_native_video( + messages, video_fps, video_max_frames + ) + + # Use HF processor's chat template (handles timestamp interleaving) + template_kwargs: dict = {} + if tools: + template_kwargs["tools"] = tools + + text = self.processor.apply_chat_template( + native_messages, + tokenize=False, + add_generation_prompt=True, + **template_kwargs, + ) + + # Extract vision inputs via mlx-vlm's process_vision_info + image_inputs, video_inputs, fps_info = process_vision_info( + native_messages, return_video_kwargs=True + ) + + # Process through HF processor to get input_ids, pixel_values, grid_thw + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + input_ids = mx.array(inputs["input_ids"]) + pixel_values = inputs.get( + "pixel_values_videos", inputs.get("pixel_values", None) + ) + if pixel_values is not None: + pixel_values = mx.array(pixel_values) + mask = mx.array(inputs["attention_mask"]) + + gen_kwargs: dict = {} + if inputs.get("video_grid_thw", None) is not None: + gen_kwargs["video_grid_thw"] = mx.array(inputs["video_grid_thw"]) + if inputs.get("image_grid_thw", None) is not None: + gen_kwargs["image_grid_thw"] = mx.array(inputs["image_grid_thw"]) + + gen_kwargs["input_ids"] = input_ids + gen_kwargs["pixel_values"] = pixel_values + gen_kwargs["mask"] = mask + + grid_thw_info = gen_kwargs.get("video_grid_thw") + logger.info( + f"Native video: {input_ids.size} input tokens, " + f"video_grid_thw={grid_thw_info.tolist() if grid_thw_info is not None else None}" + ) + + return text, gen_kwargs + + def _generate_native_video( + self, + messages: list[dict], + max_tokens: int = 256, + temperature: float = 0.7, + video_fps: float = DEFAULT_FPS, + video_max_frames: int = MAX_FRAMES, + tools: list | None = None, + **kwargs, + ) -> MLLMOutput: + """Generate using native video pipeline (Qwen-family models). + + Delegates preprocessing to _prepare_native_video_inputs and generation + to mlx_vlm.video_generate.generate(), keeping our code aligned with + upstream's video pipeline so improvements are easy to adopt. + """ + try: + from mlx_vlm.video_generate import generate + except ImportError: + raise ImportError( + "mlx_vlm.video_generate is required for native video support. " + "Upgrade with: pip install --upgrade mlx-vlm" + ) + + text, gen_kwargs = self._prepare_native_video_inputs( + messages, video_fps, video_max_frames, tools + ) + gen_kwargs["temperature"] = temperature + + result = generate( + self.model, + self.processor, + prompt=text, + max_tokens=max_tokens, + verbose=False, + **gen_kwargs, + ) + + if hasattr(result, "text"): + return MLLMOutput( + text=result.text, + finish_reason="stop", + prompt_tokens=getattr(result, "prompt_tokens", 0), + completion_tokens=getattr(result, "generation_tokens", 0), + ) + return MLLMOutput(text=str(result), finish_reason="stop") + + def _translate_messages_for_native_video( + self, + messages: list[dict], + video_fps: float, + video_max_frames: int, + ) -> list[dict]: + """Translate OpenAI API format messages to process_vision_info format. + + Converts video_url/video types and resolves URLs/base64 to local paths. + Images are preserved as-is (process_vision_info handles them). + """ + translated = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if isinstance(content, str): + translated.append({"role": role, "content": content}) + continue + + if not isinstance(content, list): + translated.append({"role": role, "content": str(content)}) + continue + + new_content = [] + for item in content: + if hasattr(item, "model_dump"): + item = item.model_dump(exclude_none=True) + elif hasattr(item, "dict"): + item = {k: v for k, v in item.dict().items() if v is not None} + + if not isinstance(item, dict): + new_content.append({"type": "text", "text": str(item)}) + continue + + item_type = item.get("type", "") + + if item_type == "text": + new_content.append(item) + + elif item_type == "image_url": + img_url = item.get("image_url", {}) + url = ( + img_url.get("url", img_url) + if isinstance(img_url, dict) + else img_url + ) + # Resolve to local path for process_vision_info + local_path = process_image_input(url) + new_content.append({"type": "image", "image": local_path}) + + elif item_type == "image": + img = item.get("image", item.get("url", "")) + local_path = process_image_input(img) + new_content.append({"type": "image", "image": local_path}) + + elif item_type in ("video", "video_url"): + # Extract video path/URL from various formats + if item_type == "video_url": + vid_url = item.get("video_url", {}) + if isinstance(vid_url, str): + video_source = vid_url + elif isinstance(vid_url, dict): + video_source = vid_url.get("url", "") + else: + continue + else: + video_source = item.get("video", item.get("url", "")) + + if not video_source: + continue + + # Resolve to local path + video_path = process_video_input(video_source) + new_content.append( + { + "type": "video", + "video": video_path, + "fps": video_fps, + "max_frames": video_max_frames, + } + ) + + else: + new_content.append(item) + + translated.append({"role": role, "content": new_content}) + + return translated + def generate( self, prompt: str, @@ -1060,12 +1319,47 @@ def chat( # Extract text and images from messages # Build chat_messages for multi-turn support WITH proper image tokens per message all_image_urls = [] # Raw URLs/paths to process later - videos = [] chat_messages = [] # List of properly formatted messages for chat template logger.info(f"MLLM.chat() called with {len(messages)} messages") - for msg in messages: + # Pop params early so they don't leak into mlx_vlm.generate() + video_fps = kwargs.pop("video_fps", DEFAULT_FPS) + video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) + tools = kwargs.pop("tools", None) + use_cache = kwargs.pop("use_cache", True) + + # Collect video inputs from messages + _msg_video_inputs = self._collect_video_inputs(messages) + + # Use native video pipeline for supported models + if self._video_native and _msg_video_inputs: + return self._generate_native_video( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + video_fps=video_fps, + video_max_frames=video_max_frames, + tools=tools, + **kwargs, + ) + + # Fallback: extract frames and treat as individual images + _msg_video_frame_counts: dict[int, int] = {} + all_video_frames: list[str] = [] + for msg_idx, vid_inputs in _msg_video_inputs.items(): + total_frames = 0 + for vid_input in vid_inputs: + frames = self._prepare_video( + vid_input, fps=video_fps, max_frames=video_max_frames + ) + all_video_frames.extend(frames) + total_frames += len(frames) + logger.info(f"Added {len(frames)} frames from video: {vid_input}") + _msg_video_frame_counts[msg_idx] = total_frames + + # Second pass: build chat messages with image counts that include video frames + for msg_idx, msg in enumerate(messages): role = msg.get("role", "user") content = msg.get("content", "") msg_text = "" # Text content for this message @@ -1107,8 +1401,8 @@ def chat( ) msg_image_count += 1 - elif item_type == "video": - videos.append(item.get("video", item.get("url", ""))) + # Add video frame count to image count for this message + msg_image_count += _msg_video_frame_counts.get(msg_idx, 0) # Build properly structured message for Qwen3-VL-MoE # Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "text", "text": "..."}]} @@ -1140,16 +1434,8 @@ def chat( all_images = [] if all_image_urls: all_images.extend(self._prepare_images(all_image_urls)) - - # Process videos - video_fps = kwargs.pop("video_fps", DEFAULT_FPS) - video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) - for video_path in videos: - frames = self._prepare_video( - video_path, fps=video_fps, max_frames=video_max_frames - ) - all_images.extend(frames) - logger.info(f"Added {len(frames)} frames from video: {video_path}") + # Append pre-processed video frames + all_images.extend(all_video_frames) # Apply chat template directly - messages are already properly structured logger.info( @@ -1161,8 +1447,7 @@ def chat( f" Chat msg {i}: role={cm['role']}, content={content_preview}..." ) - # Pop tools so they don't leak into mlx_vlm.generate()/stream_generate() - tools = kwargs.pop("tools", None) + # Build template kwargs for tool definitions (tools already popped above) template_extra_kwargs = {} if tools: template_extra_kwargs["tools"] = tools @@ -1196,10 +1481,10 @@ def chat( # Prefix caching with vision embedding support # Following LMCache approach: cache vision embeddings to skip encoder on hit - from mlx_vlm.models import cache as vlm_cache import time - use_cache = kwargs.pop("use_cache", True) + from mlx_vlm.models import cache as vlm_cache + cache_entry = None prefix_match_len = 0 vision_embeddings = None @@ -1316,6 +1601,7 @@ def chat( ): try: import copy + import mlx.core as mx # Get prompt token count (before generation) @@ -1431,26 +1717,64 @@ def stream_chat( # Extract text and images from messages # Build chat_messages for multi-turn support WITH proper image tokens per message all_image_urls = [] # Raw URLs/paths to process later - videos = [] chat_messages = [] # List of properly formatted messages for chat template - for msg in messages: + # Pop params early so they don't leak into mlx_vlm.generate() + video_fps = kwargs.pop("video_fps", DEFAULT_FPS) + video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) + tools = kwargs.pop("tools", None) + use_cache = kwargs.pop("use_cache", True) + + # Collect video inputs from messages + _msg_video_inputs = self._collect_video_inputs(messages) + + # Use native video pipeline for supported models. + # NOTE: Native video yields a single chunk (not incremental streaming) + # because mlx_vlm.video_generate has no streaming API. The event loop + # is NOT blocked at the server level — SimpleEngine wraps this in + # asyncio.to_thread(). True token-level streaming requires upstream + # mlx-vlm support for video stream_generate. + if self._video_native and _msg_video_inputs: + output = self._generate_native_video( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + video_fps=video_fps, + video_max_frames=video_max_frames, + tools=tools, + **kwargs, + ) + yield output + return + + # Fallback: frames as images + _msg_video_frame_counts: dict[int, int] = {} + all_video_frames: list[str] = [] + for msg_idx, vid_inputs in _msg_video_inputs.items(): + total_frames = 0 + for vid_input in vid_inputs: + frames = self._prepare_video( + vid_input, fps=video_fps, max_frames=video_max_frames + ) + all_video_frames.extend(frames) + total_frames += len(frames) + logger.info(f"Added {len(frames)} frames from video: {vid_input}") + _msg_video_frame_counts[msg_idx] = total_frames + + for msg_idx, msg in enumerate(messages): role = msg.get("role", "user") content = msg.get("content", "") - msg_text = "" # Text content for this message - msg_image_count = 0 # Number of images in THIS message + msg_text = "" + msg_image_count = 0 if isinstance(content, str): msg_text = content elif isinstance(content, list): - # OpenAI multimodal format - extract text and count images for THIS message for item in content: if isinstance(item, str): msg_text += item continue - # Convert Pydantic models to dicts, excluding None fields - # to avoid null keys like image_url: null on text parts if hasattr(item, "model_dump"): item = item.model_dump(exclude_none=True) elif hasattr(item, "dict"): @@ -1476,14 +1800,10 @@ def stream_chat( ) msg_image_count += 1 - elif item_type == "video": - videos.append(item.get("video", item.get("url", ""))) + msg_image_count += _msg_video_frame_counts.get(msg_idx, 0) - # Build properly structured message for Qwen3-VL-MoE - # Format: {"role": "...", "content": [{"type": "image"}, ..., {"type": "text", "text": "..."}]} if msg_text or msg_image_count > 0: if role == "user" and msg_image_count > 0: - # User message WITH images - build content array with image tokens FIRST content_list = [] for _ in range(msg_image_count): content_list.append({"type": "image"}) @@ -1492,10 +1812,8 @@ def stream_chat( ) chat_messages.append({"role": role, "content": content_list}) elif role == "assistant": - # Assistant messages - just text content (not array) chat_messages.append({"role": role, "content": msg_text}) else: - # User/system message WITHOUT images - still use content array format chat_messages.append( { "role": role, @@ -1505,29 +1823,17 @@ def stream_chat( } ) - # Process images all_images = [] if all_image_urls: all_images.extend(self._prepare_images(all_image_urls)) + all_images.extend(all_video_frames) - # Process videos - video_fps = kwargs.pop("video_fps", DEFAULT_FPS) - video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) - for video_path in videos: - frames = self._prepare_video( - video_path, fps=video_fps, max_frames=video_max_frames - ) - all_images.extend(frames) - - # Apply chat template directly - messages are already properly structured - # Pop tools so they don't leak into mlx_vlm.generate()/stream_generate() - tools = kwargs.pop("tools", None) + # Build template kwargs for tool definitions (tools already popped above) template_extra_kwargs = {} if tools: template_extra_kwargs["tools"] = tools try: - # Use get_chat_template directly since messages are already properly formatted formatted_prompt = get_chat_template( self.processor, chat_messages, @@ -1558,7 +1864,6 @@ def stream_chat( prompt_cache = None cache_hit = False - use_cache = kwargs.pop("use_cache", True) if use_cache and self._cache_manager is not None and all_images: prompt_cache, cache_hit = self._cache_manager.fetch_cache( diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f599ddc8..bfb7a062 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -755,8 +755,7 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: elapsed = time.perf_counter() - start_time logger.info( - f"Embeddings: {len(texts)} inputs, {prompt_tokens} tokens " - f"in {elapsed:.2f}s" + f"Embeddings: {len(texts)} inputs, {prompt_tokens} tokens in {elapsed:.2f}s" ) # Build OpenAI-compatible response with ordered indices @@ -777,8 +776,7 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: raise HTTPException( status_code=503, detail=( - "mlx-embeddings not installed. " - "Install with: pip install mlx-embeddings" + "mlx-embeddings not installed. Install with: pip install mlx-embeddings" ), ) except HTTPException: @@ -1355,6 +1353,23 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) has_media = bool(images or videos) + if engine.is_mllm and not has_media: + # MLLM extracts media from messages directly, so images/videos are + # always empty. Check message content for video/image types instead. + for msg in request.messages: + content = msg.content if hasattr(msg, "content") else msg.get("content", "") + if isinstance(content, list): + for item in content: + item_type = ( + item.type + if hasattr(item, "type") + else (item.get("type", "") if isinstance(item, dict) else "") + ) + if item_type in ("image_url", "image", "video", "video_url"): + has_media = True + break + if has_media: + break # Handle response_format - inject system prompt if needed response_format = request.response_format