Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm_mlx/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def clean_output_text(text: str) -> str:
"llava", "LLaVA", # LLaVA models
"idefics", "Idefics", # Idefics models
"paligemma", "PaliGemma", # PaliGemma
"gemma-3", "gemma3", # Gemma 3 (multimodal)
"gemma-3", "gemma3", # Gemma 3 (multimodal) - FIXED in mlx_vlm/models/gemma3/language.py
"pixtral", "Pixtral", # Pixtral
"molmo", "Molmo", # Molmo
"phi3-vision", "phi-3-vision", # Phi-3 Vision
Expand Down
33 changes: 20 additions & 13 deletions vllm_mlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,20 @@ class MLLMModelWrapper:
but MLLM models return LanguageModelOutput objects. This wrapper extracts
the logits from the output.

Additionally, some models like Gemma 3 require `pixel_values` as a required
positional argument, while others like Qwen2-VL make it optional. This wrapper
ensures `pixel_values=None` is passed for text-only requests.
Also handles Gemma 3's required pixel_values argument by injecting None
for text-only requests.
"""

def __init__(self, model):
self._model = model
# Detect if this is a Gemma 3 model (requires pixel_values as positional arg)
self._is_gemma3 = hasattr(model, 'model_type') and 'gemma3' in str(getattr(model, 'model_type', '')).lower()

def __call__(self, *args, **kwargs):
"""Call the model and extract logits from LanguageModelOutput.

Handles both models where pixel_values is optional (Qwen2-VL) and
models where it's required (Gemma 3) by ensuring pixel_values=None
is passed for text-only requests.
"""
# For text-only requests, BatchGenerator calls model(input_ids, cache=cache)
# But Gemma 3 requires pixel_values as 2nd positional arg.
# Inject pixel_values=None if not provided and only input_ids is passed
if 'pixel_values' not in kwargs and len(args) == 1:
"""Call the model and extract logits from LanguageModelOutput."""
# Gemma 3 requires pixel_values as a positional argument, unlike Qwen
# which makes it optional. Inject pixel_values=None for text-only requests.
if self._is_gemma3 and 'pixel_values' not in kwargs:
kwargs['pixel_values'] = None

output = self._model(*args, **kwargs)
Expand Down Expand Up @@ -113,6 +108,18 @@ async def start(self) -> None:

from ..engine_core import EngineConfig, AsyncEngineCore
from ..scheduler import SchedulerConfig
import os

# Auto-configure Gemma 3 for continuous batching compatibility
# Gemma 3's RotatingKVCache with sliding_window causes garbled output
# in batch mode due to offset tracking corruption. Force full KVCache.
if self._is_mllm and ("gemma-3" in self._model_name.lower() or "gemma3" in self._model_name.lower()):
if os.environ.get("GEMMA3_SLIDING_WINDOW") is None:
os.environ["GEMMA3_SLIDING_WINDOW"] = "0"
logger.info(
"Auto-set GEMMA3_SLIDING_WINDOW=0 for continuous batching compatibility. "
"This uses full KVCache for all layers (~35GB at 50K tokens)."
)

# Load model and tokenizer
if self._is_mllm:
Expand Down