Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 88 additions & 64 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
This engine wraps AsyncEngineCore to provide continuous batching
for better throughput when serving multiple concurrent requests.

For MLLM models, this engine supports a hybrid approach:
- Text-only requests: Use BatchGenerator for continuous batching
- Multimodal requests (with images/videos): Fall back to MLLM.chat() for correct processing

This is necessary because BatchGenerator only supports token IDs, not pixel_values.
For MLLM models, all requests (text-only and multimodal) are routed through
the MLLMScheduler, which handles vision encoding and batched generation via
MLLMBatchGenerator. MLLM models only initialise the MLLM scheduler (not the
LLM engine), so text-only requests must also be routed through it.
"""

import logging
Expand Down Expand Up @@ -325,70 +324,93 @@ def _apply_chat_template(
tools: list[dict] | None = None,
num_images: int = 0,
) -> str:
"""Apply chat template to messages."""
tokenizer = self.tokenizer
"""Apply chat template to messages.

if self._is_mllm and self._processor:
# Use mlx_vlm's chat template for MLLM
try:
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config

config = getattr(self._model, "config", None)
if config is None:
config = load_config(self._model_name)

# Extract text from last user message
text_prompt = ""
for msg in reversed(messages):
if msg.get("role") == "user":
content = msg.get("content", "")
if isinstance(content, str):
text_prompt = content
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
text_prompt = item
break
elif (
isinstance(item, dict)
and item.get("type") == "text"
):
text_prompt = item.get("text", "")
break
break

return apply_chat_template(
self._processor,
config,
text_prompt,
num_images=num_images,
)
except Exception as e:
logger.warning(f"Failed to apply MLLM chat template: {e}")
# Fall through to standard template
Uses the processor's (or tokenizer's) apply_chat_template with the
full message list so that system prompts and conversation history
are preserved. The previous implementation extracted only the last
user message text via mlx_vlm.prompt_utils.apply_chat_template,
which dropped system prompts and all prior turns.
"""
# Choose the best template applicator.
# For MLLM models, the processor handles special vision tokens.
# For text-only models, the tokenizer is sufficient.
template_applicator = None
if (
self._is_mllm
and self._processor
and hasattr(self._processor, "apply_chat_template")
):
template_applicator = self._processor
elif hasattr(self.tokenizer, "apply_chat_template"):
template_applicator = self.tokenizer

if template_applicator is not None:
# Convert OpenAI image_url content parts to HuggingFace format
# so the processor can insert the correct vision placeholder tokens.
if self._is_mllm and num_images > 0:
messages = self._prepare_mllm_messages(messages)

if hasattr(tokenizer, "apply_chat_template"):
enable_thinking = "coder" not in self._model_name.lower()
template_kwargs = {
"tokenize": False,
"add_generation_prompt": True,
"enable_thinking": enable_thinking,
}
if tools:
template_kwargs["tools"] = tools

try:
return tokenizer.apply_chat_template(messages, **template_kwargs)
except TypeError:
for key in ["tools", "enable_thinking"]:
return template_applicator.apply_chat_template(
messages, **template_kwargs
)
except TypeError as e:
# Some templates don't accept 'tools'; retry without them.
logger.debug(f"Chat template TypeError, retrying without extras: {e}")
for key in ["tools"]:
if key in template_kwargs:
del template_kwargs[key]
return tokenizer.apply_chat_template(messages, **template_kwargs)
return template_applicator.apply_chat_template(
messages, **template_kwargs
)
else:
# Fallback for models without apply_chat_template
prompt = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
return prompt + "\nassistant:"

@staticmethod
def _prepare_mllm_messages(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Convert OpenAI-style image_url content to HuggingFace format.

The OpenAI API uses ``{"type": "image_url", "image_url": {"url": ...}}``
while HuggingFace processors expect ``{"type": "image"}``.

Args:
messages: List of chat messages in OpenAI format. Each message is a
dict with at least ``role`` and ``content`` keys.

Returns:
A new list of messages with ``image_url`` parts replaced by
``{"type": "image"}`` entries for the HuggingFace processor.
"""
prepared = []
for msg in messages:
if not isinstance(msg, dict):
continue
content = msg.get("content")
if isinstance(content, list):
new_content = []
for part in content:
if isinstance(part, dict) and part.get("type") == "image_url":
new_content.append({"type": "image"})
elif isinstance(part, (dict, str)):
new_content.append(part)
# skip non-dict/non-str parts to avoid passing unexpected types
prepared.append({**msg, "content": new_content})
else:
prepared.append(msg)
return prepared

