Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions sgl-router/scripts/generate_vision_golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
"processor_class": "Llama4ImageProcessorFast",
"description": "Tile-based processing with 336x336 tiles and global tile",
},
"pixtral": {
"model_id": "mistralai/Pixtral-12B-2409",
"processor_class": "PixtralImageProcessor",
"description": "Dynamic resolution with CLIP normalization and bicubic resize",
},
}

# Default test images
Expand Down Expand Up @@ -547,6 +552,64 @@ def generate_golden_llama4_vision(image_path: str, output_dir: str) -> dict:
return result


def generate_golden_pixtral(image_path: str, output_dir: str) -> dict:
"""Generate golden output for Pixtral/Mistral3 Vision.

Pixtral uses dynamic resolution processing:
1. If image exceeds longest_edge (default 1024), scale down proportionally
2. Resize to dimensions that are multiples of patch_size (default 16)
3. Use bicubic interpolation for resize
4. Normalize with CLIP mean/std

Output:
- pixel_values: [1, 3, H, W] where H, W are multiples of patch_size
- image_sizes: [(H, W)]

Token count: (H / patch_size) * (W / patch_size)
"""
from transformers import PixtralImageProcessor

processor = PixtralImageProcessor.from_pretrained("mistral-community/pixtral-12b")
image = Image.open(image_path).convert("RGB")
original_size = image.size

# Process image
outputs = processor(images=image, return_tensors="np")
pixel_values = outputs["pixel_values"]
image_sizes = outputs.get("image_sizes")

result = {
"pixel_values": pixel_values,
"original_size": original_size,
"processor_config": processor.to_dict(),
}

if image_sizes is not None:
result["image_sizes"] = np.array(image_sizes)

# Calculate num_tokens from image_sizes
if image_sizes is not None:
h, w = image_sizes[0]
patch_size = getattr(processor, "patch_size", {"height": 16, "width": 16})
if isinstance(patch_size, dict):
patch_h = patch_size.get("height", 16)
patch_w = patch_size.get("width", 16)
else:
patch_h = patch_w = patch_size
num_tokens = (h // patch_h) * (w // patch_w)
result["num_tokens"] = num_tokens

# Add debug info
result["config_info"] = {
"longest_edge": processor.size.get("longest_edge", 1024),
"patch_size": processor.patch_size,
"image_mean": processor.image_mean,
"image_std": processor.image_std,
}

return result


def generate_for_model(model_key: str, image_paths: list, output_dir: str):
"""Generate golden outputs for a specific model."""
print(f"\nGenerating golden outputs for {model_key}...")
Expand All @@ -560,6 +623,7 @@ def generate_for_model(model_key: str, image_paths: list, output_dir: str):
"phi3_vision": generate_golden_phi3_vision,
"phi4_vision": generate_golden_phi4_vision,
"llama4_vision": generate_golden_llama4_vision,
"pixtral": generate_golden_pixtral,
}.get(model_key)

if generator_fn is None:
Expand Down
2 changes: 1 addition & 1 deletion sgl-router/src/multimodal/vision/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ pub use image_processor::{
pub use preprocessor_config::PreProcessorConfig;
pub use processors::{
Llama4VisionProcessor, LlavaNextProcessor, LlavaProcessor, Phi3VisionProcessor,
Phi4VisionProcessor, Qwen2VLProcessor, Qwen3VLProcessor,
Phi4VisionProcessor, PixtralProcessor, Qwen2VLProcessor, Qwen3VLProcessor,
};
pub use transforms::TransformError;
107 changes: 103 additions & 4 deletions sgl-router/src/multimodal/vision/preprocessor_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,97 @@
use std::collections::HashMap;

use image::imageops::FilterType;
use serde::Deserialize;
use serde::{Deserialize, Deserializer};

use super::transforms;

/// Struct to represent patch_size as dict {"height": x, "width": y}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct PatchSize {
pub height: Option<u32>,
pub width: Option<u32>,
}

/// Custom deserializer for patch_size that handles both integer and dict formats.
/// - Integer format: `"patch_size": 16` -> PatchSize { height: 16, width: 16 }
/// - Dict format: `"patch_size": {"height": 16, "width": 16}` -> PatchSize { height: 16, width: 16 }
fn deserialize_patch_size<'de, D>(deserializer: D) -> Result<Option<PatchSize>, D::Error>
where
D: Deserializer<'de>,
{
use std::fmt;

use serde::de::{self, MapAccess, Visitor};

struct PatchSizeVisitor;

impl<'de> Visitor<'de> for PatchSizeVisitor {
type Value = Option<PatchSize>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("an integer, a dict with height/width, or null")
}

fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}

fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}

fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
let v = value as u32;
Ok(Some(PatchSize {
height: Some(v),
width: Some(v),
}))
}

fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
let v = value as u32;
Ok(Some(PatchSize {
height: Some(v),
width: Some(v),
}))
}

fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
let mut height = None;
let mut width = None;

while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"height" => height = Some(map.next_value::<u32>()?),
"width" => width = Some(map.next_value::<u32>()?),
_ => {
let _ = map.next_value::<de::IgnoredAny>()?;
}
}
}

Ok(Some(PatchSize { height, width }))
}
}

deserializer.deserialize_any(PatchSizeVisitor)
}

/// HuggingFace preprocessor_config.json structure.
///
/// This struct captures the common fields across different vision model processors.
Expand Down Expand Up @@ -73,8 +160,9 @@ pub struct PreProcessorConfig {
// Model-specific fields
// =====================
/// Vision encoder patch size (typically 14 or 16)
#[serde(default)]
pub patch_size: Option<usize>,
/// Can be an integer or a dict {"height": x, "width": y}
#[serde(default, deserialize_with = "deserialize_patch_size")]
pub patch_size: Option<PatchSize>,

/// Qwen-VL: merge size for token reduction
#[serde(default)]
Expand Down Expand Up @@ -151,6 +239,17 @@ impl PreProcessorConfig {
serde_json::from_value(value)
}

/// Get patch size as a simple usize.
///
/// Returns the height value from PatchSize if available, falling back to provided default.
pub fn get_patch_size(&self, default: usize) -> usize {
self.patch_size
.as_ref()
.and_then(|p| p.height)
.map(|h| h as usize)
.unwrap_or(default)
}
Comment on lines +245 to +251
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The get_patch_size method currently only returns the height component of PatchSize. If patch_size can represent non-square dimensions (i.e., height != width), then this method might not provide the full information needed by all consumers. It assumes patch_size is always square or only the height is relevant. Consider clarifying this assumption in the documentation or providing a more comprehensive method if both dimensions are needed for certain models.


/// Get image mean as fixed array, with fallback to CLIP defaults.
pub fn get_image_mean(&self) -> [f64; 3] {
self.image_mean
Expand Down Expand Up @@ -310,7 +409,7 @@ mod tests {

assert_eq!(config.min_pixels, Some(200704));
assert_eq!(config.max_pixels, Some(1003520));
assert_eq!(config.patch_size, Some(14));
assert_eq!(config.get_patch_size(0), 14);
assert_eq!(config.merge_size, Some(2));
assert!((config.get_rescale_factor() - 1.0 / 255.0).abs() < 1e-10);
}
Expand Down
2 changes: 1 addition & 1 deletion sgl-router/src/multimodal/vision/processors/llava.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ impl ImagePreProcessor for LlavaProcessor {
config: &PreProcessorConfig,
) -> usize {
// For LLaVA 1.5, token count is based on processed image size and patch size
let patch_size = config.patch_size.unwrap_or(self.patch_size as usize) as u32;
let patch_size = config.get_patch_size(self.patch_size as usize) as u32;
let image_size = config
.get_target_size()
.map(|(h, _w)| h)
Expand Down
3 changes: 3 additions & 0 deletions sgl-router/src/multimodal/vision/processors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
//! - **Phi3-Vision** (`phi3_vision`): Dynamic HD transform with 336x336 tiles
//! - **Phi4-Vision** (`phi4_vision`): Dynamic HD transform with 448x448 tiles and SiGLIP encoder
//! - **LLaMA 4 Vision** (`llama4_vision`): Tile-based processing with 336x336 tiles and global tile
//! - **Pixtral/Mistral3** (`pixtral`): CLIP-based preprocessing with dynamic resolution

pub mod llama4_vision;
pub mod llava;
pub mod phi3_vision;
pub mod phi4_vision;
pub mod pixtral;
pub mod qwen2_vl;
pub mod qwen3_vl;
pub mod qwen_vl_base;
Expand All @@ -26,5 +28,6 @@ pub use llama4_vision::Llama4VisionProcessor;
pub use llava::{ImageAspectRatio, LlavaNextProcessor, LlavaProcessor};
pub use phi3_vision::Phi3VisionProcessor;
pub use phi4_vision::Phi4VisionProcessor;
pub use pixtral::PixtralProcessor;
pub use qwen2_vl::Qwen2VLProcessor;
pub use qwen3_vl::Qwen3VLProcessor;
Loading
Loading