Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
"""

from .patches.gemma3n import apply_patches as _apply_patches_gemma3n
from .patches.gemma4 import apply_patches as _apply_patches_gemma4
from .patches.ernie_4_5 import apply_patches as _apply_patches_ernie_4_5
from .patches.qwen3_5 import apply_patches as _apply_patches_qwen3_5

_apply_patches_gemma3n()
_apply_patches_gemma4()
_apply_patches_ernie_4_5()
_apply_patches_qwen3_5()
53 changes: 2 additions & 51 deletions mlx_engine/model_kit/batched_model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,52 +28,6 @@
logger = logging.getLogger(__name__)


class _BatchedLogitsProcessorAdapter:
"""
Adapt mlx-engine logits processors to mlx-lm's batched generation contract.

BatchGenerator keeps prompt history in Python lists and passes the history
before the current input token has been appended. Our processors are written
against the sequential contract, so we restore that view here.

Remove this adapter after https://github.com/ml-explore/mlx-lm/pull/1115
is merged.
"""

def __init__(self, processors, initial_input_tokens):
self._processors = processors or []
self._current_input_tokens = (
mx.array(initial_input_tokens) if initial_input_tokens else None
)

def sampler(self, sampler):
if sampler is None:
return None

def wrapped(logprobs):
sampled = sampler(logprobs)
self._current_input_tokens = mx.array(sampled).reshape(-1)
return sampled

return wrapped

def logits_processors(self):
return [self._wrap_processor(processor) for processor in self._processors]

def _wrap_processor(self, processor):
def wrapped(tokens, logits):
if not isinstance(tokens, mx.array):
tokens = mx.array(tokens)
if (
self._current_input_tokens is not None
and self._current_input_tokens.size > 0
):
tokens = mx.concatenate([tokens, self._current_input_tokens])
return processor(tokens, logits)

return wrapped


def _prepare_prompt_cache_for_generation(
prompt_cache: LRUPromptCache, model_key: str, prompt_tokens: list[int]
):
Expand Down Expand Up @@ -374,18 +328,15 @@ def get_next_request(timeout=None):
cache, cached_prefix, rest = _prepare_prompt_cache_for_generation(
self._prompt_cache, current_model_key, request.prompt_tokens
)
adapter = _BatchedLogitsProcessorAdapter(
request.logits_processors, rest[-1:]
)

# Add to batch
(uid,) = batch_generator.insert(
[rest],
[request.max_tokens],
caches=[cache],
all_tokens=[cached_prefix],
samplers=[adapter.sampler(request.samplers)],
logits_processors=[adapter.logits_processors()],
samplers=[request.samplers],
logits_processors=[request.logits_processors],
)

# Track this request
Expand Down
2 changes: 2 additions & 0 deletions mlx_engine/model_kit/model_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn
from mlx_engine.model_kit.vision_add_ons.gemma3 import Gemma3VisionAddOn
from mlx_engine.model_kit.vision_add_ons.gemma4 import Gemma4VisionAddOn
from mlx_engine.model_kit.vision_add_ons.pixtral import PixtralVisionAddOn
from mlx_engine.model_kit.vision_add_ons.gemma3n import Gemma3nVisionAddOn
from mlx_engine.model_kit.vision_add_ons.mistral3 import Mistral3VisionAddOn
Expand Down Expand Up @@ -39,6 +40,7 @@ class ModelKit:
VISION_ADD_ON_MAP = {
"gemma3": Gemma3VisionAddOn,
"gemma3n": Gemma3nVisionAddOn,
"gemma4": Gemma4VisionAddOn,
"lfm2-vl": LFM2VisionAddOn,
"mistral3": Mistral3VisionAddOn,
"pixtral": PixtralVisionAddOn,
Expand Down
68 changes: 68 additions & 0 deletions mlx_engine/model_kit/patches/gemma4.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you confirm the logits are the same for both text-only and vision prompts before and after the patch? Consider adding tests to test_patched_models.py in line with the Qwen 3.5 heavy tests to verify.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I verified that the logits matched for text-only work, and that the logits are close-enough within a tolerance for image+text work. I added a test for each of these cases.

Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
Patch Gemma 4 so unified multimodal prompts reuse the full prompt's masked
per-layer-input token ids during chunked prefill.
"""

from typing import Any, Optional

import mlx.core as mx

from mlx_lm.models.gemma4_text import Gemma4TextModel

# Stable alias to the pristine mlx-lm class captured before apply_patches()
# mutates mlx_lm.models.gemma4_text in place.
OriginalGemma4TextModel = Gemma4TextModel


class PatchedGemma4TextModel(OriginalGemma4TextModel):
def __init__(self, config):
super().__init__(config)
self.prompt_per_layer_input_ids = None

def __call__(
self,
inputs: mx.array = None,
cache=None,
input_embeddings: Optional[mx.array] = None,
per_layer_inputs: Optional[mx.array] = None,
):
if (
per_layer_inputs is None
and input_embeddings is not None
and self.prompt_per_layer_input_ids is not None
):
prompt_per_layer_input_ids = self.prompt_per_layer_input_ids
if prompt_per_layer_input_ids.shape[1] != input_embeddings.shape[-2]:
start = self._cache_offset(cache)
target_len = input_embeddings.shape[-2]
prompt_per_layer_input_ids = prompt_per_layer_input_ids[
:, start : start + target_len
]
per_layer_inputs = self._get_per_layer_inputs(prompt_per_layer_input_ids)

