diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index 795ee39d..a1279fc1 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -59,7 +59,7 @@ def clean_output_text(text: str) -> str: "llava", "LLaVA", # LLaVA models "idefics", "Idefics", # Idefics models "paligemma", "PaliGemma", # PaliGemma - "gemma-3", "gemma3", # Gemma 3 (multimodal) + "gemma-3", "gemma3", # Gemma 3 (multimodal) - FIXED in mlx_vlm/models/gemma3/language.py "pixtral", "Pixtral", # Pixtral "molmo", "Molmo", # Molmo "phi3-vision", "phi-3-vision", # Phi-3 Vision diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index 3aeda361..c63410fb 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -24,25 +24,20 @@ class MLLMModelWrapper: but MLLM models return LanguageModelOutput objects. This wrapper extracts the logits from the output. - Additionally, some models like Gemma 3 require `pixel_values` as a required - positional argument, while others like Qwen2-VL make it optional. This wrapper - ensures `pixel_values=None` is passed for text-only requests. + Also handles Gemma 3's required pixel_values argument by injecting None + for text-only requests. """ def __init__(self, model): self._model = model + # Detect if this is a Gemma 3 model (requires pixel_values as positional arg) + self._is_gemma3 = hasattr(model, 'model_type') and 'gemma3' in str(getattr(model, 'model_type', '')).lower() def __call__(self, *args, **kwargs): - """Call the model and extract logits from LanguageModelOutput. - - Handles both models where pixel_values is optional (Qwen2-VL) and - models where it's required (Gemma 3) by ensuring pixel_values=None - is passed for text-only requests. - """ - # For text-only requests, BatchGenerator calls model(input_ids, cache=cache) - # But Gemma 3 requires pixel_values as 2nd positional arg. - # Inject pixel_values=None if not provided and only input_ids is passed - if 'pixel_values' not in kwargs and len(args) == 1: + """Call the model and extract logits from LanguageModelOutput.""" + # Gemma 3 requires pixel_values as a positional argument, unlike Qwen + # which makes it optional. Inject pixel_values=None for text-only requests. + if self._is_gemma3 and 'pixel_values' not in kwargs: kwargs['pixel_values'] = None output = self._model(*args, **kwargs) @@ -113,6 +108,18 @@ async def start(self) -> None: from ..engine_core import EngineConfig, AsyncEngineCore from ..scheduler import SchedulerConfig + import os + + # Auto-configure Gemma 3 for continuous batching compatibility + # Gemma 3's RotatingKVCache with sliding_window causes garbled output + # in batch mode due to offset tracking corruption. Force full KVCache. + if self._is_mllm and ("gemma-3" in self._model_name.lower() or "gemma3" in self._model_name.lower()): + if os.environ.get("GEMMA3_SLIDING_WINDOW") is None: + os.environ["GEMMA3_SLIDING_WINDOW"] = "0" + logger.info( + "Auto-set GEMMA3_SLIDING_WINDOW=0 for continuous batching compatibility. " + "This uses full KVCache for all layers (~35GB at 50K tokens)." + ) # Load model and tokenizer if self._is_mllm: