Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 99 additions & 85 deletions backends/candle/src/models/qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct Qwen3Config {
pub rope_theta: f32,
pub sliding_window: Option<usize>,
pub use_sliding_window: bool,
pub eos_token_id: usize,
}

struct Qwen3Attention {
Expand Down Expand Up @@ -164,8 +165,8 @@ impl Qwen3Attention {
.concat(),
)?;

let (q, _res) = self.q_norm.forward(&q, None)?;
let (k, _res) = self.k_norm.forward(&k, None)?;
let (q, _) = self.q_norm.forward(&q, None)?;
let (k, _) = self.k_norm.forward(&k, None)?;

let q = q.transpose(1, 2)?;
let k = k.transpose(1, 2)?;
Expand Down Expand Up @@ -355,16 +356,21 @@ impl Qwen3Layer {
) -> Result<Tensor> {
let _enter = self.span.enter();

let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, None)?;
let (normed_hidden_states, residual) =
self.input_layer_norm.forward(hidden_states, None)?;

let attn_output =
self.attention
.forward(&normed_hidden_states, attention_bias, cos, sin)?;

let (normed_attn_res_output, attn_res) = self
.post_attention_layer_norm
.forward(&attn_output, Some(&res))?;
.forward(&attn_output, Some(&residual))?;

let mlp_output = self.mlp.forward(&normed_attn_res_output)?;

let output = (&mlp_output + &attn_res)?;

Ok(output)
}
}
Expand All @@ -378,6 +384,7 @@ pub struct Qwen3Model {
pool: Pool,
pub device: Device,
num_attention_heads: usize,
pad_token_id: u32,

span: tracing::Span,
}
Expand Down Expand Up @@ -427,12 +434,35 @@ impl Qwen3Model {
rotary_cache,
rotary_dim,
pool,
pad_token_id: config.eos_token_id as u32,
device: vb.device().clone(),
num_attention_heads: config.num_attention_heads,
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}

fn get_causal_attention_bias(&self, attention_bias: Tensor) -> Result<Tensor> {
let (bs, dim, seq_len, _) = attention_bias.dims4()?;

let device = attention_bias.device();

let mask: Vec<u8> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| (j > i) as u8))
.collect();

let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), &Device::Cpu)?;
let causal_mask = causal_mask.expand(&[bs, dim, seq_len, seq_len])?;

let negatives = Tensor::full(f32::MIN, attention_bias.shape(), &Device::Cpu)?;
let zeros = Tensor::zeros_like(&attention_bias)?.to_device(&Device::Cpu)?;

let causal_mask = causal_mask
.where_cond(&negatives, &zeros)?
.to_device(device)?;

attention_bias.broadcast_add(&causal_mask)
}

pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
let _enter = self.span.enter();

Expand All @@ -441,93 +471,77 @@ impl Qwen3Model {

let shape = (batch_size, max_length);

let (input_ids, position_ids, input_lengths, attention_bias, _attention_mask) =
if batch_size > 1 {
// Prepare padded batch
let elems = batch_size * max_length;

let mut input_ids = Vec::with_capacity(elems);
let mut position_ids = Vec::with_capacity(elems);
let mut attention_mask = Vec::with_capacity(elems);
let mut attention_bias = Vec::with_capacity(elems);
let mut input_lengths = Vec::with_capacity(batch_size);
let mut masking = false;

for i in 0..batch_size {
let start = batch.cumulative_seq_lengths[i] as usize;
let end = batch.cumulative_seq_lengths[i + 1] as usize;
let seq_length = end - start;
input_lengths.push(seq_length);

// Input ids
for j in start..end {
input_ids.push(batch.input_ids[j]);
position_ids.push(batch.position_ids[j]);
attention_mask.push(1.0_f32);
attention_bias.push(0.0);
}
let (input_ids, position_ids, input_lengths, attention_bias) = if batch_size > 1 {
// Prepare padded batch
let elems = batch_size * max_length;

let mut input_ids = Vec::with_capacity(elems);
let mut position_ids = Vec::with_capacity(elems);
let mut attention_bias = Vec::with_capacity(elems);
let mut input_lengths = Vec::with_capacity(batch_size);
let mut masking = false;

for i in 0..batch_size {
let start = batch.cumulative_seq_lengths[i] as usize;
let end = batch.cumulative_seq_lengths[i + 1] as usize;
let seq_length = end - start;
input_lengths.push(seq_length);

for j in start..end {
input_ids.push(batch.input_ids[j]);
position_ids.push(batch.position_ids[j]);
attention_bias.push(0.0);
}

// Pad to max_length
for _ in seq_length..max_length {
input_ids.push(0);
position_ids.push(0);
attention_mask.push(0.0_f32);
attention_bias.push(f32::NEG_INFINITY);
masking = true;
let padding = max_length - seq_length;
if padding > 0 {
masking = true;
for _ in 0..padding {
input_ids.insert(start, self.pad_token_id);
position_ids.insert(start, 0);
attention_bias.insert(start, f32::MIN);
}
}
}

let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?;
let attention_mask = if masking {
Some(Tensor::from_vec(attention_mask, shape, &self.device)?)
} else {
None
};

let attention_bias = if masking {
let attention_bias = Tensor::from_vec(
attention_bias,
(batch_size, 1, 1, max_length),
&self.device,
)?;
// Broadcast once instead of at every layer
let attention_bias = attention_bias
.broadcast_as((
batch_size,
self.num_attention_heads,
max_length,
max_length,
))?
.contiguous()?;
Some(attention_bias)
} else {
None
};

(
input_ids,
position_ids,
input_lengths,
attention_bias,
attention_mask,
)
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?;

let attention_bias = if masking {
let attention_bias =
Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?;
// Broadcast once instead of at every layer
let attention_bias = attention_bias
.broadcast_as((batch_size, self.num_attention_heads, max_length, max_length))?
.contiguous()?;
Some(attention_bias)
} else {
let input_ids = Tensor::from_vec(
batch.input_ids.clone(),
(1, batch.input_ids.len()),
&self.device,
)?;
let position_ids = Tensor::from_vec(
batch.position_ids.clone(),
(1, batch.position_ids.len()),
&self.device,
)?;
let input_lengths = vec![batch.input_ids.len()];

(input_ids, position_ids, input_lengths, None, None)
None
};

(input_ids, position_ids, input_lengths, attention_bias)
} else {
let input_ids = Tensor::from_vec(
batch.input_ids.clone(),
(1, batch.input_ids.len()),
&self.device,
)?;
let position_ids = Tensor::from_vec(
batch.position_ids.clone(),
(1, batch.position_ids.len()),
&self.device,
)?;
let input_lengths = vec![batch.input_ids.len()];

(input_ids, position_ids, input_lengths, None)
};

let attention_bias = if let Some(attn_bias) = attention_bias {
Some(self.get_causal_attention_bias(attn_bias)?)
} else {
None
};

let mut hidden_states = self.embeddings.forward(&input_ids)?;

let cos = self
Expand Down Expand Up @@ -583,7 +597,7 @@ impl Qwen3Model {
.iter()
.map(|&i| {
let i = i as usize;
let last_token_idx = input_lengths[i] - 1;
let last_token_idx = max_length - 1;
outputs.i((i, last_token_idx))?.unsqueeze(0)
})
.collect();
Expand Down
Loading
Loading