Skip to content
135 changes: 134 additions & 1 deletion mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use crate::layers_utils::repeat_kv;
use crate::{
amoe::{AnyMoeTrainableLayer, MlpLayer},
gguf::Content,
models::llama,
models::{llama, smollm3},
ops::SplitOp,
vision_models::{
gemma3::config::Gemma3TextConfig,
Expand Down Expand Up @@ -970,6 +970,139 @@ impl Llama3RotaryEmbedding {
}
}

/// RoPE for SmolLm3
#[derive(Debug, Clone)]
pub struct SmolLm3RotaryEmbedding(RotaryEmbedding);

#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub enum SmolLm3RopeType {
#[serde(rename = "llama3")]
Llama3,
#[serde(rename = "linear")]
Linear,
#[default]
#[serde(rename = "default")]
Default,
}

#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct SmolLm3RopeConfig {
pub factor: f32,
pub low_freq_factor: Option<f32>,
pub high_freq_factor: Option<f32>,
pub original_max_position_embeddings: Option<usize>,
pub rope_type: SmolLm3RopeType,
}

fn calculate_default_inv_freq_smollm3(cfg: &smollm3::Config) -> Vec<f32> {
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
(0..head_dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
.collect()
}

impl SmolLm3RotaryEmbedding {
pub fn new_llama3(
dtype: DType,
cfg: &smollm3::Config,
dev: &Device,
is_gpt_neox: bool,
) -> Result<Self> {
match &cfg.rope_scaling {
None
| Some(SmolLm3RopeConfig {
rope_type: SmolLm3RopeType::Default,
..
}) => Ok(Self(RotaryEmbedding::new(
cfg.rope_theta,
cfg.hidden_size / cfg.num_attention_heads,
cfg.max_position_embeddings,
dev,
is_gpt_neox,
dtype,
)?)),
Some(SmolLm3RopeConfig {
rope_type: SmolLm3RopeType::Llama3,
factor,
low_freq_factor,
high_freq_factor,
original_max_position_embeddings,
}) => {
let low_freq_factor = low_freq_factor.context("low_freq_factor is required")?;
let high_freq_factor = high_freq_factor.context("high_freq_factor is required")?;
let original_max_position_embeddings = original_max_position_embeddings
.context("original_max_position_embeddings is required")?;

let low_freq_wavelen = original_max_position_embeddings as f32 / low_freq_factor;
let high_freq_wavelen = original_max_position_embeddings as f32 / high_freq_factor;

let inv_freq = calculate_default_inv_freq_smollm3(cfg)
.into_iter()
.map(|freq| {
let wavelen = 2. * PI / freq;
if wavelen < high_freq_wavelen {
freq
} else if wavelen > low_freq_wavelen {
freq / *factor
} else {
let smooth = (original_max_position_embeddings as f32 / wavelen
- low_freq_factor)
/ (high_freq_factor - low_freq_factor);
(1. - smooth) * freq / *factor + smooth * freq
}
})
.collect::<Vec<_>>();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
Ok(Self(RotaryEmbedding {
sin,
cos,
is_gpt_neox,
}))
}
Some(SmolLm3RopeConfig {
rope_type: SmolLm3RopeType::Linear,
factor,
..
}) => {
let inv_freq_vec = calculate_default_inv_freq_smollm3(cfg)
.into_iter()
.map(|freq| freq / *factor)
.collect::<Vec<_>>();
let inv_freq_len = inv_freq_vec.len();
let inv_freq = Tensor::from_vec(inv_freq_vec, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, cfg.max_position_embeddings as u32, dev)?
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?.to_dtype(dtype)?;
let cos = freqs.cos()?.to_dtype(dtype)?;
Ok(Self(RotaryEmbedding {
sin,
cos,
is_gpt_neox,
}))
}
}
}

pub fn forward(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offsets: &[usize],
) -> Result<(Tensor, Tensor)> {
self.0.forward(q, k, seqlen_offsets)
}
}
Comment on lines +973 to +1104
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Refactor to eliminate code duplication with Llama3RotaryEmbedding.

The SmolLm3RotaryEmbedding implementation is nearly identical to Llama3RotaryEmbedding (lines 657-971). The only differences are the config type and naming. This violates the DRY principle and will make maintenance more difficult.

Consider creating a generic implementation that both can use:

  • Extract the common rope logic into a generic function or trait
  • Pass the config-specific parameters (rope_theta, hidden_size, etc.) as arguments
  • Create thin wrappers for model-specific implementations
🤖 Prompt for AI Agents
In mistralrs-core/src/layers.rs between lines 973 and 1104, the
SmolLm3RotaryEmbedding implementation duplicates nearly all logic from
Llama3RotaryEmbedding (lines 657-971), differing only in config types and
naming. To fix this, refactor by extracting the shared rope calculation and
embedding creation logic into a generic function or trait that accepts config
parameters like rope_theta, hidden_size, and max_position_embeddings as
arguments. Then, replace the current SmolLm3RotaryEmbedding and
Llama3RotaryEmbedding new methods with thin wrappers that call this generic
implementation, eliminating duplication and improving maintainability.


// https://github.com/huggingface/transformers/blob/f2c388e3f946862f657acc1e21b272ec946fc66c/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L107
#[derive(Debug, Clone)]
pub struct Qwen2VLRotaryEmbedding {
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ pub(crate) mod quantized_starcoder2;
pub(crate) mod qwen2;
pub(crate) mod qwen3;
pub(crate) mod qwen3_moe;
pub(crate) mod smollm3;
pub(crate) mod starcoder2;
Loading
Loading