diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index 032b43f2..1257d531 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -75,6 +75,9 @@ def get_input_embeddings( # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) + # Get number of images + num_images = len(pixel_values[0]) + # Get the ouptut hidden states from the vision model if isinstance(pixel_values, list): if input_ids.shape[0] == 1: # Batch size is 1 @@ -88,8 +91,13 @@ def get_input_embeddings( if pixel_values.ndim == 3: pixel_values = pixel_values[None, ...] + pixel_values = mx.split(pixel_values, num_images, axis=2) + + # pass pixel_values as list of images, as each image is individually run through conv2d and position encoding + # reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21 + # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 *_, hidden_states = self.vision_tower( - pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + [pv.transpose(0, 2, 3, 1) for pv in pixel_values], output_hidden_states=True ) # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] @@ -119,7 +127,10 @@ def _merge_input_ids_with_image_features( text_segments.append(inputs_embeds[:, start_idx:position]) start_idx = position + 1 - image_embeddings = mx.split(image_features, image_features.shape[0]) + # [IMG_BREAK] and [IMG_END] are missing with existing implementation + # image_embeddings = mx.split(image_features, image_features.shape[0]) + + image_embeddings = mx.split(image_features, num_image_patches, axis=1) final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] final_embeddings += [inputs_embeds[:, start_idx:]] diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py index 2db77f4d..ce5bce8c 100644 --- a/mlx_vlm/models/pixtral/vision.py +++ b/mlx_vlm/models/pixtral/vision.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Optional +from typing import Optional, List import mlx.core as mx import mlx.nn as nn @@ -253,11 +253,11 @@ def __init__(self, config: VisionConfig): def __call__( self, - x: mx.array, + x: List[mx.array], output_hidden_states: Optional[bool] = None, ) -> mx.array: - B, H, W, C = x.shape - patch_embeds_list = [self.patch_conv(img[None, :]) for img in x] + B, H, W, C = x[0].shape + patch_embeds_list = [self.patch_conv(img) for img in x] patch_embeds = mx.concatenate( [p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1 @@ -299,7 +299,7 @@ def __init__(self, config: VisionConfig): self.vision_model = PixtralVisionModel(config) def __call__( - self, x: mx.array, output_hidden_states: Optional[bool] = None + self, x: List[mx.array], output_hidden_states: Optional[bool] = None ) -> mx.array: return self.vision_model(x, output_hidden_states)