diff --git a/mistralrs-core/src/vision_models/mllama/vision.rs b/mistralrs-core/src/vision_models/mllama/vision.rs index 3332cf13f2..cadbea45c0 100644 --- a/mistralrs-core/src/vision_models/mllama/vision.rs +++ b/mistralrs-core/src/vision_models/mllama/vision.rs @@ -260,7 +260,7 @@ impl MLlamaMlp { cfg.hidden_size, cfg.intermediate_size, &None, - false, + true, comm, vb.pp("fc1"), )?, @@ -268,7 +268,7 @@ impl MLlamaMlp { cfg.intermediate_size, cfg.hidden_size, &None, - false, + true, comm, vb.pp("fc2"), )?, diff --git a/mistralrs-core/src/vision_models/qwen2_5_vl/mod.rs b/mistralrs-core/src/vision_models/qwen2_5_vl/mod.rs index 3372e1ed2d..249df21a2d 100644 --- a/mistralrs-core/src/vision_models/qwen2_5_vl/mod.rs +++ b/mistralrs-core/src/vision_models/qwen2_5_vl/mod.rs @@ -50,8 +50,8 @@ impl Qwen2_5VLModel { let vision = Qwen2_5VLVisionModel::new( &cfg.vision_config, vb.pp("visual") - .set_device(normal_loading_metadata.real_device.clone()) - .set_dtype(DType::F32), + .set_device(normal_loading_metadata.real_device.clone()), + &normal_loading_metadata.mapper.get_comm_for(0)?, )?; let text = Qwen2_5VLTextModel::new( cfg, diff --git a/mistralrs-core/src/vision_models/qwen2_5_vl/vision.rs b/mistralrs-core/src/vision_models/qwen2_5_vl/vision.rs index 4ab66cbc86..5ed9cd5d6e 100644 --- a/mistralrs-core/src/vision_models/qwen2_5_vl/vision.rs +++ b/mistralrs-core/src/vision_models/qwen2_5_vl/vision.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Linear, Module}; -use mistralrs_quant::{QuantMethod, ShardedVarBuilder}; +use mistralrs_quant::{ColumnParallelLayer, QuantMethod, RowParallelLayer, ShardedVarBuilder}; use crate::{ layers::{self, Activation, Conv3dConfig, Conv3dNoBias, MatMul, RmsNorm}, @@ -64,11 +64,38 @@ struct VisionMlp { } impl VisionMlp { - fn new(dim: usize, hidden_dim: usize, act: Activation, vb: ShardedVarBuilder) -> Result { + fn new( + dim: usize, + hidden_dim: usize, + act: Activation, + vb: ShardedVarBuilder, + comm: &Arc, + ) -> Result { Ok(Self { - gate_proj: mistralrs_quant::linear(dim, hidden_dim, &None, vb.pp("gate_proj"))?, - up_proj: mistralrs_quant::linear(dim, hidden_dim, &None, vb.pp("up_proj"))?, - down_proj: mistralrs_quant::linear(hidden_dim, dim, &None, vb.pp("down_proj"))?, + gate_proj: ColumnParallelLayer::new( + dim, + hidden_dim, + &None, + true, + comm, + vb.pp("gate_proj"), + )?, + up_proj: ColumnParallelLayer::new( + dim, + hidden_dim, + &None, + true, + comm, + vb.pp("up_proj"), + )?, + down_proj: RowParallelLayer::new( + hidden_dim, + dim, + &None, + true, + comm, + vb.pp("down_proj"), + )?, act, }) } @@ -102,11 +129,10 @@ fn rotate_half(xs: &Tensor) -> Result { } fn apply_rotary_pos_emb_vision(xs: &Tensor, freqs: &Tensor) -> Result { - let xs = xs.to_dtype(DType::F32)?; - let cos = freqs.cos()?.unsqueeze(D::Minus2)?; - let sin = freqs.sin()?.unsqueeze(D::Minus2)?; + let cos = freqs.cos()?.unsqueeze(D::Minus2)?.to_dtype(xs.dtype())?; + let sin = freqs.sin()?.unsqueeze(D::Minus2)?.to_dtype(xs.dtype())?; - xs.broadcast_mul(&cos)? + rotate_half(&xs)?.broadcast_mul(&sin) + xs.broadcast_mul(&cos)? + rotate_half(xs)?.broadcast_mul(&sin) } // https://github.com/huggingface/transformers/blob/a769ed45e17c44fd17b85c025863c4e4f2f73634/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L325 @@ -182,7 +208,11 @@ struct VisionBlock { } impl VisionBlock { - fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result { + fn new( + cfg: &VisionConfig, + vb: ShardedVarBuilder, + comm: &Arc, + ) -> Result { let norm1 = RmsNorm::new(cfg.hidden_size, 1e-6, vb.pp("norm1"))?; let norm2 = RmsNorm::new(cfg.hidden_size, 1e-6, vb.pp("norm2"))?; @@ -191,6 +221,7 @@ impl VisionBlock { cfg.intermediate_size, cfg.hidden_act, vb.pp("mlp"), + comm, )?; let attn = VisionAttention::new(cfg.hidden_size, cfg.num_heads, vb.pp("attn"))?; @@ -290,10 +321,14 @@ pub struct Qwen2_5VLVisionModel { } impl Qwen2_5VLVisionModel { - pub fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result { + pub fn new( + cfg: &VisionConfig, + vb: ShardedVarBuilder, + comm: &Arc, + ) -> Result { let mut blocks = Vec::new(); for i in 0..cfg.depth { - blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")))?); + blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")), comm)?); } let patch_merger = PatchMerger::new( @@ -454,6 +489,7 @@ impl Qwen2_5VLVisionModel { rotary_pos_emb = rotary_pos_emb.index_select(&window_index, 0)?; rotary_pos_emb = rotary_pos_emb.reshape((seq_len, ()))?; rotary_pos_emb = Tensor::cat(&[&rotary_pos_emb; 2], D::Minus1)?; + rotary_pos_emb = rotary_pos_emb.to_dtype(xs.dtype())?; let grid_thw = grid_thw.to_device(&Device::Cpu)?; let cu_seqlens = (grid_thw.i((.., 1))? * grid_thw.i((.., 2))?)? @@ -470,13 +506,13 @@ impl Qwen2_5VLVisionModel { cu_seqlens => { let mut attention_mask = Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())? - .to_dtype(DType::F32)?; + .to_dtype(xs.dtype())?; for i in 1..cu_seqlens.len() { let a = cu_seqlens[i - 1] as usize; let b = cu_seqlens[i] as usize; attention_mask = attention_mask.slice_assign( &[&.., &(a..b), &(a..b)], - &Tensor::zeros((1, b - a, b - a), DType::F32, xs.device())?, + &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?, )?; } Some(attention_mask) @@ -487,13 +523,13 @@ impl Qwen2_5VLVisionModel { cu_seqlens => { let mut attention_mask = Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())? - .to_dtype(DType::F32)?; + .to_dtype(xs.dtype())?; for i in 1..cu_seqlens.len() { let a = cu_seqlens[i - 1] as usize; let b = cu_seqlens[i] as usize; attention_mask = attention_mask.slice_assign( &[&.., &(a..b), &(a..b)], - &Tensor::zeros((1, b - a, b - a), DType::F32, xs.device())?, + &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?, )?; } Some(attention_mask) diff --git a/mistralrs-core/src/vision_models/qwen2vl/mod.rs b/mistralrs-core/src/vision_models/qwen2vl/mod.rs index 2790244aad..2af67e360e 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/mod.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/mod.rs @@ -50,8 +50,8 @@ impl Qwen2VLModel { let vision = Qwen2VLVisionModel::new( &cfg.vision_config, vb.pp("visual") - .set_device(normal_loading_metadata.real_device.clone()) - .set_dtype(DType::F32), + .set_device(normal_loading_metadata.real_device.clone()), + &normal_loading_metadata.mapper.get_comm_for(0)?, )?; let text = Qwen2VLTextModel::new( cfg, diff --git a/mistralrs-core/src/vision_models/qwen2vl/vision.rs b/mistralrs-core/src/vision_models/qwen2vl/vision.rs index c966d73dd5..1c7e6fe9da 100644 --- a/mistralrs-core/src/vision_models/qwen2vl/vision.rs +++ b/mistralrs-core/src/vision_models/qwen2vl/vision.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use candle_core::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{LayerNorm, Linear, Module}; -use mistralrs_quant::{QuantMethod, ShardedVarBuilder}; +use mistralrs_quant::{ColumnParallelLayer, QuantMethod, ShardedVarBuilder}; use crate::{ layers::{self, layer_norm, Activation, Conv3dConfig, Conv3dNoBias, MatMul}, @@ -63,10 +63,16 @@ struct VisionMlp { } impl VisionMlp { - fn new(dim: usize, hidden_dim: usize, act: Activation, vb: ShardedVarBuilder) -> Result { + fn new( + dim: usize, + hidden_dim: usize, + act: Activation, + vb: ShardedVarBuilder, + comm: &Arc, + ) -> Result { Ok(Self { - fc1: mistralrs_quant::linear(dim, hidden_dim, &None, vb.pp("fc1"))?, - fc2: mistralrs_quant::linear(hidden_dim, dim, &None, vb.pp("fc2"))?, + fc1: ColumnParallelLayer::new(dim, hidden_dim, &None, true, comm, vb.pp("fc1"))?, + fc2: ColumnParallelLayer::new(hidden_dim, dim, &None, true, comm, vb.pp("fc2"))?, act, }) } @@ -85,11 +91,10 @@ fn rotate_half(xs: &Tensor) -> Result { } fn apply_rotary_pos_emb_vision(xs: &Tensor, freqs: &Tensor) -> Result { - let xs = xs.to_dtype(DType::F32)?; let cos = freqs.cos()?; let sin = freqs.sin()?; - xs.broadcast_mul(&cos)? + rotate_half(&xs)?.broadcast_mul(&sin) + xs.broadcast_mul(&cos)? + rotate_half(xs)?.broadcast_mul(&sin) } // https://github.com/huggingface/transformers/blob/a769ed45e17c44fd17b85c025863c4e4f2f73634/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L325 @@ -165,12 +170,22 @@ struct VisionBlock { } impl VisionBlock { - fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result { + fn new( + cfg: &VisionConfig, + vb: ShardedVarBuilder, + comm: &Arc, + ) -> Result { let norm1 = layer_norm(cfg.embed_dim, 1e-6, vb.pp("norm1"))?; let norm2 = layer_norm(cfg.embed_dim, 1e-6, vb.pp("norm2"))?; let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize; - let mlp = VisionMlp::new(cfg.embed_dim, mlp_hidden_dim, cfg.hidden_act, vb.pp("mlp"))?; + let mlp = VisionMlp::new( + cfg.embed_dim, + mlp_hidden_dim, + cfg.hidden_act, + vb.pp("mlp"), + comm, + )?; let attn = VisionAttention::new(cfg.embed_dim, cfg.num_heads, vb.pp("attn"))?; Ok(Self { @@ -265,10 +280,14 @@ pub struct Qwen2VLVisionModel { } impl Qwen2VLVisionModel { - pub fn new(cfg: &VisionConfig, vb: ShardedVarBuilder) -> Result { + pub fn new( + cfg: &VisionConfig, + vb: ShardedVarBuilder, + comm: &Arc, + ) -> Result { let mut blocks = Vec::new(); for i in 0..cfg.depth { - blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")))?); + blocks.push(VisionBlock::new(cfg, vb.pp(format!("blocks.{i}")), comm)?); } let patch_merger = PatchMerger::new( @@ -342,7 +361,7 @@ impl Qwen2VLVisionModel { .unsqueeze(1)? .repeat((1, 1, 2))? .unsqueeze(0)? - .to_dtype(DType::F32)?; + .to_dtype(xs.dtype())?; let grid_thw = grid_thw.to_device(&Device::Cpu)?; let cu_seqlens = (grid_thw.i((.., 1))? * grid_thw.i((.., 2))?)? @@ -359,13 +378,13 @@ impl Qwen2VLVisionModel { cu_seqlens => { let mut attention_mask = Tensor::full(f32::MIN, (1, seq_len, seq_len), xs.device())? - .to_dtype(DType::F32)?; + .to_dtype(xs.dtype())?; for i in 1..cu_seqlens.len() { let a = cu_seqlens[i - 1] as usize; let b = cu_seqlens[i] as usize; attention_mask = attention_mask.slice_assign( &[&.., &(a..b), &(a..b)], - &Tensor::zeros((1, b - a, b - a), DType::F32, xs.device())?, + &Tensor::zeros((1, b - a, b - a), xs.dtype(), xs.device())?, )?; } Some(attention_mask)