diff --git a/README.md b/README.md index c8d783f0b4..6ce1a57eb9 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis - Check out UQFF for prequantized models of various methods! - Models can be found [here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c). -- 💎💎💎 Run the **Gemma 3** Model (*text only for now, vision coming very soon!*): +- 💎💎💎 Run the **Gemma 3** Model with 128k context length and vision support: [documentation](docs/GEMMA3.md) ``` ./mistralrs-server -i vision-plain -m google/gemma-3-4b-it -a gemma3 diff --git a/docs/GEMMA3.md b/docs/GEMMA3.md new file mode 100644 index 0000000000..9420f50477 --- /dev/null +++ b/docs/GEMMA3.md @@ -0,0 +1,192 @@ +# Gemma 3 Model: [`google/gemma-3-4b-it`](https://huggingface.co/google/gemma-3-4b-it) + +The Gemma 3 model is a family of multimodal (text+vision) models with 128k context length. The collection can be found [here](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d), with model sizes ranging from 4B to 27B. + +We support the Gemma 3 Model in the Rust, Python, and HTTP APIs, including ISQ for increased performance. + +The Python and HTTP APIs support sending images as: +- URL +- Path to a local image +- [Base64](https://en.wikipedia.org/wiki/Base64) encoded string + +The Rust API takes an image from the [image](https://docs.rs/image/latest/image/index.html) crate. + +## HTTP server +You can find this example [here](../examples/server/gemma3.py). + +We support an OpenAI compatible HTTP API for vision models. This example demonstrates sending a chat completion request with an image. + +> Note: The image_url may be either a path, URL, or a base64 encoded string. + +--- + +**Image:** +Mount Washington +
Credit
+ +**Prompt:** +``` +image shows Mount Washington in New Hampshire, USA. It's a prominent peak in the White Mountains, known for its extreme weather conditions and being the highest peak in the Northeastern United States. The image captures it covered in snow with a dramatic sky above. The structures at the summit are communication towers. + + + +The winding path visible on the mountain slopes appears to be part of the Mount Washington Auto Road, a historic road that allows vehicles to drive to the summit. +``` + +**Output:** +``` +A mountain with snow on it. +``` + +--- + +1) Start the server + +> [!NOTE] +> You should replace `--features ...` with one of the features specified [here](../README.md#supported-accelerators), or remove it for pure CPU inference. + +``` +cargo run --release --features ... -- --port 1234 vision-plain -m google/gemma-3-12b-it -a gemma3 +``` + +2) Send a request + +```py +from openai import OpenAI +import httpx +import textwrap +import json + + +client = OpenAI(api_key="foobar", base_url="http://localhost:1234/v1/") + + +completion = client.chat.completions.create( + model="gemma3", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg" + }, + }, + { + "type": "text", + "text": "What is this?", + }, + ], + }, + ], + max_tokens=256, + frequency_penalty=1.0, + top_p=0.1, + temperature=0, +) +resp = completion.choices[0].message.content +print(resp) + +``` + +- You can find an example of encoding the [image via base64 here](../examples/server/phi3v_base64.py). +- You can find an example of loading an [image locally here](../examples/server/phi3v_local_img.py). + +--- + +## Rust +You can find this example [here](../mistralrs/examples/gemma3/main.rs). + +This is a minimal example of running the Phi 4 Multimodal model with a dummy image. + +```rust +use anyhow::Result; +use mistralrs::{IsqType, TextMessageRole, VisionLoaderType, VisionMessages, VisionModelBuilder}; + +#[tokio::main] +async fn main() -> Result<()> { + let model = + VisionModelBuilder::new("google/gemma-3-12b-it", VisionLoaderType::Gemma3) + .with_isq(IsqType::Q4K) + .with_logging() + .build() + .await?; + + let bytes = match reqwest::blocking::get( + "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg", + ) { + Ok(http_resp) => http_resp.bytes()?.to_vec(), + Err(e) => anyhow::bail!(e), + }; + let image = image::load_from_memory(&bytes)?; + + let messages = VisionMessages::new().add_image_message( + TextMessageRole::User, + "What is depicted here? Please describe the scene in detail.", + image, + &model, + )?; + + let response = model.send_chat_request(messages).await?; + + println!("{}", response.choices[0].message.content.as_ref().unwrap()); + dbg!( + response.usage.avg_prompt_tok_per_sec, + response.usage.avg_compl_tok_per_sec + ); + + Ok(()) +} +``` + +## Python +You can find this example [here](../examples/python/gemma3.py). + +This example demonstrates loading and sending a chat completion request with an image. + +> Note: the image_url may be either a path, URL, or a base64 encoded string. + +```py +from mistralrs import Runner, Which, ChatCompletionRequest, VisionArchitecture + +runner = Runner( + which=Which.VisionPlain( + model_id="google/gemma-3-12b-it", + arch=VisionArchitecture.Gemma3, + ), +) + +res = runner.send_chat_completion_request( + ChatCompletionRequest( + model="gemma3", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg" + }, + }, + { + "type": "text", + "text": "What is this?", + }, + ], + } + ], + max_tokens=256, + presence_penalty=1.0, + top_p=0.1, + temperature=0.1, + ) +) +print(res.choices[0].message.content) +print(res.usage) + +``` + +- You can find an example of encoding the [image via base64 here](../examples/python/phi3v_base64.py). +- You can find an example of loading an [image locally here](../examples/python/phi3v_local_img.py). \ No newline at end of file diff --git a/examples/python/gemma3.py b/examples/python/gemma3.py new file mode 100644 index 0000000000..c52acaa83f --- /dev/null +++ b/examples/python/gemma3.py @@ -0,0 +1,37 @@ +from mistralrs import Runner, Which, ChatCompletionRequest, VisionArchitecture + +runner = Runner( + which=Which.VisionPlain( + model_id="google/gemma-3-12b-it", + arch=VisionArchitecture.Gemma3, + ), +) + +res = runner.send_chat_completion_request( + ChatCompletionRequest( + model="gemma3", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg" + }, + }, + { + "type": "text", + "text": "What is this?", + }, + ], + } + ], + max_tokens=256, + presence_penalty=1.0, + top_p=0.1, + temperature=0.1, + ) +) +print(res.choices[0].message.content) +print(res.usage) diff --git a/examples/server/gemma3.py b/examples/server/gemma3.py new file mode 100644 index 0000000000..b09ac850e7 --- /dev/null +++ b/examples/server/gemma3.py @@ -0,0 +1,63 @@ +from openai import OpenAI +import httpx +import textwrap +import json + + +def log_response(response: httpx.Response): + request = response.request + print(f"Request: {request.method} {request.url}") + print(" Headers:") + for key, value in request.headers.items(): + if key.lower() == "authorization": + value = "[...]" + if key.lower() == "cookie": + value = value.split("=")[0] + "=..." + print(f" {key}: {value}") + print(" Body:") + try: + request_body = json.loads(request.content) + print(textwrap.indent(json.dumps(request_body, indent=2), " ")) + except json.JSONDecodeError: + print(textwrap.indent(request.content.decode(), " ")) + print(f"Response: status_code={response.status_code}") + print(" Headers:") + for key, value in response.headers.items(): + if key.lower() == "set-cookie": + value = value.split("=")[0] + "=..." + print(f" {key}: {value}") + + +client = OpenAI(api_key="foobar", base_url="http://localhost:1234/v1/") + +# Enable this to log requests and responses +# client._client = httpx.Client( +# event_hooks={"request": [print], "response": [log_response]} +# ) + +completion = client.chat.completions.create( + model="gemma3", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg" + }, + }, + { + "type": "text", + "text": "What is this?", + }, + ], + }, + ], + max_tokens=256, + frequency_penalty=1.0, + top_p=0.1, + temperature=0, +) +resp = completion.choices[0].message.content +print(resp) diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs index 057636b0c0..858b24d493 100644 --- a/mistralrs-core/src/attention.rs +++ b/mistralrs-core/src/attention.rs @@ -226,7 +226,7 @@ fn naive_sdpa( candle_nn::ops::inplace_attn_softmax_last_dim( &mut att, - &mask, + &mask.contiguous()?, sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0), )?; diff --git a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs index 2ed518bfed..55e1528a78 100644 --- a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs @@ -3134,11 +3134,11 @@ impl VisionModelLoader for Gemma3Loader { fn get_processor( &self, _model_config: &str, - _processor_config: Option, + processor_config: Option, _preprocessor_config: PreProcessorConfig, _max_edge: Option, ) -> Arc { - Arc::new(Gemma3Processor) + Arc::new(Gemma3Processor::new(processor_config.unwrap())) } fn supports_paged_attention(&self) -> bool { true diff --git a/mistralrs-core/src/vision_models/gemma3/config.rs b/mistralrs-core/src/vision_models/gemma3/config.rs index 6bdd5508f0..11c6761f89 100644 --- a/mistralrs-core/src/vision_models/gemma3/config.rs +++ b/mistralrs-core/src/vision_models/gemma3/config.rs @@ -3,6 +3,7 @@ use mistralrs_quant::QuantizedConfig; use crate::{ layers::{Activation, Gemma3RopeScalingConfig}, serde_default_fn, + vision_models::siglip::SiglipVisionConfig, }; serde_default_fn!(bool, attention_bias, false); @@ -63,4 +64,7 @@ pub struct Gemma3TextConfig { #[derive(Debug, Clone, serde::Deserialize)] pub struct Gemma3Config { pub text_config: Gemma3TextConfig, + pub vision_config: SiglipVisionConfig, + pub image_token_index: usize, + pub mm_tokens_per_image: usize, } diff --git a/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs b/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs index c9984d6dfa..6a9486e9c2 100644 --- a/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs @@ -2,8 +2,11 @@ use std::{any::Any, num::NonZeroUsize, sync::Arc}; -use candle_core::{Device, Result}; -use image::DynamicImage; +use candle_core::{Device, Result, Tensor}; +use image::{DynamicImage, GenericImageView}; +use itertools::Itertools; +use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms}; +use regex::Regex; use tokenizers::Tokenizer; use tracing::warn; @@ -18,28 +21,51 @@ use crate::{ sequence::Sequence, vision_models::{ image_processor::{ImagePreProcessor, PreprocessedImages}, - preprocessor_config::PreProcessorConfig, + preprocessor_config::{PreProcessorConfig, ToFilter}, + processor_config::ProcessorConfig, ModelInputs, }, }; use super::Gemma3SpecificArgs; -struct Gemma3ImageProcessor; +struct Gemma3ImageProcessor { + full_image_sequence: String, +} + +const IMAGE_TOKEN: &str = ""; +const BOI_TOKEN: &str = ""; +const EOI_TOKEN: &str = ""; -pub struct Gemma3Processor; +pub struct Gemma3Processor { + full_image_sequence: String, +} + +impl Gemma3Processor { + pub fn new(processor_config: ProcessorConfig) -> Self { + let image_tokens_expanded = + vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join(""); + let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n"); + + Self { + full_image_sequence, + } + } +} impl Processor for Gemma3Processor { fn inputs_processor(&self) -> Arc { - Arc::new(Gemma3ImageProcessor) + Arc::new(Gemma3ImageProcessor { + full_image_sequence: self.full_image_sequence.clone(), + }) } fn get_special_tokens(&self) -> &[&'static str] { - &[] + &[BOI_TOKEN, EOI_TOKEN, IMAGE_TOKEN] } fn template_action(&self) -> MessagesAction { - MessagesAction::FlattenOnlyText + MessagesAction::Keep } } @@ -49,7 +75,7 @@ impl InputsProcessor for Gemma3ImageProcessor { } fn process_inputs( &self, - _tokenizer: Option>, + tokenizer: Option>, input_seqs: &mut [&mut Sequence], is_prompt: bool, is_xlora: bool, @@ -57,7 +83,7 @@ impl InputsProcessor for Gemma3ImageProcessor { no_kv_cache: bool, last_n_context_len: Option<(usize, usize)>, return_raw_logits: bool, - _other_config: Option>, + other_config: Option>, mut paged_attn_metadata: Option>, prompt_chunksize: Option, mapper: Option<&dyn DeviceMapper>, @@ -76,6 +102,11 @@ impl InputsProcessor for Gemma3ImageProcessor { if prompt_chunksize.is_some() { warn!("`prompt_chunksize` is set. Gemma3 does not support prompt batching."); } + let Some(tokenizer) = tokenizer else { + return Box::new(std::iter::once(Err(anyhow::Error::msg( + "Idefics3ImageProcessor requires a specified tokenizer.", + )))); + }; let text_models_inputs_processor::InnerInputProcessorOutput { inputs: @@ -125,12 +156,112 @@ impl InputsProcessor for Gemma3ImageProcessor { .unwrap() }; + let config = other_config.expect("Need a PreProcessorConfig config."); + let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed."); + + let has_images = input_seqs + .iter() + .all(|seq| seq.images().is_some_and(|images| !images.is_empty())); + + let (new_input, pixel_values) = if has_images { + let mut pixel_values_accum = Vec::new(); + let mut all_ids = Vec::new(); + let re = Regex::new(BOI_TOKEN).unwrap(); + for seq in input_seqs.iter_mut() { + let PreprocessedImages { + pixel_values, + pixel_attention_mask: _, + image_sizes: _, + num_img_tokens: _, + aspect_ratio_ids: _, + aspect_ratio_mask: _, + num_tiles: _, + image_grid_thw: _, + video_grid_thw: _, + rows: _, + cols: _, + pixel_values_list: _, + tgt_sizes: _, + image_sizes_all: _, + num_crops, + } = self + .preprocess( + seq.take_images() + .expect("Need to have images by this point."), + vec![], + config, + device, + (usize::MAX, usize::MAX), // Don't use it here... + ) + .expect("Preprocessing failed"); + + let num_crops = num_crops.unwrap(); + + // Deliberately no .unsqueeze here + pixel_values_accum.push(pixel_values.clone()); + + let mut prompt = tokenizer + .decode(seq.get_toks(), false) + .expect("Detokenization failed!"); + + let image_indexes: Vec = + re.find_iter(&prompt).map(|mat| mat.start()).collect(); + + assert_ne!(pixel_values.dim(0).unwrap(), image_indexes.len()); + + for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() { + if num != 0 { + let formatted_image_text = format!( + "Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ") + ); + prompt = format!( + "{}{formatted_image_text}{}", + &prompt[..idx], + &prompt[idx + BOI_TOKEN.len()..] + ); + } + } + + prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence); + + seq.set_initial_prompt(prompt.clone()); + let toks = tokenizer + .encode(prompt, false) + .expect("Detokenization failed!"); + + let ids = toks.get_ids().to_vec(); + all_ids.push(ids.clone()); + + seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut()); + } + + let mut all_ids_new = Vec::new(); + let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap(); + for ids in all_ids { + let pad = max_len - ids.len(); + all_ids_new + .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap()); + } + + ( + Some(Tensor::stack(&all_ids_new, 0).unwrap()), + Some(Tensor::cat(&pixel_values_accum, 0).unwrap()), + ) + } else { + (None, None) + }; + + let input = match new_input { + Some(new_input) => new_input, + None => input, + }; + let inputs: Box = Box::new(ModelInputs { input_ids: input, seqlen_offsets: positions, context_lens, position_ids, - pixel_values: None, + pixel_values, model_specific_args: Box::new(Gemma3SpecificArgs), paged_attn_meta, flash_meta, @@ -142,18 +273,193 @@ impl InputsProcessor for Gemma3ImageProcessor { } } +impl Gemma3ImageProcessor { + fn pan_and_scan( + &self, + image: &DynamicImage, + pan_and_scan_min_crop_size: usize, + pan_and_scan_max_num_crops: usize, + pan_and_scan_min_ratio_to_activate: f64, + ) -> Vec { + let (width, height) = image.dimensions(); + + let (num_crops_w, num_crops_h) = if width >= height { + if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate { + return vec![]; + } + + // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + let mut num_crops_w = (width as f64 / height as f64 + 0.5).floor() as usize; + num_crops_w = num_crops_w + .min((width as f64 / pan_and_scan_min_crop_size as f64).floor() as usize); + + // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_w = num_crops_w.max(2); + num_crops_w = num_crops_w.min(pan_and_scan_max_num_crops); + + (num_crops_w, 1) + } else { + if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate { + return vec![]; + } + + // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size. + let mut num_crops_h = (height as f64 / width as f64 + 0.5).floor() as usize; + num_crops_h = num_crops_h + .min((height as f64 / pan_and_scan_min_crop_size as f64).floor() as usize); + + // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops]. + num_crops_h = num_crops_h.max(2); + num_crops_h = num_crops_h.min(pan_and_scan_max_num_crops); + + (1, num_crops_h) + }; + + let crop_size_w = (width as f64 / num_crops_w as f64).ceil() as usize; + let crop_size_h = (height as f64 / num_crops_h as f64).ceil() as usize; + + if crop_size_w.min(crop_size_h) < pan_and_scan_min_crop_size { + return vec![]; + } + + let crop_positions_w = (0..num_crops_w) + .map(|i| i * crop_size_w) + .collect::>(); + let crop_positions_h = (0..num_crops_h) + .map(|i| i * crop_size_h) + .collect::>(); + + let mut image_crops = Vec::new(); + for (pos_h, pos_w) in crop_positions_h + .into_iter() + .cartesian_product(crop_positions_w) + { + image_crops.push(image.crop_imm( + pos_w as u32, + pos_h as u32, + crop_size_w as u32, + crop_size_h as u32, + )); + } + + image_crops + } + + fn process_images_for_pan_and_scan( + &self, + images: Vec, + pan_and_scan_min_crop_size: usize, + pan_and_scan_max_num_crops: usize, + pan_and_scan_min_ratio_to_activate: f64, + ) -> (Vec, Vec) { + let mut pas_images_list = Vec::new(); + let mut num_crops = Vec::new(); + + for image in images { + let pas_images = self.pan_and_scan( + &image, + pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate, + ); + num_crops.push(pas_images.len()); + pas_images_list.extend([vec![image], pas_images].concat()); + } + + (pas_images_list, num_crops) + } +} + impl ImagePreProcessor for Gemma3ImageProcessor { const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5]; const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5]; fn preprocess( &self, - _images: Vec, - _videos: Vec>, - _config: &PreProcessorConfig, - _device: &Device, + mut images: Vec, + videos: Vec>, + config: &PreProcessorConfig, + device: &Device, (_bs, _max_num_images): (usize, usize), ) -> Result { - todo!() + assert!(videos.is_empty()); + + let do_resize = config.do_resize.unwrap(); + let size = config.size.as_ref().unwrap(); + let (height, width) = (size["height"], size["width"]); + let resample = config.resampling.to_filter()?; + let do_rescale = config.do_rescale.unwrap(); + let rescale_factor = config.rescale_factor.unwrap(); + let do_normalize = config.do_normalize.unwrap(); + let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN); + let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD); + let do_convert_rgb = config.do_convert_rgb.unwrap_or(true); + let do_pan_and_scan = config.do_pan_and_scan.unwrap_or(do_convert_rgb); + // https://github.com/huggingface/transformers/blob/ea219ed164bead55a5513e8cfaa17a25d5613b9e/src/transformers/models/gemma3/processing_gemma3.py#L42 + let pan_and_scan_min_crop_size = config.pan_and_scan_min_crop_size.unwrap_or(256); + let pan_and_scan_max_num_crops = config.pan_and_scan_max_num_crops.unwrap_or(4); + let pan_and_scan_min_ratio_to_activate = + config.pan_and_scan_min_ratio_to_activate.unwrap_or(1.2); + + for image in images.iter_mut() { + // Convert to rgb + if config.do_convert_rgb.is_some_and(|x| x) { + *image = DynamicImage::ImageRgb8(image.to_rgb8()); + } + } + + let num_crops = if do_pan_and_scan { + let (new_images, num_crops) = self.process_images_for_pan_and_scan( + images, + pan_and_scan_min_crop_size, + pan_and_scan_max_num_crops, + pan_and_scan_min_ratio_to_activate, + ); + images = new_images; + num_crops + } else { + vec![0] + }; + + let mut pixel_values = Vec::new(); + for mut image in images { + if do_resize { + image = image.resize_exact(width, height, resample); + } + + let transforms = Transforms { + input: &ToTensorNoNorm, + inner_transforms: &[ + &do_rescale.then_some(Rescale { + factor: Some(rescale_factor), + }), + &do_normalize.then(|| Normalize { + mean: image_mean.to_vec(), + std: image_std.to_vec(), + }), + ], + }; + + let image = image.apply(transforms, device)?; + pixel_values.push(image.unsqueeze(0)?); + } + + Ok(PreprocessedImages { + pixel_values: Tensor::cat(&pixel_values, 0)?, + pixel_attention_mask: None, + image_sizes: None, + num_img_tokens: None, + aspect_ratio_ids: None, + aspect_ratio_mask: None, + num_tiles: None, + image_grid_thw: None, + video_grid_thw: None, + rows: None, + cols: None, + pixel_values_list: None, + tgt_sizes: None, + image_sizes_all: None, + num_crops: Some(num_crops), + }) } } diff --git a/mistralrs-core/src/vision_models/gemma3/mmproj.rs b/mistralrs-core/src/vision_models/gemma3/mmproj.rs new file mode 100644 index 0000000000..f5e9a7cda5 --- /dev/null +++ b/mistralrs-core/src/vision_models/gemma3/mmproj.rs @@ -0,0 +1,77 @@ +use candle_core::{Result, Tensor}; +use candle_nn::Module; +use mistralrs_quant::ShardedVarBuilder; + +use crate::{ + layers::{AvgPool2d, RmsNorm}, + utils::unvarbuilder::UnVarBuilder, +}; + +use super::config::Gemma3Config; + +pub struct Gemma3MultiModalProjector { + mm_input_projection_weight: Tensor, + mm_soft_emb_norm: RmsNorm, + patches_per_image: usize, + avg_pool: AvgPool2d, +} + +impl Gemma3MultiModalProjector { + pub fn new(cfg: &Gemma3Config, vb: ShardedVarBuilder) -> Result { + let mm_input_projection_weight = vb.get( + (cfg.vision_config.hidden_size, cfg.text_config.hidden_size), + "mm_input_projection_weight", + )?; + let mm_soft_emb_norm = RmsNorm::new_gemma( + cfg.vision_config.hidden_size, + cfg.vision_config.layer_norm_eps, + vb.pp("mm_soft_emb_norm"), + )?; + + let patches_per_image = cfg.vision_config.image_size / cfg.vision_config.patch_size; + let tokens_per_side = cfg.mm_tokens_per_image.isqrt(); + let kernel_size = patches_per_image / tokens_per_side; + let avg_pool = AvgPool2d::new(kernel_size, kernel_size); + + Ok(Self { + mm_input_projection_weight, + mm_soft_emb_norm, + patches_per_image, + avg_pool, + }) + } + + pub fn forward(&self, xs: &Tensor) -> Result { + let (bs, _, seqlen) = xs.dims3()?; + + let mut reshaped_vision_outputs = xs.transpose(1, 2)?; + reshaped_vision_outputs = reshaped_vision_outputs.reshape(( + bs, + seqlen, + self.patches_per_image, + self.patches_per_image, + ))?; + reshaped_vision_outputs = reshaped_vision_outputs.contiguous()?; + + let mut pooled_vision_outputs = self.avg_pool.forward(&reshaped_vision_outputs)?; + pooled_vision_outputs = pooled_vision_outputs.flatten_from(2)?; + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)?; + + let normed_vision_outputs = self.mm_soft_emb_norm.forward(&pooled_vision_outputs)?; + + normed_vision_outputs.broadcast_matmul(&self.mm_input_projection_weight) + } + + pub fn residual_tensors(&self) -> Vec<(String, Tensor)> { + let uvb = UnVarBuilder::new(); + + uvb.add_tensor( + "mm_input_projection_weight", + self.mm_input_projection_weight.clone(), + ); + uvb.pp("mm_soft_emb_norm") + .add(&self.mm_soft_emb_norm.undo_gemma().unwrap()); + + uvb.to_safetensors() + } +} diff --git a/mistralrs-core/src/vision_models/gemma3/mod.rs b/mistralrs-core/src/vision_models/gemma3/mod.rs index 6bc644cf57..6dbb88eb12 100644 --- a/mistralrs-core/src/vision_models/gemma3/mod.rs +++ b/mistralrs-core/src/vision_models/gemma3/mod.rs @@ -2,29 +2,38 @@ use std::sync::Arc; -use candle_core::{Device, Result, Tensor}; +use candle_core::{DType, Device, Result, Tensor, D}; use config::Gemma3Config; use mistralrs_quant::{QuantMethod, ShardedVarBuilder}; +use mmproj::Gemma3MultiModalProjector; use text::TextModel; use crate::{ amoe::{AnyMoeBaseModelMixin, MlpLayer}, device_map::DeviceMapper, + ops::NonZeroOp, paged_attention::{AttentionImplementation, ModelConfigMetadata}, pipeline::{ text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata}, EitherCache, IsqModel, NormalLoadingMetadata, VisionModel, }, + utils::unvarbuilder::UnVarBuilder, AnyMoeConfig, AnyMoeExpertType, }; pub mod config; mod inputs_processor; +mod mmproj; mod text; pub(crate) use inputs_processor::Gemma3Processor; +use super::siglip::SiglipVisionTransformer; + pub struct Gemma3Model { language_model: TextModel, + multi_modal_projector: Gemma3MultiModalProjector, + vision_tower: SiglipVisionTransformer, + cfg: Gemma3Config, } impl Gemma3Model { @@ -35,6 +44,7 @@ impl Gemma3Model { normal_loading_metadata: NormalLoadingMetadata, attention_mechanism: AttentionImplementation, ) -> Result { + assert!(cfg.image_token_index < cfg.text_config.vocab_size); Ok(Self { language_model: TextModel::new( &cfg.text_config, @@ -43,19 +53,55 @@ impl Gemma3Model { normal_loading_metadata, attention_mechanism, )?, + multi_modal_projector: Gemma3MultiModalProjector::new( + cfg, + vb.pp("multi_modal_projector"), + )?, + vision_tower: SiglipVisionTransformer::new( + &cfg.vision_config, + vb.pp("vision_tower").pp("vision_model"), + )?, + cfg: cfg.clone(), }) } fn forward( &self, input_ids: &Tensor, + pixel_values: Option, seqlen_offsets: &[usize], context_lens: Vec<(usize, usize)>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>, flash_params: &FlashParams, ) -> Result { - self.language_model.forward( + let mut input_embeds = self.language_model.embed_tokens(input_ids)?; + if let Some(pixel_values) = pixel_values { + let dtype = self.vision_tower.dtype(); + let vision_outputs = + self.vision_tower + .forward(&pixel_values.to_dtype(dtype)?, None, None)?; + let image_features = self.multi_modal_projector.forward(&vision_outputs)?; + + let special_image_mask = input_ids + .eq(self.cfg.image_token_index as f64)? + .unsqueeze(D::Minus1)? + .broadcast_as(input_embeds.shape())? + .to_dtype(DType::U32)?; + + let mask_flat = special_image_mask.flatten_all()?; + let mut x_flat = input_embeds.flatten_all()?; + let src_flat = image_features.flatten_all()?; + + let indices = mask_flat.nonzero()?.squeeze(1)?; + let current_vals = x_flat.gather(&indices, 0)?; + let diff = (src_flat - current_vals)?; + x_flat = x_flat.scatter_add(&indices, &diff, 0)?; + + input_embeds = x_flat.reshape(input_embeds.shape())?; + }; + self.language_model.forward_embeds( input_ids, + input_embeds, seqlen_offsets, context_lens, metadata, @@ -75,7 +121,17 @@ impl IsqModel for Gemma3Model { } fn residual_tensors(&self) -> Vec<(String, Tensor)> { - self.language_model.residual_tensors() + let uvb = UnVarBuilder::new(); + + uvb.pp("multi_modal_projector") + .extend(self.multi_modal_projector.residual_tensors()); + uvb.pp("language_model") + .extend(self.language_model.residual_tensors()); + uvb.pp("vision_tower") + .pp("vision_model") + .extend(self.vision_tower.residual_tensors()); + + uvb.to_safetensors() } fn imatrix_names(&self) -> candle_core::Result>> { @@ -89,7 +145,7 @@ impl VisionModel for Gemma3Model { fn forward( &self, input_ids: &Tensor, - _pixel_values: Option, + pixel_values: Option, seqlen_offsets: &[usize], context_lens: Vec<(usize, usize)>, _position_ids: Vec, @@ -99,6 +155,7 @@ impl VisionModel for Gemma3Model { ) -> candle_core::Result { self.forward( input_ids, + pixel_values, seqlen_offsets, context_lens, metadata, diff --git a/mistralrs-core/src/vision_models/gemma3/text.rs b/mistralrs-core/src/vision_models/gemma3/text.rs index bf11d4c55d..e081ba3291 100644 --- a/mistralrs-core/src/vision_models/gemma3/text.rs +++ b/mistralrs-core/src/vision_models/gemma3/text.rs @@ -536,15 +536,19 @@ impl TextModel { }) } - pub fn forward( + pub fn embed_tokens(&self, input_ids: &Tensor) -> Result { + self.embed_tokens.forward(input_ids) + } + + pub fn forward_embeds( &self, input_ids: &Tensor, + mut xs: Tensor, seqlen_offsets: &[usize], context_lens: Vec<(usize, usize)>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>, flash_params: &FlashParams, ) -> Result { - let mut xs = self.embed_tokens.forward(input_ids)?; let cache = &mut self.cache.normal().0; let attention_mask = CausalMasker.make_causal_mask_matrix( input_ids, diff --git a/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs b/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs index 1e951c5be2..428ce544bb 100644 --- a/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs +++ b/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs @@ -232,6 +232,7 @@ impl InputsProcessor for Idefics2ImageProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all: _, + num_crops: _, } = self .preprocess( seq.take_images() @@ -402,6 +403,7 @@ impl ImagePreProcessor for Idefics2ImageProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }) } } diff --git a/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs b/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs index d8889cebce..1a4a880123 100644 --- a/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs @@ -209,6 +209,7 @@ impl InputsProcessor for Idefics3ImageProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all: _, + num_crops: _, } = self .preprocess( seq.take_images() @@ -245,7 +246,7 @@ impl InputsProcessor for Idefics3ImageProcessor { seq.set_initial_prompt(sample.clone()); let toks = tokenizer - .encode(sample, true) + .encode(sample, false) .expect("Detokenization failed!"); let ids = toks.get_ids().to_vec(); @@ -599,6 +600,7 @@ impl ImagePreProcessor for Idefics3ImageProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }) } } diff --git a/mistralrs-core/src/vision_models/image_processor.rs b/mistralrs-core/src/vision_models/image_processor.rs index dd5b3d2eb5..c672e32272 100644 --- a/mistralrs-core/src/vision_models/image_processor.rs +++ b/mistralrs-core/src/vision_models/image_processor.rs @@ -36,6 +36,8 @@ pub(crate) struct PreprocessedImages { pub(crate) tgt_sizes: Option, /// Without batch size. Per image. (w,h). pub(crate) image_sizes_all: Option>, + /// Without batch size + pub(crate) num_crops: Option>, } /// ImagePreProcessor: process images for the model (similar to `InputsProcessor`, typically called by it) diff --git a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs index dc9fa787a6..e90aefb7ec 100644 --- a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs @@ -141,6 +141,7 @@ impl InputsProcessor for LLaVAInputProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all: _, + num_crops: _, } = self .preprocess( imgs.clone(), @@ -238,7 +239,7 @@ impl InputsProcessor for LLaVAInputProcessor { .map(|s| { // we don't use encode_batch here, because encode_batch will pad 0 to the end of the shor sequences, which will cause the image_ids_pad to be wrong. tokenizer - .encode(*s, true) + .encode(*s, false) .unwrap() .get_ids() .to_vec() @@ -398,6 +399,7 @@ impl ImagePreProcessor for LLaVAInputProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }) } } diff --git a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs index 9e6a1195e8..a3c1db2d65 100644 --- a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs @@ -154,6 +154,7 @@ impl InputsProcessor for LLaVANextInputProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all: _, + num_crops: _, } = self .preprocess( imgs.clone(), @@ -282,7 +283,7 @@ impl InputsProcessor for LLaVANextInputProcessor { .map(|s| { // we don't use encode_batch here, because encode_batch will pad 0 to the end of the shor sequences, which will cause the image_ids_pad to be wrong. tokenizer - .encode(*s, true) + .encode(*s, false) .unwrap() .get_ids() .to_vec() @@ -458,6 +459,7 @@ impl ImagePreProcessor for LLaVANextInputProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }) } } diff --git a/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs b/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs index e354381807..74f2f10c24 100644 --- a/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs @@ -208,6 +208,7 @@ impl InputsProcessor for MiniCpmOImageProcessor { pixel_values_list, tgt_sizes, image_sizes_all, + num_crops: _, } = self .preprocess( seq.take_images() @@ -274,7 +275,7 @@ impl InputsProcessor for MiniCpmOImageProcessor { .im_start_token .clone() .unwrap_or(DEFAULT_IM_START_TOKEN.to_string()), - true, + false, ) .unwrap() .get_ids()[0]; @@ -284,7 +285,7 @@ impl InputsProcessor for MiniCpmOImageProcessor { .im_end_token .clone() .unwrap_or(DEFAULT_IM_END_TOKEN.to_string()), - true, + false, ) .unwrap() .get_ids()[0]; @@ -294,7 +295,7 @@ impl InputsProcessor for MiniCpmOImageProcessor { .slice_start_token .clone() .unwrap_or(DEFAULT_SLICE_START_TOKEN.to_string()), - true, + false, ) .unwrap() .get_ids()[0]; @@ -304,13 +305,13 @@ impl InputsProcessor for MiniCpmOImageProcessor { .slice_end_token .clone() .unwrap_or(DEFAULT_SLICE_END_TOKEN.to_string()), - true, + false, ) .unwrap() .get_ids()[0]; let input_ids = tokenizer - .encode(final_text, true) + .encode(final_text, false) .unwrap() .get_ids() .to_vec(); @@ -779,6 +780,7 @@ impl ImagePreProcessor for MiniCpmOImageProcessor { pixel_values_list: Some(pixel_values), tgt_sizes: Some(tgt_sizes), image_sizes_all: Some(image_sizes), + num_crops: None, }) } } diff --git a/mistralrs-core/src/vision_models/mllama/inputs_processor.rs b/mistralrs-core/src/vision_models/mllama/inputs_processor.rs index a7259ac6d8..712fd82685 100644 --- a/mistralrs-core/src/vision_models/mllama/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/mllama/inputs_processor.rs @@ -310,6 +310,7 @@ impl InputsProcessor for MLlamaImageProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all: _, + num_crops: _, } = self .preprocess( seq.take_images() @@ -828,6 +829,7 @@ impl ImagePreProcessor for MLlamaImageProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }) } } diff --git a/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs b/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs index 15e9552c3a..62b29baaee 100644 --- a/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs +++ b/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs @@ -139,6 +139,7 @@ impl InputsProcessor for Phi3InputsProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all: _, + num_crops: _, } = self .preprocess( imgs, @@ -567,6 +568,7 @@ impl ImagePreProcessor for Phi3InputsProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }) } } diff --git a/mistralrs-core/src/vision_models/phi4/inputs_processor.rs b/mistralrs-core/src/vision_models/phi4/inputs_processor.rs index 0877187bd8..2110231583 100644 --- a/mistralrs-core/src/vision_models/phi4/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/phi4/inputs_processor.rs @@ -137,6 +137,7 @@ impl InputsProcessor for Phi4MMInputsProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all, + num_crops: _, } = self .preprocess( imgs, @@ -243,7 +244,7 @@ impl InputsProcessor for Phi4MMInputsProcessor { seq.set_toks_and_reallocate( tokenizer - .encode(detokenized.clone(), true) + .encode(detokenized.clone(), false) .expect("Encode failed") .get_ids() .to_vec(), @@ -659,6 +660,7 @@ impl ImagePreProcessor for Phi4MMInputsProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: Some(image_sizes), + num_crops: None, }) } } diff --git a/mistralrs-core/src/vision_models/preprocessor_config.rs b/mistralrs-core/src/vision_models/preprocessor_config.rs index a2d577d659..a3ee94e688 100644 --- a/mistralrs-core/src/vision_models/preprocessor_config.rs +++ b/mistralrs-core/src/vision_models/preprocessor_config.rs @@ -19,6 +19,7 @@ pub struct PreProcessorConfig { #[serde(alias = "norm_std")] pub(crate) image_std: Option<[f64; 3]>, pub(crate) rescale_factor: Option, + #[serde(alias = "resample")] pub(crate) resampling: Option, pub(crate) max_image_size: Option>, pub(crate) size: Option>, @@ -44,6 +45,12 @@ pub struct PreProcessorConfig { pub(crate) im_id_start: Option, pub(crate) im_id_end: Option, pub(crate) dynamic_hd: Option, + #[serde(alias = "image_seq_length")] + pub(crate) image_seq_len: Option, + pub(crate) pan_and_scan_min_crop_size: Option, + pub(crate) pan_and_scan_max_num_crops: Option, + pub(crate) pan_and_scan_min_ratio_to_activate: Option, + pub(crate) do_pan_and_scan: Option, } #[allow(dead_code)] diff --git a/mistralrs-core/src/vision_models/processor_config.rs b/mistralrs-core/src/vision_models/processor_config.rs index fba572b558..0e05d7841f 100644 --- a/mistralrs-core/src/vision_models/processor_config.rs +++ b/mistralrs-core/src/vision_models/processor_config.rs @@ -4,5 +4,6 @@ use serde::Deserialize; #[derive(Deserialize, Debug, Default)] pub struct ProcessorConfig { pub(crate) chat_template: Option, + #[serde(alias = "image_seq_length")] pub(crate) image_seq_len: Option, } diff --git a/mistralrs-core/src/vision_models/qwen2_5_vl/inputs_processor.rs b/mistralrs-core/src/vision_models/qwen2_5_vl/inputs_processor.rs index 17fd0aacff..70024de4ad 100644 --- a/mistralrs-core/src/vision_models/qwen2_5_vl/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/qwen2_5_vl/inputs_processor.rs @@ -215,6 +215,7 @@ impl InputsProcessor for Qwen2_5VLImageProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all: _, + num_crops: _, } = self .preprocess( seq.clone_images() @@ -326,7 +327,7 @@ impl InputsProcessor for Qwen2_5VLImageProcessor { seq.set_initial_prompt(detok.clone()); let toks = tokenizer - .encode(detok, true) + .encode(detok, false) .expect("Detokenization failed!"); let ids = toks.get_ids().to_vec(); @@ -378,7 +379,7 @@ impl InputsProcessor for Qwen2_5VLImageProcessor { ); let ids = tokenizer - .encode(prompt, true) + .encode(prompt, false) .expect("Tokenization failed!"); input_ids_searching.push(ids.get_ids().to_vec()); @@ -696,6 +697,7 @@ impl ImagePreProcessor for Qwen2_5VLImageProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }); } @@ -735,6 +737,7 @@ impl ImagePreProcessor for Qwen2_5VLImageProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }); } unreachable!() diff --git a/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs b/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs index ea2eaab205..4a1e0d1ef7 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs @@ -264,6 +264,7 @@ impl InputsProcessor for Qwen2VLImageProcessor { pixel_values_list: _, tgt_sizes: _, image_sizes_all: _, + num_crops: _, } = self .preprocess( seq.clone_images() @@ -371,7 +372,7 @@ impl InputsProcessor for Qwen2VLImageProcessor { seq.set_initial_prompt(detok.clone()); let toks = tokenizer - .encode(detok, true) + .encode(detok, false) .expect("Detokenization failed!"); let ids = toks.get_ids().to_vec(); @@ -422,7 +423,7 @@ impl InputsProcessor for Qwen2VLImageProcessor { ); let ids = tokenizer - .encode(prompt, true) + .encode(prompt, false) .expect("Tokenization failed!"); input_ids_searching.push(ids.get_ids().to_vec()); @@ -693,6 +694,7 @@ impl ImagePreProcessor for Qwen2VLImageProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }); } @@ -732,6 +734,7 @@ impl ImagePreProcessor for Qwen2VLImageProcessor { pixel_values_list: None, tgt_sizes: None, image_sizes_all: None, + num_crops: None, }); } unreachable!() diff --git a/mistralrs/examples/gemma3/main.rs b/mistralrs/examples/gemma3/main.rs new file mode 100644 index 0000000000..83a644d8b7 --- /dev/null +++ b/mistralrs/examples/gemma3/main.rs @@ -0,0 +1,36 @@ +use anyhow::Result; +use mistralrs::{IsqType, TextMessageRole, VisionLoaderType, VisionMessages, VisionModelBuilder}; + +#[tokio::main] +async fn main() -> Result<()> { + let model = VisionModelBuilder::new("google/gemma-3-12b-it", VisionLoaderType::Gemma3) + .with_isq(IsqType::Q4K) + .with_logging() + .build() + .await?; + + let bytes = match reqwest::blocking::get( + "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg", + ) { + Ok(http_resp) => http_resp.bytes()?.to_vec(), + Err(e) => anyhow::bail!(e), + }; + let image = image::load_from_memory(&bytes)?; + + let messages = VisionMessages::new().add_image_message( + TextMessageRole::User, + "What is this?", + image, + &model, + )?; + + let response = model.send_chat_request(messages).await?; + + println!("{}", response.choices[0].message.content.as_ref().unwrap()); + dbg!( + response.usage.avg_prompt_tok_per_sec, + response.usage.avg_compl_tok_per_sec + ); + + Ok(()) +}