diff --git a/sgl-router/scripts/generate_vision_golden.py b/sgl-router/scripts/generate_vision_golden.py index ae282d2ab8b3..37965f47477f 100755 --- a/sgl-router/scripts/generate_vision_golden.py +++ b/sgl-router/scripts/generate_vision_golden.py @@ -66,6 +66,11 @@ "processor_class": "Llama4ImageProcessorFast", "description": "Tile-based processing with 336x336 tiles and global tile", }, + "pixtral": { + "model_id": "mistralai/Pixtral-12B-2409", + "processor_class": "PixtralImageProcessor", + "description": "Dynamic resolution with CLIP normalization and bicubic resize", + }, } # Default test images @@ -547,6 +552,64 @@ def generate_golden_llama4_vision(image_path: str, output_dir: str) -> dict: return result +def generate_golden_pixtral(image_path: str, output_dir: str) -> dict: + """Generate golden output for Pixtral/Mistral3 Vision. + + Pixtral uses dynamic resolution processing: + 1. If image exceeds longest_edge (default 1024), scale down proportionally + 2. Resize to dimensions that are multiples of patch_size (default 16) + 3. Use bicubic interpolation for resize + 4. Normalize with CLIP mean/std + + Output: + - pixel_values: [1, 3, H, W] where H, W are multiples of patch_size + - image_sizes: [(H, W)] + + Token count: (H / patch_size) * (W / patch_size) + """ + from transformers import PixtralImageProcessor + + processor = PixtralImageProcessor.from_pretrained("mistral-community/pixtral-12b") + image = Image.open(image_path).convert("RGB") + original_size = image.size + + # Process image + outputs = processor(images=image, return_tensors="np") + pixel_values = outputs["pixel_values"] + image_sizes = outputs.get("image_sizes") + + result = { + "pixel_values": pixel_values, + "original_size": original_size, + "processor_config": processor.to_dict(), + } + + if image_sizes is not None: + result["image_sizes"] = np.array(image_sizes) + + # Calculate num_tokens from image_sizes + if image_sizes is not None: + h, w = image_sizes[0] + patch_size = getattr(processor, "patch_size", {"height": 16, "width": 16}) + if isinstance(patch_size, dict): + patch_h = patch_size.get("height", 16) + patch_w = patch_size.get("width", 16) + else: + patch_h = patch_w = patch_size + num_tokens = (h // patch_h) * (w // patch_w) + result["num_tokens"] = num_tokens + + # Add debug info + result["config_info"] = { + "longest_edge": processor.size.get("longest_edge", 1024), + "patch_size": processor.patch_size, + "image_mean": processor.image_mean, + "image_std": processor.image_std, + } + + return result + + def generate_for_model(model_key: str, image_paths: list, output_dir: str): """Generate golden outputs for a specific model.""" print(f"\nGenerating golden outputs for {model_key}...") @@ -560,6 +623,7 @@ def generate_for_model(model_key: str, image_paths: list, output_dir: str): "phi3_vision": generate_golden_phi3_vision, "phi4_vision": generate_golden_phi4_vision, "llama4_vision": generate_golden_llama4_vision, + "pixtral": generate_golden_pixtral, }.get(model_key) if generator_fn is None: diff --git a/sgl-router/src/multimodal/vision/mod.rs b/sgl-router/src/multimodal/vision/mod.rs index 189a0c0f45e0..a46509a38b57 100644 --- a/sgl-router/src/multimodal/vision/mod.rs +++ b/sgl-router/src/multimodal/vision/mod.rs @@ -41,6 +41,6 @@ pub use image_processor::{ pub use preprocessor_config::PreProcessorConfig; pub use processors::{ Llama4VisionProcessor, LlavaNextProcessor, LlavaProcessor, Phi3VisionProcessor, - Phi4VisionProcessor, Qwen2VLProcessor, Qwen3VLProcessor, + Phi4VisionProcessor, PixtralProcessor, Qwen2VLProcessor, Qwen3VLProcessor, }; pub use transforms::TransformError; diff --git a/sgl-router/src/multimodal/vision/preprocessor_config.rs b/sgl-router/src/multimodal/vision/preprocessor_config.rs index 4c5a52e1134a..d237ebf2160d 100644 --- a/sgl-router/src/multimodal/vision/preprocessor_config.rs +++ b/sgl-router/src/multimodal/vision/preprocessor_config.rs @@ -6,10 +6,97 @@ use std::collections::HashMap; use image::imageops::FilterType; -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; use super::transforms; +/// Struct to represent patch_size as dict {"height": x, "width": y} +#[derive(Debug, Clone, Deserialize, Default)] +pub struct PatchSize { + pub height: Option, + pub width: Option, +} + +/// Custom deserializer for patch_size that handles both integer and dict formats. +/// - Integer format: `"patch_size": 16` -> PatchSize { height: 16, width: 16 } +/// - Dict format: `"patch_size": {"height": 16, "width": 16}` -> PatchSize { height: 16, width: 16 } +fn deserialize_patch_size<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + use std::fmt; + + use serde::de::{self, MapAccess, Visitor}; + + struct PatchSizeVisitor; + + impl<'de> Visitor<'de> for PatchSizeVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an integer, a dict with height/width, or null") + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: de::Error, + { + Ok(None) + } + + fn visit_i64(self, value: i64) -> Result + where + E: de::Error, + { + let v = value as u32; + Ok(Some(PatchSize { + height: Some(v), + width: Some(v), + })) + } + + fn visit_u64(self, value: u64) -> Result + where + E: de::Error, + { + let v = value as u32; + Ok(Some(PatchSize { + height: Some(v), + width: Some(v), + })) + } + + fn visit_map(self, mut map: M) -> Result + where + M: MapAccess<'de>, + { + let mut height = None; + let mut width = None; + + while let Some(key) = map.next_key::()? { + match key.as_str() { + "height" => height = Some(map.next_value::()?), + "width" => width = Some(map.next_value::()?), + _ => { + let _ = map.next_value::()?; + } + } + } + + Ok(Some(PatchSize { height, width })) + } + } + + deserializer.deserialize_any(PatchSizeVisitor) +} + /// HuggingFace preprocessor_config.json structure. /// /// This struct captures the common fields across different vision model processors. @@ -73,8 +160,9 @@ pub struct PreProcessorConfig { // Model-specific fields // ===================== /// Vision encoder patch size (typically 14 or 16) - #[serde(default)] - pub patch_size: Option, + /// Can be an integer or a dict {"height": x, "width": y} + #[serde(default, deserialize_with = "deserialize_patch_size")] + pub patch_size: Option, /// Qwen-VL: merge size for token reduction #[serde(default)] @@ -151,6 +239,17 @@ impl PreProcessorConfig { serde_json::from_value(value) } + /// Get patch size as a simple usize. + /// + /// Returns the height value from PatchSize if available, falling back to provided default. + pub fn get_patch_size(&self, default: usize) -> usize { + self.patch_size + .as_ref() + .and_then(|p| p.height) + .map(|h| h as usize) + .unwrap_or(default) + } + /// Get image mean as fixed array, with fallback to CLIP defaults. pub fn get_image_mean(&self) -> [f64; 3] { self.image_mean @@ -310,7 +409,7 @@ mod tests { assert_eq!(config.min_pixels, Some(200704)); assert_eq!(config.max_pixels, Some(1003520)); - assert_eq!(config.patch_size, Some(14)); + assert_eq!(config.get_patch_size(0), 14); assert_eq!(config.merge_size, Some(2)); assert!((config.get_rescale_factor() - 1.0 / 255.0).abs() < 1e-10); } diff --git a/sgl-router/src/multimodal/vision/processors/llava.rs b/sgl-router/src/multimodal/vision/processors/llava.rs index 0574e7a06cea..ac9e0dd5180e 100644 --- a/sgl-router/src/multimodal/vision/processors/llava.rs +++ b/sgl-router/src/multimodal/vision/processors/llava.rs @@ -308,7 +308,7 @@ impl ImagePreProcessor for LlavaProcessor { config: &PreProcessorConfig, ) -> usize { // For LLaVA 1.5, token count is based on processed image size and patch size - let patch_size = config.patch_size.unwrap_or(self.patch_size as usize) as u32; + let patch_size = config.get_patch_size(self.patch_size as usize) as u32; let image_size = config .get_target_size() .map(|(h, _w)| h) diff --git a/sgl-router/src/multimodal/vision/processors/mod.rs b/sgl-router/src/multimodal/vision/processors/mod.rs index ba0b95e15ab7..4b2c48d2c04a 100644 --- a/sgl-router/src/multimodal/vision/processors/mod.rs +++ b/sgl-router/src/multimodal/vision/processors/mod.rs @@ -13,11 +13,13 @@ //! - **Phi3-Vision** (`phi3_vision`): Dynamic HD transform with 336x336 tiles //! - **Phi4-Vision** (`phi4_vision`): Dynamic HD transform with 448x448 tiles and SiGLIP encoder //! - **LLaMA 4 Vision** (`llama4_vision`): Tile-based processing with 336x336 tiles and global tile +//! - **Pixtral/Mistral3** (`pixtral`): CLIP-based preprocessing with dynamic resolution pub mod llama4_vision; pub mod llava; pub mod phi3_vision; pub mod phi4_vision; +pub mod pixtral; pub mod qwen2_vl; pub mod qwen3_vl; pub mod qwen_vl_base; @@ -26,5 +28,6 @@ pub use llama4_vision::Llama4VisionProcessor; pub use llava::{ImageAspectRatio, LlavaNextProcessor, LlavaProcessor}; pub use phi3_vision::Phi3VisionProcessor; pub use phi4_vision::Phi4VisionProcessor; +pub use pixtral::PixtralProcessor; pub use qwen2_vl::Qwen2VLProcessor; pub use qwen3_vl::Qwen3VLProcessor; diff --git a/sgl-router/src/multimodal/vision/processors/pixtral.rs b/sgl-router/src/multimodal/vision/processors/pixtral.rs new file mode 100644 index 000000000000..dd93d85e4206 --- /dev/null +++ b/sgl-router/src/multimodal/vision/processors/pixtral.rs @@ -0,0 +1,411 @@ +//! Pixtral/Mistral3 Vision image processor implementation. +//! +//! This module implements the image preprocessing for Pixtral/Mistral3 models, +//! matching the behavior of HuggingFace's `PixtralImageProcessor`. +//! +//! Key characteristics: +//! - CLIP normalization: mean [0.48145466, 0.4578275, 0.40821073], std [0.26862954, 0.26130258, 0.27577711] +//! - Bicubic resampling for resize +//! - Images resized to fit within longest_edge (default 1024) +//! - Output dimensions are multiples of patch_size (default 16) +//! - No tiling - single image output per input + +use std::collections::HashMap; + +use image::{imageops::FilterType, DynamicImage}; +use ndarray::{Array4, IxDyn}; + +use crate::multimodal::vision::{ + image_processor::{ImagePreProcessor, ModelSpecificValue, PreprocessedImages}, + preprocessor_config::PreProcessorConfig, + transforms::{self, TransformError}, +}; + +/// Default normalization mean values (CLIP) +const DEFAULT_IMAGE_MEAN: [f64; 3] = [0.48145466, 0.4578275, 0.40821073]; + +/// Default normalization std values (CLIP) +const DEFAULT_IMAGE_STD: [f64; 3] = [0.26862954, 0.26130258, 0.27577711]; + +/// Default longest edge for resize +const DEFAULT_LONGEST_EDGE: u32 = 1024; + +/// Default patch size +const DEFAULT_PATCH_SIZE: u32 = 16; + +/// Pixtral/Mistral3 Vision image processor. +/// +/// This processor handles image preprocessing for Pixtral and Mistral3 vision models. +/// Unlike tile-based processors (Phi3, LLaMA4), Pixtral processes images at their +/// natural resolution (up to a maximum), preserving aspect ratio. +#[derive(Debug, Clone)] +pub struct PixtralProcessor { + /// Maximum dimension for the longest edge + longest_edge: u32, + /// Patch size for calculating output dimensions + patch_size: u32, + /// Normalization mean values + image_mean: [f64; 3], + /// Normalization std values + image_std: [f64; 3], +} + +impl Default for PixtralProcessor { + fn default() -> Self { + Self::new() + } +} + +impl PixtralProcessor { + /// Creates a new Pixtral processor with default settings. + pub fn new() -> Self { + Self { + longest_edge: DEFAULT_LONGEST_EDGE, + patch_size: DEFAULT_PATCH_SIZE, + image_mean: DEFAULT_IMAGE_MEAN, + image_std: DEFAULT_IMAGE_STD, + } + } + + /// Creates a processor from a HuggingFace preprocessor config. + pub fn from_preprocessor_config(config: &PreProcessorConfig) -> Self { + let longest_edge = config + .size + .as_ref() + .and_then(|s| s.get("longest_edge").copied()) + .unwrap_or(DEFAULT_LONGEST_EDGE); + + // Patch size uses the new PatchSize type from config + let patch_size = config.get_patch_size(DEFAULT_PATCH_SIZE as usize) as u32; + + let image_mean = config + .image_mean + .as_ref() + .filter(|m| m.len() >= 3) + .map(|m| [m[0], m[1], m[2]]) + .unwrap_or(DEFAULT_IMAGE_MEAN); + + let image_std = config + .image_std + .as_ref() + .filter(|s| s.len() >= 3) + .map(|s| [s[0], s[1], s[2]]) + .unwrap_or(DEFAULT_IMAGE_STD); + + Self { + longest_edge, + patch_size, + image_mean, + image_std, + } + } + + /// Calculates the target output size for an image. + /// + /// The image is resized to fit within `longest_edge` while preserving aspect ratio. + /// The output dimensions are then adjusted to be multiples of `patch_size`. + fn get_resize_output_size(&self, height: u32, width: u32) -> (u32, u32) { + let max_size = self.longest_edge; + let patch_size = self.patch_size; + + // Calculate ratio for scaling down (only if larger than max_size) + let ratio = f64::max( + height as f64 / max_size as f64, + width as f64 / max_size as f64, + ); + + let (new_height, new_width) = if ratio > 1.0 { + // Scale down using floor to ensure we don't exceed max_size + let new_height = (height as f64 / ratio).floor() as u32; + let new_width = (width as f64 / ratio).floor() as u32; + (new_height, new_width) + } else { + (height, width) + }; + + // Calculate number of patches in each dimension + // Using: num_tokens = (dim - 1) / patch_size + 1 (i.e., ceiling division) + let num_height_tokens = (new_height.max(1) - 1) / patch_size + 1; + let num_width_tokens = (new_width.max(1) - 1) / patch_size + 1; + + // Final size is patches * patch_size + ( + num_height_tokens * patch_size, + num_width_tokens * patch_size, + ) + } + + /// Processes a single image through the Pixtral pipeline. + fn process_single_image( + &self, + image: &DynamicImage, + ) -> Result<(Array4, (usize, usize)), TransformError> { + let (orig_width, orig_height) = (image.width(), image.height()); + + // Step 1: Calculate output size + let (target_h, target_w) = self.get_resize_output_size(orig_height, orig_width); + + // Step 2: Resize image using bicubic interpolation + let resized = image.resize_exact(target_w, target_h, FilterType::CatmullRom); + + // Step 3: Convert to tensor (0-1 range) and normalize + let mut tensor = transforms::to_tensor(&resized); + transforms::normalize(&mut tensor, &self.image_mean, &self.image_std); + + // Step 4: Reshape to (1, C, H, W) + let (c, h, w) = (tensor.shape()[0], tensor.shape()[1], tensor.shape()[2]); + let output = tensor + .into_shape_with_order((1, c, h, w)) + .map_err(|e| TransformError::ShapeError(e.to_string()))?; + + Ok((output, (target_h as usize, target_w as usize))) + } +} + +impl ImagePreProcessor for PixtralProcessor { + fn default_mean(&self) -> [f64; 3] { + self.image_mean + } + + fn default_std(&self) -> [f64; 3] { + self.image_std + } + + fn preprocess( + &self, + images: &[DynamicImage], + config: &PreProcessorConfig, + ) -> Result { + if images.is_empty() { + return Err(TransformError::InvalidShape { + expected: "non-empty image batch".to_string(), + actual: vec![0], + }); + } + + // Apply config overrides if present + let processor = if config.size.is_some() + || config.patch_size.is_some() + || config.image_mean.is_some() + || config.image_std.is_some() + { + Self::from_preprocessor_config(config) + } else { + self.clone() + }; + + let mut all_pixel_values = Vec::new(); + let mut all_image_sizes = Vec::new(); + let mut original_sizes = Vec::new(); + let mut num_img_tokens = Vec::new(); + + for image in images { + let (pixels, size) = processor.process_single_image(image)?; + let tokens = processor.calculate_num_tokens(image.width(), image.height(), config); + + all_pixel_values.push(pixels); + all_image_sizes.push(size); + original_sizes.push((image.height(), image.width())); + num_img_tokens.push(tokens); + } + + // Pad images to the same size for batching + let max_height = all_image_sizes.iter().map(|(h, _)| *h).max().unwrap_or(0); + let max_width = all_image_sizes.iter().map(|(_, w)| *w).max().unwrap_or(0); + + // Create batch tensor with padding + let batch_size = all_pixel_values.len(); + let channels = 3; + let mut batch_tensor = + ndarray::ArrayD::::zeros(IxDyn(&[batch_size, channels, max_height, max_width])); + + for (i, (pixels, (h, w))) in all_pixel_values + .iter() + .zip(all_image_sizes.iter()) + .enumerate() + { + // Copy the image data into the batch (top-left aligned, zero-padded) + for c in 0..channels { + for y in 0..*h { + for x in 0..*w { + batch_tensor[[i, c, y, x]] = pixels[[0, c, y, x]]; + } + } + } + } + + // Store image sizes as model-specific data + let mut model_specific = HashMap::new(); + let image_sizes_flat: Vec = all_image_sizes + .iter() + .flat_map(|&(h, w)| vec![h as i64, w as i64]) + .collect(); + model_specific.insert( + "image_sizes".to_string(), + ModelSpecificValue::IntTensor { + data: image_sizes_flat, + shape: vec![batch_size, 2], + }, + ); + + Ok(PreprocessedImages { + pixel_values: batch_tensor, + num_img_tokens, + image_sizes: original_sizes, + model_specific, + }) + } + + fn calculate_num_tokens(&self, width: u32, height: u32, config: &PreProcessorConfig) -> usize { + let processor = Self::from_preprocessor_config(config); + let (target_h, target_w) = processor.get_resize_output_size(height, width); + let patch_size = processor.patch_size; + + // Number of tokens = num_patches_h * num_patches_w + let num_patches_h = target_h / patch_size; + let num_patches_w = target_w / patch_size; + (num_patches_h * num_patches_w) as usize + } + + fn model_name(&self) -> &'static str { + "pixtral" + } + + fn get_processed_size(&self, _config: &PreProcessorConfig) -> Option<(u32, u32)> { + // Pixtral has dynamic size based on input + None + } +} + +#[cfg(test)] +mod tests { + use image::{Rgb, RgbImage}; + + use super::*; + + fn create_test_image(width: u32, height: u32) -> DynamicImage { + let mut img = RgbImage::new(width, height); + for y in 0..height { + for x in 0..width { + let r = ((x * 255) / width.max(1)) as u8; + let g = ((y * 255) / height.max(1)) as u8; + let b = (((x + y) * 128) / (width + height).max(1)) as u8; + img.put_pixel(x, y, Rgb([r, g, b])); + } + } + DynamicImage::ImageRgb8(img) + } + + #[test] + fn test_resize_output_size_small_image() { + let processor = PixtralProcessor::new(); + + // Small image that doesn't need resizing - just pad to patch boundary + // 100x100 -> patches: ceil(100/16) = 7, output: 7*16 = 112 + let (h, w) = processor.get_resize_output_size(100, 100); + assert_eq!((h, w), (112, 112)); + } + + #[test] + fn test_resize_output_size_large_image() { + let processor = PixtralProcessor::new(); + + // Large image that needs resizing + // 2048x1024: ratio = 2048/1024 = 2.0 + // scaled: 2048/2 = 1024, 1024/2 = 512 + // patches h: ceil(1024/16) = 64, patches w: ceil(512/16) = 32 + // output: 64*16 = 1024, 32*16 = 512 + let (h, w) = processor.get_resize_output_size(2048, 1024); + assert_eq!((h, w), (1024, 512)); + } + + #[test] + fn test_resize_output_size_at_limit() { + let processor = PixtralProcessor::new(); + + // Image exactly at limit + // 1024x768: ratio = max(1024/1024, 768/1024) = 1.0 + // No resize needed + // patches h: ceil(1024/16) = 64, patches w: ceil(768/16) = 48 + // output: 64*16 = 1024, 48*16 = 768 + let (h, w) = processor.get_resize_output_size(1024, 768); + assert_eq!((h, w), (1024, 768)); + } + + #[test] + fn test_process_single_image() { + let processor = PixtralProcessor::new(); + let image = create_test_image(200, 150); + + let (tensor, size) = processor.process_single_image(&image).unwrap(); + + // 200x150 -> patches h: ceil(150/16) = 10, patches w: ceil(200/16) = 13 + // output: 10*16 = 160, 13*16 = 208 + assert_eq!(size, (160, 208)); + assert_eq!(tensor.shape(), &[1, 3, 160, 208]); + } + + #[test] + fn test_preprocess_batch() { + let processor = PixtralProcessor::new(); + let config = PreProcessorConfig::default(); + + let images = vec![create_test_image(200, 150), create_test_image(300, 100)]; + + let result = processor.preprocess(&images, &config).unwrap(); + + // First image: 150x200 -> 160x208 + // Second image: 100x300 -> 112x304 (ceil(100/16)=7, ceil(300/16)=19) + // Batch padded to max: 160x304 + assert_eq!(result.pixel_values.shape()[0], 2); // batch size + assert_eq!(result.pixel_values.shape()[1], 3); // channels + } + + #[test] + fn test_normalization_values() { + let processor = PixtralProcessor::new(); + + // Verify CLIP normalization values + assert!((processor.image_mean[0] - 0.48145466).abs() < 1e-6); + assert!((processor.image_mean[1] - 0.4578275).abs() < 1e-6); + assert!((processor.image_mean[2] - 0.40821073).abs() < 1e-6); + + assert!((processor.image_std[0] - 0.26862954).abs() < 1e-6); + assert!((processor.image_std[1] - 0.26130258).abs() < 1e-6); + assert!((processor.image_std[2] - 0.27577711).abs() < 1e-6); + } + + #[test] + fn test_from_config() { + let mut size = HashMap::new(); + size.insert("longest_edge".to_string(), 2048u32); + + let config = PreProcessorConfig { + size: Some(size), + patch_size: Some(crate::multimodal::vision::preprocessor_config::PatchSize { + height: Some(14), + width: Some(14), + }), + image_mean: Some(vec![0.5, 0.5, 0.5]), + image_std: Some(vec![0.5, 0.5, 0.5]), + ..Default::default() + }; + + let processor = PixtralProcessor::from_preprocessor_config(&config); + + assert_eq!(processor.longest_edge, 2048); + assert_eq!(processor.patch_size, 14); + assert_eq!(processor.image_mean, [0.5, 0.5, 0.5]); + assert_eq!(processor.image_std, [0.5, 0.5, 0.5]); + } + + #[test] + fn test_calculate_num_tokens() { + let processor = PixtralProcessor::new(); + let config = PreProcessorConfig::default(); + + // 200x150 -> 208x160 -> 13*10 = 130 patches + let tokens = processor.calculate_num_tokens(200, 150, &config); + assert_eq!(tokens, 130); + } +} diff --git a/sgl-router/src/multimodal/vision/processors/qwen2_vl.rs b/sgl-router/src/multimodal/vision/processors/qwen2_vl.rs index 614dc2c4c543..4505af3bd327 100644 --- a/sgl-router/src/multimodal/vision/processors/qwen2_vl.rs +++ b/sgl-router/src/multimodal/vision/processors/qwen2_vl.rs @@ -119,7 +119,7 @@ impl Qwen2VLProcessor { pub fn from_preprocessor_config(config: &PreProcessorConfig) -> Self { Self { inner: QwenVLProcessorBase::new(QwenVLConfig { - patch_size: config.patch_size.unwrap_or(DEFAULT_PATCH_SIZE), + patch_size: config.get_patch_size(DEFAULT_PATCH_SIZE), merge_size: config.merge_size.unwrap_or(DEFAULT_MERGE_SIZE), min_pixels: config.min_pixels.unwrap_or(DEFAULT_MIN_PIXELS), max_pixels: config.max_pixels.unwrap_or(DEFAULT_MAX_PIXELS), @@ -245,7 +245,9 @@ mod tests { use image::{Rgb, RgbImage}; use super::*; - use crate::multimodal::vision::image_processor::ModelSpecificValue; + use crate::multimodal::vision::{ + image_processor::ModelSpecificValue, preprocessor_config::PatchSize, + }; fn create_test_image(width: u32, height: u32, color: Rgb) -> DynamicImage { DynamicImage::from(RgbImage::from_pixel(width, height, color)) @@ -363,7 +365,10 @@ mod tests { do_normalize: Some(true), image_mean: Some(CLIP_MEAN.to_vec()), image_std: Some(CLIP_STD.to_vec()), - patch_size: Some(14), + patch_size: Some(PatchSize { + height: Some(14), + width: Some(14), + }), merge_size: Some(2), min_pixels: Some(DEFAULT_MIN_PIXELS), max_pixels: Some(DEFAULT_MAX_PIXELS), @@ -418,7 +423,10 @@ mod tests { #[test] fn test_qwen2_vl_from_config() { let config = PreProcessorConfig { - patch_size: Some(16), + patch_size: Some(PatchSize { + height: Some(16), + width: Some(16), + }), merge_size: Some(4), min_pixels: Some(100000), max_pixels: Some(500000), diff --git a/sgl-router/src/multimodal/vision/processors/qwen3_vl.rs b/sgl-router/src/multimodal/vision/processors/qwen3_vl.rs index f306027a76b0..c3b0ef319a08 100644 --- a/sgl-router/src/multimodal/vision/processors/qwen3_vl.rs +++ b/sgl-router/src/multimodal/vision/processors/qwen3_vl.rs @@ -120,7 +120,7 @@ impl Qwen3VLProcessor { pub fn from_preprocessor_config(config: &PreProcessorConfig) -> Self { Self { inner: QwenVLProcessorBase::new(QwenVLConfig { - patch_size: config.patch_size.unwrap_or(DEFAULT_PATCH_SIZE), + patch_size: config.get_patch_size(DEFAULT_PATCH_SIZE), merge_size: config.merge_size.unwrap_or(DEFAULT_MERGE_SIZE), min_pixels: config.min_pixels.unwrap_or(DEFAULT_MIN_PIXELS), max_pixels: config.max_pixels.unwrap_or(DEFAULT_MAX_PIXELS), @@ -246,7 +246,9 @@ mod tests { use image::{Rgb, RgbImage}; use super::*; - use crate::multimodal::vision::image_processor::ModelSpecificValue; + use crate::multimodal::vision::{ + image_processor::ModelSpecificValue, preprocessor_config::PatchSize, + }; fn create_test_image(width: u32, height: u32, color: Rgb) -> DynamicImage { DynamicImage::from(RgbImage::from_pixel(width, height, color)) @@ -338,7 +340,10 @@ mod tests { do_normalize: Some(true), image_mean: Some(QWEN3_MEAN.to_vec()), image_std: Some(QWEN3_STD.to_vec()), - patch_size: Some(16), + patch_size: Some(PatchSize { + height: Some(16), + width: Some(16), + }), merge_size: Some(2), min_pixels: Some(DEFAULT_MIN_PIXELS), max_pixels: Some(DEFAULT_MAX_PIXELS), @@ -398,7 +403,10 @@ mod tests { #[test] fn test_qwen3_vl_from_config() { let config = PreProcessorConfig { - patch_size: Some(16), + patch_size: Some(PatchSize { + height: Some(16), + width: Some(16), + }), merge_size: Some(4), min_pixels: Some(100000), max_pixels: Some(500000), diff --git a/sgl-router/src/multimodal/vision/transforms.rs b/sgl-router/src/multimodal/vision/transforms.rs index 61485485d125..60dc3454f939 100644 --- a/sgl-router/src/multimodal/vision/transforms.rs +++ b/sgl-router/src/multimodal/vision/transforms.rs @@ -24,6 +24,9 @@ pub enum TransformError { #[error("Inconsistent tensor shapes in batch")] InconsistentShapes, + + #[error("Shape error: {0}")] + ShapeError(String), } pub type Result = std::result::Result; diff --git a/sgl-router/tests/vision_golden_tests.rs b/sgl-router/tests/vision_golden_tests.rs index eea7781d82c6..18e71eff18a1 100644 --- a/sgl-router/tests/vision_golden_tests.rs +++ b/sgl-router/tests/vision_golden_tests.rs @@ -19,8 +19,8 @@ use std::{fs::File, io::Read, path::Path}; use ndarray::{Array4, Array5}; use sgl_model_gateway::multimodal::vision::{ image_processor::ModelSpecificValue, ImagePreProcessor, Llama4VisionProcessor, LlavaProcessor, - Phi3VisionProcessor, Phi4VisionProcessor, PreProcessorConfig, Qwen2VLProcessor, - Qwen3VLProcessor, + Phi3VisionProcessor, Phi4VisionProcessor, PixtralProcessor, PreProcessorConfig, + Qwen2VLProcessor, Qwen3VLProcessor, }; /// Load a numpy .npz file and extract pixel_values @@ -396,7 +396,7 @@ fn run_qwen2_vl_golden_test(image_name: &str) { // Verify shapes match let expected_num_patches = grid_t * grid_h * grid_w; - let patch_size = config.patch_size.unwrap_or(14); + let patch_size = config.get_patch_size(14); let temporal_patch_size = config.temporal_patch_size.unwrap_or(2); let expected_patch_features = 3 * temporal_patch_size * patch_size * patch_size; @@ -580,7 +580,7 @@ fn run_qwen3_vl_golden_test(image_name: &str) { // Verify shapes match (Qwen3-VL has patch_size=16) let expected_num_patches = grid_t * grid_h * grid_w; - let patch_size = config.patch_size.unwrap_or(16); + let patch_size = config.get_patch_size(16); let temporal_patch_size = config.temporal_patch_size.unwrap_or(2); let expected_patch_features = 3 * temporal_patch_size * patch_size * patch_size; @@ -1388,3 +1388,214 @@ fn test_llama4_vision_golden_odd_dims() { fn test_llama4_vision_golden_grayscale() { run_llama4_vision_golden_test("grayscale"); } + +// ============================================================================ +// Pixtral/Mistral3 Vision tests +// ============================================================================ + +/// Load image_sizes from npz file for Pixtral +fn load_pixtral_image_sizes(path: &Path) -> Vec<(usize, usize)> { + let file = File::open(path).expect("Failed to open golden file"); + let mut npz = npyz::npz::NpzArchive::new(file).expect("Failed to parse npz"); + + let reader = npz + .by_name("image_sizes") + .expect("Failed to read npz") + .expect("No image_sizes"); + + let shape = reader.shape().to_vec(); + + // Read data as i64 vec (numpy default for int) + let data: Vec = reader.into_vec().expect("Failed to read array"); + + // Convert to Vec<(usize, usize)> + let num_images = shape[0] as usize; + (0..num_images) + .map(|i| (data[i * 2] as usize, data[i * 2 + 1] as usize)) + .collect() +} + +/// Run a Pixtral golden test for a specific image. +/// +/// This test validates: +/// 1. Output shape matches (batch, 3, H, W) +/// 2. image_sizes match +/// 3. Pixel values match HuggingFace output +/// 4. Token count is correct +/// +/// Pixtral processing: +/// - Longest edge: 1024 (default) +/// - Patch size: 16 +/// - Normalization: CLIP mean/std +/// - No tiling - single output per image +fn run_pixtral_golden_test(image_name: &str) { + let golden_dir = Path::new("tests/fixtures/golden/pixtral"); + let image_path = Path::new("tests/fixtures/images").join(format!("{}.jpg", image_name)); + + if !golden_dir.exists() || !image_path.exists() { + eprintln!( + "Golden test fixtures for pixtral/{} not found, skipping test", + image_name + ); + eprintln!("Run: python scripts/generate_vision_golden.py --model pixtral"); + return; + } + + let npz_path = golden_dir.join(format!("golden_{}.npz", image_name)); + let config = load_config(&golden_dir.join("preprocessor_config.json")); + + // Load golden values + let golden_pixels = load_golden_npz(&npz_path); + let golden_shape: Vec = golden_pixels.shape().to_vec(); + let golden_image_sizes = load_pixtral_image_sizes(&npz_path); + let golden_num_tokens = load_golden_num_tokens(&npz_path); + + // Process image with our Rust processor + let image = image::open(&image_path).expect("Failed to open image"); + let processor = PixtralProcessor::from_preprocessor_config(&config); + let result = processor + .preprocess(&[image], &config) + .expect("Processing failed"); + + // Check image_sizes from model_specific + let rust_image_sizes: Vec<(usize, usize)> = match result.model_specific.get("image_sizes") { + Some(ModelSpecificValue::IntTensor { data, shape }) => { + let num_images = shape[0]; + (0..num_images) + .map(|i| (data[i * 2] as usize, data[i * 2 + 1] as usize)) + .collect() + } + _ => panic!("Expected image_sizes in model_specific"), + }; + + println!( + "pixtral - {} image - Image sizes: golden={:?}, rust={:?}", + image_name, golden_image_sizes, rust_image_sizes + ); + assert_eq!( + golden_image_sizes, rust_image_sizes, + "image_sizes mismatch for {}", + image_name + ); + + // Check num_tokens + let rust_num_tokens = result.num_img_tokens[0]; + println!( + "pixtral - {} image - Tokens: golden={}, rust={}", + image_name, golden_num_tokens, rust_num_tokens + ); + assert_eq!( + golden_num_tokens, rust_num_tokens, + "num_tokens mismatch for {}", + image_name + ); + + // Check output shape + let rust_shape = result.pixel_values.shape(); + println!( + "pixtral - {} image - Shape: golden={:?}, rust={:?}", + image_name, golden_shape, rust_shape + ); + + // Pixtral outputs [batch, C, H, W] with padding to max size in batch + // Single image should match golden shape exactly + assert_eq!(rust_shape[0], 1, "Expected batch dim to be 1"); + assert_eq!(rust_shape[1], golden_shape[1], "Channel mismatch"); + assert!( + rust_shape[2] >= golden_shape[2], + "Height {} < golden height {}", + rust_shape[2], + golden_shape[2] + ); + assert!( + rust_shape[3] >= golden_shape[3], + "Width {} < golden width {}", + rust_shape[3], + golden_shape[3] + ); + + // Compare pixel values - only compare the actual image region, not padding + let rust_pixels = result.pixel_values_flat(); + let golden_pixels_flat: Vec = golden_pixels.iter().copied().collect(); + + // Calculate indices for the actual image region (not padding) + let h = golden_shape[2]; + let w = golden_shape[3]; + let rust_w = rust_shape[3]; + + let mut max_diff = 0.0f32; + for c in 0..3 { + for y in 0..h { + for x in 0..w { + let golden_idx = c * h * w + y * w + x; + let rust_idx = c * rust_shape[2] * rust_w + y * rust_w + x; + let diff = (rust_pixels[rust_idx] - golden_pixels_flat[golden_idx]).abs(); + max_diff = max_diff.max(diff); + } + } + } + + println!( + "pixtral - {} image - Max pixel diff: {:.6}", + image_name, max_diff + ); + + // Allow tolerance for bicubic interpolation differences between PIL and Rust image library + // Pixtral uses bicubic which has larger differences than bilinear + assert!( + max_diff < 0.06, + "Max pixel difference {} exceeds tolerance 0.06 for {}", + max_diff, + image_name + ); +} + +#[test] +fn test_pixtral_golden_square() { + run_pixtral_golden_test("square"); +} + +#[test] +fn test_pixtral_golden_tall() { + run_pixtral_golden_test("tall"); +} + +#[test] +fn test_pixtral_golden_wide() { + run_pixtral_golden_test("wide"); +} + +#[test] +fn test_pixtral_golden_small() { + run_pixtral_golden_test("small"); +} + +#[test] +fn test_pixtral_golden_tiny() { + run_pixtral_golden_test("tiny"); +} + +#[test] +fn test_pixtral_golden_very_tall() { + run_pixtral_golden_test("very_tall"); +} + +#[test] +fn test_pixtral_golden_very_wide() { + run_pixtral_golden_test("very_wide"); +} + +#[test] +fn test_pixtral_golden_large() { + run_pixtral_golden_test("large"); +} + +#[test] +fn test_pixtral_golden_odd_dims() { + run_pixtral_golden_test("odd_dims"); +} + +#[test] +fn test_pixtral_golden_grayscale() { + run_pixtral_golden_test("grayscale"); +}