Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,71 @@ For full documentation, see the [docs](docs/) directory:

See [benchmarks](docs/benchmarks/) for detailed results.

## Gemma 3 Support

This fork includes patches for Gemma 3 vision support. Gemma 3 is a multimodal model but requires detection as MLLM.

### Usage

```bash
# Start server with Gemma 3
vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000

# Verify it loaded as MLLM (not LLM)
curl http://localhost:8000/health
# Should show: "model_type": "mllm"
```

### Long Context Patch (mlx-vlm)

Gemma 3's default `sliding_window=1024` limits context to ~10K tokens on Apple Silicon (Metal GPU timeout at higher context). To enable longer context (up to ~50K tokens), patch mlx-vlm:

**Location:** `~/.../site-packages/mlx_vlm/models/gemma3/language.py`

Find the `make_cache` method and replace with:

```python
def make_cache(self):
import os
# Set GEMMA3_SLIDING_WINDOW=8192 for ~40K context
# Set GEMMA3_SLIDING_WINDOW=0 for ~50K context (full KVCache)
sliding_window = int(os.environ.get('GEMMA3_SLIDING_WINDOW', self.config.sliding_window))

caches = []
for i in range(self.config.num_hidden_layers):
if (
i % self.config.sliding_window_pattern
== self.config.sliding_window_pattern - 1
):
caches.append(KVCache())
elif sliding_window == 0:
caches.append(KVCache()) # Full context for all layers
else:
caches.append(RotatingKVCache(max_size=sliding_window, keep=0))
return caches
```

**Usage:**

```bash
# Default (~10K max context)
vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000

# Extended context (~40K max)
GEMMA3_SLIDING_WINDOW=8192 vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000

# Maximum context (~50K max)
GEMMA3_SLIDING_WINDOW=0 vllm-mlx serve mlx-community/gemma-3-27b-it-4bit --port 8000
```

**Benchmark Results (M4 Max 128GB):**

| Setting | Max Context | Memory |
|---------|-------------|--------|
| Default (1024) | ~10K tokens | ~16GB |
| `GEMMA3_SLIDING_WINDOW=8192` | ~40K tokens | ~25GB |
| `GEMMA3_SLIDING_WINDOW=0` | ~50K tokens | ~35GB |

## Contributing

We welcome contributions! See [Contributing Guide](docs/development/contributing.md) for details.
Expand Down
1 change: 1 addition & 0 deletions vllm_mlx/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def clean_output_text(text: str) -> str:
"llava", "LLaVA", # LLaVA models
"idefics", "Idefics", # Idefics models
"paligemma", "PaliGemma", # PaliGemma
"gemma-3", "gemma3", # Gemma 3 (multimodal)
"pixtral", "Pixtral", # Pixtral
"molmo", "Molmo", # Molmo
"phi3-vision", "phi-3-vision", # Phi-3 Vision
Expand Down
45 changes: 33 additions & 12 deletions vllm_mlx/engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,18 +305,39 @@ async def stream_chat(

# Build prompt using tokenizer
if self._is_mllm:
# For MLLM, fall back to non-streaming chat
output = await self.chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
tools=tools,
images=images,
videos=videos,
**kwargs,
)
yield output
# For MLLM, use stream_chat which yields tokens incrementally
accumulated_text = ""
token_count = 0

# Run stream_chat in thread pool since it's synchronous
def run_stream():
return list(self._model.stream_chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
**kwargs,
))

chunks = await asyncio.to_thread(run_stream)

for chunk in chunks:
token_count += 1
new_text = chunk.text if hasattr(chunk, 'text') else str(chunk)
accumulated_text += new_text

finished = chunk.finish_reason is not None

yield GenerationOutput(
text=accumulated_text,
new_text=new_text,
prompt_tokens=getattr(chunk, 'prompt_tokens', 0),
completion_tokens=token_count,
finished=finished,
finish_reason=chunk.finish_reason if finished else None,
)

if finished:
break
return

