Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mistralrs-core/src/vision_models/mllama/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,15 @@ impl MLlamaMlp {
cfg.hidden_size,
cfg.intermediate_size,
&None,
false,
true,
comm,
vb.pp("fc1"),
)?,
fc2: RowParallelLayer::new(
cfg.intermediate_size,
cfg.hidden_size,
&None,
false,
true,
comm,
vb.pp("fc2"),
)?,
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/vision_models/qwen2_5_vl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ impl Qwen2_5VLModel {
let vision = Qwen2_5VLVisionModel::new(
&cfg.vision_config,
vb.pp("visual")
.set_device(normal_loading_metadata.real_device.clone())
.set_dtype(DType::F32),
.set_device(normal_loading_metadata.real_device.clone()),
&normal_loading_metadata.mapper.get_comm_for(0)?,
)?;
let text = Qwen2_5VLTextModel::new(
cfg,
Expand Down
68 changes: 52 additions & 16 deletions mistralrs-core/src/vision_models/qwen2_5_vl/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Linear, Module};
use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder};

use crate::{
layers::{self, Activation, Conv3dConfig, Conv3dNoBias, MatMul, RmsNorm},
Expand Down Expand Up @@ -64,11 +64,38 @@ struct VisionMlp {
}

impl VisionMlp {
fn new(dim: usize, hidden_dim: usize, act: Activation, vb: ShardedVarBuilder) -> Result<Self> {
fn new(
dim: usize,
hidden_dim: usize,
act: Activation,
vb: ShardedVarBuilder,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
Ok(Self {
gate_proj: mistralrs_quant::linear(dim, hidden_dim, &None, vb.pp("gate_proj"))?,
up_proj: mistralrs_quant::linear(dim, hidden_dim, &None, vb.pp("up_proj"))?,
down_proj: mistralrs_quant::linear(hidden_dim, dim, &None, vb.pp("down_proj"))?,
gate_proj: ColumnParallelLayer::new(
dim,
hidden_dim,
&None,
true,
comm,
vb.pp("gate_proj"),
)?,
up_proj: ColumnParallelLayer::new(
dim,
hidden_dim,
&None,
true,
comm,
vb.pp("up_proj"),
)?,
down_proj: RowParallelLayer::new(
hidden_dim,
dim,
&None,
true,
comm,
vb.pp("down_proj"),
)?,
act,
})
}
Expand Down Expand Up @@ -102,11 +129,10 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
}

fn apply_rotary_pos_emb_vision(xs: &Tensor, freqs: &Tensor) -> Result<Tensor> {
let xs = xs.to_dtype(DType::F32)?;
let cos = freqs.cos()?.unsqueeze(D::Minus2)?;
let sin = freqs.sin()?.unsqueeze(D::Minus2)?;
let cos = freqs.cos()?.unsqueeze(D::Minus2)?.to_dtype(xs.dtype())?;
let sin = freqs.sin()?.unsqueeze(D::Minus2)?.to_dtype(xs.dtype())?;

xs.broadcast_mul(&cos)? + rotate_half(&xs)?.broadcast_mul(&sin)
xs.broadcast_mul(&cos)? + rotate_half(xs)?.broadcast_mul(&sin)
}

// https://github.com/huggingface/transformers/blob/a769ed45e17c44fd17b85c025863c4e4f2f73634/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L325
Expand Down Expand Up @@ -182,7 +208,11 @@ struct VisionBlock {
}

impl VisionBlock {
fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
fn new(
cfg: &VisionConfig,
vb: ShardedVarBuilder,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let norm1 = RmsNorm::new(cfg.hidden_size, 1e-6, vb.pp("norm1"))?;
let norm2 = RmsNorm::new(cfg.hidden_size, 1e-6, vb.pp("norm2"))?;

Expand All @@ -191,6 +221,7 @@ impl VisionBlock {
cfg.intermediate_size,
cfg.hidden_act,
vb.pp("mlp"),
comm,
)?;
let attn = VisionAttention::new(cfg.hidden_size, cfg.num_heads, vb.pp("attn"))?;

Expand Down Expand Up @@ -290,10 +321,14 @@ pub struct Qwen2_5VLVisionModel {
}

impl Qwen2_5VLVisionModel {
pub fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
pub fn new(
cfg: &VisionConfig,
vb: ShardedVarBuilder,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let mut blocks = Vec::new();
for i in 0..cfg.depth {
blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")))?);
blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")), comm)?);
}

let patch_merger = PatchMerger::new(
Expand Down Expand Up @@ -454,6 +489,7 @@ impl Qwen2_5VLVisionModel {
rotary_pos_emb = rotary_pos_emb.index_select(&window_index, 0)?;
rotary_pos_emb = rotary_pos_emb.reshape((seq_len, ()))?;
rotary_pos_emb = Tensor::cat(&[&rotary_pos_emb; 2], D::Minus1)?;
rotary_pos_emb = rotary_pos_emb.to_dtype(xs.dtype())?;

let grid_thw = grid_thw.to_device(&Device::Cpu)?;
let cu_seqlens = (grid_thw.i((.., 1))? * grid_thw.i((.., 2))?)?
Expand All @@ -470,13 +506,13 @@ impl Qwen2_5VLVisionModel {
cu_seqlens => {
let mut attention_mask =
Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
.to_dtype(DType::F32)?;
.to_dtype(xs.dtype())?;
for i in 1..cu_seqlens.len() {
let a = cu_seqlens[i - 1] as usize;
let b = cu_seqlens[i] as usize;
attention_mask = attention_mask.slice_assign(
&[&.., &(a..b), &(a..b)],
&Tensor::zeros((1, b - a, b - a), DType::F32, xs.device())?,
&Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
)?;
}
Some(attention_mask)
Expand All @@ -487,13 +523,13 @@ impl Qwen2_5VLVisionModel {
cu_seqlens => {
let mut attention_mask =
Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
.to_dtype(DType::F32)?;
.to_dtype(xs.dtype())?;
for i in 1..cu_seqlens.len() {
let a = cu_seqlens[i - 1] as usize;
let b = cu_seqlens[i] as usize;
attention_mask = attention_mask.slice_assign(
&[&.., &(a..b), &(a..b)],
&Tensor::zeros((1, b - a, b - a), DType::F32, xs.device())?,
&Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
)?;
}
Some(attention_mask)
Expand Down
4 changes: 2 additions & 2 deletions mistralrs-core/src/vision_models/qwen2vl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ impl Qwen2VLModel {
let vision = Qwen2VLVisionModel::new(
&cfg.vision_config,
vb.pp("visual")
.set_device(normal_loading_metadata.real_device.clone())
.set_dtype(DType::F32),
.set_device(normal_loading_metadata.real_device.clone()),
&normal_loading_metadata.mapper.get_comm_for(0)?,
)?;
let text = Qwen2VLTextModel::new(
cfg,
Expand Down
45 changes: 32 additions & 13 deletions mistralrs-core/src/vision_models/qwen2vl/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{LayerNorm, Linear, Module};
use mistralrs_quant::{QuantMethod, ShardedVarBuilder};
use mistralrs_quant::{ColumnParallelLayer, QuantMethod, ShardedVarBuilder};

use crate::{
layers::{self, layer_norm, Activation, Conv3dConfig, Conv3dNoBias, MatMul},
Expand Down Expand Up @@ -63,10 +63,16 @@ struct VisionMlp {
}

impl VisionMlp {
fn new(dim: usize, hidden_dim: usize, act: Activation, vb: ShardedVarBuilder) -> Result<Self> {
fn new(
dim: usize,
hidden_dim: usize,
act: Activation,
vb: ShardedVarBuilder,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
Ok(Self {
fc1: mistralrs_quant::linear(dim, hidden_dim, &None, vb.pp("fc1"))?,
fc2: mistralrs_quant::linear(hidden_dim, dim, &None, vb.pp("fc2"))?,
fc1: ColumnParallelLayer::new(dim, hidden_dim, &None, true, comm, vb.pp("fc1"))?,
fc2: ColumnParallelLayer::new(hidden_dim, dim, &None, true, comm, vb.pp("fc2"))?,
act,
})
}
Expand All @@ -85,11 +91,10 @@ fn rotate_half(xs: &Tensor) -> Result<Tensor> {
}

fn apply_rotary_pos_emb_vision(xs: &Tensor, freqs: &Tensor) -> Result<Tensor> {
let xs = xs.to_dtype(DType::F32)?;
let cos = freqs.cos()?;
let sin = freqs.sin()?;

xs.broadcast_mul(&cos)? + rotate_half(&xs)?.broadcast_mul(&sin)
xs.broadcast_mul(&cos)? + rotate_half(xs)?.broadcast_mul(&sin)
}

// https://github.com/huggingface/transformers/blob/a769ed45e17c44fd17b85c025863c4e4f2f73634/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L325
Expand Down Expand Up @@ -165,12 +170,22 @@ struct VisionBlock {
}

impl VisionBlock {
fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
fn new(
cfg: &VisionConfig,
vb: ShardedVarBuilder,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let norm1 = layer_norm(cfg.embed_dim, 1e-6, vb.pp("norm1"))?;
let norm2 = layer_norm(cfg.embed_dim, 1e-6, vb.pp("norm2"))?;

let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
let mlp = VisionMlp::new(cfg.embed_dim, mlp_hidden_dim, cfg.hidden_act, vb.pp("mlp"))?;
let mlp = VisionMlp::new(
cfg.embed_dim,
mlp_hidden_dim,
cfg.hidden_act,
vb.pp("mlp"),
comm,
)?;
let attn = VisionAttention::new(cfg.embed_dim, cfg.num_heads, vb.pp("attn"))?;

Ok(Self {
Expand Down Expand Up @@ -265,10 +280,14 @@ pub struct Qwen2VLVisionModel {
}

impl Qwen2VLVisionModel {
pub fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result<Self> {
pub fn new(
cfg: &VisionConfig,
vb: ShardedVarBuilder,
comm: &Arc<mistralrs_quant::Comm>,
) -> Result<Self> {
let mut blocks = Vec::new();
for i in 0..cfg.depth {
blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")))?);
blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")), comm)?);
}

let patch_merger = PatchMerger::new(
Expand Down Expand Up @@ -342,7 +361,7 @@ impl Qwen2VLVisionModel {
.unsqueeze(1)?
.repeat((1, 1, 2))?
.unsqueeze(0)?
.to_dtype(DType::F32)?;
.to_dtype(xs.dtype())?;

let grid_thw = grid_thw.to_device(&Device::Cpu)?;
let cu_seqlens = (grid_thw.i((.., 1))? * grid_thw.i((.., 2))?)?
Expand All @@ -359,13 +378,13 @@ impl Qwen2VLVisionModel {
cu_seqlens => {
let mut attention_mask =
Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())?
.to_dtype(DType::F32)?;
.to_dtype(xs.dtype())?;
for i in 1..cu_seqlens.len() {
let a = cu_seqlens[i - 1] as usize;
let b = cu_seqlens[i] as usize;
attention_mask = attention_mask.slice_assign(
&[&.., &(a..b), &(a..b)],
&Tensor::zeros((1, b - a, b - a), DType::F32, xs.device())?,
&Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?,
)?;
}
Some(attention_mask)
Expand Down
Loading