diff --git a/README.md b/README.md
index c8d783f0b4..6ce1a57eb9 100644
--- a/README.md
+++ b/README.md
@@ -31,7 +31,7 @@ Please submit requests for new models [here](https://github.com/EricLBuehler/mis
- Check out UQFF for prequantized models of various methods!
- Models can be found [here](https://huggingface.co/collections/EricB/uqff-670e4a49d56ecdd3f7f0fd4c).
-- 💎💎💎 Run the **Gemma 3** Model (*text only for now, vision coming very soon!*):
+- 💎💎💎 Run the **Gemma 3** Model with 128k context length and vision support: [documentation](docs/GEMMA3.md)
```
./mistralrs-server -i vision-plain -m google/gemma-3-4b-it -a gemma3
diff --git a/docs/GEMMA3.md b/docs/GEMMA3.md
new file mode 100644
index 0000000000..9420f50477
--- /dev/null
+++ b/docs/GEMMA3.md
@@ -0,0 +1,192 @@
+# Gemma 3 Model: [`google/gemma-3-4b-it`](https://huggingface.co/google/gemma-3-4b-it)
+
+The Gemma 3 model is a family of multimodal (text+vision) models with 128k context length. The collection can be found [here](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d), with model sizes ranging from 4B to 27B.
+
+We support the Gemma 3 Model in the Rust, Python, and HTTP APIs, including ISQ for increased performance.
+
+The Python and HTTP APIs support sending images as:
+- URL
+- Path to a local image
+- [Base64](https://en.wikipedia.org/wiki/Base64) encoded string
+
+The Rust API takes an image from the [image](https://docs.rs/image/latest/image/index.html) crate.
+
+## HTTP server
+You can find this example [here](../examples/server/gemma3.py).
+
+We support an OpenAI compatible HTTP API for vision models. This example demonstrates sending a chat completion request with an image.
+
+> Note: The image_url may be either a path, URL, or a base64 encoded string.
+
+---
+
+**Image:**
+
+
+
+**Prompt:**
+```
+image shows Mount Washington in New Hampshire, USA. It's a prominent peak in the White Mountains, known for its extreme weather conditions and being the highest peak in the Northeastern United States. The image captures it covered in snow with a dramatic sky above. The structures at the summit are communication towers.
+
+
+
+The winding path visible on the mountain slopes appears to be part of the Mount Washington Auto Road, a historic road that allows vehicles to drive to the summit.
+```
+
+**Output:**
+```
+A mountain with snow on it.
+```
+
+---
+
+1) Start the server
+
+> [!NOTE]
+> You should replace `--features ...` with one of the features specified [here](../README.md#supported-accelerators), or remove it for pure CPU inference.
+
+```
+cargo run --release --features ... -- --port 1234 vision-plain -m google/gemma-3-12b-it -a gemma3
+```
+
+2) Send a request
+
+```py
+from openai import OpenAI
+import httpx
+import textwrap
+import json
+
+
+client = OpenAI(api_key="foobar", base_url="http://localhost:1234/v1/")
+
+
+completion = client.chat.completions.create(
+ model="gemma3",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg"
+ },
+ },
+ {
+ "type": "text",
+ "text": "What is this?",
+ },
+ ],
+ },
+ ],
+ max_tokens=256,
+ frequency_penalty=1.0,
+ top_p=0.1,
+ temperature=0,
+)
+resp = completion.choices[0].message.content
+print(resp)
+
+```
+
+- You can find an example of encoding the [image via base64 here](../examples/server/phi3v_base64.py).
+- You can find an example of loading an [image locally here](../examples/server/phi3v_local_img.py).
+
+---
+
+## Rust
+You can find this example [here](../mistralrs/examples/gemma3/main.rs).
+
+This is a minimal example of running the Phi 4 Multimodal model with a dummy image.
+
+```rust
+use anyhow::Result;
+use mistralrs::{IsqType, TextMessageRole, VisionLoaderType, VisionMessages, VisionModelBuilder};
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ let model =
+ VisionModelBuilder::new("google/gemma-3-12b-it", VisionLoaderType::Gemma3)
+ .with_isq(IsqType::Q4K)
+ .with_logging()
+ .build()
+ .await?;
+
+ let bytes = match reqwest::blocking::get(
+ "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg",
+ ) {
+ Ok(http_resp) => http_resp.bytes()?.to_vec(),
+ Err(e) => anyhow::bail!(e),
+ };
+ let image = image::load_from_memory(&bytes)?;
+
+ let messages = VisionMessages::new().add_image_message(
+ TextMessageRole::User,
+ "What is depicted here? Please describe the scene in detail.",
+ image,
+ &model,
+ )?;
+
+ let response = model.send_chat_request(messages).await?;
+
+ println!("{}", response.choices[0].message.content.as_ref().unwrap());
+ dbg!(
+ response.usage.avg_prompt_tok_per_sec,
+ response.usage.avg_compl_tok_per_sec
+ );
+
+ Ok(())
+}
+```
+
+## Python
+You can find this example [here](../examples/python/gemma3.py).
+
+This example demonstrates loading and sending a chat completion request with an image.
+
+> Note: the image_url may be either a path, URL, or a base64 encoded string.
+
+```py
+from mistralrs import Runner, Which, ChatCompletionRequest, VisionArchitecture
+
+runner = Runner(
+ which=Which.VisionPlain(
+ model_id="google/gemma-3-12b-it",
+ arch=VisionArchitecture.Gemma3,
+ ),
+)
+
+res = runner.send_chat_completion_request(
+ ChatCompletionRequest(
+ model="gemma3",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg"
+ },
+ },
+ {
+ "type": "text",
+ "text": "What is this?",
+ },
+ ],
+ }
+ ],
+ max_tokens=256,
+ presence_penalty=1.0,
+ top_p=0.1,
+ temperature=0.1,
+ )
+)
+print(res.choices[0].message.content)
+print(res.usage)
+
+```
+
+- You can find an example of encoding the [image via base64 here](../examples/python/phi3v_base64.py).
+- You can find an example of loading an [image locally here](../examples/python/phi3v_local_img.py).
\ No newline at end of file
diff --git a/examples/python/gemma3.py b/examples/python/gemma3.py
new file mode 100644
index 0000000000..c52acaa83f
--- /dev/null
+++ b/examples/python/gemma3.py
@@ -0,0 +1,37 @@
+from mistralrs import Runner, Which, ChatCompletionRequest, VisionArchitecture
+
+runner = Runner(
+ which=Which.VisionPlain(
+ model_id="google/gemma-3-12b-it",
+ arch=VisionArchitecture.Gemma3,
+ ),
+)
+
+res = runner.send_chat_completion_request(
+ ChatCompletionRequest(
+ model="gemma3",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg"
+ },
+ },
+ {
+ "type": "text",
+ "text": "What is this?",
+ },
+ ],
+ }
+ ],
+ max_tokens=256,
+ presence_penalty=1.0,
+ top_p=0.1,
+ temperature=0.1,
+ )
+)
+print(res.choices[0].message.content)
+print(res.usage)
diff --git a/examples/server/gemma3.py b/examples/server/gemma3.py
new file mode 100644
index 0000000000..b09ac850e7
--- /dev/null
+++ b/examples/server/gemma3.py
@@ -0,0 +1,63 @@
+from openai import OpenAI
+import httpx
+import textwrap
+import json
+
+
+def log_response(response: httpx.Response):
+ request = response.request
+ print(f"Request: {request.method} {request.url}")
+ print(" Headers:")
+ for key, value in request.headers.items():
+ if key.lower() == "authorization":
+ value = "[...]"
+ if key.lower() == "cookie":
+ value = value.split("=")[0] + "=..."
+ print(f" {key}: {value}")
+ print(" Body:")
+ try:
+ request_body = json.loads(request.content)
+ print(textwrap.indent(json.dumps(request_body, indent=2), " "))
+ except json.JSONDecodeError:
+ print(textwrap.indent(request.content.decode(), " "))
+ print(f"Response: status_code={response.status_code}")
+ print(" Headers:")
+ for key, value in response.headers.items():
+ if key.lower() == "set-cookie":
+ value = value.split("=")[0] + "=..."
+ print(f" {key}: {value}")
+
+
+client = OpenAI(api_key="foobar", base_url="http://localhost:1234/v1/")
+
+# Enable this to log requests and responses
+# client._client = httpx.Client(
+# event_hooks={"request": [print], "response": [log_response]}
+# )
+
+completion = client.chat.completions.create(
+ model="gemma3",
+ messages=[
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg"
+ },
+ },
+ {
+ "type": "text",
+ "text": "What is this?",
+ },
+ ],
+ },
+ ],
+ max_tokens=256,
+ frequency_penalty=1.0,
+ top_p=0.1,
+ temperature=0,
+)
+resp = completion.choices[0].message.content
+print(resp)
diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs
index 057636b0c0..858b24d493 100644
--- a/mistralrs-core/src/attention.rs
+++ b/mistralrs-core/src/attention.rs
@@ -226,7 +226,7 @@ fn naive_sdpa(
candle_nn::ops::inplace_attn_softmax_last_dim(
&mut att,
- &mask,
+ &mask.contiguous()?,
sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0),
)?;
diff --git a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs
index 2ed518bfed..55e1528a78 100644
--- a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs
+++ b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs
@@ -3134,11 +3134,11 @@ impl VisionModelLoader for Gemma3Loader {
fn get_processor(
&self,
_model_config: &str,
- _processor_config: Option,
+ processor_config: Option,
_preprocessor_config: PreProcessorConfig,
_max_edge: Option,
) -> Arc {
- Arc::new(Gemma3Processor)
+ Arc::new(Gemma3Processor::new(processor_config.unwrap()))
}
fn supports_paged_attention(&self) -> bool {
true
diff --git a/mistralrs-core/src/vision_models/gemma3/config.rs b/mistralrs-core/src/vision_models/gemma3/config.rs
index 6bdd5508f0..11c6761f89 100644
--- a/mistralrs-core/src/vision_models/gemma3/config.rs
+++ b/mistralrs-core/src/vision_models/gemma3/config.rs
@@ -3,6 +3,7 @@ use mistralrs_quant::QuantizedConfig;
use crate::{
layers::{Activation, Gemma3RopeScalingConfig},
serde_default_fn,
+ vision_models::siglip::SiglipVisionConfig,
};
serde_default_fn!(bool, attention_bias, false);
@@ -63,4 +64,7 @@ pub struct Gemma3TextConfig {
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Gemma3Config {
pub text_config: Gemma3TextConfig,
+ pub vision_config: SiglipVisionConfig,
+ pub image_token_index: usize,
+ pub mm_tokens_per_image: usize,
}
diff --git a/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs b/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs
index c9984d6dfa..6a9486e9c2 100644
--- a/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs
@@ -2,8 +2,11 @@
use std::{any::Any, num::NonZeroUsize, sync::Arc};
-use candle_core::{Device, Result};
-use image::DynamicImage;
+use candle_core::{Device, Result, Tensor};
+use image::{DynamicImage, GenericImageView};
+use itertools::Itertools;
+use mistralrs_vision::{ApplyTransforms, Normalize, Rescale, ToTensorNoNorm, Transforms};
+use regex::Regex;
use tokenizers::Tokenizer;
use tracing::warn;
@@ -18,28 +21,51 @@ use crate::{
sequence::Sequence,
vision_models::{
image_processor::{ImagePreProcessor, PreprocessedImages},
- preprocessor_config::PreProcessorConfig,
+ preprocessor_config::{PreProcessorConfig, ToFilter},
+ processor_config::ProcessorConfig,
ModelInputs,
},
};
use super::Gemma3SpecificArgs;
-struct Gemma3ImageProcessor;
+struct Gemma3ImageProcessor {
+ full_image_sequence: String,
+}
+
+const IMAGE_TOKEN: &str = "";
+const BOI_TOKEN: &str = "";
+const EOI_TOKEN: &str = "";
-pub struct Gemma3Processor;
+pub struct Gemma3Processor {
+ full_image_sequence: String,
+}
+
+impl Gemma3Processor {
+ pub fn new(processor_config: ProcessorConfig) -> Self {
+ let image_tokens_expanded =
+ vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join("");
+ let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n");
+
+ Self {
+ full_image_sequence,
+ }
+ }
+}
impl Processor for Gemma3Processor {
fn inputs_processor(&self) -> Arc {
- Arc::new(Gemma3ImageProcessor)
+ Arc::new(Gemma3ImageProcessor {
+ full_image_sequence: self.full_image_sequence.clone(),
+ })
}
fn get_special_tokens(&self) -> &[&'static str] {
- &[]
+ &[BOI_TOKEN, EOI_TOKEN, IMAGE_TOKEN]
}
fn template_action(&self) -> MessagesAction {
- MessagesAction::FlattenOnlyText
+ MessagesAction::Keep
}
}
@@ -49,7 +75,7 @@ impl InputsProcessor for Gemma3ImageProcessor {
}
fn process_inputs(
&self,
- _tokenizer: Option>,
+ tokenizer: Option>,
input_seqs: &mut [&mut Sequence],
is_prompt: bool,
is_xlora: bool,
@@ -57,7 +83,7 @@ impl InputsProcessor for Gemma3ImageProcessor {
no_kv_cache: bool,
last_n_context_len: Option<(usize, usize)>,
return_raw_logits: bool,
- _other_config: Option>,
+ other_config: Option>,
mut paged_attn_metadata: Option>,
prompt_chunksize: Option,
mapper: Option<&dyn DeviceMapper>,
@@ -76,6 +102,11 @@ impl InputsProcessor for Gemma3ImageProcessor {
if prompt_chunksize.is_some() {
warn!("`prompt_chunksize` is set. Gemma3 does not support prompt batching.");
}
+ let Some(tokenizer) = tokenizer else {
+ return Box::new(std::iter::once(Err(anyhow::Error::msg(
+ "Idefics3ImageProcessor requires a specified tokenizer.",
+ ))));
+ };
let text_models_inputs_processor::InnerInputProcessorOutput {
inputs:
@@ -125,12 +156,112 @@ impl InputsProcessor for Gemma3ImageProcessor {
.unwrap()
};
+ let config = other_config.expect("Need a PreProcessorConfig config.");
+ let config: &PreProcessorConfig = config.downcast_ref().expect("Downcast failed.");
+
+ let has_images = input_seqs
+ .iter()
+ .all(|seq| seq.images().is_some_and(|images| !images.is_empty()));
+
+ let (new_input, pixel_values) = if has_images {
+ let mut pixel_values_accum = Vec::new();
+ let mut all_ids = Vec::new();
+ let re = Regex::new(BOI_TOKEN).unwrap();
+ for seq in input_seqs.iter_mut() {
+ let PreprocessedImages {
+ pixel_values,
+ pixel_attention_mask: _,
+ image_sizes: _,
+ num_img_tokens: _,
+ aspect_ratio_ids: _,
+ aspect_ratio_mask: _,
+ num_tiles: _,
+ image_grid_thw: _,
+ video_grid_thw: _,
+ rows: _,
+ cols: _,
+ pixel_values_list: _,
+ tgt_sizes: _,
+ image_sizes_all: _,
+ num_crops,
+ } = self
+ .preprocess(
+ seq.take_images()
+ .expect("Need to have images by this point."),
+ vec![],
+ config,
+ device,
+ (usize::MAX, usize::MAX), // Don't use it here...
+ )
+ .expect("Preprocessing failed");
+
+ let num_crops = num_crops.unwrap();
+
+ // Deliberately no .unsqueeze here
+ pixel_values_accum.push(pixel_values.clone());
+
+ let mut prompt = tokenizer
+ .decode(seq.get_toks(), false)
+ .expect("Detokenization failed!");
+
+ let image_indexes: Vec =
+ re.find_iter(&prompt).map(|mat| mat.start()).collect();
+
+ assert_ne!(pixel_values.dim(0).unwrap(), image_indexes.len());
+
+ for (num, idx) in num_crops.into_iter().zip(image_indexes).rev() {
+ if num != 0 {
+ let formatted_image_text = format!(
+ "Here is the original image {BOI_TOKEN} and here are some crops to help you see better {}", vec![BOI_TOKEN.to_string(); num].join(" ")
+ );
+ prompt = format!(
+ "{}{formatted_image_text}{}",
+ &prompt[..idx],
+ &prompt[idx + BOI_TOKEN.len()..]
+ );
+ }
+ }
+
+ prompt = prompt.replace(BOI_TOKEN, &self.full_image_sequence);
+
+ seq.set_initial_prompt(prompt.clone());
+ let toks = tokenizer
+ .encode(prompt, false)
+ .expect("Detokenization failed!");
+
+ let ids = toks.get_ids().to_vec();
+ all_ids.push(ids.clone());
+
+ seq.set_toks_and_reallocate(ids, paged_attn_metadata.as_mut());
+ }
+
+ let mut all_ids_new = Vec::new();
+ let max_len = all_ids.iter().map(|ids| ids.len()).max().unwrap();
+ for ids in all_ids {
+ let pad = max_len - ids.len();
+ all_ids_new
+ .push(Tensor::new([ids, vec![0; pad]].concat(), input.device()).unwrap());
+ }
+
+ (
+ Some(Tensor::stack(&all_ids_new, 0).unwrap()),
+ Some(Tensor::cat(&pixel_values_accum, 0).unwrap()),
+ )
+ } else {
+ (None, None)
+ };
+
+ let input = match new_input {
+ Some(new_input) => new_input,
+ None => input,
+ };
+
let inputs: Box = Box::new(ModelInputs {
input_ids: input,
seqlen_offsets: positions,
context_lens,
position_ids,
- pixel_values: None,
+ pixel_values,
model_specific_args: Box::new(Gemma3SpecificArgs),
paged_attn_meta,
flash_meta,
@@ -142,18 +273,193 @@ impl InputsProcessor for Gemma3ImageProcessor {
}
}
+impl Gemma3ImageProcessor {
+ fn pan_and_scan(
+ &self,
+ image: &DynamicImage,
+ pan_and_scan_min_crop_size: usize,
+ pan_and_scan_max_num_crops: usize,
+ pan_and_scan_min_ratio_to_activate: f64,
+ ) -> Vec {
+ let (width, height) = image.dimensions();
+
+ let (num_crops_w, num_crops_h) = if width >= height {
+ if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
+ return vec![];
+ }
+
+ // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
+ let mut num_crops_w = (width as f64 / height as f64 + 0.5).floor() as usize;
+ num_crops_w = num_crops_w
+ .min((width as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
+
+ // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
+ num_crops_w = num_crops_w.max(2);
+ num_crops_w = num_crops_w.min(pan_and_scan_max_num_crops);
+
+ (num_crops_w, 1)
+ } else {
+ if (width as f64 / height as f64) < pan_and_scan_min_ratio_to_activate {
+ return vec![];
+ }
+
+ // Select ideal number of crops close to the image aspect ratio and such that crop_size > min_crop_size.
+ let mut num_crops_h = (height as f64 / width as f64 + 0.5).floor() as usize;
+ num_crops_h = num_crops_h
+ .min((height as f64 / pan_and_scan_min_crop_size as f64).floor() as usize);
+
+ // Make sure the number of crops is in range [2, pan_and_scan_max_num_crops].
+ num_crops_h = num_crops_h.max(2);
+ num_crops_h = num_crops_h.min(pan_and_scan_max_num_crops);
+
+ (1, num_crops_h)
+ };
+
+ let crop_size_w = (width as f64 / num_crops_w as f64).ceil() as usize;
+ let crop_size_h = (height as f64 / num_crops_h as f64).ceil() as usize;
+
+ if crop_size_w.min(crop_size_h) < pan_and_scan_min_crop_size {
+ return vec![];
+ }
+
+ let crop_positions_w = (0..num_crops_w)
+ .map(|i| i * crop_size_w)
+ .collect::>();
+ let crop_positions_h = (0..num_crops_h)
+ .map(|i| i * crop_size_h)
+ .collect::>();
+
+ let mut image_crops = Vec::new();
+ for (pos_h, pos_w) in crop_positions_h
+ .into_iter()
+ .cartesian_product(crop_positions_w)
+ {
+ image_crops.push(image.crop_imm(
+ pos_w as u32,
+ pos_h as u32,
+ crop_size_w as u32,
+ crop_size_h as u32,
+ ));
+ }
+
+ image_crops
+ }
+
+ fn process_images_for_pan_and_scan(
+ &self,
+ images: Vec,
+ pan_and_scan_min_crop_size: usize,
+ pan_and_scan_max_num_crops: usize,
+ pan_and_scan_min_ratio_to_activate: f64,
+ ) -> (Vec, Vec) {
+ let mut pas_images_list = Vec::new();
+ let mut num_crops = Vec::new();
+
+ for image in images {
+ let pas_images = self.pan_and_scan(
+ &image,
+ pan_and_scan_min_crop_size,
+ pan_and_scan_max_num_crops,
+ pan_and_scan_min_ratio_to_activate,
+ );
+ num_crops.push(pas_images.len());
+ pas_images_list.extend([vec![image], pas_images].concat());
+ }
+
+ (pas_images_list, num_crops)
+ }
+}
+
impl ImagePreProcessor for Gemma3ImageProcessor {
const DEFAULT_MEAN: [f64; 3] = [0.5, 0.5, 0.5];
const DEFAULT_STD: [f64; 3] = [0.5, 0.5, 0.5];
fn preprocess(
&self,
- _images: Vec,
- _videos: Vec>,
- _config: &PreProcessorConfig,
- _device: &Device,
+ mut images: Vec,
+ videos: Vec>,
+ config: &PreProcessorConfig,
+ device: &Device,
(_bs, _max_num_images): (usize, usize),
) -> Result {
- todo!()
+ assert!(videos.is_empty());
+
+ let do_resize = config.do_resize.unwrap();
+ let size = config.size.as_ref().unwrap();
+ let (height, width) = (size["height"], size["width"]);
+ let resample = config.resampling.to_filter()?;
+ let do_rescale = config.do_rescale.unwrap();
+ let rescale_factor = config.rescale_factor.unwrap();
+ let do_normalize = config.do_normalize.unwrap();
+ let image_mean = config.image_mean.unwrap_or(Self::DEFAULT_MEAN);
+ let image_std = config.image_std.unwrap_or(Self::DEFAULT_STD);
+ let do_convert_rgb = config.do_convert_rgb.unwrap_or(true);
+ let do_pan_and_scan = config.do_pan_and_scan.unwrap_or(do_convert_rgb);
+ // https://github.com/huggingface/transformers/blob/ea219ed164bead55a5513e8cfaa17a25d5613b9e/src/transformers/models/gemma3/processing_gemma3.py#L42
+ let pan_and_scan_min_crop_size = config.pan_and_scan_min_crop_size.unwrap_or(256);
+ let pan_and_scan_max_num_crops = config.pan_and_scan_max_num_crops.unwrap_or(4);
+ let pan_and_scan_min_ratio_to_activate =
+ config.pan_and_scan_min_ratio_to_activate.unwrap_or(1.2);
+
+ for image in images.iter_mut() {
+ // Convert to rgb
+ if config.do_convert_rgb.is_some_and(|x| x) {
+ *image = DynamicImage::ImageRgb8(image.to_rgb8());
+ }
+ }
+
+ let num_crops = if do_pan_and_scan {
+ let (new_images, num_crops) = self.process_images_for_pan_and_scan(
+ images,
+ pan_and_scan_min_crop_size,
+ pan_and_scan_max_num_crops,
+ pan_and_scan_min_ratio_to_activate,
+ );
+ images = new_images;
+ num_crops
+ } else {
+ vec![0]
+ };
+
+ let mut pixel_values = Vec::new();
+ for mut image in images {
+ if do_resize {
+ image = image.resize_exact(width, height, resample);
+ }
+
+ let transforms = Transforms {
+ input: &ToTensorNoNorm,
+ inner_transforms: &[
+ &do_rescale.then_some(Rescale {
+ factor: Some(rescale_factor),
+ }),
+ &do_normalize.then(|| Normalize {
+ mean: image_mean.to_vec(),
+ std: image_std.to_vec(),
+ }),
+ ],
+ };
+
+ let image = image.apply(transforms, device)?;
+ pixel_values.push(image.unsqueeze(0)?);
+ }
+
+ Ok(PreprocessedImages {
+ pixel_values: Tensor::cat(&pixel_values, 0)?,
+ pixel_attention_mask: None,
+ image_sizes: None,
+ num_img_tokens: None,
+ aspect_ratio_ids: None,
+ aspect_ratio_mask: None,
+ num_tiles: None,
+ image_grid_thw: None,
+ video_grid_thw: None,
+ rows: None,
+ cols: None,
+ pixel_values_list: None,
+ tgt_sizes: None,
+ image_sizes_all: None,
+ num_crops: Some(num_crops),
+ })
}
}
diff --git a/mistralrs-core/src/vision_models/gemma3/mmproj.rs b/mistralrs-core/src/vision_models/gemma3/mmproj.rs
new file mode 100644
index 0000000000..f5e9a7cda5
--- /dev/null
+++ b/mistralrs-core/src/vision_models/gemma3/mmproj.rs
@@ -0,0 +1,77 @@
+use candle_core::{Result, Tensor};
+use candle_nn::Module;
+use mistralrs_quant::ShardedVarBuilder;
+
+use crate::{
+ layers::{AvgPool2d, RmsNorm},
+ utils::unvarbuilder::UnVarBuilder,
+};
+
+use super::config::Gemma3Config;
+
+pub struct Gemma3MultiModalProjector {
+ mm_input_projection_weight: Tensor,
+ mm_soft_emb_norm: RmsNorm,
+ patches_per_image: usize,
+ avg_pool: AvgPool2d,
+}
+
+impl Gemma3MultiModalProjector {
+ pub fn new(cfg: &Gemma3Config, vb: ShardedVarBuilder) -> Result {
+ let mm_input_projection_weight = vb.get(
+ (cfg.vision_config.hidden_size, cfg.text_config.hidden_size),
+ "mm_input_projection_weight",
+ )?;
+ let mm_soft_emb_norm = RmsNorm::new_gemma(
+ cfg.vision_config.hidden_size,
+ cfg.vision_config.layer_norm_eps,
+ vb.pp("mm_soft_emb_norm"),
+ )?;
+
+ let patches_per_image = cfg.vision_config.image_size / cfg.vision_config.patch_size;
+ let tokens_per_side = cfg.mm_tokens_per_image.isqrt();
+ let kernel_size = patches_per_image / tokens_per_side;
+ let avg_pool = AvgPool2d::new(kernel_size, kernel_size);
+
+ Ok(Self {
+ mm_input_projection_weight,
+ mm_soft_emb_norm,
+ patches_per_image,
+ avg_pool,
+ })
+ }
+
+ pub fn forward(&self, xs: &Tensor) -> Result {
+ let (bs, _, seqlen) = xs.dims3()?;
+
+ let mut reshaped_vision_outputs = xs.transpose(1, 2)?;
+ reshaped_vision_outputs = reshaped_vision_outputs.reshape((
+ bs,
+ seqlen,
+ self.patches_per_image,
+ self.patches_per_image,
+ ))?;
+ reshaped_vision_outputs = reshaped_vision_outputs.contiguous()?;
+
+ let mut pooled_vision_outputs = self.avg_pool.forward(&reshaped_vision_outputs)?;
+ pooled_vision_outputs = pooled_vision_outputs.flatten_from(2)?;
+ pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)?;
+
+ let normed_vision_outputs = self.mm_soft_emb_norm.forward(&pooled_vision_outputs)?;
+
+ normed_vision_outputs.broadcast_matmul(&self.mm_input_projection_weight)
+ }
+
+ pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
+ let uvb = UnVarBuilder::new();
+
+ uvb.add_tensor(
+ "mm_input_projection_weight",
+ self.mm_input_projection_weight.clone(),
+ );
+ uvb.pp("mm_soft_emb_norm")
+ .add(&self.mm_soft_emb_norm.undo_gemma().unwrap());
+
+ uvb.to_safetensors()
+ }
+}
diff --git a/mistralrs-core/src/vision_models/gemma3/mod.rs b/mistralrs-core/src/vision_models/gemma3/mod.rs
index 6bc644cf57..6dbb88eb12 100644
--- a/mistralrs-core/src/vision_models/gemma3/mod.rs
+++ b/mistralrs-core/src/vision_models/gemma3/mod.rs
@@ -2,29 +2,38 @@
use std::sync::Arc;
-use candle_core::{Device, Result, Tensor};
+use candle_core::{DType, Device, Result, Tensor, D};
use config::Gemma3Config;
use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
+use mmproj::Gemma3MultiModalProjector;
use text::TextModel;
use crate::{
amoe::{AnyMoeBaseModelMixin, MlpLayer},
device_map::DeviceMapper,
+ ops::NonZeroOp,
paged_attention::{AttentionImplementation, ModelConfigMetadata},
pipeline::{
text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
EitherCache, IsqModel, NormalLoadingMetadata, VisionModel,
},
+ utils::unvarbuilder::UnVarBuilder,
AnyMoeConfig, AnyMoeExpertType,
};
pub mod config;
mod inputs_processor;
+mod mmproj;
mod text;
pub(crate) use inputs_processor::Gemma3Processor;
+use super::siglip::SiglipVisionTransformer;
+
pub struct Gemma3Model {
language_model: TextModel,
+ multi_modal_projector: Gemma3MultiModalProjector,
+ vision_tower: SiglipVisionTransformer,
+ cfg: Gemma3Config,
}
impl Gemma3Model {
@@ -35,6 +44,7 @@ impl Gemma3Model {
normal_loading_metadata: NormalLoadingMetadata,
attention_mechanism: AttentionImplementation,
) -> Result {
+ assert!(cfg.image_token_index < cfg.text_config.vocab_size);
Ok(Self {
language_model: TextModel::new(
&cfg.text_config,
@@ -43,19 +53,55 @@ impl Gemma3Model {
normal_loading_metadata,
attention_mechanism,
)?,
+ multi_modal_projector: Gemma3MultiModalProjector::new(
+ cfg,
+ vb.pp("multi_modal_projector"),
+ )?,
+ vision_tower: SiglipVisionTransformer::new(
+ &cfg.vision_config,
+ vb.pp("vision_tower").pp("vision_model"),
+ )?,
+ cfg: cfg.clone(),
})
}
fn forward(
&self,
input_ids: &Tensor,
+ pixel_values: Option,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result {
- self.language_model.forward(
+ let mut input_embeds = self.language_model.embed_tokens(input_ids)?;
+ if let Some(pixel_values) = pixel_values {
+ let dtype = self.vision_tower.dtype();
+ let vision_outputs =
+ self.vision_tower
+ .forward(&pixel_values.to_dtype(dtype)?, None, None)?;
+ let image_features = self.multi_modal_projector.forward(&vision_outputs)?;
+
+ let special_image_mask = input_ids
+ .eq(self.cfg.image_token_index as f64)?
+ .unsqueeze(D::Minus1)?
+ .broadcast_as(input_embeds.shape())?
+ .to_dtype(DType::U32)?;
+
+ let mask_flat = special_image_mask.flatten_all()?;
+ let mut x_flat = input_embeds.flatten_all()?;
+ let src_flat = image_features.flatten_all()?;
+
+ let indices = mask_flat.nonzero()?.squeeze(1)?;
+ let current_vals = x_flat.gather(&indices, 0)?;
+ let diff = (src_flat - current_vals)?;
+ x_flat = x_flat.scatter_add(&indices, &diff, 0)?;
+
+ input_embeds = x_flat.reshape(input_embeds.shape())?;
+ };
+ self.language_model.forward_embeds(
input_ids,
+ input_embeds,
seqlen_offsets,
context_lens,
metadata,
@@ -75,7 +121,17 @@ impl IsqModel for Gemma3Model {
}
fn residual_tensors(&self) -> Vec<(String, Tensor)> {
- self.language_model.residual_tensors()
+ let uvb = UnVarBuilder::new();
+
+ uvb.pp("multi_modal_projector")
+ .extend(self.multi_modal_projector.residual_tensors());
+ uvb.pp("language_model")
+ .extend(self.language_model.residual_tensors());
+ uvb.pp("vision_tower")
+ .pp("vision_model")
+ .extend(self.vision_tower.residual_tensors());
+
+ uvb.to_safetensors()
}
fn imatrix_names(&self) -> candle_core::Result>> {
@@ -89,7 +145,7 @@ impl VisionModel for Gemma3Model {
fn forward(
&self,
input_ids: &Tensor,
- _pixel_values: Option,
+ pixel_values: Option,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
_position_ids: Vec,
@@ -99,6 +155,7 @@ impl VisionModel for Gemma3Model {
) -> candle_core::Result {
self.forward(
input_ids,
+ pixel_values,
seqlen_offsets,
context_lens,
metadata,
diff --git a/mistralrs-core/src/vision_models/gemma3/text.rs b/mistralrs-core/src/vision_models/gemma3/text.rs
index bf11d4c55d..e081ba3291 100644
--- a/mistralrs-core/src/vision_models/gemma3/text.rs
+++ b/mistralrs-core/src/vision_models/gemma3/text.rs
@@ -536,15 +536,19 @@ impl TextModel {
})
}
- pub fn forward(
+ pub fn embed_tokens(&self, input_ids: &Tensor) -> Result {
+ self.embed_tokens.forward(input_ids)
+ }
+
+ pub fn forward_embeds(
&self,
input_ids: &Tensor,
+ mut xs: Tensor,
seqlen_offsets: &[usize],
context_lens: Vec<(usize, usize)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
flash_params: &FlashParams,
) -> Result {
- let mut xs = self.embed_tokens.forward(input_ids)?;
let cache = &mut self.cache.normal().0;
let attention_mask = CausalMasker.make_causal_mask_matrix(
input_ids,
diff --git a/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs b/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs
index 1e951c5be2..428ce544bb 100644
--- a/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs
+++ b/mistralrs-core/src/vision_models/idefics2/idefics2_input_processor.rs
@@ -232,6 +232,7 @@ impl InputsProcessor for Idefics2ImageProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all: _,
+ num_crops: _,
} = self
.preprocess(
seq.take_images()
@@ -402,6 +403,7 @@ impl ImagePreProcessor for Idefics2ImageProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
})
}
}
diff --git a/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs b/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs
index d8889cebce..1a4a880123 100644
--- a/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/idefics3/inputs_processor.rs
@@ -209,6 +209,7 @@ impl InputsProcessor for Idefics3ImageProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all: _,
+ num_crops: _,
} = self
.preprocess(
seq.take_images()
@@ -245,7 +246,7 @@ impl InputsProcessor for Idefics3ImageProcessor {
seq.set_initial_prompt(sample.clone());
let toks = tokenizer
- .encode(sample, true)
+ .encode(sample, false)
.expect("Detokenization failed!");
let ids = toks.get_ids().to_vec();
@@ -599,6 +600,7 @@ impl ImagePreProcessor for Idefics3ImageProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
})
}
}
diff --git a/mistralrs-core/src/vision_models/image_processor.rs b/mistralrs-core/src/vision_models/image_processor.rs
index dd5b3d2eb5..c672e32272 100644
--- a/mistralrs-core/src/vision_models/image_processor.rs
+++ b/mistralrs-core/src/vision_models/image_processor.rs
@@ -36,6 +36,8 @@ pub(crate) struct PreprocessedImages {
pub(crate) tgt_sizes: Option,
/// Without batch size. Per image. (w,h).
pub(crate) image_sizes_all: Option>,
+ /// Without batch size
+ pub(crate) num_crops: Option>,
}
/// ImagePreProcessor: process images for the model (similar to `InputsProcessor`, typically called by it)
diff --git a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs
index dc9fa787a6..e90aefb7ec 100644
--- a/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/llava/llava_inputs_processor.rs
@@ -141,6 +141,7 @@ impl InputsProcessor for LLaVAInputProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all: _,
+ num_crops: _,
} = self
.preprocess(
imgs.clone(),
@@ -238,7 +239,7 @@ impl InputsProcessor for LLaVAInputProcessor {
.map(|s| {
// we don't use encode_batch here, because encode_batch will pad 0 to the end of the shor sequences, which will cause the image_ids_pad to be wrong.
tokenizer
- .encode(*s, true)
+ .encode(*s, false)
.unwrap()
.get_ids()
.to_vec()
@@ -398,6 +399,7 @@ impl ImagePreProcessor for LLaVAInputProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
})
}
}
diff --git a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs
index 9e6a1195e8..a3c1db2d65 100644
--- a/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/llava/llava_next_inputs_processor.rs
@@ -154,6 +154,7 @@ impl InputsProcessor for LLaVANextInputProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all: _,
+ num_crops: _,
} = self
.preprocess(
imgs.clone(),
@@ -282,7 +283,7 @@ impl InputsProcessor for LLaVANextInputProcessor {
.map(|s| {
// we don't use encode_batch here, because encode_batch will pad 0 to the end of the shor sequences, which will cause the image_ids_pad to be wrong.
tokenizer
- .encode(*s, true)
+ .encode(*s, false)
.unwrap()
.get_ids()
.to_vec()
@@ -458,6 +459,7 @@ impl ImagePreProcessor for LLaVANextInputProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
})
}
}
diff --git a/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs b/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs
index e354381807..74f2f10c24 100644
--- a/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/minicpmo/inputs_processor.rs
@@ -208,6 +208,7 @@ impl InputsProcessor for MiniCpmOImageProcessor {
pixel_values_list,
tgt_sizes,
image_sizes_all,
+ num_crops: _,
} = self
.preprocess(
seq.take_images()
@@ -274,7 +275,7 @@ impl InputsProcessor for MiniCpmOImageProcessor {
.im_start_token
.clone()
.unwrap_or(DEFAULT_IM_START_TOKEN.to_string()),
- true,
+ false,
)
.unwrap()
.get_ids()[0];
@@ -284,7 +285,7 @@ impl InputsProcessor for MiniCpmOImageProcessor {
.im_end_token
.clone()
.unwrap_or(DEFAULT_IM_END_TOKEN.to_string()),
- true,
+ false,
)
.unwrap()
.get_ids()[0];
@@ -294,7 +295,7 @@ impl InputsProcessor for MiniCpmOImageProcessor {
.slice_start_token
.clone()
.unwrap_or(DEFAULT_SLICE_START_TOKEN.to_string()),
- true,
+ false,
)
.unwrap()
.get_ids()[0];
@@ -304,13 +305,13 @@ impl InputsProcessor for MiniCpmOImageProcessor {
.slice_end_token
.clone()
.unwrap_or(DEFAULT_SLICE_END_TOKEN.to_string()),
- true,
+ false,
)
.unwrap()
.get_ids()[0];
let input_ids = tokenizer
- .encode(final_text, true)
+ .encode(final_text, false)
.unwrap()
.get_ids()
.to_vec();
@@ -779,6 +780,7 @@ impl ImagePreProcessor for MiniCpmOImageProcessor {
pixel_values_list: Some(pixel_values),
tgt_sizes: Some(tgt_sizes),
image_sizes_all: Some(image_sizes),
+ num_crops: None,
})
}
}
diff --git a/mistralrs-core/src/vision_models/mllama/inputs_processor.rs b/mistralrs-core/src/vision_models/mllama/inputs_processor.rs
index a7259ac6d8..712fd82685 100644
--- a/mistralrs-core/src/vision_models/mllama/inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/mllama/inputs_processor.rs
@@ -310,6 +310,7 @@ impl InputsProcessor for MLlamaImageProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all: _,
+ num_crops: _,
} = self
.preprocess(
seq.take_images()
@@ -828,6 +829,7 @@ impl ImagePreProcessor for MLlamaImageProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
})
}
}
diff --git a/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs b/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs
index 15e9552c3a..62b29baaee 100644
--- a/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/phi3/phi3_inputs_processor.rs
@@ -139,6 +139,7 @@ impl InputsProcessor for Phi3InputsProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all: _,
+ num_crops: _,
} = self
.preprocess(
imgs,
@@ -567,6 +568,7 @@ impl ImagePreProcessor for Phi3InputsProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
})
}
}
diff --git a/mistralrs-core/src/vision_models/phi4/inputs_processor.rs b/mistralrs-core/src/vision_models/phi4/inputs_processor.rs
index 0877187bd8..2110231583 100644
--- a/mistralrs-core/src/vision_models/phi4/inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/phi4/inputs_processor.rs
@@ -137,6 +137,7 @@ impl InputsProcessor for Phi4MMInputsProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all,
+ num_crops: _,
} = self
.preprocess(
imgs,
@@ -243,7 +244,7 @@ impl InputsProcessor for Phi4MMInputsProcessor {
seq.set_toks_and_reallocate(
tokenizer
- .encode(detokenized.clone(), true)
+ .encode(detokenized.clone(), false)
.expect("Encode failed")
.get_ids()
.to_vec(),
@@ -659,6 +660,7 @@ impl ImagePreProcessor for Phi4MMInputsProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: Some(image_sizes),
+ num_crops: None,
})
}
}
diff --git a/mistralrs-core/src/vision_models/preprocessor_config.rs b/mistralrs-core/src/vision_models/preprocessor_config.rs
index a2d577d659..a3ee94e688 100644
--- a/mistralrs-core/src/vision_models/preprocessor_config.rs
+++ b/mistralrs-core/src/vision_models/preprocessor_config.rs
@@ -19,6 +19,7 @@ pub struct PreProcessorConfig {
#[serde(alias = "norm_std")]
pub(crate) image_std: Option<[f64; 3]>,
pub(crate) rescale_factor: Option,
+ #[serde(alias = "resample")]
pub(crate) resampling: Option,
pub(crate) max_image_size: Option>,
pub(crate) size: Option>,
@@ -44,6 +45,12 @@ pub struct PreProcessorConfig {
pub(crate) im_id_start: Option,
pub(crate) im_id_end: Option,
pub(crate) dynamic_hd: Option,
+ #[serde(alias = "image_seq_length")]
+ pub(crate) image_seq_len: Option,
+ pub(crate) pan_and_scan_min_crop_size: Option,
+ pub(crate) pan_and_scan_max_num_crops: Option,
+ pub(crate) pan_and_scan_min_ratio_to_activate: Option,
+ pub(crate) do_pan_and_scan: Option,
}
#[allow(dead_code)]
diff --git a/mistralrs-core/src/vision_models/processor_config.rs b/mistralrs-core/src/vision_models/processor_config.rs
index fba572b558..0e05d7841f 100644
--- a/mistralrs-core/src/vision_models/processor_config.rs
+++ b/mistralrs-core/src/vision_models/processor_config.rs
@@ -4,5 +4,6 @@ use serde::Deserialize;
#[derive(Deserialize, Debug, Default)]
pub struct ProcessorConfig {
pub(crate) chat_template: Option,
+ #[serde(alias = "image_seq_length")]
pub(crate) image_seq_len: Option,
}
diff --git a/mistralrs-core/src/vision_models/qwen2_5_vl/inputs_processor.rs b/mistralrs-core/src/vision_models/qwen2_5_vl/inputs_processor.rs
index 17fd0aacff..70024de4ad 100644
--- a/mistralrs-core/src/vision_models/qwen2_5_vl/inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/qwen2_5_vl/inputs_processor.rs
@@ -215,6 +215,7 @@ impl InputsProcessor for Qwen2_5VLImageProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all: _,
+ num_crops: _,
} = self
.preprocess(
seq.clone_images()
@@ -326,7 +327,7 @@ impl InputsProcessor for Qwen2_5VLImageProcessor {
seq.set_initial_prompt(detok.clone());
let toks = tokenizer
- .encode(detok, true)
+ .encode(detok, false)
.expect("Detokenization failed!");
let ids = toks.get_ids().to_vec();
@@ -378,7 +379,7 @@ impl InputsProcessor for Qwen2_5VLImageProcessor {
);
let ids = tokenizer
- .encode(prompt, true)
+ .encode(prompt, false)
.expect("Tokenization failed!");
input_ids_searching.push(ids.get_ids().to_vec());
@@ -696,6 +697,7 @@ impl ImagePreProcessor for Qwen2_5VLImageProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
});
}
@@ -735,6 +737,7 @@ impl ImagePreProcessor for Qwen2_5VLImageProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
});
}
unreachable!()
diff --git a/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs b/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs
index ea2eaab205..4a1e0d1ef7 100644
--- a/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs
+++ b/mistralrs-core/src/vision_models/qwen2vl/inputs_processor.rs
@@ -264,6 +264,7 @@ impl InputsProcessor for Qwen2VLImageProcessor {
pixel_values_list: _,
tgt_sizes: _,
image_sizes_all: _,
+ num_crops: _,
} = self
.preprocess(
seq.clone_images()
@@ -371,7 +372,7 @@ impl InputsProcessor for Qwen2VLImageProcessor {
seq.set_initial_prompt(detok.clone());
let toks = tokenizer
- .encode(detok, true)
+ .encode(detok, false)
.expect("Detokenization failed!");
let ids = toks.get_ids().to_vec();
@@ -422,7 +423,7 @@ impl InputsProcessor for Qwen2VLImageProcessor {
);
let ids = tokenizer
- .encode(prompt, true)
+ .encode(prompt, false)
.expect("Tokenization failed!");
input_ids_searching.push(ids.get_ids().to_vec());
@@ -693,6 +694,7 @@ impl ImagePreProcessor for Qwen2VLImageProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
});
}
@@ -732,6 +734,7 @@ impl ImagePreProcessor for Qwen2VLImageProcessor {
pixel_values_list: None,
tgt_sizes: None,
image_sizes_all: None,
+ num_crops: None,
});
}
unreachable!()
diff --git a/mistralrs/examples/gemma3/main.rs b/mistralrs/examples/gemma3/main.rs
new file mode 100644
index 0000000000..83a644d8b7
--- /dev/null
+++ b/mistralrs/examples/gemma3/main.rs
@@ -0,0 +1,36 @@
+use anyhow::Result;
+use mistralrs::{IsqType, TextMessageRole, VisionLoaderType, VisionMessages, VisionModelBuilder};
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ let model = VisionModelBuilder::new("google/gemma-3-12b-it", VisionLoaderType::Gemma3)
+ .with_isq(IsqType::Q4K)
+ .with_logging()
+ .build()
+ .await?;
+
+ let bytes = match reqwest::blocking::get(
+ "https://www.nhmagazine.com/content/uploads/2019/05/mtwashingtonFranconia-2-19-18-108-Edit-Edit.jpg",
+ ) {
+ Ok(http_resp) => http_resp.bytes()?.to_vec(),
+ Err(e) => anyhow::bail!(e),
+ };
+ let image = image::load_from_memory(&bytes)?;
+
+ let messages = VisionMessages::new().add_image_message(
+ TextMessageRole::User,
+ "What is this?",
+ image,
+ &model,
+ )?;
+
+ let response = model.send_chat_request(messages).await?;
+
+ println!("{}", response.choices[0].message.content.as_ref().unwrap());
+ dbg!(
+ response.usage.avg_prompt_tok_per_sec,
+ response.usage.avg_compl_tok_per_sec
+ );
+
+ Ok(())
+}