diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index d94cfbb9..3aeda361 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -23,13 +23,28 @@ class MLLMModelWrapper: BatchGenerator expects model output to be subscriptable (logits array), but MLLM models return LanguageModelOutput objects. This wrapper extracts the logits from the output. + + Additionally, some models like Gemma 3 require `pixel_values` as a required + positional argument, while others like Qwen2-VL make it optional. This wrapper + ensures `pixel_values=None` is passed for text-only requests. """ def __init__(self, model): self._model = model def __call__(self, *args, **kwargs): - """Call the model and extract logits from LanguageModelOutput.""" + """Call the model and extract logits from LanguageModelOutput. + + Handles both models where pixel_values is optional (Qwen2-VL) and + models where it's required (Gemma 3) by ensuring pixel_values=None + is passed for text-only requests. + """ + # For text-only requests, BatchGenerator calls model(input_ids, cache=cache) + # But Gemma 3 requires pixel_values as 2nd positional arg. + # Inject pixel_values=None if not provided and only input_ids is passed + if 'pixel_values' not in kwargs and len(args) == 1: + kwargs['pixel_values'] = None + output = self._model(*args, **kwargs) # If output has logits attribute, return just the logits if hasattr(output, 'logits'):