Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 187 additions & 1 deletion vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@

This engine wraps AsyncEngineCore to provide continuous batching
for better throughput when serving multiple concurrent requests.

For MLLM models, this engine supports a hybrid approach:
- Text-only requests: Use BatchGenerator for continuous batching
- Multimodal requests (with images/videos): Fall back to MLLM.chat() for correct processing

This is necessary because BatchGenerator only supports token IDs, not pixel_values.
"""

import asyncio
import logging
from typing import Any, AsyncIterator, Dict, List, Optional

Expand All @@ -16,6 +23,65 @@
logger = logging.getLogger(__name__)


def _extract_media_from_messages(messages: List[Dict[str, Any]]) -> tuple:
"""
Extract images and videos from OpenAI-format messages.

Returns:
Tuple of (has_media, images_list, videos_list)
"""
images = []
videos = []

for msg in messages:
content = msg.get("content")
if not isinstance(content, list):
continue

for item in content:
# Handle Pydantic models
if hasattr(item, "model_dump"):
item = item.model_dump()
elif hasattr(item, "dict"):
item = item.dict()

if not isinstance(item, dict):
continue

item_type = item.get("type", "")

if item_type == "image_url":
img_url = item.get("image_url", {})
if isinstance(img_url, str):
images.append(img_url)
elif isinstance(img_url, dict):
url = img_url.get("url", "")
if url:
images.append(url)

elif item_type == "image":
img = item.get("image") or item.get("url", "")
if img:
images.append(img)

elif item_type == "video_url":
vid_url = item.get("video_url", {})
if isinstance(vid_url, str):
videos.append(vid_url)
elif isinstance(vid_url, dict):
url = vid_url.get("url", "")
if url:
videos.append(url)

elif item_type == "video":
vid = item.get("video") or item.get("url", "")
if vid:
videos.append(vid)

has_media = bool(images or videos)
return has_media, images, videos


class MLLMModelWrapper:
"""
Wrapper for MLLM models to make them compatible with BatchGenerator.
Expand Down Expand Up @@ -84,6 +150,7 @@ def __init__(
self._model = None
self._tokenizer = None
self._engine = None
self._mllm = None # Keep reference to MLLM for multimodal requests
self._loaded = False

@property
Expand Down Expand Up @@ -129,7 +196,10 @@ async def start(self) -> None:
trust_remote_code=self._trust_remote_code,
)
mllm.load()
# Wrap MLLM model so BatchGenerator can use it
# Keep reference to MLLM for multimodal requests
# (BatchGenerator can't handle pixel_values, so we use MLLM.chat() for images)
self._mllm = mllm
# Wrap MLLM model so BatchGenerator can use it for text-only requests
# (MLLM returns LanguageModelOutput, BatchGenerator expects logits)
self._model = MLLMModelWrapper(mllm.model)
self._tokenizer = mllm.processor
Expand Down Expand Up @@ -174,6 +244,7 @@ async def stop(self) -> None:
self._engine.engine.close()
self._engine = None
self._model = None
self._mllm = None
self._tokenizer = None
self._loaded = False
logger.info("BatchedEngine stopped")
Expand Down Expand Up @@ -320,6 +391,10 @@ async def chat(
"""
Chat completion (non-streaming).

For MLLM models with images/videos, uses the native MLLM.chat() method
which properly processes multimodal content through the vision encoder.
For text-only requests, uses BatchGenerator for continuous batching.

Args:
messages: List of chat messages
max_tokens: Maximum tokens to generate
Expand All @@ -336,6 +411,39 @@ async def chat(
if not self._loaded:
await self.start()

# Check for multimodal content in messages
has_media, extracted_images, extracted_videos = _extract_media_from_messages(messages)

# Also check explicit images/videos parameters
if images:
extracted_images.extend(images)
has_media = True
if videos:
extracted_videos.extend(videos)
has_media = True

# For MLLM with multimodal content, use native MLLM.chat() for correct processing
# BatchGenerator doesn't support pixel_values, so we can't batch multimodal requests
if self._is_mllm and has_media and self._mllm is not None:
logger.debug(f"Routing multimodal request to MLLM.chat() ({len(extracted_images)} images, {len(extracted_videos)} videos)")

# Run MLLM.chat() in thread pool to avoid blocking
output = await asyncio.to_thread(
self._mllm.chat,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
**kwargs,
)

return GenerationOutput(
text=clean_output_text(output.text),
prompt_tokens=output.prompt_tokens,
completion_tokens=output.completion_tokens,
finish_reason=output.finish_reason or "stop",
)

# For text-only requests, use BatchGenerator for continuous batching
# Convert tools for template
template_tools = convert_tools_for_template(tools) if tools else None

Expand Down Expand Up @@ -364,6 +472,10 @@ async def stream_chat(
"""
Stream chat completion token by token.

For MLLM models with images/videos, uses the native MLLM.stream_chat() method
which properly processes multimodal content through the vision encoder.
For text-only requests, uses BatchGenerator for continuous batching.

Args:
messages: List of chat messages
max_tokens: Maximum tokens to generate
Expand All @@ -380,6 +492,80 @@ async def stream_chat(
if not self._loaded:
await self.start()

# Check for multimodal content in messages
has_media, extracted_images, extracted_videos = _extract_media_from_messages(messages)

# Also check explicit images/videos parameters
if images:
extracted_images.extend(images)
has_media = True
if videos:
extracted_videos.extend(videos)
has_media = True

# For MLLM with multimodal content, use native MLLM.stream_chat() for correct processing
if self._is_mllm and has_media and self._mllm is not None:
logger.debug(f"Routing multimodal streaming request to MLLM.stream_chat() ({len(extracted_images)} images)")

# Run MLLM.stream_chat() in thread pool, yielding results
import queue
import threading

result_queue = queue.Queue()
error_holder = [None]

def stream_worker():
try:
for chunk in self._mllm.stream_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
**kwargs,
):
result_queue.put(chunk)
result_queue.put(None) # Signal completion
except Exception as e:
error_holder[0] = e
result_queue.put(None)

thread = threading.Thread(target=stream_worker)
thread.start()

accumulated_text = ""
while True:
# Use asyncio.to_thread for non-blocking queue get
chunk = await asyncio.to_thread(result_queue.get)
if chunk is None:
if error_holder[0]:
raise error_holder[0]
break

new_text = chunk.text
accumulated_text += new_text

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
prompt_tokens=chunk.prompt_tokens,
completion_tokens=chunk.completion_tokens,
finished=False,
finish_reason=None,
)

thread.join()

# Final yield with finished=True
yield GenerationOutput(
text=clean_output_text(accumulated_text),
new_text="",
prompt_tokens=chunk.prompt_tokens if chunk else 0,
completion_tokens=chunk.completion_tokens if chunk else 0,
finished=True,
finish_reason="stop",
)
return

# For text-only requests, use BatchGenerator for continuous batching
# Convert tools for template
template_tools = convert_tools_for_template(tools) if tools else None

Expand Down