Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1896,10 +1896,17 @@ impl MlpLayer for Mlp {
}
let lhs = MatMul.qmethod_matmul(&xs, &*self.gate)?;
let rhs = MatMul.qmethod_matmul(&xs, &*self.up)?;
let mut res = MatMul.qmethod_matmul(
&candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
&*self.down,
)?;
let mut res = if matches!(
self.act,
Activation::Gelu | Activation::Silu | Activation::Relu
) {
MatMul.qmethod_matmul(
&candle_nn::ops::mul_and_act(&lhs, &rhs, self.act.try_into()?)?,
&*self.down,
)?
} else {
MatMul.qmethod_matmul(&(&lhs.apply(&self.act)? * &rhs)?, &*self.down)?
};
if self.gate.quantized_act_type().is_some() {
res = res.to_dtype(original_dtype)?;
}
Expand Down
17 changes: 2 additions & 15 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,25 +161,12 @@ impl CausalMasker {
return Ok(None);
}

let mut causal_mask = {
let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;

masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?
.to_dtype(DType::U8)?
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
causal_mask = {
let mask = causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)

masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?
masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?.to_dtype(dtype)?
};

Ok(Some(causal_mask))
Expand Down
17 changes: 4 additions & 13 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ struct Attention {
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
use_sliding_window: bool,
sliding_window: Option<usize>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}
Expand Down Expand Up @@ -146,7 +145,6 @@ impl Attention {
head_dim,
rotary_emb,
use_sliding_window: layer_idx % 2 == 0, // Order is SWA, global, SWA
sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
Expand Down Expand Up @@ -249,17 +247,9 @@ impl Attention {
},
None => {
// self.sliding_window is None if !self.use_sliding_window
let (k, v, mask) =
kv_cache.append_sliding_window(&k, &v, mask, self.sliding_window)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
&k,
&v,
mask.as_ref(),
Some(flash_params),
&self.sdpa_params,
)?
Sdpa.run_attention(&q, &k, &v, mask, Some(flash_params), &self.sdpa_params)?
}
};

