diff --git a/README.md b/README.md index b1f2ea22..8db90799 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,71 @@ For full documentation, see the [docs](docs/) directory: See [benchmarks](docs/benchmarks/) for detailed results. +## Gemma 3 Support + +This fork includes patches for Gemma 3 vision support. Gemma 3 is a multimodal model but requires detection as MLLM. + +### Usage + +```bash +# Start server with Gemma 3 +vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 + +# Verify it loaded as MLLM (not LLM) +curl http://localhost:8000/health +# Should show: "model_type": "mllm" +``` + +### Long Context Patch (mlx-vlm) + +Gemma 3's default `sliding_window=1024` limits context to ~10K tokens on Apple Silicon (Metal GPU timeout at higher context). To enable longer context (up to ~50K tokens), patch mlx-vlm: + +**Location:** `~/.../site-packages/mlx_vlm/models/gemma3/language.py` + +Find the `make_cache` method and replace with: + +```python +def make_cache(self): + import os + # Set GEMMA3_SLIDING_WINDOW=8192 for ~40K context + # Set GEMMA3_SLIDING_WINDOW=0 for ~50K context (full KVCache) + sliding_window = int(os.environ.get('GEMMA3_SLIDING_WINDOW', self.config.sliding_window)) + + caches = [] + for i in range(self.config.num_hidden_layers): + if ( + i % self.config.sliding_window_pattern + == self.config.sliding_window_pattern - 1 + ): + caches.append(KVCache()) + elif sliding_window == 0: + caches.append(KVCache()) # Full context for all layers + else: + caches.append(RotatingKVCache(max_size=sliding_window, keep=0)) + return caches +``` + +**Usage:** + +```bash +# Default (~10K max context) +vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 + +# Extended context (~40K max) +GEMMA3_SLIDING_WINDOW=8192 vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 + +# Maximum context (~50K max) +GEMMA3_SLIDING_WINDOW=0 vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 +``` + +**Benchmark Results (M4 Max 128GB):** + +| Setting | Max Context | Memory | +|---------|-------------|--------| +| Default (1024) | ~10K tokens | ~16GB | +| `GEMMA3_SLIDING_WINDOW=8192` | ~40K tokens | ~25GB | +| `GEMMA3_SLIDING_WINDOW=0` | ~50K tokens | ~35GB | + ## Contributing We welcome contributions! See [Contributing Guide](docs/development/contributing.md) for details. diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index ac9f8167..a1279fc1 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -59,6 +59,7 @@ def clean_output_text(text: str) -> str: "llava", "LLaVA", # LLaVA models "idefics", "Idefics", # Idefics models "paligemma", "PaliGemma", # PaliGemma + "gemma-3", "gemma3", # Gemma 3 (multimodal) - FIXED in mlx_vlm/models/gemma3/language.py "pixtral", "Pixtral", # Pixtral "molmo", "Molmo", # Molmo "phi3-vision", "phi-3-vision", # Phi-3 Vision diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 8be77113..e05e1755 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -4,8 +4,15 @@ This engine wraps AsyncEngineCore to provide continuous batching for better throughput when serving multiple concurrent requests. + +For MLLM models, this engine supports a hybrid approach: +- Text-only requests: Use BatchGenerator for continuous batching +- Multimodal requests (with images/videos): Fall back to MLLM.chat() for correct processing + +This is necessary because BatchGenerator only supports token IDs, not pixel_values. """ +import asyncio import logging from typing import Any, AsyncIterator, Dict, List, Optional @@ -16,6 +23,100 @@ logger = logging.getLogger(__name__) +def _extract_media_from_messages(messages: List[Dict[str, Any]]) -> tuple: + """ + Extract images and videos from OpenAI-format messages. + + Returns: + Tuple of (has_media, images_list, videos_list) + """ + images = [] + videos = [] + + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + + for item in content: + # Handle Pydantic models + if hasattr(item, "model_dump"): + item = item.model_dump() + elif hasattr(item, "dict"): + item = item.dict() + + if not isinstance(item, dict): + continue + + item_type = item.get("type", "") + + if item_type == "image_url": + img_url = item.get("image_url", {}) + if isinstance(img_url, str): + images.append(img_url) + elif isinstance(img_url, dict): + url = img_url.get("url", "") + if url: + images.append(url) + + elif item_type == "image": + img = item.get("image") or item.get("url", "") + if img: + images.append(img) + + elif item_type == "video_url": + vid_url = item.get("video_url", {}) + if isinstance(vid_url, str): + videos.append(vid_url) + elif isinstance(vid_url, dict): + url = vid_url.get("url", "") + if url: + videos.append(url) + + elif item_type == "video": + vid = item.get("video") or item.get("url", "") + if vid: + videos.append(vid) + + has_media = bool(images or videos) + return has_media, images, videos + + +class MLLMModelWrapper: + """ + Wrapper for MLLM models to make them compatible with BatchGenerator. + + BatchGenerator expects model output to be subscriptable (logits array), + but MLLM models return LanguageModelOutput objects. This wrapper extracts + the logits from the output. + + Also handles Gemma 3's required pixel_values argument by injecting None + for text-only requests. + """ + + def __init__(self, model): + self._model = model + # Detect if this is a Gemma 3 model (requires pixel_values as positional arg) + self._is_gemma3 = hasattr(model, 'model_type') and 'gemma3' in str(getattr(model, 'model_type', '')).lower() + + def __call__(self, *args, **kwargs): + """Call the model and extract logits from LanguageModelOutput.""" + # Gemma 3 requires pixel_values as a positional argument, unlike Qwen + # which makes it optional. Inject pixel_values=None for text-only requests. + if self._is_gemma3 and 'pixel_values' not in kwargs: + kwargs['pixel_values'] = None + + output = self._model(*args, **kwargs) + # If output has logits attribute, return just the logits + if hasattr(output, 'logits'): + return output.logits + return output + + def __getattr__(self, name): + """Forward all other attributes to the wrapped model.""" + return getattr(self._model, name) + + class BatchedEngine(BaseEngine): """ Batched engine for continuous batching. @@ -49,6 +150,7 @@ def __init__( self._model = None self._tokenizer = None self._engine = None + self._mllm = None # Keep reference to MLLM for multimodal requests self._loaded = False @property @@ -73,6 +175,27 @@ async def start(self) -> None: from ..engine_core import EngineConfig, AsyncEngineCore from ..scheduler import SchedulerConfig + import os + + # Note on Gemma 3 sliding window configuration: + # - Default sliding_window=1024 works for multimodal (image+text) + # - GEMMA3_SLIDING_WINDOW=0 (full KVCache) enables extended text context + # but BREAKS multimodal generation with longer prompts (~1300+ tokens) + # + # Do NOT auto-set GEMMA3_SLIDING_WINDOW=0 for MLLM models. + # Users who need extended text-only context can manually set: + # GEMMA3_SLIDING_WINDOW=0 (but avoid multimodal with long prompts) + if ("gemma-3" in self._model_name.lower() or "gemma3" in self._model_name.lower()): + sliding_window = os.environ.get("GEMMA3_SLIDING_WINDOW") + if sliding_window is not None: + logger.info( + f"Gemma 3: Using GEMMA3_SLIDING_WINDOW={sliding_window} " + f"(Note: value 0 may cause issues with multimodal + long prompts)" + ) + else: + logger.info( + "Gemma 3: Using default sliding_window=1024 (optimal for multimodal)" + ) # Load model and tokenizer if self._is_mllm: @@ -82,7 +205,12 @@ async def start(self) -> None: trust_remote_code=self._trust_remote_code, ) mllm.load() - self._model = mllm.model + # Keep reference to MLLM for multimodal requests + # (BatchGenerator can't handle pixel_values, so we use MLLM.chat() for images) + self._mllm = mllm + # Wrap MLLM model so BatchGenerator can use it for text-only requests + # (MLLM returns LanguageModelOutput, BatchGenerator expects logits) + self._model = MLLMModelWrapper(mllm.model) self._tokenizer = mllm.processor else: from ..utils.tokenizer import load_model_with_fallback @@ -125,6 +253,7 @@ async def stop(self) -> None: self._engine.engine.close() self._engine = None self._model = None + self._mllm = None self._tokenizer = None self._loaded = False logger.info("BatchedEngine stopped") @@ -271,6 +400,10 @@ async def chat( """ Chat completion (non-streaming). + For MLLM models with images/videos, uses the native MLLM.chat() method + which properly processes multimodal content through the vision encoder. + For text-only requests, uses BatchGenerator for continuous batching. + Args: messages: List of chat messages max_tokens: Maximum tokens to generate @@ -287,6 +420,39 @@ async def chat( if not self._loaded: await self.start() + # Check for multimodal content in messages + has_media, extracted_images, extracted_videos = _extract_media_from_messages(messages) + + # Also check explicit images/videos parameters + if images: + extracted_images.extend(images) + has_media = True + if videos: + extracted_videos.extend(videos) + has_media = True + + # For MLLM with multimodal content, use native MLLM.chat() for correct processing + # BatchGenerator doesn't support pixel_values, so we can't batch multimodal requests + if self._is_mllm and has_media and self._mllm is not None: + logger.debug(f"Routing multimodal request to MLLM.chat() ({len(extracted_images)} images, {len(extracted_videos)} videos)") + + # Run MLLM.chat() in thread pool to avoid blocking + output = await asyncio.to_thread( + self._mllm.chat, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + **kwargs, + ) + + return GenerationOutput( + text=clean_output_text(output.text), + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + finish_reason=output.finish_reason or "stop", + ) + + # For text-only requests, use BatchGenerator for continuous batching # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None @@ -315,6 +481,10 @@ async def stream_chat( """ Stream chat completion token by token. + For MLLM models with images/videos, uses the native MLLM.stream_chat() method + which properly processes multimodal content through the vision encoder. + For text-only requests, uses BatchGenerator for continuous batching. + Args: messages: List of chat messages max_tokens: Maximum tokens to generate @@ -331,6 +501,80 @@ async def stream_chat( if not self._loaded: await self.start() + # Check for multimodal content in messages + has_media, extracted_images, extracted_videos = _extract_media_from_messages(messages) + + # Also check explicit images/videos parameters + if images: + extracted_images.extend(images) + has_media = True + if videos: + extracted_videos.extend(videos) + has_media = True + + # For MLLM with multimodal content, use native MLLM.stream_chat() for correct processing + if self._is_mllm and has_media and self._mllm is not None: + logger.debug(f"Routing multimodal streaming request to MLLM.stream_chat() ({len(extracted_images)} images)") + + # Run MLLM.stream_chat() in thread pool, yielding results + import queue + import threading + + result_queue = queue.Queue() + error_holder = [None] + + def stream_worker(): + try: + for chunk in self._mllm.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + **kwargs, + ): + result_queue.put(chunk) + result_queue.put(None) # Signal completion + except Exception as e: + error_holder[0] = e + result_queue.put(None) + + thread = threading.Thread(target=stream_worker) + thread.start() + + accumulated_text = "" + while True: + # Use asyncio.to_thread for non-blocking queue get + chunk = await asyncio.to_thread(result_queue.get) + if chunk is None: + if error_holder[0]: + raise error_holder[0] + break + + new_text = chunk.text + accumulated_text += new_text + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=chunk.prompt_tokens, + completion_tokens=chunk.completion_tokens, + finished=False, + finish_reason=None, + ) + + thread.join() + + # Final yield with finished=True + yield GenerationOutput( + text=clean_output_text(accumulated_text), + new_text="", + prompt_tokens=chunk.prompt_tokens if chunk else 0, + completion_tokens=chunk.completion_tokens if chunk else 0, + finished=True, + finish_reason="stop", + ) + return + + # For text-only requests, use BatchGenerator for continuous batching # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index b4a72c3a..66feb91c 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -305,18 +305,39 @@ async def stream_chat( # Build prompt using tokenizer if self._is_mllm: - # For MLLM, fall back to non-streaming chat - output = await self.chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - tools=tools, - images=images, - videos=videos, - **kwargs, - ) - yield output + # For MLLM, use stream_chat which yields tokens incrementally + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + return list(self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + **kwargs, + )) + + chunks = await asyncio.to_thread(run_stream) + + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, 'text') else str(chunk) + accumulated_text += new_text + + finished = chunk.finish_reason is not None + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, 'prompt_tokens', 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) + + if finished: + break return # For LLM, apply chat template and stream diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index fc558876..52cf4fae 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -1025,6 +1025,13 @@ def chat( images = [] videos = [] text_prompt = "" + + logger.info(f"MLLM.chat() called with {len(messages)} messages") + for i, msg in enumerate(messages): + logger.info(f" Message {i}: role={msg.get('role')}, content type={type(msg.get('content'))}") + if isinstance(msg.get('content'), list): + for j, item in enumerate(msg.get('content', [])): + logger.info(f" Item {j}: type={item.get('type') if isinstance(item, dict) else type(item)}") for msg in messages: role = msg.get("role", "user") @@ -1121,6 +1128,150 @@ def chat( completion_tokens=generation_tokens, ) + def stream_chat( + self, + messages: list[dict], + max_tokens: int = 256, + temperature: float = 0.7, + **kwargs, + ) -> Iterator[MLLMOutput]: + """ + Stream chat with OpenAI-compatible message format. + + Supports multimodal content in messages: + - {"type": "text", "text": "..."} + - {"type": "image_url", "image_url": {"url": "..."}} + - {"type": "image_url", "image_url": {"url": "data:image/...;base64,..."}} + + Args: + messages: List of chat messages (OpenAI format) + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + **kwargs: Additional parameters + + Yields: + MLLMOutput with incremental text chunks + """ + if not self._loaded: + self.load() + + try: + from mlx_vlm import stream_generate + from mlx_vlm.prompt_utils import apply_chat_template + except ImportError: + # Fallback to non-streaming if stream_generate not available + output = self.chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + **kwargs, + ) + yield output + return + + # Extract text and images from messages (same logic as chat()) + images = [] + videos = [] + text_prompt = "" + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if isinstance(content, str): + if role == "user": + text_prompt = content + elif isinstance(content, list): + # OpenAI multimodal format + for item in content: + if isinstance(item, str): + text_prompt = item + continue + + # Convert Pydantic models to dicts + if hasattr(item, "model_dump"): + item = item.model_dump() + elif hasattr(item, "dict"): + item = item.dict() + + if isinstance(item, dict): + item_type = item.get("type", "") + + if item_type == "text": + text_prompt = item.get("text", "") + + elif item_type == "image_url": + img_url = item.get("image_url", {}) + if isinstance(img_url, str): + images.append(img_url) + else: + images.append(img_url.get("url", "")) + + elif item_type == "image": + images.append(item.get("image", item.get("url", ""))) + + elif item_type == "video": + videos.append(item.get("video", item.get("url", ""))) + + # Process images + all_images = [] + if images: + all_images.extend(self._prepare_images(images)) + + # Process videos + video_fps = kwargs.pop("video_fps", DEFAULT_FPS) + video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) + for video_path in videos: + frames = self._prepare_video( + video_path, fps=video_fps, max_frames=video_max_frames + ) + all_images.extend(frames) + + # Apply chat template + try: + formatted_prompt = apply_chat_template( + self.processor, + self.config, + text_prompt, + num_images=len(all_images), + ) + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}, using raw prompt") + formatted_prompt = text_prompt + + # Stream generate tokens + accumulated_text = "" + token_count = 0 + + for chunk in stream_generate( + self.model, + self.processor, + formatted_prompt, + all_images if all_images else None, + max_tokens=max_tokens, + temp=temperature, + **kwargs, + ): + token_count += 1 + # chunk is a GenerationResult with .text attribute containing the new token + new_text = chunk.text if hasattr(chunk, 'text') else str(chunk) + accumulated_text += new_text + + yield MLLMOutput( + text=new_text, # Just the new token for streaming + finish_reason=None, + prompt_tokens=getattr(chunk, 'prompt_tokens', 0), + completion_tokens=token_count, + ) + + # Final yield with finish_reason + yield MLLMOutput( + text="", + finish_reason="stop", + prompt_tokens=getattr(chunk, 'prompt_tokens', 0) if 'chunk' in dir() else 0, + completion_tokens=token_count, + ) + def describe_image( self, image: str, diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 1bc86474..1a23a720 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -123,6 +123,9 @@ def __init__( self.tokenizer = tokenizer self.config = config or SchedulerConfig() + # Detect if tokenizer is a processor (MLLM) and get the actual tokenizer + self._actual_tokenizer = self._get_actual_tokenizer(tokenizer) + # Request management - following vLLM's design self.waiting: deque[Request] = deque() # Waiting queue (FCFS) self.running: Dict[str, Request] = {} # Running requests by ID @@ -172,20 +175,46 @@ def __init__( self.total_prompt_tokens = 0 self.total_completion_tokens = 0 + def _get_actual_tokenizer(self, tokenizer: Any) -> Any: + """ + Get the actual tokenizer from a processor or tokenizer. + + MLLM models use processors (e.g., Qwen3VLProcessor) which wrap + the tokenizer. This method extracts the actual tokenizer. + """ + # If it has encode method, it's already a tokenizer + if hasattr(tokenizer, 'encode') and callable(tokenizer.encode): + return tokenizer + # If it's a processor, get the wrapped tokenizer + if hasattr(tokenizer, 'tokenizer'): + return tokenizer.tokenizer + # Fallback to the original + return tokenizer + + def _decode_tokens(self, token_ids: List[int]) -> str: + """ + Decode token IDs to text, handling both tokenizers and processors. + """ + return self._actual_tokenizer.decode(token_ids) + def _get_stop_tokens(self) -> Set[int]: - """Get stop token IDs from tokenizer.""" + """Get stop token IDs from tokenizer or processor.""" stop_tokens = set() - if hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: - if isinstance(self.tokenizer.eos_token_id, list): - stop_tokens.update(self.tokenizer.eos_token_id) - else: - stop_tokens.add(self.tokenizer.eos_token_id) - if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids is not None: - if isinstance(self.tokenizer.eos_token_ids, (list, set, tuple)): - stop_tokens.update(self.tokenizer.eos_token_ids) - else: - # Handle case where eos_token_ids is a single int - stop_tokens.add(self.tokenizer.eos_token_ids) + # Check both the processor/tokenizer and the actual tokenizer + for tok in [self.tokenizer, self._actual_tokenizer]: + if tok is None: + continue + if hasattr(tok, 'eos_token_id') and tok.eos_token_id is not None: + if isinstance(tok.eos_token_id, list): + stop_tokens.update(tok.eos_token_id) + else: + stop_tokens.add(tok.eos_token_id) + if hasattr(tok, 'eos_token_ids') and tok.eos_token_ids is not None: + if isinstance(tok.eos_token_ids, (list, set, tuple)): + stop_tokens.update(tok.eos_token_ids) + else: + # Handle case where eos_token_ids is a single int + stop_tokens.add(tok.eos_token_ids) return stop_tokens def _create_batch_generator(self, sampling_params: SamplingParams) -> BatchGenerator: @@ -325,7 +354,17 @@ def add_request(self, request: Request) -> None: # Tokenize if needed if request.prompt_token_ids is None: if isinstance(request.prompt, str): - request.prompt_token_ids = self.tokenizer.encode(request.prompt) + # Handle both tokenizers and processors (for MLLM models) + if hasattr(self.tokenizer, 'encode'): + request.prompt_token_ids = self.tokenizer.encode(request.prompt) + elif hasattr(self.tokenizer, 'tokenizer') and hasattr(self.tokenizer.tokenizer, 'encode'): + # Processor wraps tokenizer (e.g., Qwen3VLProcessor) + request.prompt_token_ids = self.tokenizer.tokenizer.encode(request.prompt) + else: + raise AttributeError( + f"Tokenizer {type(self.tokenizer)} has no 'encode' method. " + "Continuous batching requires a tokenizer with encode support." + ) else: request.prompt_token_ids = list(request.prompt) request.num_prompt_tokens = len(request.prompt_token_ids) @@ -533,7 +572,7 @@ def _process_batch_responses( request.append_output_token(response.token) # Decode the new token - new_text = self.tokenizer.decode([response.token]) + new_text = self._decode_tokens([response.token]) # Create output output = RequestOutput( @@ -557,7 +596,7 @@ def _process_batch_responses( finished_ids.add(request_id) # Decode full output - output.output_text = self.tokenizer.decode(request.output_token_ids) + output.output_text = self._decode_tokens(request.output_token_ids) request.output_text = output.output_text # Extract cache for future reuse diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index efb8f220..0fb31e39 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -633,8 +633,32 @@ async def create_chat_completion(request: ChatCompletionRequest): """ engine = get_engine() - # Extract text, images, and videos from messages - messages, images, videos = extract_multimodal_content(request.messages) + # For MLLM models, keep original messages with embedded images + # (MLLM.chat() extracts images from message content internally) + print(f"DEBUG: engine.is_mllm = {engine.is_mllm}") + if engine.is_mllm: + print("DEBUG: Taking MLLM path") + # Convert Pydantic messages to dicts preserving full content + messages = [] + for msg in request.messages: + msg_dict = msg.model_dump() if hasattr(msg, 'model_dump') else dict(msg) + messages.append(msg_dict) + images, videos = [], [] # MLLM extracts these from messages + # Debug: log message structure + import logging + _logger = logging.getLogger(__name__) + _logger.info(f"MLLM: Processing {len(messages)} messages") + for i, m in enumerate(messages): + c = m.get('content') + if isinstance(c, list): + _logger.info(f" Msg {i}: role={m.get('role')}, content is list with {len(c)} items") + for j, item in enumerate(c): + _logger.info(f" Item {j}: {item.get('type') if isinstance(item, dict) else type(item)}") + else: + _logger.info(f" Msg {i}: role={m.get('role')}, content is {type(c).__name__}") + else: + # For LLM, extract text, images, and videos separately + messages, images, videos = extract_multimodal_content(request.messages) has_media = bool(images or videos)