return super().__call__(
inputs,
cache=cache,
input_embeddings=input_embeddings,
per_layer_inputs=per_layer_inputs,
)

@staticmethod
def _cache_offset(cache: Optional[Any]) -> int:
for layer_cache in cache or []:
offset = getattr(layer_cache, "offset", None)
if offset is None:
continue
if isinstance(offset, int):
return offset
if isinstance(offset, mx.array) and offset.ndim == 0:
return offset.item()
if isinstance(offset, mx.array):
return offset[0].item()
return 0


def apply_patches():
import mlx_lm.models.gemma4_text

mlx_lm.models.gemma4_text.Gemma4TextModel = PatchedGemma4TextModel
125 changes: 125 additions & 0 deletions mlx_engine/model_kit/vision_add_ons/gemma4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import logging
from pathlib import Path

from mlx import nn
import mlx.core as mx

from mlx_vlm.models.gemma4 import (
ModelConfig as Gemma4ModelConfig,
TextConfig as Gemma4TextConfig,
VisionConfig as Gemma4VisionConfig,
VisionModel as Gemma4VisionTower,
)
from mlx_vlm.models.gemma4.gemma4 import MultimodalEmbedder, masked_scatter
from mlx_vlm.utils import load_processor, sanitize_weights

from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn
from mlx_engine.model_kit.vision_add_ons.load_utils import (
load_and_filter_weights,
load_and_parse_config,
maybe_apply_quantization,
prepare_components,
)
from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import (
common_process_prompt_with_images,
)

logger = logging.getLogger(__name__)


class Gemma4VisionComponents(nn.Module):
def __init__(self, vision_tower: nn.Module, embed_vision: nn.Module):
super().__init__()
self.vision_tower = vision_tower
self.embed_vision = embed_vision


class Gemma4VisionAddOn(BaseVisionAddOn):
"""
Vision add-on for Gemma4 models.

Gemma4's text model still applies `embed_scale` when `input_embeddings` are
provided, so image features must be pre-divided by that scale before being
scattered into the mixed prompt embeddings.
"""

def __init__(self, model_path: Path):
super().__init__()

config, config_dict = load_and_parse_config(
model_path=model_path,
model_config_class=Gemma4ModelConfig,
vision_config_class=Gemma4VisionConfig,
text_config_class=Gemma4TextConfig,
)

components = Gemma4VisionComponents(
vision_tower=Gemma4VisionTower(config.vision_config),
embed_vision=MultimodalEmbedder(
embedding_dim=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
eps=config.vision_config.rms_norm_eps,
),
)

processor = load_processor(model_path=model_path, add_detokenizer=True)
vision_weights = load_and_filter_weights(model_path, components)
vision_weights = sanitize_weights(
components.vision_tower.__class__, vision_weights, config.vision_config
)
maybe_apply_quantization(components, config_dict, vision_weights)
prepare_components(components, vision_weights)

logger.info(f"Vision add-on loaded successfully from {model_path}")

self.vision_tower = components.vision_tower
self.embed_vision = components.embed_vision
self.config = config
self.processor = processor

def compute_embeddings(
self,
text_model: nn.Module,
prompt_tokens: mx.array,
images_b64: list[str],
max_size: tuple[int, int] | None,
) -> tuple[mx.array, mx.array]:
"""Compute input_ids and embeddings for text with images."""
input_ids, pixel_values, _, _ = common_process_prompt_with_images(
prompt_tokens=prompt_tokens,
images_b64=images_b64,
processor=self.processor,
config=self.config,
max_size=max_size,
)

language_model = text_model.language_model.model
input_embeddings = language_model.embed_tokens(input_ids)

image_features = self.vision_tower(pixel_values)
image_features = self.embed_vision(image_features).astype(
input_embeddings.dtype
)

# Gemma4TextModel applies embed_scale even when input_embeddings are provided.
scaled_image_features = image_features / language_model.embed_scale

image_mask = input_ids == self.config.image_token_id
image_mask_expanded = mx.expand_dims(image_mask, -1)
image_mask_expanded = mx.broadcast_to(
image_mask_expanded, input_embeddings.shape
)

final_inputs_embeds = masked_scatter(
input_embeddings, image_mask_expanded, scaled_image_features
)

if language_model.hidden_size_per_layer_input:
masked_input_ids = mx.where(
input_ids == self.config.image_token_id, 0, input_ids
)
language_model.prompt_per_layer_input_ids = mx.where(
input_ids == self.config.audio_token_id, 0, masked_input_ids
)

return input_ids.squeeze(0), final_inputs_embeds.squeeze(0)
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def common_process_prompt_with_images(

if hasattr(config, "image_token_index"):
image_token_index = config.image_token_index
elif hasattr(config, "image_token_id"):
image_token_index = config.image_token_id
elif hasattr(config.vision_config, "image_token_id"):
image_token_index = config.vision_config.image_token_id
else:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ lark==1.3.1
markdown-it-py==4.0.0
mdurl==0.1.2
mlx==0.31.1
mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@3257c3df172977c97fdfe3740e3a5edeb812e0b5
mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@dcbf6e33d135a1b7c6767ca0fe7ebbd23df814a7
mlx-metal==0.31.1
mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm.git@23e1dffd224488141a4f022b6d21d6a730f11507
nest-asyncio==1.6.0
Expand Down
Loading
Loading