Expand Down Expand Up @@ -487,9 +477,10 @@ impl Model {
))?),
device: normal_loading_metadata.real_device,
hidden_size: cfg.hidden_size,
cache: EitherCache::Normal(NormalCache::new(
cache: EitherCache::Normal(NormalCache::new_sliding(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
Some(cfg.sliding_window),
)),
max_seq_len: cfg.max_position_embeddings,
sliding_window: cfg.sliding_window,
Expand Down
10 changes: 4 additions & 6 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ struct Attention {
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
sliding_window: Option<usize>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}
Expand Down Expand Up @@ -125,7 +124,6 @@ impl Attention {
num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
head_dim,
rotary_emb,
sliding_window: cfg.sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
Expand Down Expand Up @@ -220,14 +218,13 @@ impl Attention {
}
},
None => {
let (k, v, attn_mask) =
kv_cache.append_sliding_window(&k, &v, attention_mask, self.sliding_window)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
&k,
&v,
attn_mask.as_ref(),
attention_mask,
Some(flash_params),
&self.sdpa_params,
)?
Expand Down Expand Up @@ -468,9 +465,10 @@ impl Model {
lm_head,
sliding_window: cfg.sliding_window,
device: normal_loading_metadata.real_device,
cache: EitherCache::Normal(NormalCache::new(
cache: EitherCache::Normal(NormalCache::new_sliding(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
cfg.sliding_window,
)),
max_seq_len: cfg.max_position_embeddings,
cfg: ModelConfigMetadata {
Expand Down
10 changes: 4 additions & 6 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ struct Attention {
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
sliding_window: Option<usize>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}
Expand Down Expand Up @@ -125,7 +124,6 @@ impl Attention {
num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
head_dim,
rotary_emb,
sliding_window: cfg.sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
Expand Down Expand Up @@ -220,14 +218,13 @@ impl Attention {
}
},
None => {
let (k, v, attn_mask) =
kv_cache.append_sliding_window(&k, &v, attention_mask, self.sliding_window)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
&k,
&v,
attn_mask.as_ref(),
attention_mask,
Some(flash_params),
&self.sdpa_params,
)?
Expand Down Expand Up @@ -605,9 +602,10 @@ impl Model {
lm_head,
sliding_window: cfg.sliding_window,
device: normal_loading_metadata.real_device,
cache: EitherCache::Normal(NormalCache::new(
cache: EitherCache::Normal(NormalCache::new_sliding(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
cfg.sliding_window,
)),
max_seq_len: cfg.max_position_embeddings,
cfg: ModelConfigMetadata {
Expand Down
10 changes: 4 additions & 6 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ struct Attention {
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<PhiRotaryEmbedding>,
sliding_window: Option<usize>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}
Expand Down Expand Up @@ -119,7 +118,6 @@ impl Attention {
num_heads,
num_kv_heads,
head_dim,
sliding_window: cfg.sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: num_heads / num_kv_heads,
Expand Down Expand Up @@ -217,14 +215,13 @@ impl Attention {
}
},
_ => {
let (k, v, attn_mask) =
kv_cache.append_sliding_window(&k, &v, attention_mask, self.sliding_window)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
&k,
&v,
attn_mask.as_ref(),
attention_mask,
Some(flash_params),
&self.sdpa_params,
)?
Expand Down Expand Up @@ -522,9 +519,10 @@ impl Model {
norm,
lm_head,
device: normal_loading_metadata.real_device,
cache: EitherCache::Normal(NormalCache::new(
cache: EitherCache::Normal(NormalCache::new_sliding(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
cfg.sliding_window,
)),
max_seq_len: cfg.max_position_embeddings,
sliding_window: cfg.sliding_window,
Expand Down
10 changes: 4 additions & 6 deletions mistralrs-core/src/models/phi3_5_moe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ struct Attention {
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<PhiRotaryEmbedding>,
sliding_window: Option<usize>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}
Expand Down Expand Up @@ -151,7 +150,6 @@ impl Attention {
num_heads: num_heads / comm.world_size(),
num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
head_dim,
sliding_window: cfg.sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
Expand Down Expand Up @@ -249,14 +247,13 @@ impl Attention {
}
},
_ => {
let (k, v, attn_mask) =
kv_cache.append_sliding_window(&k, &v, attention_mask, self.sliding_window)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
&k,
&v,
attn_mask.as_ref(),
attention_mask,
Some(flash_params),
&self.sdpa_params,
)?
Expand Down Expand Up @@ -676,9 +673,10 @@ impl Model {
norm,
lm_head,
device: normal_loading_metadata.real_device,
cache: EitherCache::Normal(NormalCache::new(
cache: EitherCache::Normal(NormalCache::new_sliding(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
cfg.sliding_window,
)),
max_seq_len: cfg.max_position_embeddings,
sliding_window: cfg.sliding_window,
Expand Down
13 changes: 7 additions & 6 deletions mistralrs-core/src/models/quantized_phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ struct LayerWeights {
head_dim: usize,
cos: Tensor,
sin: Tensor,
sliding_window: usize,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
dtype: DType,
Expand Down Expand Up @@ -133,10 +132,9 @@ impl LayerWeights {
)?
}
None => {
let (k, v, attn_mask) =
kv_cache.append_sliding_window(&k, &v, mask, Some(self.sliding_window))?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(&q, &k, &v, attn_mask.as_ref(), None, &self.sdpa_params)?
Sdpa.run_attention(&q, &k, &v, mask, None, &self.sdpa_params)?
}
};

Expand Down Expand Up @@ -334,7 +332,6 @@ impl ModelConfig::FromGGUF for ModelWeights {
head_dim,
cos: cos.to_device(device)?,
sin: sin.to_device(device)?,
sliding_window: context_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: head_count / head_count_kv,
Expand All @@ -353,7 +350,11 @@ impl ModelConfig::FromGGUF for ModelWeights {
output,
mapper: Some(mapper),
device: device.clone(),
cache: EitherCache::Normal(NormalCache::new(block_count, context_window)),
cache: EitherCache::Normal(NormalCache::new_sliding(
block_count,
context_window,
Some(context_window),
)),
max_seq_len: context_window,
dtype,
})
Expand Down
10 changes: 4 additions & 6 deletions mistralrs-core/src/models/starcoder2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ struct Attention {
num_kv_heads: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
sliding_window: Option<usize>,
paged_attn: Option<PagedAttention>,
sdpa_params: SdpaParams,
}
Expand Down Expand Up @@ -217,7 +216,6 @@ impl Attention {
num_kv_heads: (num_kv_heads / comm.world_size()).max(1),
head_dim,
rotary_emb,
sliding_window: cfg.sliding_window,
paged_attn,
sdpa_params: SdpaParams {
n_kv_groups: mistralrs_quant::compute_n_kv_groups(
Expand Down Expand Up @@ -312,14 +310,13 @@ impl Attention {
}
},
None => {
let (k, v, attn_mask) =
kv_cache.append_sliding_window(&k, &v, attention_mask, self.sliding_window)?;
let (k, v) = kv_cache.append(&k, &v)?;

Sdpa.run_attention(
&q,
&k,
&v,
attn_mask.as_ref(),
attention_mask,
Some(flash_params),
&self.sdpa_params,
)?
Expand Down Expand Up @@ -526,9 +523,10 @@ impl Model {
))?),
sliding_window: cfg.sliding_window,
device: normal_loading_metadata.real_device,
cache: EitherCache::Normal(NormalCache::new(
cache: EitherCache::Normal(NormalCache::new_sliding(
cfg.num_hidden_layers,
cfg.max_position_embeddings,
cfg.sliding_window,
)),
max_seq_len: cfg.max_position_embeddings,
cfg: ModelConfigMetadata {
Expand Down
Loading
Loading