diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 3b2e973e..840bb069 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -5,11 +5,10 @@ This engine wraps AsyncEngineCore to provide continuous batching for better throughput when serving multiple concurrent requests. -For MLLM models, this engine supports a hybrid approach: -- Text-only requests: Use BatchGenerator for continuous batching -- Multimodal requests (with images/videos): Fall back to MLLM.chat() for correct processing - -This is necessary because BatchGenerator only supports token IDs, not pixel_values. +For MLLM models, all requests (text-only and multimodal) are routed through +the MLLMScheduler, which handles vision encoding and batched generation via +MLLMBatchGenerator. MLLM models only initialise the MLLM scheduler (not the +LLM engine), so text-only requests must also be routed through it. """ import logging @@ -325,70 +324,93 @@ def _apply_chat_template( tools: list[dict] | None = None, num_images: int = 0, ) -> str: - """Apply chat template to messages.""" - tokenizer = self.tokenizer + """Apply chat template to messages. - if self._is_mllm and self._processor: - # Use mlx_vlm's chat template for MLLM - try: - from mlx_vlm.prompt_utils import apply_chat_template - from mlx_vlm.utils import load_config - - config = getattr(self._model, "config", None) - if config is None: - config = load_config(self._model_name) - - # Extract text from last user message - text_prompt = "" - for msg in reversed(messages): - if msg.get("role") == "user": - content = msg.get("content", "") - if isinstance(content, str): - text_prompt = content - elif isinstance(content, list): - for item in content: - if isinstance(item, str): - text_prompt = item - break - elif ( - isinstance(item, dict) - and item.get("type") == "text" - ): - text_prompt = item.get("text", "") - break - break - - return apply_chat_template( - self._processor, - config, - text_prompt, - num_images=num_images, - ) - except Exception as e: - logger.warning(f"Failed to apply MLLM chat template: {e}") - # Fall through to standard template + Uses the processor's (or tokenizer's) apply_chat_template with the + full message list so that system prompts and conversation history + are preserved. The previous implementation extracted only the last + user message text via mlx_vlm.prompt_utils.apply_chat_template, + which dropped system prompts and all prior turns. + """ + # Choose the best template applicator. + # For MLLM models, the processor handles special vision tokens. + # For text-only models, the tokenizer is sufficient. + template_applicator = None + if ( + self._is_mllm + and self._processor + and hasattr(self._processor, "apply_chat_template") + ): + template_applicator = self._processor + elif hasattr(self.tokenizer, "apply_chat_template"): + template_applicator = self.tokenizer + + if template_applicator is not None: + # Convert OpenAI image_url content parts to HuggingFace format + # so the processor can insert the correct vision placeholder tokens. + if self._is_mllm and num_images > 0: + messages = self._prepare_mllm_messages(messages) - if hasattr(tokenizer, "apply_chat_template"): - enable_thinking = "coder" not in self._model_name.lower() template_kwargs = { "tokenize": False, "add_generation_prompt": True, - "enable_thinking": enable_thinking, } if tools: template_kwargs["tools"] = tools try: - return tokenizer.apply_chat_template(messages, **template_kwargs) - except TypeError: - for key in ["tools", "enable_thinking"]: + return template_applicator.apply_chat_template( + messages, **template_kwargs + ) + except TypeError as e: + # Some templates don't accept 'tools'; retry without them. + logger.debug(f"Chat template TypeError, retrying without extras: {e}") + for key in ["tools"]: if key in template_kwargs: del template_kwargs[key] - return tokenizer.apply_chat_template(messages, **template_kwargs) + return template_applicator.apply_chat_template( + messages, **template_kwargs + ) else: + # Fallback for models without apply_chat_template prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages) return prompt + "\nassistant:" + @staticmethod + def _prepare_mllm_messages( + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Convert OpenAI-style image_url content to HuggingFace format. + + The OpenAI API uses ``{"type": "image_url", "image_url": {"url": ...}}`` + while HuggingFace processors expect ``{"type": "image"}``. + + Args: + messages: List of chat messages in OpenAI format. Each message is a + dict with at least ``role`` and ``content`` keys. + + Returns: + A new list of messages with ``image_url`` parts replaced by + ``{"type": "image"}`` entries for the HuggingFace processor. + """ + prepared = [] + for msg in messages: + if not isinstance(msg, dict): + continue + content = msg.get("content") + if isinstance(content, list): + new_content = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "image_url": + new_content.append({"type": "image"}) + elif isinstance(part, (dict, str)): + new_content.append(part) + # skip non-dict/non-str parts to avoid passing unexpected types + prepared.append({**msg, "content": new_content}) + else: + prepared.append(msg) + return prepared + async def generate( self, prompt: str, @@ -419,8 +441,10 @@ async def generate( if not self._loaded: await self.start() - if self._is_mllm and self._mllm_scheduler and (images or videos): - # Use MLLM scheduler for multimodal + if self._is_mllm and self._mllm_scheduler: + # Use MLLM scheduler for all requests when model is multimodal. + # MLLM models only initialise the _mllm_scheduler (not _engine), + # so text-only requests must also be routed here. output = await self._mllm_scheduler.generate( prompt=prompt, images=images, @@ -437,7 +461,7 @@ async def generate( finish_reason=output.finish_reason, ) - # Use LLM engine for text-only + # Use LLM engine for text-only (non-MLLM models) from ..request import SamplingParams sampling_params = SamplingParams( @@ -491,8 +515,8 @@ async def stream_generate( if not self._loaded: await self.start() - if self._is_mllm and self._mllm_scheduler and (images or videos): - # Use MLLM scheduler for multimodal streaming + if self._is_mllm and self._mllm_scheduler: + # Use MLLM scheduler for all streaming when model is multimodal request_id = await self._mllm_scheduler.add_request_async( prompt=prompt, images=images, @@ -556,9 +580,9 @@ async def chat( """ Chat completion (non-streaming). - For MLLM models with images/videos, uses the native MLLM.chat() method - which properly processes multimodal content through the vision encoder. - For text-only requests, uses BatchGenerator for continuous batching. + For MLLM models, all requests (including text-only) are routed through + the MLLMScheduler for vision-aware batched generation. + For non-MLLM models, uses the LLM engine with BatchGenerator. Args: messages: List of chat messages (OpenAI format) @@ -667,9 +691,9 @@ async def stream_chat( """ Stream chat completion token by token. - For MLLM models with images/videos, uses the native MLLM.stream_chat() method - which properly processes multimodal content through the vision encoder. - For text-only requests, uses BatchGenerator for continuous batching. + For MLLM models, all requests (including text-only) are streamed through + the MLLMScheduler for vision-aware batched generation. + For non-MLLM models, uses the LLM engine with BatchGenerator. Args: messages: List of chat messages (OpenAI format) diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index fba3ae02..ee8d8da7 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -551,15 +551,20 @@ def _preprocess_request(self, request: MLLMBatchRequest) -> None: f"({processing_time:.2f}s)" ) - def _run_vision_encoding(self, request: MLLMBatchRequest) -> mx.array: + def _run_vision_encoding( + self, request: MLLMBatchRequest, cache: Optional[List[Any]] = None + ) -> mx.array: """ Run the initial VLM forward pass to encode vision and get first logits. This runs the full VLM model (vision + language) on the prompt, - which encodes the images and prepares the language model cache. + which encodes the images and fills the provided KV cache. Args: request: Preprocessed request with input_ids and pixel_values + cache: KV cache list for the language model. If provided, the + language model writes its KV state directly into this cache + during the forward pass. Returns: Logits from the forward pass @@ -574,13 +579,14 @@ def _run_vision_encoding(self, request: MLLMBatchRequest) -> mx.array: if request.image_grid_thw is not None: kwargs["image_grid_thw"] = request.image_grid_thw - # Run full VLM forward pass - # This processes vision inputs and fills the language model cache + # Run full VLM forward pass with cache. + # The VLM passes cache= through to self.language_model(), + # so the language model writes KV state directly into our cache. input_ids = request.input_ids if input_ids.ndim == 1: input_ids = input_ids[None, :] - output = self.model(input_ids, **kwargs) + output = self.model(input_ids, cache=cache, **kwargs) request.vision_encoded = True # Handle LanguageModelOutput or plain tensor @@ -594,8 +600,8 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: For MLLM, this is more complex than LLM: 1. Preprocess each request (tokenize, process images) - 2. Run vision encoding for each request (cannot batch vision yet) - 3. Set up BatchKVCache for language model generation + 2. Run vision encoding per-request with individual KVCache objects + 3. Merge individual caches into a BatchKVCache for generation Args: requests: Requests to process @@ -603,38 +609,46 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: Returns: MLLMBatch ready for generation """ + from mlx_lm.models.cache import make_prompt_cache + tic = time.perf_counter() # Preprocess all requests for req in requests: self._preprocess_request(req) - # Get token sequences and lengths - input_ids_list = [ - req.input_ids.tolist() if req.input_ids is not None else [0] - for req in requests - ] - lengths = [len(ids) for ids in input_ids_list] - max_length = max(lengths) - padding = [max_length - seq_len for seq_len in lengths] - - self._stats.prompt_tokens += sum(lengths) - - # Create batch cache for language model - batch_cache = _make_batch_cache(self.language_model, padding) + total_prompt_tokens = sum( + req.input_ids.size if req.input_ids is not None else 1 for req in requests + ) + self._stats.prompt_tokens += total_prompt_tokens + + # Guard against excessive memory usage during cache merge. + # Each token in the batch requires KV entries across all layers. + max_batch_tokens = self.prefill_step_size * len(requests) + if total_prompt_tokens > max_batch_tokens: + raise ValueError( + f"Total prompt tokens ({total_prompt_tokens}) exceeds safe limit " + f"({max_batch_tokens}) for {len(requests)} requests. " + f"Reduce prompt length or batch size." + ) - # Run vision encoding for each request and fill cache - # This must be done per-request because vision inputs differ + # Run vision encoding for each request with its own KVCache. + # Vision encoding cannot be batched because each request may have + # different images/pixel values. We pass a per-request KVCache to + # the VLM so the language model writes its KV state directly into it. first_tokens = [] all_logprobs = [] + per_request_caches = [] + + for req in requests: + # Create a fresh KVCache for this request's language model prefill + request_cache = make_prompt_cache(self.language_model) - for i, req in enumerate(requests): - # Run full VLM forward pass for this request - # This fills the cache for layer i with this request's KV states with mx.stream(MLLMBatchGenerator._stream): - logits = self._run_vision_encoding(req) + # Run VLM forward pass — cache= flows through to language_model + logits = self._run_vision_encoding(req, cache=request_cache) - # Extract last token logits + # Extract last token logits and sample last_logits = logits[:, -1, :] logprobs = last_logits - mx.logsumexp( last_logits, axis=-1, keepdims=True @@ -646,6 +660,35 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: first_tokens.append(sampled.item()) all_logprobs.append(logprobs.squeeze(0)) + per_request_caches.append(request_cache) + + # Merge per-request KVCaches into a single BatchKVCache. + # KVCache.merge() creates a BatchKVCache with proper left-padding + # alignment, so all requests share a single batched cache for + # subsequent generation steps. + from mlx_lm.models.cache import KVCache + + sample_cache = per_request_caches[0][0] + if not isinstance(sample_cache, KVCache): + raise ValueError( + f"MLLM continuous batching requires standard KVCache but got " + f"{type(sample_cache).__name__}. Disable --kv-cache-quantization " + f"when using multimodal models with --continuous-batching." + ) + + try: + batch_cache = [ + per_request_caches[0][layer_idx].merge( + [c[layer_idx] for c in per_request_caches] + ) + for layer_idx in range(len(per_request_caches[0])) + ] + except Exception as e: + logger.error( + f"Failed to merge per-request KV caches: {type(e).__name__}: {e}" + ) + raise + # Create initial y (first generated tokens) y = mx.array(first_tokens) @@ -710,10 +753,10 @@ def _next(self) -> List[MLLMBatchResponse]: num_active = len(batch) if batch else 0 # Only start a new batch when there is no active batch generating. - # MLLM vision encoding produces per-request KV caches that cannot be - # safely extended into an active batch's cache (shape mismatch in - # attention layers). Instead, queued requests wait until the current - # batch finishes, then all get processed together in one prefill. + # Per-request KV caches are created during vision encoding and then + # merged into a single BatchKVCache. Merging into an active batch + # mid-generation would cause shape mismatches in attention layers, + # so queued requests wait until the current batch finishes. if num_active == 0: requests = self.unprocessed_requests[: self.completion_batch_size]