From 3838c80ecae0c14816eae051b8db0ae1a50eb0da Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 7 Apr 2026 10:26:20 -0400 Subject: [PATCH 1/7] gemma4 unified --- mlx_engine/model_kit/model_kit.py | 2 + mlx_engine/model_kit/vision_add_ons/gemma4.py | 117 +++++++++++++++ .../process_prompt_with_images.py | 2 + tests/test_vision_models.py | 141 ++++++++++++++++++ 4 files changed, 262 insertions(+) create mode 100644 mlx_engine/model_kit/vision_add_ons/gemma4.py diff --git a/mlx_engine/model_kit/model_kit.py b/mlx_engine/model_kit/model_kit.py index 9f6de4b8..4cdba3d1 100644 --- a/mlx_engine/model_kit/model_kit.py +++ b/mlx_engine/model_kit/model_kit.py @@ -9,6 +9,7 @@ import logging from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn from mlx_engine.model_kit.vision_add_ons.gemma3 import Gemma3VisionAddOn +from mlx_engine.model_kit.vision_add_ons.gemma4 import Gemma4VisionAddOn from mlx_engine.model_kit.vision_add_ons.pixtral import PixtralVisionAddOn from mlx_engine.model_kit.vision_add_ons.gemma3n import Gemma3nVisionAddOn from mlx_engine.model_kit.vision_add_ons.mistral3 import Mistral3VisionAddOn @@ -39,6 +40,7 @@ class ModelKit: VISION_ADD_ON_MAP = { "gemma3": Gemma3VisionAddOn, "gemma3n": Gemma3nVisionAddOn, + "gemma4": Gemma4VisionAddOn, "lfm2-vl": LFM2VisionAddOn, "mistral3": Mistral3VisionAddOn, "pixtral": PixtralVisionAddOn, diff --git a/mlx_engine/model_kit/vision_add_ons/gemma4.py b/mlx_engine/model_kit/vision_add_ons/gemma4.py new file mode 100644 index 00000000..86e7d24d --- /dev/null +++ b/mlx_engine/model_kit/vision_add_ons/gemma4.py @@ -0,0 +1,117 @@ +import logging +from pathlib import Path + +from mlx import nn +import mlx.core as mx + +from mlx_vlm.models.gemma4 import ( + ModelConfig as Gemma4ModelConfig, + TextConfig as Gemma4TextConfig, + VisionConfig as Gemma4VisionConfig, + VisionModel as Gemma4VisionTower, +) +from mlx_vlm.models.gemma4.gemma4 import MultimodalEmbedder, masked_scatter +from mlx_vlm.utils import load_processor, sanitize_weights + +from mlx_engine.model_kit.vision_add_ons.base import BaseVisionAddOn +from mlx_engine.model_kit.vision_add_ons.load_utils import ( + load_and_filter_weights, + load_and_parse_config, + maybe_apply_quantization, + prepare_components, +) +from mlx_engine.model_kit.vision_add_ons.process_prompt_with_images import ( + common_process_prompt_with_images, +) + +logger = logging.getLogger(__name__) + + +class Gemma4VisionComponents(nn.Module): + def __init__(self, vision_tower: nn.Module, embed_vision: nn.Module): + super().__init__() + self.vision_tower = vision_tower + self.embed_vision = embed_vision + + +class Gemma4VisionAddOn(BaseVisionAddOn): + """ + Vision add-on for Gemma4 models. + + Gemma4's text model still applies `embed_scale` when `input_embeddings` are + provided, so image features must be pre-divided by that scale before being + scattered into the mixed prompt embeddings. + """ + + def __init__(self, model_path: Path): + super().__init__() + + config, config_dict = load_and_parse_config( + model_path=model_path, + model_config_class=Gemma4ModelConfig, + vision_config_class=Gemma4VisionConfig, + text_config_class=Gemma4TextConfig, + ) + + components = Gemma4VisionComponents( + vision_tower=Gemma4VisionTower(config.vision_config), + embed_vision=MultimodalEmbedder( + embedding_dim=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + eps=config.vision_config.rms_norm_eps, + ), + ) + + processor = load_processor(model_path=model_path, add_detokenizer=True) + vision_weights = load_and_filter_weights(model_path, components) + vision_weights = sanitize_weights( + components.vision_tower.__class__, vision_weights, config.vision_config + ) + maybe_apply_quantization(components, config_dict, vision_weights) + prepare_components(components, vision_weights) + + logger.info(f"Vision add-on loaded successfully from {model_path}") + + self.vision_tower = components.vision_tower + self.embed_vision = components.embed_vision + self.config = config + self.processor = processor + + def compute_embeddings( + self, + text_model: nn.Module, + prompt_tokens: mx.array, + images_b64: list[str], + max_size: tuple[int, int] | None, + ) -> tuple[mx.array, mx.array]: + """Compute input_ids and embeddings for text with images.""" + input_ids, pixel_values, _, _ = common_process_prompt_with_images( + prompt_tokens=prompt_tokens, + images_b64=images_b64, + processor=self.processor, + config=self.config, + max_size=max_size, + ) + + input_embeddings = text_model.language_model.model.embed_tokens(input_ids) + + image_features = self.vision_tower(pixel_values) + image_features = self.embed_vision(image_features).astype( + input_embeddings.dtype + ) + + # Gemma4TextModel applies embed_scale even when input_embeddings are provided. + scaled_image_features = ( + image_features / text_model.language_model.model.embed_scale + ) + + image_mask = input_ids == self.config.image_token_id + image_mask_expanded = mx.expand_dims(image_mask, -1) + image_mask_expanded = mx.broadcast_to( + image_mask_expanded, input_embeddings.shape + ) + + final_inputs_embeds = masked_scatter( + input_embeddings, image_mask_expanded, scaled_image_features + ) + return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) diff --git a/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py b/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py index 0dde9705..24183576 100644 --- a/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py +++ b/mlx_engine/model_kit/vision_add_ons/process_prompt_with_images.py @@ -49,6 +49,8 @@ def common_process_prompt_with_images( if hasattr(config, "image_token_index"): image_token_index = config.image_token_index + elif hasattr(config, "image_token_id"): + image_token_index = config.image_token_id elif hasattr(config.vision_config, "image_token_id"): image_token_index = config.vision_config.image_token_id else: diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py index 3ac81774..0da85308 100644 --- a/tests/test_vision_models.py +++ b/tests/test_vision_models.py @@ -13,6 +13,7 @@ model_load_and_tokenize_prompt, ) from textwrap import dedent +from transformers import AutoProcessor MAX_IMAGE_SIZE = (1024, 1024) @@ -111,6 +112,24 @@ def toucan_test_runner( return generated_text + def build_gemma4_prompt( + self, + model_path: Path, + prompt: str, + *, + text_only: bool = False, + ) -> str: + processor = AutoProcessor.from_pretrained(model_path) + content = [{"type": "text", "text": prompt}] + if not text_only: + content.insert(0, {"type": "image", "base64": self.toucan_image_b64}) + conversation = [{"role": "user", "content": content}] + return processor.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=True, + ) + ### MODEL-SPECIFIC TESTS ### def test_llama_3_2_vision_instruct(self): """Test Llama 3.2 11B Vision Instruct model""" @@ -660,6 +679,128 @@ def test_gemma3n_text_only(self): "lmstudio-community/gemma-3n-E2B-it-MLX-4bit", prompt, text_only=True ) + @pytest.mark.heavy + def test_gemma4_vision(self): + """Test Gemma 4 model via the unified multimodal path.""" + model_name = "lmstudio-community/gemma-4-E2B-it-MLX-4bit" + model_path = model_getter(model_name) + prompt = self.build_gemma4_prompt(model_path, self.description_prompt) + self.toucan_test_runner( + model_name, + prompt, + supplemental_accept_phrases=["bird"], + ) + + @pytest.mark.heavy + def test_gemma4_text_only(self): + """Test Gemma 4 model with text only via the unified multimodal path.""" + model_name = "lmstudio-community/gemma-4-E2B-it-MLX-4bit" + model_path = model_getter(model_name) + prompt = self.build_gemma4_prompt( + model_path, + self.text_only_prompt, + text_only=True, + ) + self.toucan_test_runner(model_name, prompt, text_only=True) + + @pytest.mark.heavy + def test_gemma4_text_only_generation_caching(self): + """Ensure unified-arch Gemma 4 reuses cross-prompt cache for text-only turns.""" + model_name = "lmstudio-community/gemma-4-E2B-it-MLX-4bit" + model_path = model_getter(model_name) + processor = AutoProcessor.from_pretrained(model_path) + model_kit = load_model( + model_path=model_path, + max_kv_size=MAX_KV_CACHE_SIZE, + max_seq_nums=1, + prefill_step_size=CACHING_TEST_PREFILL_STEP_SIZE, + ) + + def render_prompt(conversation): + return processor.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=True, + ) + + def generate_text(prompt): + prompt_tokens = tokenize(model_kit, prompt) + reporter = RecordingReporter() + + generated_text = "" + for result in create_generator( + model_kit=model_kit, + prompt_tokens=prompt_tokens, + seed=0, + temp=0.0, + max_tokens=1000, + repetition_penalty=1.01, # to enable this code path + prompt_progress_reporter=reporter, + ): + generated_text += result.text + print(result.text, end="", flush=True) + if result.stop_condition: + break + print("\n", flush=True) + return generated_text, reporter + + first_conversation = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Tell me a 500-word story about a traveler, and make the main character's name distinctive.", + } + ], + } + ] + + prompt = render_prompt(first_conversation) + generated_text, reporter = generate_text(prompt) + first_update_events = [ + event for event in reporter.events if event["type"] == "update" + ] + assert len(first_update_events) > 0 + begin_event = reporter.events[0] + assert begin_event["type"] == "begin" + assert begin_event["cached_tokens"] == 0 + assert len(generated_text) > 0 + + second_conversation = first_conversation + [ + { + "role": "assistant", + "content": [{"type": "text", "text": generated_text}], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What was the main character's name? Answer with only the name.", + } + ], + }, + ] + prompt = render_prompt(second_conversation) + num_tokens = len(model_kit.tokenize(prompt)) + # Without caching, the follow-up prompt is long enough to require multi-chunk prefill. + assert num_tokens > CACHING_TEST_PREFILL_STEP_SIZE + + follow_up_text, reporter = generate_text(prompt) + second_update_events = [ + event for event in reporter.events if event["type"] == "update" + ] + begin_event = reporter.events[0] + assert begin_event["type"] == "begin" + assert begin_event["cached_tokens"] > 0 + assert len(second_update_events) <= len(first_update_events) + assert ( + second_update_events[-1]["prefill_tokens_processed"] + < first_update_events[-1]["prefill_tokens_processed"] + ) + assert len(follow_up_text.strip()) > 0 + # TODO(will): Parameterize and de-dup def test_gemma3n_text_only_generation_caching(self): """Ensure that text only prompts with vlms take full advantage of caching generated tokens""" From 12750c7acbc2d04e41ab490e841d52085f3819da Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 7 Apr 2026 11:24:08 -0400 Subject: [PATCH 2/7] wire in per layers inputs --- mlx_engine/model_kit/__init__.py | 2 + mlx_engine/model_kit/batched_model_kit.py | 53 +----- mlx_engine/model_kit/patches/gemma4.py | 138 ++++++++++++++++ mlx_engine/model_kit/vision_add_ons/gemma4.py | 38 +++++ requirements.txt | 2 +- tests/test_gemma4_patch.py | 151 ++++++++++++++++++ tests/test_gemma4_vision_addon.py | 54 +++++++ 7 files changed, 386 insertions(+), 52 deletions(-) create mode 100644 mlx_engine/model_kit/patches/gemma4.py create mode 100644 tests/test_gemma4_patch.py create mode 100644 tests/test_gemma4_vision_addon.py diff --git a/mlx_engine/model_kit/__init__.py b/mlx_engine/model_kit/__init__.py index 90da3d13..ac21310b 100644 --- a/mlx_engine/model_kit/__init__.py +++ b/mlx_engine/model_kit/__init__.py @@ -6,9 +6,11 @@ """ from .patches.gemma3n import apply_patches as _apply_patches_gemma3n +from .patches.gemma4 import apply_patches as _apply_patches_gemma4 from .patches.ernie_4_5 import apply_patches as _apply_patches_ernie_4_5 from .patches.qwen3_5 import apply_patches as _apply_patches_qwen3_5 _apply_patches_gemma3n() +_apply_patches_gemma4() _apply_patches_ernie_4_5() _apply_patches_qwen3_5() diff --git a/mlx_engine/model_kit/batched_model_kit.py b/mlx_engine/model_kit/batched_model_kit.py index 6344d256..4dcd21eb 100644 --- a/mlx_engine/model_kit/batched_model_kit.py +++ b/mlx_engine/model_kit/batched_model_kit.py @@ -28,52 +28,6 @@ logger = logging.getLogger(__name__) -class _BatchedLogitsProcessorAdapter: - """ - Adapt mlx-engine logits processors to mlx-lm's batched generation contract. - - BatchGenerator keeps prompt history in Python lists and passes the history - before the current input token has been appended. Our processors are written - against the sequential contract, so we restore that view here. - - Remove this adapter after https://github.com/ml-explore/mlx-lm/pull/1115 - is merged. - """ - - def __init__(self, processors, initial_input_tokens): - self._processors = processors or [] - self._current_input_tokens = ( - mx.array(initial_input_tokens) if initial_input_tokens else None - ) - - def sampler(self, sampler): - if sampler is None: - return None - - def wrapped(logprobs): - sampled = sampler(logprobs) - self._current_input_tokens = mx.array(sampled).reshape(-1) - return sampled - - return wrapped - - def logits_processors(self): - return [self._wrap_processor(processor) for processor in self._processors] - - def _wrap_processor(self, processor): - def wrapped(tokens, logits): - if not isinstance(tokens, mx.array): - tokens = mx.array(tokens) - if ( - self._current_input_tokens is not None - and self._current_input_tokens.size > 0 - ): - tokens = mx.concatenate([tokens, self._current_input_tokens]) - return processor(tokens, logits) - - return wrapped - - def _prepare_prompt_cache_for_generation( prompt_cache: LRUPromptCache, model_key: str, prompt_tokens: list[int] ): @@ -374,9 +328,6 @@ def get_next_request(timeout=None): cache, cached_prefix, rest = _prepare_prompt_cache_for_generation( self._prompt_cache, current_model_key, request.prompt_tokens ) - adapter = _BatchedLogitsProcessorAdapter( - request.logits_processors, rest[-1:] - ) # Add to batch (uid,) = batch_generator.insert( @@ -384,8 +335,8 @@ def get_next_request(timeout=None): [request.max_tokens], caches=[cache], all_tokens=[cached_prefix], - samplers=[adapter.sampler(request.samplers)], - logits_processors=[adapter.logits_processors()], + samplers=[request.samplers], + logits_processors=[request.logits_processors], ) # Track this request diff --git a/mlx_engine/model_kit/patches/gemma4.py b/mlx_engine/model_kit/patches/gemma4.py new file mode 100644 index 00000000..057ffa68 --- /dev/null +++ b/mlx_engine/model_kit/patches/gemma4.py @@ -0,0 +1,138 @@ +""" +Gemma 4 compatibility patch for unified multimodal prompts. + +The upstream mlx-lm Gemma4TextModel reconstructs per-layer inputs from the +current prompt chunk when the caller does not pass them explicitly. That is +fine for text-only prompts, but incorrect for unified multimodal prompts +because Gemma 4 image/audio token ids fall inside the per-layer-input vocab +and must be masked to 0 before lookup, matching mlx-vlm. + +This patch lets mlx-engine stash the full prompt's masked per-layer inputs on +the text model and have later prefill chunks slice the correct window from +that stored state. +""" + +from typing import Any, Optional + +import mlx.core as mx + +from mlx_lm.models.gemma4_text import Gemma4TextModel + +OriginalGemma4TextModel = Gemma4TextModel + + +class PatchedGemma4TextModel(Gemma4TextModel): + def __init__(self, config): + super().__init__(config) + self.prompt_per_layer_inputs = None + + def reset_prompt_per_layer_input_state(self) -> None: + self.prompt_per_layer_inputs = None + + def set_prompt_per_layer_inputs( + self, + prompt_per_layer_inputs: Optional[mx.array], + ) -> None: + if prompt_per_layer_inputs is None: + self.prompt_per_layer_inputs = None + return + if prompt_per_layer_inputs.ndim == 3: + prompt_per_layer_inputs = prompt_per_layer_inputs[None] + self.prompt_per_layer_inputs = prompt_per_layer_inputs + + def __call__( + self, + inputs: mx.array = None, + cache=None, + input_embeddings: Optional[mx.array] = None, + per_layer_inputs: Optional[mx.array] = None, + ): + effective_per_layer_inputs = per_layer_inputs + if ( + effective_per_layer_inputs is None + and input_embeddings is not None + and self.prompt_per_layer_inputs is not None + ): + effective_per_layer_inputs = self.prompt_per_layer_inputs + + if effective_per_layer_inputs is not None: + effective_per_layer_inputs = self._slice_per_layer_inputs( + per_layer_inputs=effective_per_layer_inputs, + cache=cache, + batch_size=( + input_embeddings.shape[0] + if input_embeddings is not None + else inputs.shape[0] + ), + target_len=( + input_embeddings.shape[1] + if input_embeddings is not None + else inputs.shape[1] + ), + ) + + return super().__call__( + inputs, + cache=cache, + input_embeddings=input_embeddings, + per_layer_inputs=effective_per_layer_inputs, + ) + + def _slice_per_layer_inputs( + self, + *, + per_layer_inputs: mx.array, + cache: Optional[Any], + batch_size: int, + target_len: int, + ) -> mx.array: + if per_layer_inputs.ndim == 3: + per_layer_inputs = per_layer_inputs[None] + if per_layer_inputs.ndim != 4: + raise ValueError( + "Gemma 4 prompt per-layer inputs must have shape " + "(batch, seq, num_layers, hidden)." + ) + if per_layer_inputs.shape[0] != batch_size: + if per_layer_inputs.shape[0] == 1: + per_layer_inputs = mx.broadcast_to( + per_layer_inputs, + (batch_size,) + tuple(per_layer_inputs.shape[1:]), + ) + else: + raise ValueError( + "Gemma 4 prompt per-layer inputs batch dimension does not " + "match the current input batch size." + ) + if per_layer_inputs.shape[1] < target_len: + raise ValueError( + "Gemma 4 prompt per-layer inputs are shorter than the current " + "input chunk." + ) + if per_layer_inputs.shape[1] == target_len: + return per_layer_inputs + + cache_offset = self._cache_offset(cache) + max_start = max(per_layer_inputs.shape[1] - target_len, 0) + start = min(cache_offset, max_start) + return per_layer_inputs[:, start : start + target_len] + + @staticmethod + def _cache_offset(cache: Optional[Any]) -> int: + for layer_cache in cache or []: + if layer_cache is None or not hasattr(layer_cache, "offset"): + continue + offset = layer_cache.offset + if isinstance(offset, int): + return offset + if isinstance(offset, mx.array) and offset.ndim == 0: + return offset.item() + if isinstance(offset, mx.array): + return offset[0].item() + return 0 + + +def apply_patches(): + import mlx_lm.models.gemma4_text + + mlx_lm.models.gemma4_text.Gemma4TextModel = PatchedGemma4TextModel diff --git a/mlx_engine/model_kit/vision_add_ons/gemma4.py b/mlx_engine/model_kit/vision_add_ons/gemma4.py index 86e7d24d..b66943bf 100644 --- a/mlx_engine/model_kit/vision_add_ons/gemma4.py +++ b/mlx_engine/model_kit/vision_add_ons/gemma4.py @@ -27,6 +27,26 @@ logger = logging.getLogger(__name__) +def _compute_prompt_per_layer_inputs( + language_model: nn.Module, + input_ids: mx.array, + image_token_id: int, + audio_token_id: int | None, +) -> mx.array | None: + if not getattr(language_model, "hidden_size_per_layer_input", 0): + return None + + image_mask_ids = input_ids == image_token_id + audio_mask_ids = ( + input_ids == audio_token_id + if audio_token_id is not None + else mx.zeros_like(image_mask_ids) + ) + text_mask = ~(image_mask_ids | audio_mask_ids) + per_layer_inputs_tokens = mx.where(text_mask, input_ids, mx.zeros_like(input_ids)) + return language_model._get_per_layer_inputs(per_layer_inputs_tokens) + + class Gemma4VisionComponents(nn.Module): def __init__(self, vision_tower: nn.Module, embed_vision: nn.Module): super().__init__() @@ -77,6 +97,12 @@ def __init__(self, model_path: Path): self.config = config self.processor = processor + def clear_prediction_state(self, text_model: nn.Module) -> None: + language_model = text_model.language_model.model + reset = getattr(language_model, "reset_prompt_per_layer_input_state", None) + if reset is not None: + reset() + def compute_embeddings( self, text_model: nn.Module, @@ -114,4 +140,16 @@ def compute_embeddings( final_inputs_embeds = masked_scatter( input_embeddings, image_mask_expanded, scaled_image_features ) + + prompt_per_layer_inputs = _compute_prompt_per_layer_inputs( + text_model.language_model.model, + input_ids, + self.config.image_token_id, + getattr(self.config, "audio_token_id", None), + ) + if prompt_per_layer_inputs is not None: + text_model.language_model.model.set_prompt_per_layer_inputs( + prompt_per_layer_inputs + ) + return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) diff --git a/requirements.txt b/requirements.txt index 487b85c9..4bf6bc16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,7 +25,7 @@ lark==1.3.1 markdown-it-py==4.0.0 mdurl==0.1.2 mlx==0.31.1 -mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@3257c3df172977c97fdfe3740e3a5edeb812e0b5 +mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@dcbf6e33d135a1b7c6767ca0fe7ebbd23df814a7 mlx-metal==0.31.1 mlx-vlm @ git+https://github.com/Blaizzy/mlx-vlm.git@23e1dffd224488141a4f022b6d21d6a730f11507 nest-asyncio==1.6.0 diff --git a/tests/test_gemma4_patch.py b/tests/test_gemma4_patch.py new file mode 100644 index 00000000..b5d6a241 --- /dev/null +++ b/tests/test_gemma4_patch.py @@ -0,0 +1,151 @@ +import mlx.core as mx +from mlx_lm.generate import generate_step + +from mlx_engine.model_kit.patches.gemma4 import ( + OriginalGemma4TextModel as _OrigTextModel, + PatchedGemma4TextModel, + apply_patches, +) + +apply_patches() + +from mlx_lm.models.gemma4_text import Model, ModelArgs # noqa: E402 + +GEMMA4_TEXT_CONFIG = { + "model_type": "gemma4_text", + "hidden_size": 32, + "num_hidden_layers": 2, + "intermediate_size": 64, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "num_global_key_value_heads": 1, + "head_dim": 16, + "global_head_dim": 16, + "sliding_window": 8, + "sliding_window_pattern": 1, + "layer_types": ["full_attention", "full_attention"], + "hidden_size_per_layer_input": 8, + "vocab_size": 32, + "vocab_size_per_layer_input": 32, + "num_kv_shared_layers": 0, +} + + +def make_model(**text_config_overrides): + args = ModelArgs.from_dict( + { + **GEMMA4_TEXT_CONFIG, + **text_config_overrides, + } + ) + return Model(args) + + +def test_gemma4_text_only_patched_matches_unpatched(): + import mlx_lm.models.gemma4_text as mod + + patched_text_model = mod.Gemma4TextModel + assert patched_text_model is PatchedGemma4TextModel + assert _OrigTextModel is not PatchedGemma4TextModel + + mx.random.seed(0) + model = make_model() + tokens = mx.array([[1, 2, 3, 4, 5, 6]], dtype=mx.int32) + patched_logits = model(tokens) + mx.eval(patched_logits) + patched_logits = mx.array(patched_logits) + + mod.Gemma4TextModel = _OrigTextModel + try: + mx.random.seed(0) + unpatched_model = make_model() + finally: + mod.Gemma4TextModel = patched_text_model + + unpatched_logits = unpatched_model(tokens) + mx.eval(unpatched_logits) + max_diff = mx.max(mx.abs(patched_logits - unpatched_logits)).item() + assert mx.allclose(patched_logits, unpatched_logits, atol=1e-4).item(), ( + f"Patched Gemma 4 text-only logits diverged from unpatched mlx-lm " + f"(max diff {max_diff:.6f})." + ) + + +def test_gemma4_prompt_per_layer_inputs_chunked_prefill_matches_unchunked(): + mx.random.seed(0) + + model = make_model() + text_model = model.model + prompt = mx.array([1, 2, 3, 4, 5, 6], dtype=mx.int32) + prompt_embeddings = text_model.embed_tokens(prompt[None]).squeeze(0) + prompt_per_layer_inputs = mx.full( + ( + 1, + prompt.shape[0], + text_model.config.num_hidden_layers, + text_model.hidden_size_per_layer_input, + ), + 3.0, + dtype=prompt_embeddings.dtype, + ) + + def first_step_logprobs(prefill_step_size: int) -> mx.array: + text_model.set_prompt_per_layer_inputs(prompt_per_layer_inputs) + step = generate_step( + prompt, + model, + max_tokens=1, + sampler=lambda x: mx.argmax(x, axis=-1), + input_embeddings=prompt_embeddings, + prefill_step_size=prefill_step_size, + ) + _, logprobs = next(step) + step.close() + mx.eval(logprobs) + return logprobs + + reference_logprobs = first_step_logprobs(prefill_step_size=16) + chunked_logprobs = first_step_logprobs(prefill_step_size=2) + + max_diff = mx.max(mx.abs(reference_logprobs - chunked_logprobs)).item() + assert mx.allclose(reference_logprobs, chunked_logprobs, atol=1e-4).item(), ( + f"Chunked Gemma 4 prompt per-layer-input prefill mismatch " + f"(max diff {max_diff:.6f})." + ) + + +def test_gemma4_prompt_state_matches_explicit_per_layer_inputs(): + mx.random.seed(0) + + model = make_model() + text_model = model.model + prompt = mx.array([[1, 2, 3, 4]], dtype=mx.int32) + prompt_embeddings = text_model.embed_tokens(prompt) + prompt_per_layer_inputs = mx.full( + ( + 1, + prompt.shape[1], + text_model.config.num_hidden_layers, + text_model.hidden_size_per_layer_input, + ), + 2.0, + dtype=prompt_embeddings.dtype, + ) + + explicit_logits = model( + prompt, + input_embeddings=prompt_embeddings, + per_layer_inputs=prompt_per_layer_inputs, + ) + text_model.set_prompt_per_layer_inputs(prompt_per_layer_inputs) + stateful_logits = model( + prompt, + input_embeddings=prompt_embeddings, + ) + mx.eval(explicit_logits, stateful_logits) + + max_diff = mx.max(mx.abs(explicit_logits - stateful_logits)).item() + assert mx.allclose(explicit_logits, stateful_logits, atol=1e-4).item(), ( + f"Stored Gemma 4 prompt per-layer inputs diverged from explicit inputs " + f"(max diff {max_diff:.6f})." + ) diff --git a/tests/test_gemma4_vision_addon.py b/tests/test_gemma4_vision_addon.py new file mode 100644 index 00000000..ff9a0a8c --- /dev/null +++ b/tests/test_gemma4_vision_addon.py @@ -0,0 +1,54 @@ +from types import SimpleNamespace + +import mlx.core as mx + +from mlx_engine.model_kit.vision_add_ons.gemma4 import _compute_prompt_per_layer_inputs + + +class _FakeGemma4TextModel: + def __init__(self, hidden_size_per_layer_input: int = 1): + self.hidden_size_per_layer_input = hidden_size_per_layer_input + self.seen_input_ids = None + + def _get_per_layer_inputs(self, input_ids: mx.array) -> mx.array: + self.seen_input_ids = input_ids + return input_ids[..., None, None] + + +class _FakeGemma4Model: + def __init__(self, hidden_size_per_layer_input: int = 1): + self.language_model = SimpleNamespace( + model=_FakeGemma4TextModel(hidden_size_per_layer_input) + ) + + +def test_compute_prompt_per_layer_inputs_masks_special_tokens(): + text_model = _FakeGemma4Model() + input_ids = mx.array([[1, 99, 2, 100, 3]], dtype=mx.int32) + + prompt_per_layer_inputs = _compute_prompt_per_layer_inputs( + text_model.language_model.model, + input_ids, + image_token_id=99, + audio_token_id=100, + ) + + assert text_model.language_model.model.seen_input_ids.tolist() == [[1, 0, 2, 0, 3]] + assert prompt_per_layer_inputs.shape == (1, 5, 1, 1) + assert prompt_per_layer_inputs[0, :, 0, 0].tolist() == [1, 0, 2, 0, 3] + + +def test_compute_prompt_per_layer_inputs_skips_models_without_per_layer_inputs(): + text_model = _FakeGemma4Model(hidden_size_per_layer_input=0) + input_ids = mx.array([[1, 99, 2]], dtype=mx.int32) + + assert ( + _compute_prompt_per_layer_inputs( + text_model.language_model.model, + input_ids, + image_token_id=99, + audio_token_id=100, + ) + is None + ) + assert text_model.language_model.model.seen_input_ids is None From bdde0afed7b4b6d605db53676bc03514454d95af Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 7 Apr 2026 11:51:52 -0400 Subject: [PATCH 3/7] refactor --- mlx_engine/model_kit/patches/gemma4.py | 106 ++---------- mlx_engine/model_kit/vision_add_ons/gemma4.py | 51 ++---- tests/test_gemma4_patch.py | 151 ------------------ tests/test_gemma4_vision_addon.py | 54 ------- 4 files changed, 32 insertions(+), 330 deletions(-) delete mode 100644 tests/test_gemma4_patch.py delete mode 100644 tests/test_gemma4_vision_addon.py diff --git a/mlx_engine/model_kit/patches/gemma4.py b/mlx_engine/model_kit/patches/gemma4.py index 057ffa68..67e8792c 100644 --- a/mlx_engine/model_kit/patches/gemma4.py +++ b/mlx_engine/model_kit/patches/gemma4.py @@ -1,15 +1,6 @@ """ -Gemma 4 compatibility patch for unified multimodal prompts. - -The upstream mlx-lm Gemma4TextModel reconstructs per-layer inputs from the -current prompt chunk when the caller does not pass them explicitly. That is -fine for text-only prompts, but incorrect for unified multimodal prompts -because Gemma 4 image/audio token ids fall inside the per-layer-input vocab -and must be masked to 0 before lookup, matching mlx-vlm. - -This patch lets mlx-engine stash the full prompt's masked per-layer inputs on -the text model and have later prefill chunks slice the correct window from -that stored state. +Patch Gemma 4 so unified multimodal prompts reuse the full prompt's masked +per-layer-input token ids during chunked prefill. """ from typing import Any, Optional @@ -18,27 +9,11 @@ from mlx_lm.models.gemma4_text import Gemma4TextModel -OriginalGemma4TextModel = Gemma4TextModel - class PatchedGemma4TextModel(Gemma4TextModel): def __init__(self, config): super().__init__(config) - self.prompt_per_layer_inputs = None - - def reset_prompt_per_layer_input_state(self) -> None: - self.prompt_per_layer_inputs = None - - def set_prompt_per_layer_inputs( - self, - prompt_per_layer_inputs: Optional[mx.array], - ) -> None: - if prompt_per_layer_inputs is None: - self.prompt_per_layer_inputs = None - return - if prompt_per_layer_inputs.ndim == 3: - prompt_per_layer_inputs = prompt_per_layer_inputs[None] - self.prompt_per_layer_inputs = prompt_per_layer_inputs + self.prompt_per_layer_input_ids = None def __call__( self, @@ -47,82 +22,33 @@ def __call__( input_embeddings: Optional[mx.array] = None, per_layer_inputs: Optional[mx.array] = None, ): - effective_per_layer_inputs = per_layer_inputs if ( - effective_per_layer_inputs is None + per_layer_inputs is None and input_embeddings is not None - and self.prompt_per_layer_inputs is not None + and self.prompt_per_layer_input_ids is not None ): - effective_per_layer_inputs = self.prompt_per_layer_inputs - - if effective_per_layer_inputs is not None: - effective_per_layer_inputs = self._slice_per_layer_inputs( - per_layer_inputs=effective_per_layer_inputs, - cache=cache, - batch_size=( - input_embeddings.shape[0] - if input_embeddings is not None - else inputs.shape[0] - ), - target_len=( - input_embeddings.shape[1] - if input_embeddings is not None - else inputs.shape[1] - ), - ) + prompt_per_layer_input_ids = self.prompt_per_layer_input_ids + if prompt_per_layer_input_ids.shape[1] != input_embeddings.shape[-2]: + start = self._cache_offset(cache) + target_len = input_embeddings.shape[-2] + prompt_per_layer_input_ids = prompt_per_layer_input_ids[ + :, start : start + target_len + ] + per_layer_inputs = self._get_per_layer_inputs(prompt_per_layer_input_ids) return super().__call__( inputs, cache=cache, input_embeddings=input_embeddings, - per_layer_inputs=effective_per_layer_inputs, + per_layer_inputs=per_layer_inputs, ) - def _slice_per_layer_inputs( - self, - *, - per_layer_inputs: mx.array, - cache: Optional[Any], - batch_size: int, - target_len: int, - ) -> mx.array: - if per_layer_inputs.ndim == 3: - per_layer_inputs = per_layer_inputs[None] - if per_layer_inputs.ndim != 4: - raise ValueError( - "Gemma 4 prompt per-layer inputs must have shape " - "(batch, seq, num_layers, hidden)." - ) - if per_layer_inputs.shape[0] != batch_size: - if per_layer_inputs.shape[0] == 1: - per_layer_inputs = mx.broadcast_to( - per_layer_inputs, - (batch_size,) + tuple(per_layer_inputs.shape[1:]), - ) - else: - raise ValueError( - "Gemma 4 prompt per-layer inputs batch dimension does not " - "match the current input batch size." - ) - if per_layer_inputs.shape[1] < target_len: - raise ValueError( - "Gemma 4 prompt per-layer inputs are shorter than the current " - "input chunk." - ) - if per_layer_inputs.shape[1] == target_len: - return per_layer_inputs - - cache_offset = self._cache_offset(cache) - max_start = max(per_layer_inputs.shape[1] - target_len, 0) - start = min(cache_offset, max_start) - return per_layer_inputs[:, start : start + target_len] - @staticmethod def _cache_offset(cache: Optional[Any]) -> int: for layer_cache in cache or []: - if layer_cache is None or not hasattr(layer_cache, "offset"): + offset = getattr(layer_cache, "offset", None) + if offset is None: continue - offset = layer_cache.offset if isinstance(offset, int): return offset if isinstance(offset, mx.array) and offset.ndim == 0: diff --git a/mlx_engine/model_kit/vision_add_ons/gemma4.py b/mlx_engine/model_kit/vision_add_ons/gemma4.py index b66943bf..9ca057dd 100644 --- a/mlx_engine/model_kit/vision_add_ons/gemma4.py +++ b/mlx_engine/model_kit/vision_add_ons/gemma4.py @@ -27,24 +27,14 @@ logger = logging.getLogger(__name__) -def _compute_prompt_per_layer_inputs( - language_model: nn.Module, +def _mask_prompt_per_layer_input_ids( input_ids: mx.array, image_token_id: int, - audio_token_id: int | None, -) -> mx.array | None: - if not getattr(language_model, "hidden_size_per_layer_input", 0): - return None - - image_mask_ids = input_ids == image_token_id - audio_mask_ids = ( - input_ids == audio_token_id - if audio_token_id is not None - else mx.zeros_like(image_mask_ids) - ) - text_mask = ~(image_mask_ids | audio_mask_ids) - per_layer_inputs_tokens = mx.where(text_mask, input_ids, mx.zeros_like(input_ids)) - return language_model._get_per_layer_inputs(per_layer_inputs_tokens) + audio_token_id: int, +) -> mx.array: + masked_input_ids = mx.where(input_ids == image_token_id, 0, input_ids) + masked_input_ids = mx.where(input_ids == audio_token_id, 0, masked_input_ids) + return masked_input_ids class Gemma4VisionComponents(nn.Module): @@ -97,12 +87,6 @@ def __init__(self, model_path: Path): self.config = config self.processor = processor - def clear_prediction_state(self, text_model: nn.Module) -> None: - language_model = text_model.language_model.model - reset = getattr(language_model, "reset_prompt_per_layer_input_state", None) - if reset is not None: - reset() - def compute_embeddings( self, text_model: nn.Module, @@ -119,7 +103,8 @@ def compute_embeddings( max_size=max_size, ) - input_embeddings = text_model.language_model.model.embed_tokens(input_ids) + language_model = text_model.language_model.model + input_embeddings = language_model.embed_tokens(input_ids) image_features = self.vision_tower(pixel_values) image_features = self.embed_vision(image_features).astype( @@ -127,9 +112,7 @@ def compute_embeddings( ) # Gemma4TextModel applies embed_scale even when input_embeddings are provided. - scaled_image_features = ( - image_features / text_model.language_model.model.embed_scale - ) + scaled_image_features = image_features / language_model.embed_scale image_mask = input_ids == self.config.image_token_id image_mask_expanded = mx.expand_dims(image_mask, -1) @@ -141,15 +124,13 @@ def compute_embeddings( input_embeddings, image_mask_expanded, scaled_image_features ) - prompt_per_layer_inputs = _compute_prompt_per_layer_inputs( - text_model.language_model.model, - input_ids, - self.config.image_token_id, - getattr(self.config, "audio_token_id", None), - ) - if prompt_per_layer_inputs is not None: - text_model.language_model.model.set_prompt_per_layer_inputs( - prompt_per_layer_inputs + if language_model.hidden_size_per_layer_input: + language_model.prompt_per_layer_input_ids = ( + _mask_prompt_per_layer_input_ids( + input_ids, + self.config.image_token_id, + self.config.audio_token_id, + ) ) return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) diff --git a/tests/test_gemma4_patch.py b/tests/test_gemma4_patch.py deleted file mode 100644 index b5d6a241..00000000 --- a/tests/test_gemma4_patch.py +++ /dev/null @@ -1,151 +0,0 @@ -import mlx.core as mx -from mlx_lm.generate import generate_step - -from mlx_engine.model_kit.patches.gemma4 import ( - OriginalGemma4TextModel as _OrigTextModel, - PatchedGemma4TextModel, - apply_patches, -) - -apply_patches() - -from mlx_lm.models.gemma4_text import Model, ModelArgs # noqa: E402 - -GEMMA4_TEXT_CONFIG = { - "model_type": "gemma4_text", - "hidden_size": 32, - "num_hidden_layers": 2, - "intermediate_size": 64, - "num_attention_heads": 2, - "num_key_value_heads": 1, - "num_global_key_value_heads": 1, - "head_dim": 16, - "global_head_dim": 16, - "sliding_window": 8, - "sliding_window_pattern": 1, - "layer_types": ["full_attention", "full_attention"], - "hidden_size_per_layer_input": 8, - "vocab_size": 32, - "vocab_size_per_layer_input": 32, - "num_kv_shared_layers": 0, -} - - -def make_model(**text_config_overrides): - args = ModelArgs.from_dict( - { - **GEMMA4_TEXT_CONFIG, - **text_config_overrides, - } - ) - return Model(args) - - -def test_gemma4_text_only_patched_matches_unpatched(): - import mlx_lm.models.gemma4_text as mod - - patched_text_model = mod.Gemma4TextModel - assert patched_text_model is PatchedGemma4TextModel - assert _OrigTextModel is not PatchedGemma4TextModel - - mx.random.seed(0) - model = make_model() - tokens = mx.array([[1, 2, 3, 4, 5, 6]], dtype=mx.int32) - patched_logits = model(tokens) - mx.eval(patched_logits) - patched_logits = mx.array(patched_logits) - - mod.Gemma4TextModel = _OrigTextModel - try: - mx.random.seed(0) - unpatched_model = make_model() - finally: - mod.Gemma4TextModel = patched_text_model - - unpatched_logits = unpatched_model(tokens) - mx.eval(unpatched_logits) - max_diff = mx.max(mx.abs(patched_logits - unpatched_logits)).item() - assert mx.allclose(patched_logits, unpatched_logits, atol=1e-4).item(), ( - f"Patched Gemma 4 text-only logits diverged from unpatched mlx-lm " - f"(max diff {max_diff:.6f})." - ) - - -def test_gemma4_prompt_per_layer_inputs_chunked_prefill_matches_unchunked(): - mx.random.seed(0) - - model = make_model() - text_model = model.model - prompt = mx.array([1, 2, 3, 4, 5, 6], dtype=mx.int32) - prompt_embeddings = text_model.embed_tokens(prompt[None]).squeeze(0) - prompt_per_layer_inputs = mx.full( - ( - 1, - prompt.shape[0], - text_model.config.num_hidden_layers, - text_model.hidden_size_per_layer_input, - ), - 3.0, - dtype=prompt_embeddings.dtype, - ) - - def first_step_logprobs(prefill_step_size: int) -> mx.array: - text_model.set_prompt_per_layer_inputs(prompt_per_layer_inputs) - step = generate_step( - prompt, - model, - max_tokens=1, - sampler=lambda x: mx.argmax(x, axis=-1), - input_embeddings=prompt_embeddings, - prefill_step_size=prefill_step_size, - ) - _, logprobs = next(step) - step.close() - mx.eval(logprobs) - return logprobs - - reference_logprobs = first_step_logprobs(prefill_step_size=16) - chunked_logprobs = first_step_logprobs(prefill_step_size=2) - - max_diff = mx.max(mx.abs(reference_logprobs - chunked_logprobs)).item() - assert mx.allclose(reference_logprobs, chunked_logprobs, atol=1e-4).item(), ( - f"Chunked Gemma 4 prompt per-layer-input prefill mismatch " - f"(max diff {max_diff:.6f})." - ) - - -def test_gemma4_prompt_state_matches_explicit_per_layer_inputs(): - mx.random.seed(0) - - model = make_model() - text_model = model.model - prompt = mx.array([[1, 2, 3, 4]], dtype=mx.int32) - prompt_embeddings = text_model.embed_tokens(prompt) - prompt_per_layer_inputs = mx.full( - ( - 1, - prompt.shape[1], - text_model.config.num_hidden_layers, - text_model.hidden_size_per_layer_input, - ), - 2.0, - dtype=prompt_embeddings.dtype, - ) - - explicit_logits = model( - prompt, - input_embeddings=prompt_embeddings, - per_layer_inputs=prompt_per_layer_inputs, - ) - text_model.set_prompt_per_layer_inputs(prompt_per_layer_inputs) - stateful_logits = model( - prompt, - input_embeddings=prompt_embeddings, - ) - mx.eval(explicit_logits, stateful_logits) - - max_diff = mx.max(mx.abs(explicit_logits - stateful_logits)).item() - assert mx.allclose(explicit_logits, stateful_logits, atol=1e-4).item(), ( - f"Stored Gemma 4 prompt per-layer inputs diverged from explicit inputs " - f"(max diff {max_diff:.6f})." - ) diff --git a/tests/test_gemma4_vision_addon.py b/tests/test_gemma4_vision_addon.py deleted file mode 100644 index ff9a0a8c..00000000 --- a/tests/test_gemma4_vision_addon.py +++ /dev/null @@ -1,54 +0,0 @@ -from types import SimpleNamespace - -import mlx.core as mx - -from mlx_engine.model_kit.vision_add_ons.gemma4 import _compute_prompt_per_layer_inputs - - -class _FakeGemma4TextModel: - def __init__(self, hidden_size_per_layer_input: int = 1): - self.hidden_size_per_layer_input = hidden_size_per_layer_input - self.seen_input_ids = None - - def _get_per_layer_inputs(self, input_ids: mx.array) -> mx.array: - self.seen_input_ids = input_ids - return input_ids[..., None, None] - - -class _FakeGemma4Model: - def __init__(self, hidden_size_per_layer_input: int = 1): - self.language_model = SimpleNamespace( - model=_FakeGemma4TextModel(hidden_size_per_layer_input) - ) - - -def test_compute_prompt_per_layer_inputs_masks_special_tokens(): - text_model = _FakeGemma4Model() - input_ids = mx.array([[1, 99, 2, 100, 3]], dtype=mx.int32) - - prompt_per_layer_inputs = _compute_prompt_per_layer_inputs( - text_model.language_model.model, - input_ids, - image_token_id=99, - audio_token_id=100, - ) - - assert text_model.language_model.model.seen_input_ids.tolist() == [[1, 0, 2, 0, 3]] - assert prompt_per_layer_inputs.shape == (1, 5, 1, 1) - assert prompt_per_layer_inputs[0, :, 0, 0].tolist() == [1, 0, 2, 0, 3] - - -def test_compute_prompt_per_layer_inputs_skips_models_without_per_layer_inputs(): - text_model = _FakeGemma4Model(hidden_size_per_layer_input=0) - input_ids = mx.array([[1, 99, 2]], dtype=mx.int32) - - assert ( - _compute_prompt_per_layer_inputs( - text_model.language_model.model, - input_ids, - image_token_id=99, - audio_token_id=100, - ) - is None - ) - assert text_model.language_model.model.seen_input_ids is None From 476c740951cc68d2fb83f624817302a6fee88f2d Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 7 Apr 2026 12:28:10 -0400 Subject: [PATCH 4/7] inline --- mlx_engine/model_kit/vision_add_ons/gemma4.py | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/mlx_engine/model_kit/vision_add_ons/gemma4.py b/mlx_engine/model_kit/vision_add_ons/gemma4.py index 9ca057dd..d6845805 100644 --- a/mlx_engine/model_kit/vision_add_ons/gemma4.py +++ b/mlx_engine/model_kit/vision_add_ons/gemma4.py @@ -27,16 +27,6 @@ logger = logging.getLogger(__name__) -def _mask_prompt_per_layer_input_ids( - input_ids: mx.array, - image_token_id: int, - audio_token_id: int, -) -> mx.array: - masked_input_ids = mx.where(input_ids == image_token_id, 0, input_ids) - masked_input_ids = mx.where(input_ids == audio_token_id, 0, masked_input_ids) - return masked_input_ids - - class Gemma4VisionComponents(nn.Module): def __init__(self, vision_tower: nn.Module, embed_vision: nn.Module): super().__init__() @@ -125,12 +115,11 @@ def compute_embeddings( ) if language_model.hidden_size_per_layer_input: - language_model.prompt_per_layer_input_ids = ( - _mask_prompt_per_layer_input_ids( - input_ids, - self.config.image_token_id, - self.config.audio_token_id, - ) + masked_input_ids = mx.where( + input_ids == self.config.image_token_id, 0, input_ids + ) + language_model.prompt_per_layer_input_ids = mx.where( + input_ids == self.config.audio_token_id, 0, masked_input_ids ) return input_ids.squeeze(0), final_inputs_embeds.squeeze(0) From 21750de73b2d4fa16961cadfaf9a180cf4dc0bf0 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Wed, 8 Apr 2026 13:42:38 -0400 Subject: [PATCH 5/7] gemma4 unified tests --- tests/patched_model_test_utils.py | 277 +++++++++++++++ tests/test_patched_gemma4.py | 173 ++++++++++ tests/test_patched_models.py | 536 ------------------------------ tests/test_patched_qwen3_5.py | 339 +++++++++++++++++++ 4 files changed, 789 insertions(+), 536 deletions(-) create mode 100644 tests/patched_model_test_utils.py create mode 100644 tests/test_patched_gemma4.py delete mode 100644 tests/test_patched_models.py create mode 100644 tests/test_patched_qwen3_5.py diff --git a/tests/patched_model_test_utils.py b/tests/patched_model_test_utils.py new file mode 100644 index 00000000..f13be2f9 --- /dev/null +++ b/tests/patched_model_test_utils.py @@ -0,0 +1,277 @@ +"""Shared helpers for patched-model parity tests.""" + +from __future__ import annotations + +from contextlib import contextmanager +from pathlib import Path +from typing import Iterator + +import numpy as np +import pytest + +import mlx.core as mx +import mlx_lm.models.gemma4_text as gemma4_text_module +import mlx_lm.models.qwen3_5 as qwen3_5_module +import mlx_lm.utils +from mlx_lm.models.cache import make_prompt_cache + +import mlx_engine.model_kit # noqa: F401 +from mlx_engine.model_kit.patches.gemma4 import PatchedGemma4TextModel +from mlx_engine.model_kit.patches.qwen3_5 import ( + OriginalDecoderLayer, + OriginalQwen3_5TextModel, +) +from mlx_vlm.models.cache import make_prompt_cache as make_vlm_prompt_cache +from mlx_vlm.utils import load_model as vlm_load_model, load_processor +from tests.shared import model_getter +from transformers import AutoProcessor + +OriginalGemma4TextModel = PatchedGemma4TextModel.__mro__[1] + +REAL_MODEL_CASES = [ + pytest.param("lmstudio-community/Qwen3.5-2B-MLX-4bit", id="dense"), + pytest.param( + "lmstudio-community/Qwen3.5-35B-A3B-MLX-4bit", + marks=pytest.mark.heavy, + id="moe", + ), +] +GEMMA4_MODEL_NAME = "lmstudio-community/gemma-4-E2B-it-MLX-4bit" +GEMMA4_IMAGE_TOPK = 5 +GEMMA4_IMAGE_TOPK_PROB_RTOL = 0.25 +GEMMA4_IMAGE_TOPK_PROB_REF_FLOOR = 1e-3 + + +def get_real_model_path(model_name: str) -> Path: + model_path = model_getter(model_name) + if not any(model_path.glob("*.safetensors")): + pytest.skip(f"{model_name}: no local MLX safetensors found in {model_path}") + return model_path + + +def max_abs_diff(actual: mx.array, reference: mx.array) -> float: + return float(mx.max(mx.abs(actual - reference)).item()) + + +def tokenize_prompt(tokenizer, prompt: str) -> list[int]: + ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(prompt)) + return [ids] if isinstance(ids, int) else ids + + +@contextmanager +def _temporary_bindings(module, **replacements) -> Iterator[None]: + current_bindings = {name: getattr(module, name) for name in replacements} + for name, replacement in replacements.items(): + setattr(module, name, replacement) + try: + yield + finally: + for name, current_binding in current_bindings.items(): + setattr(module, name, current_binding) + + +def _assert_restorable_binding(original, current, label: str) -> None: + if original is current: + raise AssertionError( + f"Expected a pristine {label} reference captured before mlx-engine " + "patched mlx-lm." + ) + + +def _load_unpatched_mlx_lm( + model_path: Path, *, module, replacements: dict[str, object] +): + with _temporary_bindings(module, **replacements): + return mlx_lm.utils.load(model_path) + + +def load_patched_mlx_lm(model_path: Path): + return mlx_lm.utils.load(model_path) + + +def load_unpatched_qwen_mlx_lm(model_path: Path): + _assert_restorable_binding( + OriginalDecoderLayer, + qwen3_5_module.DecoderLayer, + "qwen3.5 DecoderLayer", + ) + _assert_restorable_binding( + OriginalQwen3_5TextModel, + qwen3_5_module.Qwen3_5TextModel, + "qwen3.5 Qwen3_5TextModel", + ) + return _load_unpatched_mlx_lm( + model_path, + module=qwen3_5_module, + replacements={ + "DecoderLayer": OriginalDecoderLayer, + "Qwen3_5TextModel": OriginalQwen3_5TextModel, + }, + ) + + +def load_unpatched_gemma4_mlx_lm(model_path: Path): + _assert_restorable_binding( + OriginalGemma4TextModel, + gemma4_text_module.Gemma4TextModel, + "Gemma4TextModel", + ) + return _load_unpatched_mlx_lm( + model_path, + module=gemma4_text_module, + replacements={"Gemma4TextModel": OriginalGemma4TextModel}, + ) + + +def load_vlm(model_path: Path): + result = vlm_load_model(model_path) + return result[0] if isinstance(result, tuple) else result + + +def load_vlm_processor(model_path: Path): + return load_processor(model_path, add_detokenizer=True) + + +def build_gemma4_prompt( + model_path: Path, + user_text: str, + *, + image_b64: str | None = None, +) -> str: + processor = AutoProcessor.from_pretrained(model_path) + content = [{"type": "text", "text": user_text}] + if image_b64 is not None: + content.insert(0, {"type": "image", "base64": image_b64}) + conversation = [{"role": "user", "content": content}] + return processor.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=True, + ) + + +def first_mlx_lm_generation_logits( + model, + prompt_tokens: mx.array, + *, + input_embeddings: mx.array | None = None, + prefill_step_size: int = 2048, +) -> mx.array: + """Return the first-step logits from mlx-lm's generation path.""" + prompt_cache = make_prompt_cache(model) + remaining_tokens = prompt_tokens + remaining_embeddings = input_embeddings + + while len(remaining_tokens) > 1: + n_to_process = min(prefill_step_size, len(remaining_tokens) - 1) + if n_to_process <= 0: + break + kwargs = {"cache": prompt_cache} + if remaining_embeddings is not None: + kwargs["input_embeddings"] = remaining_embeddings[:n_to_process][None] + model(remaining_tokens[:n_to_process][None], **kwargs) + mx.eval([cache.state for cache in prompt_cache]) + remaining_tokens = remaining_tokens[n_to_process:] + if remaining_embeddings is not None: + remaining_embeddings = remaining_embeddings[n_to_process:] + mx.clear_cache() + + kwargs = {"cache": prompt_cache} + if remaining_embeddings is not None: + kwargs["input_embeddings"] = remaining_embeddings[None] + logits = model(remaining_tokens[None], **kwargs) + mx.eval(logits) + return mx.array(logits[0, -1, :]) + + +def first_vlm_generation_logits( + model, + *, + input_ids: mx.array, + pixel_values: mx.array, + attention_mask: mx.array, + prefill_step_size: int = 2048, +) -> mx.array: + """Return the first-step logits from mlx-vlm's generation path.""" + prompt_cache = make_vlm_prompt_cache(model.language_model) + embedding_output = model.get_input_embeddings( + input_ids=input_ids, + pixel_values=pixel_values, + mask=attention_mask, + ) + inputs_embeds = embedding_output.inputs_embeds + kwargs = { + key: value + for key, value in embedding_output.to_dict().items() + if key != "inputs_embeds" and value is not None + } + + while inputs_embeds.shape[1] > 1: + n_to_process = min(prefill_step_size, inputs_embeds.shape[1] - 1) + if n_to_process <= 0: + break + model.language_model( + inputs=input_ids[:, :n_to_process], + inputs_embeds=inputs_embeds[:, :n_to_process], + cache=prompt_cache, + n_to_process=n_to_process, + **kwargs, + ) + mx.eval([cache.state for cache in prompt_cache]) + input_ids = input_ids[:, n_to_process:] + inputs_embeds = inputs_embeds[:, n_to_process:] + mx.clear_cache() + + outputs = model.language_model( + input_ids[:, -1:], + inputs_embeds=inputs_embeds[:, -1:], + cache=prompt_cache, + **kwargs, + ) + mx.eval(outputs.logits) + return mx.array(outputs.logits[0, -1, :]) + + +def topk_token_ids(logits: mx.array, k: int) -> list[int]: + values = np.array(logits.tolist(), dtype=np.float32) + return [int(index) for index in np.argsort(values)[-k:][::-1]] + + +def gather_values(values: mx.array, token_ids: list[int]) -> list[float]: + return [float(values[token_id].item()) for token_id in token_ids] + + +def softmax_probabilities(logits: mx.array) -> mx.array: + return mx.softmax(logits.astype(mx.float32), axis=-1) + + +def relative_differences( + actual_values: list[float], + reference_values: list[float], + reference_floor: float, +) -> list[float]: + diffs = [] + for actual, reference in zip(actual_values, reference_values): + scale = max(abs(reference), reference_floor) + diffs.append(abs(actual - reference) / scale) + return diffs + + +def format_token_values(token_ids: list[int], values: list[float], tokenizer) -> str: + parts = [] + for token_id, value in zip(token_ids, values): + parts.append(f"{token_id}:{tokenizer.decode([token_id])!r}:{value:.6f}") + return "[" + ", ".join(parts) + "]" + + +def resolve_image_token_index(config) -> int | None: + vision_config = getattr(config, "vision_config", None) + return getattr( + config, + "image_token_index", + getattr( + config, + "image_token_id", + getattr(vision_config, "image_token_id", None), + ), + ) diff --git a/tests/test_patched_gemma4.py b/tests/test_patched_gemma4.py new file mode 100644 index 00000000..526f3594 --- /dev/null +++ b/tests/test_patched_gemma4.py @@ -0,0 +1,173 @@ +"""Tests for the Gemma 4 monkey patch.""" + +from pathlib import Path + +import pytest + +import mlx.core as mx + +from mlx_engine.generate import load_model +from mlx_engine.utils.image_utils import convert_to_pil +from mlx_engine.utils.prompt_progress_reporter import DefaultPromptProgressReporter +from mlx_vlm.utils import prepare_inputs + +from tests.patched_model_test_utils import ( + GEMMA4_IMAGE_TOPK, + GEMMA4_IMAGE_TOPK_PROB_REF_FLOOR, + GEMMA4_IMAGE_TOPK_PROB_RTOL, + GEMMA4_MODEL_NAME, + build_gemma4_prompt, + first_mlx_lm_generation_logits, + first_vlm_generation_logits, + format_token_values, + gather_values, + get_real_model_path, + load_patched_mlx_lm, + load_unpatched_gemma4_mlx_lm, + load_vlm, + load_vlm_processor, + max_abs_diff, + relative_differences, + resolve_image_token_index, + softmax_probabilities, + tokenize_prompt, + topk_token_ids, +) +from tests.shared import read_image_b64 + + +def test_gemma4_text_only_generation_patched_matches_unpatched(): + """The Gemma 4 patch must be a no-op for text-only generation.""" + model_path = get_real_model_path(GEMMA4_MODEL_NAME) + prefill_step_size = 16 + user_text = " ".join( + f"Segment {index}: explain why careful benchmarking matters before changing an inference stack." + for index in range(1, 25) + ) + prompt = build_gemma4_prompt(model_path, user_text) + + patched_model, patched_tokenizer = load_patched_mlx_lm(model_path) + patched_prompt_tokens = tokenize_prompt(patched_tokenizer, prompt) + assert len(patched_prompt_tokens) > prefill_step_size * 2 + patched_first_logits = first_mlx_lm_generation_logits( + patched_model, + mx.array(patched_prompt_tokens), + prefill_step_size=prefill_step_size, + ) + del patched_model + del patched_tokenizer + mx.clear_cache() + + unpatched_model, unpatched_tokenizer = load_unpatched_gemma4_mlx_lm(model_path) + assert tokenize_prompt(unpatched_tokenizer, prompt) == patched_prompt_tokens + unpatched_first_logits = first_mlx_lm_generation_logits( + unpatched_model, + mx.array(patched_prompt_tokens), + prefill_step_size=prefill_step_size, + ) + del unpatched_model + del unpatched_tokenizer + mx.clear_cache() + + diff = max_abs_diff(patched_first_logits, unpatched_first_logits) + assert diff == 0.0, ( + "Gemma 4 text-only generation logits mismatch between patched and " + f"unpatched mlx-lm (max diff {diff:.6f})." + ) + + +@pytest.mark.heavy +def test_gemma4_image_prompt_unified_arch_top5_matches_vlm(): + """Image+text Gemma 4 generation should stay close to native mlx-vlm.""" + model_path = get_real_model_path(GEMMA4_MODEL_NAME) + image_b64 = read_image_b64( + Path(__file__).parent.parent / "demo-data" / "toucan.jpeg" + ) + prompt = build_gemma4_prompt(model_path, "What is this?", image_b64=image_b64) + prefill_step_size = 2048 + + model_kit = load_model( + model_path=model_path, + max_seq_nums=1, + prefill_step_size=prefill_step_size, + ) + prompt_tokens = model_kit.tokenize(prompt) + input_tokens, input_embeddings = model_kit.process_prompt( + prompt_tokens, + images_b64=[image_b64], + prompt_progress_reporter=DefaultPromptProgressReporter(), + generate_args={}, + max_image_size=(1024, 1024), + ) + assert input_embeddings is not None + unified_first_logits = first_mlx_lm_generation_logits( + model_kit.model, + input_tokens, + input_embeddings=input_embeddings, + prefill_step_size=prefill_step_size, + ) + model_kit.shutdown() + del model_kit + mx.clear_cache() + + vlm_model = load_vlm(model_path) + vlm_processor = load_vlm_processor(model_path) + vlm_inputs = prepare_inputs( + processor=vlm_processor, + images=convert_to_pil([image_b64]), + prompts=prompt, + image_token_index=resolve_image_token_index(vlm_model.config), + resize_shape=None, + ) + native_input_ids = vlm_inputs["input_ids"] + native_attention_mask = vlm_inputs["attention_mask"] + native_pixel_values = vlm_inputs["pixel_values"] + + assert input_tokens.tolist() == native_input_ids[0].tolist() + + vlm_first_logits = first_vlm_generation_logits( + vlm_model, + input_ids=native_input_ids, + pixel_values=native_pixel_values, + attention_mask=native_attention_mask, + prefill_step_size=prefill_step_size, + ) + tokenizer = ( + vlm_processor.tokenizer + if hasattr(vlm_processor, "tokenizer") + else vlm_processor + ) + del vlm_model + mx.clear_cache() + + unified_top5_ids = topk_token_ids(unified_first_logits, GEMMA4_IMAGE_TOPK) + vlm_top5_ids = topk_token_ids(vlm_first_logits, GEMMA4_IMAGE_TOPK) + unified_logits = gather_values(unified_first_logits, unified_top5_ids) + vlm_logits = gather_values(vlm_first_logits, unified_top5_ids) + unified_probabilities = softmax_probabilities(unified_first_logits) + vlm_probabilities = softmax_probabilities(vlm_first_logits) + unified_top5_probabilities = gather_values(unified_probabilities, unified_top5_ids) + vlm_top5_probabilities = gather_values(vlm_probabilities, unified_top5_ids) + + assert unified_top5_ids == vlm_top5_ids, ( + "Top-5 token IDs/order mismatch: " + f"unified={format_token_values(unified_top5_ids, unified_top5_probabilities, tokenizer)} " + f"vlm={format_token_values(vlm_top5_ids, gather_values(vlm_probabilities, vlm_top5_ids), tokenizer)}" + ) + + relative_diffs = relative_differences( + unified_top5_probabilities, + vlm_top5_probabilities, + GEMMA4_IMAGE_TOPK_PROB_REF_FLOOR, + ) + max_relative_diff = max(relative_diffs) + + assert max_relative_diff <= GEMMA4_IMAGE_TOPK_PROB_RTOL, ( + "Top-5 probabilities exceeded tolerance: " + f"max relative diff {max_relative_diff:.6f}; " + f"relative_diffs={relative_diffs}; " + f"unified_logits={format_token_values(unified_top5_ids, unified_logits, tokenizer)} " + f"vlm_logits={format_token_values(unified_top5_ids, vlm_logits, tokenizer)} " + f"unified_probabilities={format_token_values(unified_top5_ids, unified_top5_probabilities, tokenizer)} " + f"vlm_probabilities={format_token_values(unified_top5_ids, vlm_top5_probabilities, tokenizer)}" + ) diff --git a/tests/test_patched_models.py b/tests/test_patched_models.py deleted file mode 100644 index 0f8b8ccd..00000000 --- a/tests/test_patched_models.py +++ /dev/null @@ -1,536 +0,0 @@ -""" -Tests for monkey-patched model classes (mlx_engine.model_kit.patches). - -Follows the upstream mlx-lm/tests/test_models.py pattern: construct models -from small configs with random weights, test through the public interface. - -Also includes cross-engine tests that load a real model from disk and compare -patched mlx-lm logits against unpatched mlx-lm and native mlx-vlm. -""" - -import pytest -from pathlib import Path - -import mlx.core as mx -from mlx_lm.generate import generate_step -from mlx_lm.models.cache import ArraysCache, BatchKVCache, KVCache, make_prompt_cache - -# Import pristine mlx-lm classes from the patch module's cached pre-patch refs. -from mlx_engine.model_kit.patches.qwen3_5 import ( - OriginalDecoderLayer as _OrigDecoderLayer, - OriginalQwen3_5TextModel as _OrigTextModel, - PatchedDecoderLayer, - PatchedQwen3_5TextModel, - apply_patches, -) - -apply_patches() - -from mlx_lm.models.qwen3_5 import Model, ModelArgs # noqa: E402 - -QWEN3_5_TEXT_CONFIG = { - "model_type": "qwen3_5", - "hidden_size": 128, - "num_hidden_layers": 4, - "intermediate_size": 128, - "num_attention_heads": 8, - "num_key_value_heads": 4, - "vocab_size": 1000, - "linear_num_value_heads": 4, - "linear_num_key_heads": 4, - "linear_key_head_dim": 32, - "linear_value_head_dim": 32, - "linear_conv_kernel_dim": 3, - "rms_norm_eps": 1e-5, - "head_dim": 64, - "rope_theta": 1000.0, - "partial_rotary_factor": 0.5, - "max_position_embeddings": 1000, -} - - -def make_model(**text_config_overrides): - args = ModelArgs.from_dict( - { - "model_type": "qwen3_5", - "text_config": { - **QWEN3_5_TEXT_CONFIG, - **text_config_overrides, - }, - } - ) - return Model(args) - - -def make_batched_prompt_cache(model, left_padding): - cache = model.make_cache() - for i, layer_cache in enumerate(cache): - if isinstance(layer_cache, ArraysCache): - layer_cache.left_padding = mx.array(left_padding) - elif type(layer_cache) is KVCache: - cache[i] = BatchKVCache(left_padding) - else: - raise AssertionError(f"Unexpected cache type: {type(layer_cache)!r}") - return cache - - -def _assert_pristine_qwen3_5_refs() -> None: - if _OrigDecoderLayer is PatchedDecoderLayer: - raise AssertionError( - "Expected a pristine qwen3.5 DecoderLayer reference, but the test " - "harness captured the patched class." - ) - if _OrigTextModel is PatchedQwen3_5TextModel: - raise AssertionError( - "Expected a pristine qwen3.5 Qwen3_5TextModel reference, but the " - "test harness captured the patched class." - ) - - -@pytest.mark.parametrize("use_mrope", [False, True], ids=["text_only", "mrope"]) -def test_qwen3_5_prefill_decode_consistency(use_mrope): - """Full-sequence prefill and incremental prefill+decode must produce - the same last-token logits. - - A correct autoregressive model satisfies: - model(all_tokens)[-1] == model(tokens[:-1]); model(tokens[-1]) - - Parameterized over: - - text_only: standard RoPE path (no input_embeddings, no MRoPE state) - - mrope: MRoPE path with non-degenerate 3D positions and input_embeddings, - simulating a vision request - """ - model = make_model() - text_model = model.language_model.model - tokens = mx.array([[0, 1, 2, 3]]) - - if use_mrope: - embeddings = text_model.embed_tokens(tokens) - # 3D positions simulating a vision prompt where image tokens create - # different spatial positions across temporal/height/width dims. - # rope_deltas and position_ids must be consistent: the last token's - # position (1) must equal cache_offset (3) + rope_deltas (-2). - position_ids = mx.array( - [ - [[0, 1, 0, 1]], # temporal - [[0, 0, 1, 1]], # height — differs from dim 0 during prefill - [[0, 1, 1, 1]], # width — differs from both during prefill - ] - ) - rope_deltas = mx.array(-2) - else: - embeddings = None - position_ids = None - rope_deltas = None - - # Full prefill: all tokens at once with cache - cache_full = make_prompt_cache(model) - text_model.position_ids = position_ids - text_model.rope_deltas = rope_deltas - full_output = model(tokens, cache=cache_full, input_embeddings=embeddings) - mx.eval(full_output) - full_last_logits = full_output[0, -1, :] - - # Incremental: prefill N-1 tokens, then decode last token - cache_incr = make_prompt_cache(model) - text_model.position_ids = position_ids - text_model.rope_deltas = rope_deltas - model( - tokens[:, :-1], - cache=cache_incr, - input_embeddings=embeddings[:, :-1] if embeddings is not None else None, - ) - mx.eval([c.state for c in cache_incr]) - - decode_output = model(tokens[:, -1:], cache=cache_incr) - mx.eval(decode_output) - decode_logits = decode_output[0, -1, :] - - max_diff = mx.max(mx.abs(full_last_logits - decode_logits)).item() - assert mx.allclose(full_last_logits, decode_logits, atol=1e-4).item(), ( - f"Prefill-decode logit mismatch (max diff {max_diff:.6f})." - ) - - -def test_qwen3_5_mrope_chunked_prefill_matches_unchunked(): - """Chunked prefill must match unchunked prefill even when a later prefill - chunk is still inside a non-sequential MRoPE image span. - - The prompt is only 6 tokens long, but forcing prefill_step_size=2 reproduces - the same chunking scenario as a 512-token prefill boundary in production: - - chunk 1 processes text + first image token - - chunk 2 processes image tokens while cache_offset > 0 - Later prompt chunks must continue using the stored 3D position_ids until the - precomputed prompt positions are exhausted. - """ - mx.random.seed(0) - - model = make_model( - num_hidden_layers=2, - full_attention_interval=2, - ) - text_model = model.language_model.model - - tokens = mx.array([0, 1, 2, 3, 4, 5]) - embeddings = text_model.embed_tokens(tokens) - - # Synthetic prompt layout: - # token 0: text - # tokens 1-4: 2x2 image span with non-sequential 3D positions - # token 5: trailing text - # - # The final text token's position (3) equals cache_offset (5) + rope_deltas (-2), - # so the reference unchunked path and decode path are consistent. - position_ids = mx.array( - [ - [[0, 1, 1, 1, 1, 3]], # temporal - [[0, 1, 1, 2, 2, 3]], # height - [[0, 1, 2, 1, 2, 3]], # width - ] - ) - rope_deltas = mx.array(-2) - - def first_step_logprobs(prefill_step_size: int) -> mx.array: - text_model.position_ids = position_ids - text_model.rope_deltas = rope_deltas - step = generate_step( - tokens, - model, - max_tokens=1, - prefill_step_size=prefill_step_size, - input_embeddings=embeddings, - ) - _, logprobs = next(step) - step.close() - mx.eval(logprobs) - return logprobs - - # Single prefill chunk for tokens[:-1]. - reference_logprobs = first_step_logprobs(prefill_step_size=16) - - # Forces multiple prefill chunks; chunk 2 still lies inside the image span. - chunked_logprobs = first_step_logprobs(prefill_step_size=2) - - max_diff = mx.max(mx.abs(reference_logprobs - chunked_logprobs)).item() - assert mx.allclose(reference_logprobs, chunked_logprobs, atol=1e-4).item(), ( - f"Chunked MRoPE prefill mismatch (max diff {max_diff:.6f})." - ) - - -def test_qwen3_5_mrope_later_single_image_chunk_matches_unchunked(): - """Chunked prefill must match unchunked prefill even when the full image - span is contained in a later prefill chunk. - - This isolates the more general bug: the implementation must keep using the - stored multimodal prompt positions for any later chunk that still belongs to - the original multimodal prompt. The image tokens do not cross a chunk - boundary here; earlier text simply pushes them out of the first chunk. - - With prefill_step_size=2: - - chunk 1 processes only leading text - - chunk 2 processes the entire image span - - chunk 3 processes trailing text - """ - mx.random.seed(0) - - model = make_model( - num_hidden_layers=2, - full_attention_interval=2, - ) - text_model = model.language_model.model - - tokens = mx.array([0, 1, 2, 3, 4, 5]) - embeddings = text_model.embed_tokens(tokens) - - # Synthetic prompt layout: - # tokens 0-1: leading text - # tokens 2-3: 1x2 image span with non-sequential 3D positions - # tokens 4-5: trailing text - # - # The image span is fully contained in chunk 2 when prefill_step_size=2. - # Continuation after the prompt is sequential again, so rope_deltas is 0. - position_ids = mx.array( - [ - [[0, 1, 2, 2, 4, 5]], # temporal - [[0, 1, 2, 2, 4, 5]], # height - [[0, 1, 2, 3, 4, 5]], # width - ] - ) - rope_deltas = mx.array(0) - - def first_step_logprobs(prefill_step_size: int) -> mx.array: - text_model.position_ids = position_ids - text_model.rope_deltas = rope_deltas - step = generate_step( - tokens, - model, - max_tokens=1, - prefill_step_size=prefill_step_size, - input_embeddings=embeddings, - ) - _, logprobs = next(step) - step.close() - mx.eval(logprobs) - return logprobs - - # Single prefill chunk for tokens[:-1]. - reference_logprobs = first_step_logprobs(prefill_step_size=16) - - # Chunk 2 contains the whole image span but is not the first prefill chunk. - chunked_logprobs = first_step_logprobs(prefill_step_size=2) - - max_diff = mx.max(mx.abs(reference_logprobs - chunked_logprobs)).item() - assert mx.allclose(reference_logprobs, chunked_logprobs, atol=1e-4).item(), ( - f"Later-chunk MRoPE prefill mismatch (max diff {max_diff:.6f})." - ) - - -def test_qwen3_5_text_only_uncached_matches_prompt_cache(): - """Direct uncached text-only forwards should match the cached prefill path.""" - model = make_model() - tokens = mx.array([[0, 1, 2, 3]]) - - reference_cache = make_prompt_cache(model) - reference_logits = model(tokens, cache=reference_cache) - mx.eval(reference_logits) - - uncached_logits = model(tokens, cache=None) - mx.eval(uncached_logits) - - max_diff = mx.max(mx.abs(reference_logits - uncached_logits)).item() - assert mx.allclose(reference_logits, uncached_logits, atol=1e-4).item(), ( - f"Text-only uncached logits mismatch (max diff {max_diff:.6f})." - ) - - -def test_qwen3_5_text_only_batch_cache_matches_prompt_cache(): - """Text-only forwards should handle BatchKVCache vector offsets.""" - model = make_model() - text_model = model.language_model.model - tokens = mx.array([[0, 1, 2, 3]]) - - reference_cache = make_prompt_cache(model) - reference_logits = model(tokens, cache=reference_cache) - mx.eval(reference_logits) - - batch_cache = make_batched_prompt_cache(model, [0]) - assert isinstance(batch_cache[text_model.fa_idx], BatchKVCache) - assert batch_cache[text_model.fa_idx].offset.ndim == 1 - - batch_logits = model(tokens, cache=batch_cache) - mx.eval(batch_logits) - - max_diff = mx.max(mx.abs(reference_logits - batch_logits)).item() - assert mx.allclose(reference_logits, batch_logits, atol=1e-4).item(), ( - f"Text-only batch-cache logits mismatch (max diff {max_diff:.6f})." - ) - - -# --------------------------------------------------------------------------- -# Cross-engine tests: compare patched mlx-lm against unpatched and mlx-vlm -# using a real model loaded from disk. -# --------------------------------------------------------------------------- - -REAL_MODEL_CASES = [ - pytest.param("lmstudio-community/Qwen3.5-2B-MLX-4bit", id="dense"), - pytest.param( - "lmstudio-community/Qwen3.5-35B-A3B-MLX-4bit", - marks=pytest.mark.heavy, - id="moe", - ), -] - - -def _get_real_model_path(model_name: str) -> Path: - from tests.shared import model_getter - - return model_getter(model_name) - - -def _load_patched_mlx_lm(model_path: Path): - """Load model using mlx-lm with patches already applied.""" - import mlx_lm.utils - - model, tokenizer = mlx_lm.utils.load(model_path) - return model, tokenizer - - -def _load_unpatched_mlx_lm(model_path: Path): - """Load model using mlx-lm with original (unpatched) classes temporarily restored.""" - import mlx_lm.models.qwen3_5 as mod - import mlx_lm.utils - - _assert_pristine_qwen3_5_refs() - patched_dl = mod.DecoderLayer - patched_tm = mod.Qwen3_5TextModel - mod.DecoderLayer = _OrigDecoderLayer - mod.Qwen3_5TextModel = _OrigTextModel - try: - model, tokenizer = mlx_lm.utils.load(model_path) - finally: - mod.DecoderLayer = patched_dl - mod.Qwen3_5TextModel = patched_tm - return model, tokenizer - - -def _load_vlm(model_path: Path): - """Load model using mlx-vlm's native loader.""" - from mlx_vlm.utils import load_model as vlm_load_model - - result = vlm_load_model(model_path) - return result[0] if isinstance(result, tuple) else result - - -@pytest.mark.parametrize("model_name", REAL_MODEL_CASES) -def test_qwen3_5_text_only_patched_matches_unpatched(model_name): - """Text-only logits from the patched mlx-lm model must match the - unpatched mlx-lm model for both dense and MoE Qwen3.5 variants. - - This validates that the MRoPE patch is a no-op for text-only inference: - the patched model delegates to the original mlx-lm code paths and must - produce identical logits to the unpatched original. - - Models are loaded and unloaded sequentially to limit memory usage. - """ - model_path = _get_real_model_path(model_name) - tokens = mx.array([[0, 1, 2, 3, 4, 5, 6, 7]]) - - # --- Patched mlx-lm --- - patched_model, _ = _load_patched_mlx_lm(model_path) - patched_logits = patched_model(tokens) - mx.eval(patched_logits) - # Detach from model graph before unloading - patched_logits = mx.array(patched_logits) - del patched_model - mx.clear_cache() - - # --- Unpatched mlx-lm --- - unpatched_model, _ = _load_unpatched_mlx_lm(model_path) - unpatched_logits = unpatched_model(tokens) - mx.eval(unpatched_logits) - unpatched_logits = mx.array(unpatched_logits) - del unpatched_model - mx.clear_cache() - - # --- Compare: run the required invariant before failing --- - diff_patched_unpatched = mx.max(mx.abs(patched_logits - unpatched_logits)).item() - - failures = [] - if diff_patched_unpatched != 0.0: - failures.append( - f"Patched vs unpatched mlx-lm: max diff {diff_patched_unpatched:.6f}" - ) - - summary = f"\n patched vs unpatched: {diff_patched_unpatched:.6f}" - assert len(failures) == 0, ( - f"{model_name}: Logit mismatch:{summary}\nFailures: {'; '.join(failures)}" - ) - - -@pytest.mark.parametrize("model_name", REAL_MODEL_CASES) -def test_qwen3_5_image_prompt_patched_matches_vlm(model_name): - """Image-prompt logits from the patched mlx-lm model must match the native - mlx-vlm LanguageModel for both dense and MoE Qwen3.5 variants. - - This validates two things: - 1. The vision add-on's _compute_image_mrope_state produces the same 3D - position IDs as mlx-vlm's LanguageModel.get_rope_index. - 2. Given those positions, the patched model's forward pass produces the - same logits as mlx-vlm (within bfloat16 tolerance from different - GatedDeltaNet implementations). - - Uses a synthetic token sequence with image token runs — no actual image - pixels are processed. The test focuses on position computation and - position threading through the model. - - Models are loaded and unloaded sequentially to limit memory usage. - """ - from mlx_engine.model_kit.vision_add_ons.qwen3_5 import _compute_image_mrope_state - - model_path = _get_real_model_path(model_name) - - # --- Load vlm model first to get config and compute reference positions --- - vlm_model = _load_vlm(model_path) - config = vlm_model.config - - # Construct a synthetic token sequence with an image span. - # Layout: [text, text, vision_start, image*4, text, text] - # With spatial_merge_size=2 and grid_thw=[1,4,4]: - # llm_grid = [1, 2, 2] → 4 image tokens - image_grid_thw = mx.array([[1, 4, 4]]) - tokens_list = [ - 0, - 1, - config.vision_start_token_id, - config.image_token_id, - config.image_token_id, - config.image_token_id, - config.image_token_id, - 2, - 3, - ] - tokens = mx.array([tokens_list]) - - # Compute positions via mlx-vlm's get_rope_index - vlm_position_ids, vlm_rope_deltas = vlm_model.language_model.get_rope_index( - tokens, image_grid_thw=image_grid_thw - ) - mx.eval(vlm_position_ids, vlm_rope_deltas) - - # Compute positions via the vision add-on's method - addon_position_ids, addon_rope_deltas = _compute_image_mrope_state( - mx.array(tokens_list), image_grid_thw, config - ) - mx.eval(addon_position_ids, addon_rope_deltas) - - # --- Assert position computation matches --- - position_ids_match = mx.array_equal(addon_position_ids, vlm_position_ids).item() - rope_deltas_match = mx.array_equal(addon_rope_deltas, vlm_rope_deltas).item() - - # --- Forward pass: vlm --- - vlm_model.language_model._position_ids = vlm_position_ids - vlm_model.language_model._rope_deltas = vlm_rope_deltas - vlm_logits = vlm_model.language_model( - tokens, cache=None, position_ids=vlm_position_ids - ).logits - mx.eval(vlm_logits) - vlm_logits = mx.array(vlm_logits) - del vlm_model - mx.clear_cache() - - # --- Forward pass: patched mlx-lm --- - patched_model, _ = _load_patched_mlx_lm(model_path) - patched_text_model = patched_model.language_model.model - patched_text_model.position_ids = addon_position_ids - patched_text_model.rope_deltas = addon_rope_deltas - patched_logits = patched_model(tokens) - mx.eval(patched_logits) - patched_logits = mx.array(patched_logits) - del patched_model - mx.clear_cache() - - # --- Compare: run all checks before failing --- - diff_logits = mx.max(mx.abs(patched_logits - vlm_logits)).item() - - failures = [] - if not position_ids_match: - failures.append( - f"Position IDs mismatch: addon={addon_position_ids.tolist()}, " - f"vlm={vlm_position_ids.tolist()}" - ) - if not rope_deltas_match: - failures.append( - f"Rope deltas mismatch: addon={addon_rope_deltas.item()}, " - f"vlm={vlm_rope_deltas.item()}" - ) - if diff_logits != 0.0: - failures.append(f"Logit mismatch: max diff {diff_logits:.6f}") - - summary = ( - f"\n position_ids match: {position_ids_match}" - f"\n rope_deltas match: {rope_deltas_match}" - f"\n logit max diff: {diff_logits:.6f}" - ) - assert len(failures) == 0, ( - f"{model_name}: Image prompt mismatch (expected zero diff):{summary}\nFailures: {'; '.join(failures)}" - ) diff --git a/tests/test_patched_qwen3_5.py b/tests/test_patched_qwen3_5.py new file mode 100644 index 00000000..817c13e1 --- /dev/null +++ b/tests/test_patched_qwen3_5.py @@ -0,0 +1,339 @@ +"""Tests for the Qwen3.5 monkey patches.""" + +import pytest + +import mlx.core as mx +from mlx_lm.generate import generate_step +from mlx_lm.models.cache import ArraysCache, BatchKVCache, KVCache, make_prompt_cache +from mlx_engine.model_kit.vision_add_ons.qwen3_5 import _compute_image_mrope_state + +from tests.patched_model_test_utils import ( + REAL_MODEL_CASES, + get_real_model_path, + load_patched_mlx_lm, + load_unpatched_qwen_mlx_lm, + load_vlm, + max_abs_diff, +) +from mlx_lm.models.qwen3_5 import Model, ModelArgs + +QWEN3_5_TEXT_CONFIG = { + "model_type": "qwen3_5", + "hidden_size": 128, + "num_hidden_layers": 4, + "intermediate_size": 128, + "num_attention_heads": 8, + "num_key_value_heads": 4, + "vocab_size": 1000, + "linear_num_value_heads": 4, + "linear_num_key_heads": 4, + "linear_key_head_dim": 32, + "linear_value_head_dim": 32, + "linear_conv_kernel_dim": 3, + "rms_norm_eps": 1e-5, + "head_dim": 64, + "rope_theta": 1000.0, + "partial_rotary_factor": 0.5, + "max_position_embeddings": 1000, +} + +QWEN3_5_MROPE_CHUNK_CASES = [ + pytest.param( + { + "tokens": [0, 1, 2, 3, 4, 5], + "position_ids": [ + [[0, 1, 1, 1, 1, 3]], + [[0, 1, 1, 2, 2, 3]], + [[0, 1, 2, 1, 2, 3]], + ], + "rope_deltas": -2, + "failure_label": "Chunked MRoPE prefill mismatch", + }, + id="crosses_chunk_boundary", + ), + pytest.param( + { + "tokens": [0, 1, 2, 3, 4, 5], + "position_ids": [ + [[0, 1, 2, 2, 4, 5]], + [[0, 1, 2, 2, 4, 5]], + [[0, 1, 2, 3, 4, 5]], + ], + "rope_deltas": 0, + "failure_label": "Later-chunk MRoPE prefill mismatch", + }, + id="image_span_in_later_chunk", + ), +] + + +def make_model(**text_config_overrides): + args = ModelArgs.from_dict( + { + "model_type": "qwen3_5", + "text_config": { + **QWEN3_5_TEXT_CONFIG, + **text_config_overrides, + }, + } + ) + return Model(args) + + +def make_batched_prompt_cache(model, left_padding): + cache = model.make_cache() + for index, layer_cache in enumerate(cache): + if isinstance(layer_cache, ArraysCache): + layer_cache.left_padding = mx.array(left_padding) + elif type(layer_cache) is KVCache: + cache[index] = BatchKVCache(left_padding) + else: + raise AssertionError(f"Unexpected cache type: {type(layer_cache)!r}") + return cache + + +def _first_generate_step_logprobs( + model, + tokens: mx.array, + *, + input_embeddings: mx.array | None = None, + position_ids: mx.array | None = None, + rope_deltas: mx.array | None = None, + prefill_step_size: int, +) -> mx.array: + text_model = model.language_model.model + text_model.position_ids = position_ids + text_model.rope_deltas = rope_deltas + step = generate_step( + tokens, + model, + max_tokens=1, + prefill_step_size=prefill_step_size, + input_embeddings=input_embeddings, + ) + _, logprobs = next(step) + step.close() + mx.eval(logprobs) + return logprobs + + +@pytest.mark.parametrize("use_mrope", [False, True], ids=["text_only", "mrope"]) +def test_qwen3_5_prefill_decode_consistency(use_mrope): + """Full-sequence prefill and incremental prefill+decode must agree.""" + model = make_model() + text_model = model.language_model.model + tokens = mx.array([[0, 1, 2, 3]]) + + if use_mrope: + embeddings = text_model.embed_tokens(tokens) + position_ids = mx.array( + [ + [[0, 1, 0, 1]], + [[0, 0, 1, 1]], + [[0, 1, 1, 1]], + ] + ) + rope_deltas = mx.array(-2) + else: + embeddings = None + position_ids = None + rope_deltas = None + + cache_full = make_prompt_cache(model) + text_model.position_ids = position_ids + text_model.rope_deltas = rope_deltas + full_output = model(tokens, cache=cache_full, input_embeddings=embeddings) + mx.eval(full_output) + full_last_logits = full_output[0, -1, :] + + cache_incr = make_prompt_cache(model) + text_model.position_ids = position_ids + text_model.rope_deltas = rope_deltas + model( + tokens[:, :-1], + cache=cache_incr, + input_embeddings=embeddings[:, :-1] if embeddings is not None else None, + ) + mx.eval([cache.state for cache in cache_incr]) + + decode_output = model(tokens[:, -1:], cache=cache_incr) + mx.eval(decode_output) + decode_logits = decode_output[0, -1, :] + + diff = max_abs_diff(full_last_logits, decode_logits) + assert mx.allclose(full_last_logits, decode_logits, atol=1e-4).item(), ( + f"Prefill-decode logit mismatch (max diff {diff:.6f})." + ) + + +@pytest.mark.parametrize("case", QWEN3_5_MROPE_CHUNK_CASES) +def test_qwen3_5_mrope_chunked_prefill_matches_unchunked(case): + """Chunked MRoPE prefill must match unchunked prefill.""" + mx.random.seed(0) + + model = make_model( + num_hidden_layers=2, + full_attention_interval=2, + ) + text_model = model.language_model.model + tokens = mx.array(case["tokens"]) + embeddings = text_model.embed_tokens(tokens) + position_ids = mx.array(case["position_ids"]) + rope_deltas = mx.array(case["rope_deltas"]) + + reference_logprobs = _first_generate_step_logprobs( + model, + tokens, + input_embeddings=embeddings, + position_ids=position_ids, + rope_deltas=rope_deltas, + prefill_step_size=16, + ) + chunked_logprobs = _first_generate_step_logprobs( + model, + tokens, + input_embeddings=embeddings, + position_ids=position_ids, + rope_deltas=rope_deltas, + prefill_step_size=2, + ) + + diff = max_abs_diff(reference_logprobs, chunked_logprobs) + assert mx.allclose(reference_logprobs, chunked_logprobs, atol=1e-4).item(), ( + f"{case['failure_label']} (max diff {diff:.6f})." + ) + + +def test_qwen3_5_text_only_uncached_matches_prompt_cache(): + """Direct uncached text-only forwards should match the cached prefill path.""" + model = make_model() + tokens = mx.array([[0, 1, 2, 3]]) + + reference_cache = make_prompt_cache(model) + reference_logits = model(tokens, cache=reference_cache) + mx.eval(reference_logits) + + uncached_logits = model(tokens, cache=None) + mx.eval(uncached_logits) + + diff = max_abs_diff(reference_logits, uncached_logits) + assert mx.allclose(reference_logits, uncached_logits, atol=1e-4).item(), ( + f"Text-only uncached logits mismatch (max diff {diff:.6f})." + ) + + +def test_qwen3_5_text_only_batch_cache_matches_prompt_cache(): + """Text-only forwards should handle BatchKVCache vector offsets.""" + model = make_model() + text_model = model.language_model.model + tokens = mx.array([[0, 1, 2, 3]]) + + reference_cache = make_prompt_cache(model) + reference_logits = model(tokens, cache=reference_cache) + mx.eval(reference_logits) + + batch_cache = make_batched_prompt_cache(model, [0]) + assert isinstance(batch_cache[text_model.fa_idx], BatchKVCache) + assert batch_cache[text_model.fa_idx].offset.ndim == 1 + + batch_logits = model(tokens, cache=batch_cache) + mx.eval(batch_logits) + + diff = max_abs_diff(reference_logits, batch_logits) + assert mx.allclose(reference_logits, batch_logits, atol=1e-4).item(), ( + f"Text-only batch-cache logits mismatch (max diff {diff:.6f})." + ) + + +@pytest.mark.parametrize("model_name", REAL_MODEL_CASES) +def test_qwen3_5_text_only_patched_matches_unpatched(model_name): + """The Qwen3.5 patch must be a no-op for text-only inference.""" + model_path = get_real_model_path(model_name) + tokens = mx.array([[0, 1, 2, 3, 4, 5, 6, 7]]) + + patched_model, _ = load_patched_mlx_lm(model_path) + patched_logits = patched_model(tokens) + mx.eval(patched_logits) + patched_logits = mx.array(patched_logits) + del patched_model + mx.clear_cache() + + unpatched_model, _ = load_unpatched_qwen_mlx_lm(model_path) + unpatched_logits = unpatched_model(tokens) + mx.eval(unpatched_logits) + unpatched_logits = mx.array(unpatched_logits) + del unpatched_model + mx.clear_cache() + + diff = max_abs_diff(patched_logits, unpatched_logits) + assert diff == 0.0, ( + f"{model_name}: patched vs unpatched mlx-lm logits mismatch " + f"(max diff {diff:.6f})." + ) + + +@pytest.mark.parametrize("model_name", REAL_MODEL_CASES) +def test_qwen3_5_image_prompt_patched_matches_vlm(model_name): + """The patched Qwen3.5 image path must match native mlx-vlm.""" + model_path = get_real_model_path(model_name) + vlm_model = load_vlm(model_path) + config = vlm_model.config + + image_grid_thw = mx.array([[1, 4, 4]]) + tokens_list = [ + 0, + 1, + config.vision_start_token_id, + config.image_token_id, + config.image_token_id, + config.image_token_id, + config.image_token_id, + 2, + 3, + ] + tokens = mx.array([tokens_list]) + + vlm_position_ids, vlm_rope_deltas = vlm_model.language_model.get_rope_index( + tokens, image_grid_thw=image_grid_thw + ) + mx.eval(vlm_position_ids, vlm_rope_deltas) + + addon_position_ids, addon_rope_deltas = _compute_image_mrope_state( + mx.array(tokens_list), image_grid_thw, config + ) + mx.eval(addon_position_ids, addon_rope_deltas) + + assert mx.array_equal(addon_position_ids, vlm_position_ids).item(), ( + f"{model_name}: position IDs mismatch: addon={addon_position_ids.tolist()}, " + f"vlm={vlm_position_ids.tolist()}" + ) + assert mx.array_equal(addon_rope_deltas, vlm_rope_deltas).item(), ( + f"{model_name}: rope deltas mismatch: addon={addon_rope_deltas.item()}, " + f"vlm={vlm_rope_deltas.item()}" + ) + + vlm_model.language_model._position_ids = vlm_position_ids + vlm_model.language_model._rope_deltas = vlm_rope_deltas + vlm_logits = vlm_model.language_model( + tokens, cache=None, position_ids=vlm_position_ids + ).logits + mx.eval(vlm_logits) + vlm_logits = mx.array(vlm_logits) + del vlm_model + mx.clear_cache() + + patched_model, _ = load_patched_mlx_lm(model_path) + patched_text_model = patched_model.language_model.model + patched_text_model.position_ids = addon_position_ids + patched_text_model.rope_deltas = addon_rope_deltas + patched_logits = patched_model(tokens) + mx.eval(patched_logits) + patched_logits = mx.array(patched_logits) + del patched_model + mx.clear_cache() + + diff = max_abs_diff(patched_logits, vlm_logits) + assert diff == 0.0, ( + f"{model_name}: image prompt logits mismatch against mlx-vlm " + f"(max diff {diff:.6f})." + ) From fbaed17da3bd8aa94a0feb8a8e1622b24928d40c Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Wed, 8 Apr 2026 15:38:47 -0400 Subject: [PATCH 6/7] refactor tests --- mlx_engine/model_kit/patches/gemma4.py | 6 +- tests/patched_model_test_utils.py | 83 ++------------------------ tests/test_patched_gemma4.py | 50 +++++++++++++--- tests/test_patched_qwen3_5.py | 42 ++++++++++++- 4 files changed, 92 insertions(+), 89 deletions(-) diff --git a/mlx_engine/model_kit/patches/gemma4.py b/mlx_engine/model_kit/patches/gemma4.py index 67e8792c..46ae1f08 100644 --- a/mlx_engine/model_kit/patches/gemma4.py +++ b/mlx_engine/model_kit/patches/gemma4.py @@ -9,8 +9,12 @@ from mlx_lm.models.gemma4_text import Gemma4TextModel +# Stable alias to the pristine mlx-lm class captured before apply_patches() +# mutates mlx_lm.models.gemma4_text in place. +OriginalGemma4TextModel = Gemma4TextModel -class PatchedGemma4TextModel(Gemma4TextModel): + +class PatchedGemma4TextModel(OriginalGemma4TextModel): def __init__(self, config): super().__init__(config) self.prompt_per_layer_input_ids = None diff --git a/tests/patched_model_test_utils.py b/tests/patched_model_test_utils.py index f13be2f9..ce6524e8 100644 --- a/tests/patched_model_test_utils.py +++ b/tests/patched_model_test_utils.py @@ -10,36 +10,13 @@ import pytest import mlx.core as mx -import mlx_lm.models.gemma4_text as gemma4_text_module -import mlx_lm.models.qwen3_5 as qwen3_5_module import mlx_lm.utils from mlx_lm.models.cache import make_prompt_cache import mlx_engine.model_kit # noqa: F401 -from mlx_engine.model_kit.patches.gemma4 import PatchedGemma4TextModel -from mlx_engine.model_kit.patches.qwen3_5 import ( - OriginalDecoderLayer, - OriginalQwen3_5TextModel, -) from mlx_vlm.models.cache import make_prompt_cache as make_vlm_prompt_cache from mlx_vlm.utils import load_model as vlm_load_model, load_processor from tests.shared import model_getter -from transformers import AutoProcessor - -OriginalGemma4TextModel = PatchedGemma4TextModel.__mro__[1] - -REAL_MODEL_CASES = [ - pytest.param("lmstudio-community/Qwen3.5-2B-MLX-4bit", id="dense"), - pytest.param( - "lmstudio-community/Qwen3.5-35B-A3B-MLX-4bit", - marks=pytest.mark.heavy, - id="moe", - ), -] -GEMMA4_MODEL_NAME = "lmstudio-community/gemma-4-E2B-it-MLX-4bit" -GEMMA4_IMAGE_TOPK = 5 -GEMMA4_IMAGE_TOPK_PROB_RTOL = 0.25 -GEMMA4_IMAGE_TOPK_PROB_REF_FLOOR = 1e-3 def get_real_model_path(model_name: str) -> Path: @@ -78,9 +55,11 @@ def _assert_restorable_binding(original, current, label: str) -> None: ) -def _load_unpatched_mlx_lm( - model_path: Path, *, module, replacements: dict[str, object] -): +def assert_restorable_binding(original, current, label: str) -> None: + _assert_restorable_binding(original, current, label) + + +def load_unpatched_mlx_lm(model_path: Path, *, module, replacements: dict[str, object]): with _temporary_bindings(module, **replacements): return mlx_lm.utils.load(model_path) @@ -89,40 +68,6 @@ def load_patched_mlx_lm(model_path: Path): return mlx_lm.utils.load(model_path) -def load_unpatched_qwen_mlx_lm(model_path: Path): - _assert_restorable_binding( - OriginalDecoderLayer, - qwen3_5_module.DecoderLayer, - "qwen3.5 DecoderLayer", - ) - _assert_restorable_binding( - OriginalQwen3_5TextModel, - qwen3_5_module.Qwen3_5TextModel, - "qwen3.5 Qwen3_5TextModel", - ) - return _load_unpatched_mlx_lm( - model_path, - module=qwen3_5_module, - replacements={ - "DecoderLayer": OriginalDecoderLayer, - "Qwen3_5TextModel": OriginalQwen3_5TextModel, - }, - ) - - -def load_unpatched_gemma4_mlx_lm(model_path: Path): - _assert_restorable_binding( - OriginalGemma4TextModel, - gemma4_text_module.Gemma4TextModel, - "Gemma4TextModel", - ) - return _load_unpatched_mlx_lm( - model_path, - module=gemma4_text_module, - replacements={"Gemma4TextModel": OriginalGemma4TextModel}, - ) - - def load_vlm(model_path: Path): result = vlm_load_model(model_path) return result[0] if isinstance(result, tuple) else result @@ -132,24 +77,6 @@ def load_vlm_processor(model_path: Path): return load_processor(model_path, add_detokenizer=True) -def build_gemma4_prompt( - model_path: Path, - user_text: str, - *, - image_b64: str | None = None, -) -> str: - processor = AutoProcessor.from_pretrained(model_path) - content = [{"type": "text", "text": user_text}] - if image_b64 is not None: - content.insert(0, {"type": "image", "base64": image_b64}) - conversation = [{"role": "user", "content": content}] - return processor.apply_chat_template( - conversation, - tokenize=False, - add_generation_prompt=True, - ) - - def first_mlx_lm_generation_logits( model, prompt_tokens: mx.array, diff --git a/tests/test_patched_gemma4.py b/tests/test_patched_gemma4.py index 526f3594..74e66fd4 100644 --- a/tests/test_patched_gemma4.py +++ b/tests/test_patched_gemma4.py @@ -5,25 +5,23 @@ import pytest import mlx.core as mx +import mlx_lm.models.gemma4_text as gemma4_text_module from mlx_engine.generate import load_model +from mlx_engine.model_kit.patches.gemma4 import OriginalGemma4TextModel from mlx_engine.utils.image_utils import convert_to_pil from mlx_engine.utils.prompt_progress_reporter import DefaultPromptProgressReporter from mlx_vlm.utils import prepare_inputs from tests.patched_model_test_utils import ( - GEMMA4_IMAGE_TOPK, - GEMMA4_IMAGE_TOPK_PROB_REF_FLOOR, - GEMMA4_IMAGE_TOPK_PROB_RTOL, - GEMMA4_MODEL_NAME, - build_gemma4_prompt, + assert_restorable_binding, first_mlx_lm_generation_logits, first_vlm_generation_logits, format_token_values, gather_values, get_real_model_path, + load_unpatched_mlx_lm, load_patched_mlx_lm, - load_unpatched_gemma4_mlx_lm, load_vlm, load_vlm_processor, max_abs_diff, @@ -34,6 +32,45 @@ topk_token_ids, ) from tests.shared import read_image_b64 +from transformers import AutoProcessor + +pytestmark = pytest.mark.heavy + +GEMMA4_MODEL_NAME = "lmstudio-community/gemma-4-E2B-it-MLX-4bit" +GEMMA4_IMAGE_TOPK = 5 +GEMMA4_IMAGE_TOPK_PROB_RTOL = 0.25 +GEMMA4_IMAGE_TOPK_PROB_REF_FLOOR = 1e-3 + + +def build_gemma4_prompt( + model_path: Path, + user_text: str, + *, + image_b64: str | None = None, +) -> str: + processor = AutoProcessor.from_pretrained(model_path) + content = [{"type": "text", "text": user_text}] + if image_b64 is not None: + content.insert(0, {"type": "image", "base64": image_b64}) + conversation = [{"role": "user", "content": content}] + return processor.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=True, + ) + + +def load_unpatched_gemma4_mlx_lm(model_path: Path): + assert_restorable_binding( + OriginalGemma4TextModel, + gemma4_text_module.Gemma4TextModel, + "Gemma4TextModel", + ) + return load_unpatched_mlx_lm( + model_path, + module=gemma4_text_module, + replacements={"Gemma4TextModel": OriginalGemma4TextModel}, + ) def test_gemma4_text_only_generation_patched_matches_unpatched(): @@ -76,7 +113,6 @@ def test_gemma4_text_only_generation_patched_matches_unpatched(): ) -@pytest.mark.heavy def test_gemma4_image_prompt_unified_arch_top5_matches_vlm(): """Image+text Gemma 4 generation should stay close to native mlx-vlm.""" model_path = get_real_model_path(GEMMA4_MODEL_NAME) diff --git a/tests/test_patched_qwen3_5.py b/tests/test_patched_qwen3_5.py index 817c13e1..02ff5210 100644 --- a/tests/test_patched_qwen3_5.py +++ b/tests/test_patched_qwen3_5.py @@ -5,17 +5,32 @@ import mlx.core as mx from mlx_lm.generate import generate_step from mlx_lm.models.cache import ArraysCache, BatchKVCache, KVCache, make_prompt_cache +from mlx_lm.models.qwen3_5 import Model, ModelArgs +import mlx_lm.models.qwen3_5 as qwen3_5_module + from mlx_engine.model_kit.vision_add_ons.qwen3_5 import _compute_image_mrope_state +from mlx_engine.model_kit.patches.qwen3_5 import ( + OriginalDecoderLayer, + OriginalQwen3_5TextModel, +) from tests.patched_model_test_utils import ( - REAL_MODEL_CASES, + assert_restorable_binding, get_real_model_path, + load_unpatched_mlx_lm, load_patched_mlx_lm, - load_unpatched_qwen_mlx_lm, load_vlm, max_abs_diff, ) -from mlx_lm.models.qwen3_5 import Model, ModelArgs + +REAL_MODEL_CASES = [ + pytest.param("lmstudio-community/Qwen3.5-2B-MLX-4bit", id="dense"), + pytest.param( + "lmstudio-community/Qwen3.5-35B-A3B-MLX-4bit", + marks=pytest.mark.heavy, + id="moe", + ), +] QWEN3_5_TEXT_CONFIG = { "model_type": "qwen3_5", @@ -92,6 +107,27 @@ def make_batched_prompt_cache(model, left_padding): return cache +def load_unpatched_qwen_mlx_lm(model_path): + assert_restorable_binding( + OriginalDecoderLayer, + qwen3_5_module.DecoderLayer, + "qwen3.5 DecoderLayer", + ) + assert_restorable_binding( + OriginalQwen3_5TextModel, + qwen3_5_module.Qwen3_5TextModel, + "qwen3.5 Qwen3_5TextModel", + ) + return load_unpatched_mlx_lm( + model_path, + module=qwen3_5_module, + replacements={ + "DecoderLayer": OriginalDecoderLayer, + "Qwen3_5TextModel": OriginalQwen3_5TextModel, + }, + ) + + def _first_generate_step_logprobs( model, tokens: mx.array, From d1f3c64d8caecf5946beafc01efe43a9266085e2 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Wed, 8 Apr 2026 15:51:42 -0400 Subject: [PATCH 7/7] simplify --- tests/patched_model_test_utils.py | 130 +++--------------------------- tests/test_patched_gemma4.py | 123 ++++++++++++++++++++++++---- tests/test_patched_qwen3_5.py | 13 +-- 3 files changed, 118 insertions(+), 148 deletions(-) diff --git a/tests/patched_model_test_utils.py b/tests/patched_model_test_utils.py index ce6524e8..49b1d92e 100644 --- a/tests/patched_model_test_utils.py +++ b/tests/patched_model_test_utils.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Iterator -import numpy as np import pytest import mlx.core as mx @@ -14,8 +13,7 @@ from mlx_lm.models.cache import make_prompt_cache import mlx_engine.model_kit # noqa: F401 -from mlx_vlm.models.cache import make_prompt_cache as make_vlm_prompt_cache -from mlx_vlm.utils import load_model as vlm_load_model, load_processor +from mlx_vlm.utils import load_model as vlm_load_model from tests.shared import model_getter @@ -30,11 +28,6 @@ def max_abs_diff(actual: mx.array, reference: mx.array) -> float: return float(mx.max(mx.abs(actual - reference)).item()) -def tokenize_prompt(tokenizer, prompt: str) -> list[int]: - ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(prompt)) - return [ids] if isinstance(ids, int) else ids - - @contextmanager def _temporary_bindings(module, **replacements) -> Iterator[None]: current_bindings = {name: getattr(module, name) for name in replacements} @@ -47,20 +40,16 @@ def _temporary_bindings(module, **replacements) -> Iterator[None]: setattr(module, name, current_binding) -def _assert_restorable_binding(original, current, label: str) -> None: - if original is current: - raise AssertionError( - f"Expected a pristine {label} reference captured before mlx-engine " - "patched mlx-lm." - ) - - -def assert_restorable_binding(original, current, label: str) -> None: - _assert_restorable_binding(original, current, label) - - -def load_unpatched_mlx_lm(model_path: Path, *, module, replacements: dict[str, object]): - with _temporary_bindings(module, **replacements): +def load_unpatched_mlx_lm( + model_path: Path, *, module, original_bindings: dict[str, object] +): + for name, original in original_bindings.items(): + if getattr(module, name) is original: + raise AssertionError( + f"Expected a pristine {module.__name__}.{name} reference captured " + "before mlx-engine patched mlx-lm." + ) + with _temporary_bindings(module, **original_bindings): return mlx_lm.utils.load(model_path) @@ -73,10 +62,6 @@ def load_vlm(model_path: Path): return result[0] if isinstance(result, tuple) else result -def load_vlm_processor(model_path: Path): - return load_processor(model_path, add_detokenizer=True) - - def first_mlx_lm_generation_logits( model, prompt_tokens: mx.array, @@ -109,96 +94,3 @@ def first_mlx_lm_generation_logits( logits = model(remaining_tokens[None], **kwargs) mx.eval(logits) return mx.array(logits[0, -1, :]) - - -def first_vlm_generation_logits( - model, - *, - input_ids: mx.array, - pixel_values: mx.array, - attention_mask: mx.array, - prefill_step_size: int = 2048, -) -> mx.array: - """Return the first-step logits from mlx-vlm's generation path.""" - prompt_cache = make_vlm_prompt_cache(model.language_model) - embedding_output = model.get_input_embeddings( - input_ids=input_ids, - pixel_values=pixel_values, - mask=attention_mask, - ) - inputs_embeds = embedding_output.inputs_embeds - kwargs = { - key: value - for key, value in embedding_output.to_dict().items() - if key != "inputs_embeds" and value is not None - } - - while inputs_embeds.shape[1] > 1: - n_to_process = min(prefill_step_size, inputs_embeds.shape[1] - 1) - if n_to_process <= 0: - break - model.language_model( - inputs=input_ids[:, :n_to_process], - inputs_embeds=inputs_embeds[:, :n_to_process], - cache=prompt_cache, - n_to_process=n_to_process, - **kwargs, - ) - mx.eval([cache.state for cache in prompt_cache]) - input_ids = input_ids[:, n_to_process:] - inputs_embeds = inputs_embeds[:, n_to_process:] - mx.clear_cache() - - outputs = model.language_model( - input_ids[:, -1:], - inputs_embeds=inputs_embeds[:, -1:], - cache=prompt_cache, - **kwargs, - ) - mx.eval(outputs.logits) - return mx.array(outputs.logits[0, -1, :]) - - -def topk_token_ids(logits: mx.array, k: int) -> list[int]: - values = np.array(logits.tolist(), dtype=np.float32) - return [int(index) for index in np.argsort(values)[-k:][::-1]] - - -def gather_values(values: mx.array, token_ids: list[int]) -> list[float]: - return [float(values[token_id].item()) for token_id in token_ids] - - -def softmax_probabilities(logits: mx.array) -> mx.array: - return mx.softmax(logits.astype(mx.float32), axis=-1) - - -def relative_differences( - actual_values: list[float], - reference_values: list[float], - reference_floor: float, -) -> list[float]: - diffs = [] - for actual, reference in zip(actual_values, reference_values): - scale = max(abs(reference), reference_floor) - diffs.append(abs(actual - reference) / scale) - return diffs - - -def format_token_values(token_ids: list[int], values: list[float], tokenizer) -> str: - parts = [] - for token_id, value in zip(token_ids, values): - parts.append(f"{token_id}:{tokenizer.decode([token_id])!r}:{value:.6f}") - return "[" + ", ".join(parts) + "]" - - -def resolve_image_token_index(config) -> int | None: - vision_config = getattr(config, "vision_config", None) - return getattr( - config, - "image_token_index", - getattr( - config, - "image_token_id", - getattr(vision_config, "image_token_id", None), - ), - ) diff --git a/tests/test_patched_gemma4.py b/tests/test_patched_gemma4.py index 74e66fd4..4ec7d0b6 100644 --- a/tests/test_patched_gemma4.py +++ b/tests/test_patched_gemma4.py @@ -2,6 +2,7 @@ from pathlib import Path +import numpy as np import pytest import mlx.core as mx @@ -11,25 +12,16 @@ from mlx_engine.model_kit.patches.gemma4 import OriginalGemma4TextModel from mlx_engine.utils.image_utils import convert_to_pil from mlx_engine.utils.prompt_progress_reporter import DefaultPromptProgressReporter -from mlx_vlm.utils import prepare_inputs +from mlx_vlm.models.cache import make_prompt_cache as make_vlm_prompt_cache +from mlx_vlm.utils import load_processor, prepare_inputs from tests.patched_model_test_utils import ( - assert_restorable_binding, first_mlx_lm_generation_logits, - first_vlm_generation_logits, - format_token_values, - gather_values, get_real_model_path, load_unpatched_mlx_lm, load_patched_mlx_lm, load_vlm, - load_vlm_processor, max_abs_diff, - relative_differences, - resolve_image_token_index, - softmax_probabilities, - tokenize_prompt, - topk_token_ids, ) from tests.shared import read_image_b64 from transformers import AutoProcessor @@ -42,6 +34,11 @@ GEMMA4_IMAGE_TOPK_PROB_REF_FLOOR = 1e-3 +def tokenize_prompt(tokenizer, prompt: str) -> list[int]: + ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(prompt)) + return [ids] if isinstance(ids, int) else ids + + def build_gemma4_prompt( model_path: Path, user_text: str, @@ -61,15 +58,107 @@ def build_gemma4_prompt( def load_unpatched_gemma4_mlx_lm(model_path: Path): - assert_restorable_binding( - OriginalGemma4TextModel, - gemma4_text_module.Gemma4TextModel, - "Gemma4TextModel", - ) return load_unpatched_mlx_lm( model_path, module=gemma4_text_module, - replacements={"Gemma4TextModel": OriginalGemma4TextModel}, + original_bindings={"Gemma4TextModel": OriginalGemma4TextModel}, + ) + + +def load_vlm_processor(model_path: Path): + return load_processor(model_path, add_detokenizer=True) + + +def first_vlm_generation_logits( + model, + *, + input_ids: mx.array, + pixel_values: mx.array, + attention_mask: mx.array, + prefill_step_size: int = 2048, +) -> mx.array: + """Return the first-step logits from mlx-vlm's generation path.""" + prompt_cache = make_vlm_prompt_cache(model.language_model) + embedding_output = model.get_input_embeddings( + input_ids=input_ids, + pixel_values=pixel_values, + mask=attention_mask, + ) + inputs_embeds = embedding_output.inputs_embeds + kwargs = { + key: value + for key, value in embedding_output.to_dict().items() + if key != "inputs_embeds" and value is not None + } + + while inputs_embeds.shape[1] > 1: + n_to_process = min(prefill_step_size, inputs_embeds.shape[1] - 1) + if n_to_process <= 0: + break + model.language_model( + inputs=input_ids[:, :n_to_process], + inputs_embeds=inputs_embeds[:, :n_to_process], + cache=prompt_cache, + n_to_process=n_to_process, + **kwargs, + ) + mx.eval([cache.state for cache in prompt_cache]) + input_ids = input_ids[:, n_to_process:] + inputs_embeds = inputs_embeds[:, n_to_process:] + mx.clear_cache() + + outputs = model.language_model( + input_ids[:, -1:], + inputs_embeds=inputs_embeds[:, -1:], + cache=prompt_cache, + **kwargs, + ) + mx.eval(outputs.logits) + return mx.array(outputs.logits[0, -1, :]) + + +def topk_token_ids(logits: mx.array, k: int) -> list[int]: + values = np.array(logits.tolist(), dtype=np.float32) + return [int(index) for index in np.argsort(values)[-k:][::-1]] + + +def gather_values(values: mx.array, token_ids: list[int]) -> list[float]: + return [float(values[token_id].item()) for token_id in token_ids] + + +def softmax_probabilities(logits: mx.array) -> mx.array: + return mx.softmax(logits.astype(mx.float32), axis=-1) + + +def relative_differences( + actual_values: list[float], + reference_values: list[float], + reference_floor: float, +) -> list[float]: + diffs = [] + for actual, reference in zip(actual_values, reference_values): + scale = max(abs(reference), reference_floor) + diffs.append(abs(actual - reference) / scale) + return diffs + + +def format_token_values(token_ids: list[int], values: list[float], tokenizer) -> str: + parts = [] + for token_id, value in zip(token_ids, values): + parts.append(f"{token_id}:{tokenizer.decode([token_id])!r}:{value:.6f}") + return "[" + ", ".join(parts) + "]" + + +def resolve_image_token_index(config) -> int | None: + vision_config = getattr(config, "vision_config", None) + return getattr( + config, + "image_token_index", + getattr( + config, + "image_token_id", + getattr(vision_config, "image_token_id", None), + ), ) diff --git a/tests/test_patched_qwen3_5.py b/tests/test_patched_qwen3_5.py index 02ff5210..f840135d 100644 --- a/tests/test_patched_qwen3_5.py +++ b/tests/test_patched_qwen3_5.py @@ -15,7 +15,6 @@ ) from tests.patched_model_test_utils import ( - assert_restorable_binding, get_real_model_path, load_unpatched_mlx_lm, load_patched_mlx_lm, @@ -108,20 +107,10 @@ def make_batched_prompt_cache(model, left_padding): def load_unpatched_qwen_mlx_lm(model_path): - assert_restorable_binding( - OriginalDecoderLayer, - qwen3_5_module.DecoderLayer, - "qwen3.5 DecoderLayer", - ) - assert_restorable_binding( - OriginalQwen3_5TextModel, - qwen3_5_module.Qwen3_5TextModel, - "qwen3.5 Qwen3_5TextModel", - ) return load_unpatched_mlx_lm( model_path, module=qwen3_5_module, - replacements={ + original_bindings={ "DecoderLayer": OriginalDecoderLayer, "Qwen3_5TextModel": OriginalQwen3_5TextModel, },