Skip to content
3 changes: 3 additions & 0 deletions chat_templates/idefics3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"chat_template": "<|begin_of_text|>{% for message in messages %}{{message['role'].capitalize()}}{% if message['content'][0]['type'] == 'image' %}{{':'}}{% else %}{{': '}}{% endif %}{% for line in message['content'] %}{% if line['type'] == 'text' %}{{line['text']}}{% elif line['type'] == 'image' %}{{ '<image>' }}{% endif %}{% endfor %}<end_of_utterance>\n{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
}
16 changes: 16 additions & 0 deletions mistralrs-core/src/pipeline/loaders/vision_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2125,6 +2125,22 @@ impl IsqModelLoader for Idefics3Loader {
Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
// // Attention (vision)
// Regex::new(
// r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
// )?,
// Regex::new(
// r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
// )?,
// Regex::new(
// r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
// )?,
// Regex::new(
// r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
// )?,
// MLP (vision)
// Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
// Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
])
}
}
Expand Down
74 changes: 33 additions & 41 deletions mistralrs-core/src/sampler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

use std::{
collections::{HashMap, HashSet},
iter::zip,
sync::{Arc, Mutex},
};

Expand Down Expand Up @@ -256,53 +255,46 @@ impl Sampler {
})
}

fn get_top_logprobs(&self, probs: &[f32], argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
let mut argsort_indices_sorted = argsort_indices.to_vec();
// Sort by descending prob
argsort_indices_sorted.sort_by(|a, b| {
probs[*b as usize]
.partial_cmp(&probs[*a as usize])
.expect("No ordering.")
});
// These are where the top n are
let top_n_toks_range = 0..self.top_n_logprobs;
// The top n's values
let top_n_logprobs = argsort_indices_sorted[top_n_toks_range.clone()]
.iter()
.map(|x| probs[*x as usize].log(10.0))
.collect::<Vec<_>>();
// Find where they actually are in the logits
let mut top_n_toks = Vec::new();
for val in top_n_toks_range {
top_n_toks.push(argsort_indices[val]);
fn get_top_logprobs(&self, probs: &[f32], _argsort_indices: &[u32]) -> Result<Vec<TopLogprob>> {
// Fast top-k selection without sorting the entire vocabulary
let k = self.top_n_logprobs.min(probs.len());
if k == 0 {
return Ok(Vec::new());
}

// Build (token, probability) pairs
let mut idx_probs: Vec<(u32, f32)> = (0..probs.len() as u32)
.map(|i| (i, probs[i as usize]))
.collect();
// Partition so that the top k probabilities are in the first k positions
let (top_k_slice, _, _) = idx_probs.select_nth_unstable_by(k, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
// Copy and sort only the top k elements by descending probability
let mut top_k: Vec<(u32, f32)> = top_k_slice.to_vec();
top_k.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
// Build the result vector with log10 of probabilities and optional decoding
let mut result = Vec::with_capacity(k);
if let Some(tokenizer) = &self.tokenizer {
let mut bytes = Vec::new();
for tok in &top_n_toks {
bytes.push(
tokenizer
.decode(&[{ *tok }], false)
.map_err(|x| Error::Msg(x.to_string()))?,
);
}

Ok(zip(bytes, zip(top_n_toks, top_n_logprobs))
.map(|(bytes, (token, logprob))| TopLogprob {
for (token, prob) in top_k {
let decoded = tokenizer
.decode(&[token], false)
.map_err(|e| Error::Msg(e.to_string()))?;
result.push(TopLogprob {
token,
logprob,
bytes: Some(bytes),
})
.collect::<Vec<_>>())
logprob: prob.log(10.0),
bytes: Some(decoded),
});
}
} else {
Ok(zip(top_n_toks, top_n_logprobs)
.map(|(token, logprob)| TopLogprob {
for (token, prob) in top_k {
result.push(TopLogprob {
token,
logprob,
logprob: prob.log(10.0),
bytes: None,
})
.collect::<Vec<_>>())
});
}
}
Ok(result)
}

fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,12 @@ impl InputsProcessor for Idefics3ImageProcessor {
.expect("The image token <image> should be present in the text.")
.to_string();
for (i, image_prompt_string) in image_prompt_strings.into_iter().enumerate() {
sample.push_str(&format!("{image_prompt_string}{}", split_sample[i]));
sample.push_str(&format!(
"{image_prompt_string}{}",
split_sample
.get(i + 1)
.expect("Incorrect chat template. Use the one provided in `chat_templates` with the `--chat-template`/`chat_template` settings.")
));
}

seq.set_initial_prompt(sample.clone());
Expand Down
77 changes: 27 additions & 50 deletions mistralrs-core/src/vision_models/idefics3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ mod vision;

use std::any::Any;

use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_core::{DType, Device, Result, Tensor, D};
pub use config::Idefics3Config;
pub use inputs_processor::Idefics3Processor;
use mistralrs_quant::ShardedVarBuilder;
use mistralrs_quant::{NonZeroOp, ShardedVarBuilder};
use vision::{Idefics3Connector, Idefics3VisionTransformer};

use crate::{
Expand Down Expand Up @@ -45,13 +45,11 @@ impl Idefics3Model {
let connector = Idefics3Connector::new(
cfg,
vb_m.pp("connector")
.set_dtype(DType::F32)
.set_device(normal_loading_metadata.real_device.clone()),
)?;
let vision = Idefics3VisionTransformer::new(
&cfg.vision_config,
vb_m.pp("vision_model")
.set_dtype(DType::F32)
.set_device(normal_loading_metadata.real_device.clone()),
)?;
let text_model = Llama::new_inner(
Expand All @@ -73,42 +71,18 @@ impl Idefics3Model {

fn inputs_merger(
&self,
input_ids: &Tensor,
indices: &Tensor,
input_embeds: &Tensor,
image_hidden_states: &Tensor,
) -> Result<Tensor> {
// Docs copied from Transformers impl
/*
This method aims at merging the token embeddings with the image hidden states into one single sequence of vectors that are fed to the transformer LM.
The merging happens as follows:
- The text token sequence is: `tok_1 tok_2 tok_3 <fake_token_around_image> <image> <image> ... <image> <fake_token_around_image> tok_4`.
- We get the image hidden states for the image through the vision encoder (and potentially the perceiver), and that hidden state is then projected into the text embedding space.
We thus have a sequence of image hidden states of size (1, image_seq_len, hidden_dim), where 1 is for batch_size of 1 image and hidden_dim is the hidden_dim of the LM transformer.
- The merging happens so that we obtain the following sequence: `vector_tok_1 vector_tok_2 vector_tok_3 vector_fake_tok_around_image {sequence of image_seq_len image hidden states} vector_fake_toke_around_image vector_tok_4`. That sequence is fed to the LM.
- To fit the format of that sequence, `input_ids`, `input_embeds`, `attention_mask` are all 3 adapted to insert the image hidden states.
*/
let (_, _, vision_hidden_size) = image_hidden_states.dims3()?;
let bs = input_ids.dim(0)?;
let special_image_token_mask = input_ids.eq(self.config.image_token_id as f64)?;
let mut new_inputs_embeds = input_embeds.clone();
let reshaped_image_hidden_states =
image_hidden_states.reshape((bs, (), vision_hidden_size))?;
assert_eq!(input_embeds.dim(0)?, 1);
assert_eq!(reshaped_image_hidden_states.dim(0)?, 1);
let special_image_token_mask = special_image_token_mask.i(0)?.to_vec1::<u8>()?;
let mut image_hidden_state_i = 0;
for (i, v) in special_image_token_mask.iter().enumerate() {
if *v != 0 {
new_inputs_embeds = new_inputs_embeds.slice_assign(
&[&.., &i, &..],
&reshaped_image_hidden_states
.i((.., image_hidden_state_i, ..))?
.unsqueeze(1)?,
)?;
image_hidden_state_i += 1;
}
}
Ok(new_inputs_embeds)
let mut x_flat = input_embeds.flatten_all()?;
let src_flat = image_hidden_states.flatten_all()?;

let current_vals = x_flat.gather(indices, 0)?;
let diff = (src_flat - current_vals)?;
x_flat = x_flat.scatter_add(indices, &diff, 0)?;

x_flat.reshape(input_embeds.shape())
}

#[allow(clippy::too_many_arguments)]
Expand All @@ -123,6 +97,17 @@ impl Idefics3Model {
flash_params: &FlashParams,
) -> Result<Tensor> {
let input_embeds = if let Some(pixel_values) = pixel_values {
let input_embeds = self.text_model.get_input_embeddings(input_ids)?;
let special_image_mask = input_ids
.eq(self.config.image_token_id as f64)?
.unsqueeze(D::Minus1)?
.broadcast_as(input_embeds.shape())?
.to_dtype(DType::U32)?;

let mask_flat = special_image_mask.flatten_all()?;
// Nonzero before vision model to allow async processing all the way through logits.
let indices = mask_flat.nonzero()?.squeeze(1)?;

// == START VISUAL INPUTS INTEGRATION ==
let (batch_size, num_images, _, _, _) = pixel_values.dims5()?;
let mut s = vec![batch_size * num_images];
Expand Down Expand Up @@ -195,23 +180,15 @@ impl Idefics3Model {
let pixel_values = pixel_values.to_dtype(self.dtype)?;

// Get seq from vision encoder
let image_hidden_states = self.vision.forward(
&pixel_values.to_dtype(DType::F32)?,
Some(&patch_attention_mask),
)?;
let image_hidden_states = self
.vision
.forward(&pixel_values, Some(&patch_attention_mask))?;

// Modality proj and perceiver resampling
let image_hidden_states = self.connector.forward(&image_hidden_states)?;

self.inputs_merger(
input_ids,
&self
.text_model
.get_input_embeddings(input_ids)?
.to_dtype(DType::F32)?,
&image_hidden_states,
)?
.to_dtype(self.dtype)?
self.inputs_merger(&indices, &input_embeds, &image_hidden_states)?
.to_dtype(self.dtype)?
} else {
self.text_model.get_input_embeddings(input_ids)?
};
Expand Down
Loading
Loading