async def generate(
self,
prompt: str,
Expand Down Expand Up @@ -419,8 +441,10 @@ async def generate(
if not self._loaded:
await self.start()

if self._is_mllm and self._mllm_scheduler and (images or videos):
# Use MLLM scheduler for multimodal
if self._is_mllm and self._mllm_scheduler:
# Use MLLM scheduler for all requests when model is multimodal.
# MLLM models only initialise the _mllm_scheduler (not _engine),
# so text-only requests must also be routed here.
output = await self._mllm_scheduler.generate(
prompt=prompt,
images=images,
Expand All @@ -437,7 +461,7 @@ async def generate(
finish_reason=output.finish_reason,
)

# Use LLM engine for text-only
# Use LLM engine for text-only (non-MLLM models)
from ..request import SamplingParams

sampling_params = SamplingParams(
Expand Down Expand Up @@ -491,8 +515,8 @@ async def stream_generate(
if not self._loaded:
await self.start()

if self._is_mllm and self._mllm_scheduler and (images or videos):
# Use MLLM scheduler for multimodal streaming
if self._is_mllm and self._mllm_scheduler:
# Use MLLM scheduler for all streaming when model is multimodal
request_id = await self._mllm_scheduler.add_request_async(
prompt=prompt,
images=images,
Expand Down Expand Up @@ -556,9 +580,9 @@ async def chat(
"""
Chat completion (non-streaming).

For MLLM models with images/videos, uses the native MLLM.chat() method
which properly processes multimodal content through the vision encoder.
For text-only requests, uses BatchGenerator for continuous batching.
For MLLM models, all requests (including text-only) are routed through
the MLLMScheduler for vision-aware batched generation.
For non-MLLM models, uses the LLM engine with BatchGenerator.

Args:
messages: List of chat messages (OpenAI format)
Expand Down Expand Up @@ -667,9 +691,9 @@ async def stream_chat(
"""
Stream chat completion token by token.

For MLLM models with images/videos, uses the native MLLM.stream_chat() method
which properly processes multimodal content through the vision encoder.
For text-only requests, uses BatchGenerator for continuous batching.
For MLLM models, all requests (including text-only) are streamed through
the MLLMScheduler for vision-aware batched generation.
For non-MLLM models, uses the LLM engine with BatchGenerator.

Args:
messages: List of chat messages (OpenAI format)
Expand Down
105 changes: 74 additions & 31 deletions vllm_mlx/mllm_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,15 +551,20 @@ def _preprocess_request(self, request: MLLMBatchRequest) -> None:
f"({processing_time:.2f}s)"
)

def _run_vision_encoding(self, request: MLLMBatchRequest) -> mx.array:
def _run_vision_encoding(
self, request: MLLMBatchRequest, cache: Optional[List[Any]] = None
) -> mx.array:
"""
Run the initial VLM forward pass to encode vision and get first logits.

This runs the full VLM model (vision + language) on the prompt,
which encodes the images and prepares the language model cache.
which encodes the images and fills the provided KV cache.

Args:
request: Preprocessed request with input_ids and pixel_values
cache: KV cache list for the language model. If provided, the
language model writes its KV state directly into this cache
during the forward pass.

Returns:
Logits from the forward pass
Expand All @@ -574,13 +579,14 @@ def _run_vision_encoding(self, request: MLLMBatchRequest) -> mx.array:
if request.image_grid_thw is not None:
kwargs["image_grid_thw"] = request.image_grid_thw

# Run full VLM forward pass
# This processes vision inputs and fills the language model cache
# Run full VLM forward pass with cache.
# The VLM passes cache= through to self.language_model(),
# so the language model writes KV state directly into our cache.
input_ids = request.input_ids
if input_ids.ndim == 1:
input_ids = input_ids[None, :]

output = self.model(input_ids, **kwargs)
output = self.model(input_ids, cache=cache, **kwargs)
request.vision_encoded = True

# Handle LanguageModelOutput or plain tensor
Expand All @@ -594,47 +600,55 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:

For MLLM, this is more complex than LLM:
1. Preprocess each request (tokenize, process images)
2. Run vision encoding for each request (cannot batch vision yet)
3. Set up BatchKVCache for language model generation
2. Run vision encoding per-request with individual KVCache objects
3. Merge individual caches into a BatchKVCache for generation

Args:
requests: Requests to process

Returns:
MLLMBatch ready for generation
"""
from mlx_lm.models.cache import make_prompt_cache

tic = time.perf_counter()

# Preprocess all requests
for req in requests:
self._preprocess_request(req)

# Get token sequences and lengths
input_ids_list = [
req.input_ids.tolist() if req.input_ids is not None else [0]
for req in requests
]
lengths = [len(ids) for ids in input_ids_list]
max_length = max(lengths)
padding = [max_length - seq_len for seq_len in lengths]

self._stats.prompt_tokens += sum(lengths)

# Create batch cache for language model
batch_cache = _make_batch_cache(self.language_model, padding)
total_prompt_tokens = sum(
req.input_ids.size if req.input_ids is not None else 1 for req in requests
)
self._stats.prompt_tokens += total_prompt_tokens

# Guard against excessive memory usage during cache merge.
# Each token in the batch requires KV entries across all layers.
max_batch_tokens = self.prefill_step_size * len(requests)
if total_prompt_tokens > max_batch_tokens:
raise ValueError(
f"Total prompt tokens ({total_prompt_tokens}) exceeds safe limit "
f"({max_batch_tokens}) for {len(requests)} requests. "
f"Reduce prompt length or batch size."
)

# Run vision encoding for each request and fill cache
# This must be done per-request because vision inputs differ
# Run vision encoding for each request with its own KVCache.
# Vision encoding cannot be batched because each request may have
# different images/pixel values. We pass a per-request KVCache to
# the VLM so the language model writes its KV state directly into it.
first_tokens = []
all_logprobs = []
per_request_caches = []

for req in requests:
# Create a fresh KVCache for this request's language model prefill
request_cache = make_prompt_cache(self.language_model)

for i, req in enumerate(requests):
# Run full VLM forward pass for this request
# This fills the cache for layer i with this request's KV states
with mx.stream(MLLMBatchGenerator._stream):
logits = self._run_vision_encoding(req)
# Run VLM forward pass — cache= flows through to language_model
logits = self._run_vision_encoding(req, cache=request_cache)

# Extract last token logits
# Extract last token logits and sample
last_logits = logits[:, -1, :]
logprobs = last_logits - mx.logsumexp(
last_logits, axis=-1, keepdims=True
Expand All @@ -646,6 +660,35 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
first_tokens.append(sampled.item())
all_logprobs.append(logprobs.squeeze(0))

per_request_caches.append(request_cache)

# Merge per-request KVCaches into a single BatchKVCache.
# KVCache.merge() creates a BatchKVCache with proper left-padding
# alignment, so all requests share a single batched cache for
# subsequent generation steps.
from mlx_lm.models.cache import KVCache

sample_cache = per_request_caches[0][0]
if not isinstance(sample_cache, KVCache):
raise ValueError(
f"MLLM continuous batching requires standard KVCache but got "
f"{type(sample_cache).__name__}. Disable --kv-cache-quantization "
f"when using multimodal models with --continuous-batching."
)

try:
batch_cache = [
per_request_caches[0][layer_idx].merge(
[c[layer_idx] for c in per_request_caches]
)
for layer_idx in range(len(per_request_caches[0]))
]
except Exception as e:
logger.error(
f"Failed to merge per-request KV caches: {type(e).__name__}: {e}"
)
raise

# Create initial y (first generated tokens)
y = mx.array(first_tokens)

Expand Down Expand Up @@ -710,10 +753,10 @@ def _next(self) -> List[MLLMBatchResponse]:
num_active = len(batch) if batch else 0

# Only start a new batch when there is no active batch generating.
# MLLM vision encoding produces per-request KV caches that cannot be
# safely extended into an active batch's cache (shape mismatch in
# attention layers). Instead, queued requests wait until the current
# batch finishes, then all get processed together in one prefill.
# Per-request KV caches are created during vision encoding and then
# merged into a single BatchKVCache. Merging into an active batch
# mid-generation would cause shape mismatches in attention layers,
# so queued requests wait until the current batch finishes.
if num_active == 0:
requests = self.unprocessed_requests[: self.completion_batch_size]

Expand Down
Loading