diff --git a/README.md b/README.md index b1f2ea22..8db90799 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,71 @@ For full documentation, see the [docs](docs/) directory: See [benchmarks](docs/benchmarks/) for detailed results. +## Gemma 3 Support + +This fork includes patches for Gemma 3 vision support. Gemma 3 is a multimodal model but requires detection as MLLM. + +### Usage + +```bash +# Start server with Gemma 3 +vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 + +# Verify it loaded as MLLM (not LLM) +curl http://localhost:8000/health +# Should show: "model_type": "mllm" +``` + +### Long Context Patch (mlx-vlm) + +Gemma 3's default `sliding_window=1024` limits context to ~10K tokens on Apple Silicon (Metal GPU timeout at higher context). To enable longer context (up to ~50K tokens), patch mlx-vlm: + +**Location:** `~/.../site-packages/mlx_vlm/models/gemma3/language.py` + +Find the `make_cache` method and replace with: + +```python +def make_cache(self): + import os + # Set GEMMA3_SLIDING_WINDOW=8192 for ~40K context + # Set GEMMA3_SLIDING_WINDOW=0 for ~50K context (full KVCache) + sliding_window = int(os.environ.get('GEMMA3_SLIDING_WINDOW', self.config.sliding_window)) + + caches = [] + for i in range(self.config.num_hidden_layers): + if ( + i % self.config.sliding_window_pattern + == self.config.sliding_window_pattern - 1 + ): + caches.append(KVCache()) + elif sliding_window == 0: + caches.append(KVCache()) # Full context for all layers + else: + caches.append(RotatingKVCache(max_size=sliding_window, keep=0)) + return caches +``` + +**Usage:** + +```bash +# Default (~10K max context) +vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 + +# Extended context (~40K max) +GEMMA3_SLIDING_WINDOW=8192 vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 + +# Maximum context (~50K max) +GEMMA3_SLIDING_WINDOW=0 vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000 +``` + +**Benchmark Results (M4 Max 128GB):** + +| Setting | Max Context | Memory | +|---------|-------------|--------| +| Default (1024) | ~10K tokens | ~16GB | +| `GEMMA3_SLIDING_WINDOW=8192` | ~40K tokens | ~25GB | +| `GEMMA3_SLIDING_WINDOW=0` | ~50K tokens | ~35GB | + ## Contributing We welcome contributions! See [Contributing Guide](docs/development/contributing.md) for details. diff --git a/vllm_mlx/api/utils.py b/vllm_mlx/api/utils.py index ac9f8167..795ee39d 100644 --- a/vllm_mlx/api/utils.py +++ b/vllm_mlx/api/utils.py @@ -59,6 +59,7 @@ def clean_output_text(text: str) -> str: "llava", "LLaVA", # LLaVA models "idefics", "Idefics", # Idefics models "paligemma", "PaliGemma", # PaliGemma + "gemma-3", "gemma3", # Gemma 3 (multimodal) "pixtral", "Pixtral", # Pixtral "molmo", "Molmo", # Molmo "phi3-vision", "phi-3-vision", # Phi-3 Vision diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index b4a72c3a..66feb91c 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -305,18 +305,39 @@ async def stream_chat( # Build prompt using tokenizer if self._is_mllm: - # For MLLM, fall back to non-streaming chat - output = await self.chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - tools=tools, - images=images, - videos=videos, - **kwargs, - ) - yield output + # For MLLM, use stream_chat which yields tokens incrementally + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + return list(self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + **kwargs, + )) + + chunks = await asyncio.to_thread(run_stream) + + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, 'text') else str(chunk) + accumulated_text += new_text + + finished = chunk.finish_reason is not None + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, 'prompt_tokens', 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) + + if finished: + break return # For LLM, apply chat template and stream diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index fc558876..52cf4fae 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -1025,6 +1025,13 @@ def chat( images = [] videos = [] text_prompt = "" + + logger.info(f"MLLM.chat() called with {len(messages)} messages") + for i, msg in enumerate(messages): + logger.info(f" Message {i}: role={msg.get('role')}, content type={type(msg.get('content'))}") + if isinstance(msg.get('content'), list): + for j, item in enumerate(msg.get('content', [])): + logger.info(f" Item {j}: type={item.get('type') if isinstance(item, dict) else type(item)}") for msg in messages: role = msg.get("role", "user") @@ -1121,6 +1128,150 @@ def chat( completion_tokens=generation_tokens, ) + def stream_chat( + self, + messages: list[dict], + max_tokens: int = 256, + temperature: float = 0.7, + **kwargs, + ) -> Iterator[MLLMOutput]: + """ + Stream chat with OpenAI-compatible message format. + + Supports multimodal content in messages: + - {"type": "text", "text": "..."} + - {"type": "image_url", "image_url": {"url": "..."}} + - {"type": "image_url", "image_url": {"url": "data:image/...;base64,..."}} + + Args: + messages: List of chat messages (OpenAI format) + max_tokens: Maximum tokens to generate + temperature: Sampling temperature + **kwargs: Additional parameters + + Yields: + MLLMOutput with incremental text chunks + """ + if not self._loaded: + self.load() + + try: + from mlx_vlm import stream_generate + from mlx_vlm.prompt_utils import apply_chat_template + except ImportError: + # Fallback to non-streaming if stream_generate not available + output = self.chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + **kwargs, + ) + yield output + return + + # Extract text and images from messages (same logic as chat()) + images = [] + videos = [] + text_prompt = "" + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + + if isinstance(content, str): + if role == "user": + text_prompt = content + elif isinstance(content, list): + # OpenAI multimodal format + for item in content: + if isinstance(item, str): + text_prompt = item + continue + + # Convert Pydantic models to dicts + if hasattr(item, "model_dump"): + item = item.model_dump() + elif hasattr(item, "dict"): + item = item.dict() + + if isinstance(item, dict): + item_type = item.get("type", "") + + if item_type == "text": + text_prompt = item.get("text", "") + + elif item_type == "image_url": + img_url = item.get("image_url", {}) + if isinstance(img_url, str): + images.append(img_url) + else: + images.append(img_url.get("url", "")) + + elif item_type == "image": + images.append(item.get("image", item.get("url", ""))) + + elif item_type == "video": + videos.append(item.get("video", item.get("url", ""))) + + # Process images + all_images = [] + if images: + all_images.extend(self._prepare_images(images)) + + # Process videos + video_fps = kwargs.pop("video_fps", DEFAULT_FPS) + video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES) + for video_path in videos: + frames = self._prepare_video( + video_path, fps=video_fps, max_frames=video_max_frames + ) + all_images.extend(frames) + + # Apply chat template + try: + formatted_prompt = apply_chat_template( + self.processor, + self.config, + text_prompt, + num_images=len(all_images), + ) + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}, using raw prompt") + formatted_prompt = text_prompt + + # Stream generate tokens + accumulated_text = "" + token_count = 0 + + for chunk in stream_generate( + self.model, + self.processor, + formatted_prompt, + all_images if all_images else None, + max_tokens=max_tokens, + temp=temperature, + **kwargs, + ): + token_count += 1 + # chunk is a GenerationResult with .text attribute containing the new token + new_text = chunk.text if hasattr(chunk, 'text') else str(chunk) + accumulated_text += new_text + + yield MLLMOutput( + text=new_text, # Just the new token for streaming + finish_reason=None, + prompt_tokens=getattr(chunk, 'prompt_tokens', 0), + completion_tokens=token_count, + ) + + # Final yield with finish_reason + yield MLLMOutput( + text="", + finish_reason="stop", + prompt_tokens=getattr(chunk, 'prompt_tokens', 0) if 'chunk' in dir() else 0, + completion_tokens=token_count, + ) + def describe_image( self, image: str, diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index efb8f220..0fb31e39 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -633,8 +633,32 @@ async def create_chat_completion(request: ChatCompletionRequest): """ engine = get_engine() - # Extract text, images, and videos from messages - messages, images, videos = extract_multimodal_content(request.messages) + # For MLLM models, keep original messages with embedded images + # (MLLM.chat() extracts images from message content internally) + print(f"DEBUG: engine.is_mllm = {engine.is_mllm}") + if engine.is_mllm: + print("DEBUG: Taking MLLM path") + # Convert Pydantic messages to dicts preserving full content + messages = [] + for msg in request.messages: + msg_dict = msg.model_dump() if hasattr(msg, 'model_dump') else dict(msg) + messages.append(msg_dict) + images, videos = [], [] # MLLM extracts these from messages + # Debug: log message structure + import logging + _logger = logging.getLogger(__name__) + _logger.info(f"MLLM: Processing {len(messages)} messages") + for i, m in enumerate(messages): + c = m.get('content') + if isinstance(c, list): + _logger.info(f" Msg {i}: role={m.get('role')}, content is list with {len(c)} items") + for j, item in enumerate(c): + _logger.info(f" Item {j}: {item.get('type') if isinstance(item, dict) else type(item)}") + else: + _logger.info(f" Msg {i}: role={m.get('role')}, content is {type(c).__name__}") + else: + # For LLM, extract text, images, and videos separately + messages, images, videos = extract_multimodal_content(request.messages) has_media = bool(images or videos)