# For LLM, apply chat template and stream
Expand Down
151 changes: 151 additions & 0 deletions vllm_mlx/models/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,13 @@ def chat(
images = []
videos = []
text_prompt = ""

logger.info(f"MLLM.chat() called with {len(messages)} messages")
for i, msg in enumerate(messages):
logger.info(f" Message {i}: role={msg.get('role')}, content type={type(msg.get('content'))}")
if isinstance(msg.get('content'), list):
for j, item in enumerate(msg.get('content', [])):
logger.info(f" Item {j}: type={item.get('type') if isinstance(item, dict) else type(item)}")

for msg in messages:
role = msg.get("role", "user")
Expand Down Expand Up @@ -1121,6 +1128,150 @@ def chat(
completion_tokens=generation_tokens,
)

def stream_chat(
self,
messages: list[dict],
max_tokens: int = 256,
temperature: float = 0.7,
**kwargs,
) -> Iterator[MLLMOutput]:
"""
Stream chat with OpenAI-compatible message format.

Supports multimodal content in messages:
- {"type": "text", "text": "..."}
- {"type": "image_url", "image_url": {"url": "..."}}
- {"type": "image_url", "image_url": {"url": "data:image/...;base64,..."}}

Args:
messages: List of chat messages (OpenAI format)
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
**kwargs: Additional parameters

Yields:
MLLMOutput with incremental text chunks
"""
if not self._loaded:
self.load()

try:
from mlx_vlm import stream_generate
from mlx_vlm.prompt_utils import apply_chat_template
except ImportError:
# Fallback to non-streaming if stream_generate not available
output = self.chat(
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
**kwargs,
)
yield output
return

# Extract text and images from messages (same logic as chat())
images = []
videos = []
text_prompt = ""

for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")

if isinstance(content, str):
if role == "user":
text_prompt = content
elif isinstance(content, list):
# OpenAI multimodal format
for item in content:
if isinstance(item, str):
text_prompt = item
continue

# Convert Pydantic models to dicts
if hasattr(item, "model_dump"):
item = item.model_dump()
elif hasattr(item, "dict"):
item = item.dict()

if isinstance(item, dict):
item_type = item.get("type", "")

if item_type == "text":
text_prompt = item.get("text", "")

elif item_type == "image_url":
img_url = item.get("image_url", {})
if isinstance(img_url, str):
images.append(img_url)
else:
images.append(img_url.get("url", ""))

elif item_type == "image":
images.append(item.get("image", item.get("url", "")))

elif item_type == "video":
videos.append(item.get("video", item.get("url", "")))

# Process images
all_images = []
if images:
all_images.extend(self._prepare_images(images))

# Process videos
video_fps = kwargs.pop("video_fps", DEFAULT_FPS)
video_max_frames = kwargs.pop("video_max_frames", MAX_FRAMES)
for video_path in videos:
frames = self._prepare_video(
video_path, fps=video_fps, max_frames=video_max_frames
)
all_images.extend(frames)

# Apply chat template
try:
formatted_prompt = apply_chat_template(
self.processor,
self.config,
text_prompt,
num_images=len(all_images),
)
except Exception as e:
logger.warning(f"Failed to apply chat template: {e}, using raw prompt")
formatted_prompt = text_prompt

# Stream generate tokens
accumulated_text = ""
token_count = 0

for chunk in stream_generate(
self.model,
self.processor,
formatted_prompt,
all_images if all_images else None,
max_tokens=max_tokens,
temp=temperature,
**kwargs,
):
token_count += 1
# chunk is a GenerationResult with .text attribute containing the new token
new_text = chunk.text if hasattr(chunk, 'text') else str(chunk)
accumulated_text += new_text

yield MLLMOutput(
text=new_text, # Just the new token for streaming
finish_reason=None,
prompt_tokens=getattr(chunk, 'prompt_tokens', 0),
completion_tokens=token_count,
)

# Final yield with finish_reason
yield MLLMOutput(
text="",
finish_reason="stop",
prompt_tokens=getattr(chunk, 'prompt_tokens', 0) if 'chunk' in dir() else 0,
completion_tokens=token_count,
)

def describe_image(
self,
image: str,
Expand Down
28 changes: 26 additions & 2 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,32 @@ async def create_chat_completion(request: ChatCompletionRequest):
"""
engine = get_engine()

# Extract text, images, and videos from messages
messages, images, videos = extract_multimodal_content(request.messages)
# For MLLM models, keep original messages with embedded images
# (MLLM.chat() extracts images from message content internally)
print(f"DEBUG: engine.is_mllm = {engine.is_mllm}")
if engine.is_mllm:
print("DEBUG: Taking MLLM path")
# Convert Pydantic messages to dicts preserving full content
messages = []
for msg in request.messages:
msg_dict = msg.model_dump() if hasattr(msg, 'model_dump') else dict(msg)
messages.append(msg_dict)
images, videos = [], [] # MLLM extracts these from messages
# Debug: log message structure
import logging
_logger = logging.getLogger(__name__)
_logger.info(f"MLLM: Processing {len(messages)} messages")
for i, m in enumerate(messages):
c = m.get('content')
if isinstance(c, list):
_logger.info(f" Msg {i}: role={m.get('role')}, content is list with {len(c)} items")
for j, item in enumerate(c):
_logger.info(f" Item {j}: {item.get('type') if isinstance(item, dict) else type(item)}")
else:
_logger.info(f" Msg {i}: role={m.get('role')}, content is {type(c).__name__}")
else:
# For LLM, extract text, images, and videos separately
messages, images, videos = extract_multimodal_content(request.messages)

has_media = bool(images or videos)

Expand Down