Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,28 @@ class MLLMModelWrapper:
BatchGenerator expects model output to be subscriptable (logits array),
but MLLM models return LanguageModelOutput objects. This wrapper extracts
the logits from the output.

Additionally, some models like Gemma 3 require `pixel_values` as a required
positional argument, while others like Qwen2-VL make it optional. This wrapper
ensures `pixel_values=None` is passed for text-only requests.
"""

def __init__(self, model):
self._model = model

def __call__(self, *args, **kwargs):
"""Call the model and extract logits from LanguageModelOutput."""
"""Call the model and extract logits from LanguageModelOutput.

Handles both models where pixel_values is optional (Qwen2-VL) and
models where it's required (Gemma 3) by ensuring pixel_values=None
is passed for text-only requests.
"""
# For text-only requests, BatchGenerator calls model(input_ids, cache=cache)
# But Gemma 3 requires pixel_values as 2nd positional arg.
# Inject pixel_values=None if not provided and only input_ids is passed
if 'pixel_values' not in kwargs and len(args) == 1:
kwargs['pixel_values'] = None

output = self._model(*args, **kwargs)
# If output has logits attribute, return just the logits
if hasattr(output, 'logits'):
Expand Down