diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 843c54d822..89c6622709 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -282,6 +282,7 @@ pub struct PhiRopeConfig { pub original_max_position_embeddings: usize, pub rope_theta: f64, pub head_dim: usize, + pub partial_rotary_factor: Option, } impl PhiRotaryEmbedding { @@ -294,7 +295,7 @@ impl PhiRotaryEmbedding { dev: &Device, ) -> Result { let max_seq_len = cfg.max_position_embeddings; - let dim = cfg.head_dim; + let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize; // Calculate scale let scale = @@ -356,7 +357,7 @@ impl PhiRotaryEmbedding { fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result { let max_seq_len = cfg.max_position_embeddings; - let dim = cfg.head_dim; + let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize; let inv_freq: Vec<_> = (0..dim) .step_by(2) @@ -391,7 +392,7 @@ impl PhiRotaryEmbedding { dev: &Device, ) -> Result { let max_seq_len = cfg.max_position_embeddings; - let dim = cfg.head_dim; + let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize; if !matches!(scaling_type, ScaledRopeType::Su) { candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`."); @@ -513,7 +514,52 @@ impl PhiRotaryEmbedding { let (sin, cos) = self.get_long_or_short_sin_cos(position_ids); let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]); - if all_same { + + let rot_dim = cos.dim(D::Minus1)? * 2; + + // Case for Phi 3 / Phi 4 mini + if rot_dim != q.dim(D::Minus1)? { + let rot_dim = cos.dim(D::Minus1)? * 2; + let q_rot = q.narrow(D::Minus1, 0, rot_dim)?; + let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?; + let k_rot = k.narrow(D::Minus1, 0, rot_dim)?; + let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?; + + let (q_rot, k_rot) = if all_same { + let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?; + let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?; + (q_embed, k_embed) + } else { + let mut q_embeds = Vec::new(); + let mut k_embeds = Vec::new(); + for (i, offset) in seqlen_offsets.iter().enumerate() { + let cos = cos.narrow(0, *offset, seq_len)?; + let sin = sin.narrow(0, *offset, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope( + &q_rot.i(i)?.unsqueeze(0)?.contiguous()?, + &cos, + &sin, + )?; + let k_embed = candle_nn::rotary_emb::rope( + &k_rot.i(i)?.unsqueeze(0)?.contiguous()?, + &cos, + &sin, + )?; + q_embeds.push(q_embed); + k_embeds.push(k_embed); + } + let q_rot = Tensor::cat(&q_embeds, 0)?; + let k_rot = Tensor::cat(&k_embeds, 0)?; + (q_rot, k_rot) + }; + + Ok(( + Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?, + Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?, + )) + } else if all_same { let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?; let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?; let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; diff --git a/mistralrs-core/src/layers_masker.rs b/mistralrs-core/src/layers_masker.rs index 6607177fb3..d292b290ce 100644 --- a/mistralrs-core/src/layers_masker.rs +++ b/mistralrs-core/src/layers_masker.rs @@ -161,12 +161,25 @@ impl CausalMasker { return Ok(None); } - let causal_mask = { + let mut causal_mask = { let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?; let diagonal = past_kv_len as isize - sliding_window as isize - 1; let context_mask = apply_tril(&mask.ones_like()?, diagonal)?; - masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?.to_dtype(dtype)? + masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)? + .to_dtype(DType::U8)? + }; + + let zero = Tensor::new(0.0f32, input_ids.device())?; + causal_mask = { + let mask = causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?; + // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf) + + masked_fill( + &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?, + &mask, + f32::NEG_INFINITY, + )? }; Ok(Some(causal_mask)) diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index a12a67707c..157c814417 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -53,6 +53,7 @@ pub struct Config { pub quantization_config: Option, #[serde(default = "word_emb_default")] pub tie_word_embeddings: bool, + pub partial_rotary_factor: Option, } impl From for PhiRopeConfig { @@ -63,6 +64,7 @@ impl From for PhiRopeConfig { original_max_position_embeddings: val.original_max_position_embeddings, rope_theta: val.rope_theta, head_dim: val.hidden_size / val.num_attention_heads, + partial_rotary_factor: val.partial_rotary_factor, } } } diff --git a/mistralrs-core/src/models/phi3_5_moe.rs b/mistralrs-core/src/models/phi3_5_moe.rs index 9e052e8192..5683b6c749 100644 --- a/mistralrs-core/src/models/phi3_5_moe.rs +++ b/mistralrs-core/src/models/phi3_5_moe.rs @@ -66,6 +66,7 @@ impl From for PhiRopeConfig { original_max_position_embeddings: val.original_max_position_embeddings, rope_theta: val.rope_theta, head_dim: val.hidden_size / val.num_attention_heads, + partial_rotary_factor: None, } } } diff --git a/mistralrs-core/src/pipeline/cache_manager.rs b/mistralrs-core/src/pipeline/cache_manager.rs index dce1f25a61..dbaa72d7d1 100644 --- a/mistralrs-core/src/pipeline/cache_manager.rs +++ b/mistralrs-core/src/pipeline/cache_manager.rs @@ -116,6 +116,7 @@ impl SingleCache { let ad = Tensor::zeros(shape, src.dtype(), src.device())?; self.all_data = Some(ad); }; + // Expand kv cache if self.current_seq_len + seq_len > self.capacity_seq_len { let diff = self.current_seq_len + seq_len - self.capacity_seq_len; @@ -134,7 +135,9 @@ impl SingleCache { ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?; self.all_data = Some(ad); } + let ad = self.all_data.as_mut().unwrap(); + ad.slice_set(src, self.dim, self.current_seq_len)?; self.current_seq_len += seq_len; Ok(()) @@ -152,16 +155,18 @@ pub struct RotatingCache { // max_seq_len is the size of the rotating buffer, it is actually allowed for the full // sequence to grow past this limit. pub max_seq_len: usize, + pub capacity_seq_len: usize, } impl RotatingCache { - pub fn new(dim: usize, max_seq_len: usize) -> Self { + pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self { Self { all_data: None, dim, offset: 0, current_seq_len: 0, max_seq_len, + capacity_seq_len, } } @@ -224,10 +229,32 @@ impl RotatingCache { // self.all_data.get_or_insert_with. if self.all_data.is_none() { let mut shape = src.dims().to_vec(); - shape[self.dim] = self.max_seq_len; + shape[self.dim] = self.capacity_seq_len; let ad = Tensor::zeros(shape, src.dtype(), src.device())?; self.all_data = Some(ad) }; + + // Expand kv cache, this case is a little more complex. + if self.current_seq_len + seq_len > self.capacity_seq_len + && self.current_seq_len + seq_len < self.max_seq_len + { + let diff = self.current_seq_len + seq_len - self.capacity_seq_len; + let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE); + self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE; + if self.capacity_seq_len > self.max_seq_len { + candle_core::bail!( + "kv-cache: requested capacity ({}) above max seq len ({})", + self.capacity_seq_len, + self.max_seq_len + ) + } + let mut shape = src.dims().to_vec(); + shape[self.dim] = self.capacity_seq_len; + let ad = Tensor::zeros(shape, src.dtype(), src.device())?; + ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?; + self.all_data = Some(ad); + } + let ad = self.all_data.as_mut().unwrap(); self.current_seq_len += seq_len; @@ -278,9 +305,9 @@ impl KvCache { Self::Normal { k, v } } - pub fn new_rotating(dim: usize, sliding_window: usize) -> Self { - let k = RotatingCache::new(dim, sliding_window); - let v = RotatingCache::new(dim, sliding_window); + pub fn new_rotating(dim: usize, sliding_window: usize, capacity_seq_len: usize) -> Self { + let k = RotatingCache::new(dim, sliding_window, capacity_seq_len); + let v = RotatingCache::new(dim, sliding_window, capacity_seq_len); Self::Rotating { k, v } } @@ -409,7 +436,8 @@ impl NormalCache { Some(sliding_window) => Arc::new(Mutex::new(Self(vec![ KvCache::new_rotating( 2, - sliding_window + sliding_window, + Self::CACHE_GROW_SIZE ); len ]))), @@ -432,7 +460,7 @@ impl NormalCache { caches.push(KvCache::new_normal(2, max_seq_len, Self::CACHE_GROW_SIZE)); } NormalCacheType::SlidingWindow { window } => { - caches.push(KvCache::new_rotating(2, window)); + caches.push(KvCache::new_rotating(2, window, Self::CACHE_GROW_SIZE)); } } } @@ -532,6 +560,7 @@ impl CacheManager for NormalCa let template_cache_csl = old_k.current_seq_len; let template_cache_msl = old_k.max_seq_len; let template_cache_offset = old_k.offset; + let template_cache_capsl = old_k.capacity_seq_len; caches.push(KvCache::Rotating { k: RotatingCache { @@ -540,6 +569,7 @@ impl CacheManager for NormalCa current_seq_len: template_cache_csl, max_seq_len: template_cache_msl, offset: template_cache_offset, + capacity_seq_len: template_cache_capsl, }, v: RotatingCache { all_data: v_cache.map(|x| x.contiguous().unwrap()), @@ -547,6 +577,7 @@ impl CacheManager for NormalCa current_seq_len: template_cache_csl, max_seq_len: template_cache_msl, offset: template_cache_offset, + capacity_seq_len: template_cache_capsl, }, }); } @@ -620,6 +651,7 @@ impl CacheManager for NormalCa current_seq_len: cache_k.current_seq_len, max_seq_len: cache_k.max_seq_len, offset: cache_k.offset, + capacity_seq_len: cache_k.capacity_seq_len, }, v: RotatingCache { all_data: Some(v), @@ -627,6 +659,7 @@ impl CacheManager for NormalCa current_seq_len: cache_v.current_seq_len, max_seq_len: cache_v.max_seq_len, offset: cache_v.offset, + capacity_seq_len: cache_v.capacity_seq_len, }, }); } @@ -731,6 +764,7 @@ impl CacheManager for NormalCa current_seq_len: 0, max_seq_len: template_cache_msl, offset: 0, + capacity_seq_len: 0, }, v: RotatingCache { all_data: None, @@ -738,6 +772,7 @@ impl CacheManager for NormalCa current_seq_len: 0, max_seq_len: template_cache_msl, offset: 0, + capacity_seq_len: 0, }, }; *layer = cache; diff --git a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs index a0852fe5f6..bd55680c15 100644 --- a/mistralrs-core/src/pipeline/loaders/normal_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/normal_loaders.rs @@ -1546,6 +1546,7 @@ struct Phi3BasicConfig { quantization_config: Option, #[serde(default = "word_emb_default")] tie_word_embeddings: bool, + partial_rotary_factor: Option, } impl Phi3BasicConfig { @@ -1570,6 +1571,7 @@ impl Phi3BasicConfig { sliding_window: basic_config.sliding_window, quantization_config: basic_config.quantization_config, tie_word_embeddings: basic_config.tie_word_embeddings, + partial_rotary_factor: basic_config.partial_rotary_factor, }) } } diff --git a/mistralrs-core/src/prefix_cacher.rs b/mistralrs-core/src/prefix_cacher.rs index 17b9be6658..b20e84d2e3 100644 --- a/mistralrs-core/src/prefix_cacher.rs +++ b/mistralrs-core/src/prefix_cacher.rs @@ -117,6 +117,7 @@ impl PrefixCacheManagerV2 { current_seq_len: k.current_seq_len, max_seq_len: k.max_seq_len, offset: k.offset, + capacity_seq_len: k.capacity_seq_len, }, v: RotatingCache { all_data: v.all_data.as_ref().map(|x| x.to_device(device).unwrap()), @@ -124,6 +125,7 @@ impl PrefixCacheManagerV2 { current_seq_len: v.current_seq_len, max_seq_len: v.max_seq_len, offset: v.offset, + capacity_seq_len: v.capacity_seq_len, }, } } diff --git a/mistralrs-core/src/vision_models/phi3/mod.rs b/mistralrs-core/src/vision_models/phi3/mod.rs index b28b3552bf..d73ce1a6c6 100644 --- a/mistralrs-core/src/vision_models/phi3/mod.rs +++ b/mistralrs-core/src/vision_models/phi3/mod.rs @@ -90,6 +90,7 @@ impl From for PhiRopeConfig { original_max_position_embeddings: val.original_max_position_embeddings, rope_theta: val.rope_theta, head_dim: val.hidden_size / val.num_attention_heads, + partial_rotary_factor: None, } } }