Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions mlx_vlm/models/pixtral/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def get_input_embeddings(
# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)

# Get number of images
num_images = len(pixel_values[0])

# Get the ouptut hidden states from the vision model
if isinstance(pixel_values, list):
if input_ids.shape[0] == 1: # Batch size is 1
Expand All @@ -88,8 +91,13 @@ def get_input_embeddings(
if pixel_values.ndim == 3:
pixel_values = pixel_values[None, ...]

pixel_values = mx.split(pixel_values, num_images, axis=2)

# pass pixel_values as list of images, as each image is individually run through conv2d and position encoding
# reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21
# and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85
*_, hidden_states = self.vision_tower(
pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
[pv.transpose(0, 2, 3, 1) for pv in pixel_values], output_hidden_states=True
)
# Select the hidden states from the desired layer
selected_image_feature = hidden_states[self.vision_feature_layer]
Expand Down Expand Up @@ -119,7 +127,10 @@ def _merge_input_ids_with_image_features(
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1

image_embeddings = mx.split(image_features, image_features.shape[0])
# [IMG_BREAK] and [IMG_END] are missing with existing implementation
# image_embeddings = mx.split(image_features, image_features.shape[0])

image_embeddings = mx.split(image_features, num_image_patches, axis=1)
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]

Expand Down
10 changes: 5 additions & 5 deletions mlx_vlm/models/pixtral/vision.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from dataclasses import dataclass
from typing import Optional
from typing import Optional, List

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -253,11 +253,11 @@ def __init__(self, config: VisionConfig):

def __call__(
self,
x: mx.array,
x: List[mx.array],
output_hidden_states: Optional[bool] = None,
) -> mx.array:
B, H, W, C = x.shape
patch_embeds_list = [self.patch_conv(img[None, :]) for img in x]
B, H, W, C = x[0].shape
patch_embeds_list = [self.patch_conv(img) for img in x]

patch_embeds = mx.concatenate(
[p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1
Expand Down Expand Up @@ -299,7 +299,7 @@ def __init__(self, config: VisionConfig):
self.vision_model = PixtralVisionModel(config)

def __call__(
self, x: mx.array, output_hidden_states: Optional[bool] = None
self, x: List[mx.array], output_hidden_states: Optional[bool] = None
) -> mx.array:
return self.vision_model(x, output_hidden_states)

Expand Down