diff --git a/Cargo.toml b/Cargo.toml index e78c0aa7..80a1d680 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.6.0", rev = "ab3fd74" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.6.0", rev = "0ed5cd5" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" diff --git a/ReadMe-CN.md b/ReadMe-CN.md index c5505d65..802539e4 100644 --- a/ReadMe-CN.md +++ b/ReadMe-CN.md @@ -14,7 +14,7 @@ | **⚡** | 极致性能 | 原生 Flash Attention、FlashInfer、CUDA Graphs、持续批处理、前缀缓存、PD 分离。消费级 GPU 上 `30B+` 模型解码速度高达 **197 tok/s** | | **🪶** | 极简内核 | 核心调度 + 注意力逻辑仅 **< 5000 行** Rust 代码 | | **🌍** | 跨平台 | CUDA(Linux/Windows)、Metal(macOS),统一二进制,统一 API | -| **🏭** | 生产就绪 | OpenAI/Anthropic 兼容 API、内置 ChatGPT 风格 Web UI、MCP 工具调用、结构化输出、Embedding + 分词器端点 | +| **🏭** | 生产就绪 | OpenAI/Anthropic 兼容 API、内置 ChatGPT 风格 Web UI、MCP 工具调用、结构化输出、Embedding + 分词器端点、MTP | | **🗜️** | 极致 KV 压缩 | TurboQuant(`2–4 位` KV 缓存)以极小的质量损失将上下文扩展至 **4.3 倍**。单卡 24/32 GB GPU 即可运行 `30B+` MoE 模型并支持**百万级上下文** | | **🔥** | V100 + NVFP4 | 业界首创:V100 上运行 NVFP4 + 低位 KV 缓存推理 — 无需硬件 FP4,旧 GPU 重获新生 | | **🐍** | 轻量 Python 绑定 | 需要 Python 入口时可选 PyO3 wheel 包 | @@ -55,6 +55,11 @@ xinfer --w /home/Qwen3.6-35B-A3B --d 0,1 --ui-server python3 -m xinfer.server --m Qwen/Qwen3.6-27B-FP8 --kvcache-dtype turbo4 --ui-server ``` +**MTP** +```bash +xinfer --w /home/Qwen3.6-35B-A3B --d 0,1 --ui-server --mtp 2 +``` + > **提示:** 浏览器打开 `http://IP:8001` 即可使用内置对话界面,或使用 `http://IP:8000/v1/` 作为 API 服务 `Base URL`。 --- @@ -77,7 +82,7 @@ python3 -m xinfer.server --m Qwen/Qwen3.6-27B-FP8 --kvcache-dtype turbo4 --ui-se > 测试平台:**V100-32G**、**A100-40G**、**Hopper-80G** 及 **RTX 5090** -| 模型 | 格式 | 大小 | 输出速度 | +| 模型 | 格式 | 大小 | 输出速度 (非MTP) | |---|---|---|---| | Ministral-3-3B (**多模态**) | ISQ (BF16→Q4K) | 3B | **193.67** tokens/s | | Qwen3-VL-8B-Instruct (**多模态**) | Q8_0 | 8B | **112.51** tokens/s | @@ -164,6 +169,9 @@ xinfer --m Qwen/Qwen3.6-35B-A3B-FP8 --kvcache-dtype fp8 # 27B Dense + turbo4 xinfer --m Qwen/Qwen3.6-27B-FP8 --kvcache-dtype turbo4 +# 26B Gemma4 (本地模型, 使用`--kv-fraction`选项增加kvcache占用) +xinfer --w /data/gemma-4-26B-A4B-it --ui-server --port 9000 --kv-fraction 0.8 + # 30B MoE GGUF + turbo4 xinfer --m unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF \ --f Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf --kvcache-dtype turbo4 @@ -454,6 +462,7 @@ xinfer --m unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF \ | `--frequency-penalty` | 高频惩罚(−2 到 2) | | `--mcp-config` | MCP 服务器 JSON 配置 | | `--mcp-command` / `--mcp-args` | 单个 MCP 服务器命令及参数 | +| `--mtp` | 启用MTP (只针对包含MTP层的模型) ,例如 `--mtp 2`,单次推理2个tokens | --- @@ -499,7 +508,7 @@ xinfer --m unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF \ * [x] **MXFP4/NVFP4 模型支持** * [x] **支持 Turboquant(4 位、3 位)KvCache** * [ ] TentorRT-LLM - +* [x] Multi-token Prediction (MTP) --- ## 📚 参考项目 diff --git a/ReadMe.md b/ReadMe.md index 867df761..0ac80837 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -14,7 +14,7 @@ | **⚡** | Fast | Native Flash Attention, FlashInfer, CUDA Graphs, continuous batching, prefix caching, PD disaggregation. Up to **197 tok/s** decode for `30B+` models on consumer GPUs | | **🪶** | Tiny footprint | Core scheduling + attention logic in **< 5 000 lines** of Rust | | **🌍** | Cross-platform | CUDA (Linux/Windows), Metal (macOS). Same binary, same API | -| **🏭** | Production-ready | OpenAI/Anthropic-compatible APIs, built-in ChatGPT-style Web UI, MCP tool calling, structured outputs, embedding + tokenizer endpoints | +| **🏭** | Production-ready | OpenAI/Anthropic-compatible APIs, built-in ChatGPT-style Web UI, MCP tool calling, structured outputs, embedding + tokenizer endpoints, multi-token prediction (MTP) | | **🗜️** | Aggressive KV compression | TurboQuant (`2–4 bit` KV cache) extends context up to **4.3×** with minimal quality loss. Run `30B+` MoE models with **millions of context** on single 24/32 GB GPUs | | **🔥** | V100 + NVFP4 | First-ever NVFP4 + low-bit KV cache on V100 — no hardware FP4 needed, coherent output on legacy GPUs | | **🐍** | Lightweight Python bindings | Optional PyO3 wheel when you need a Python entry point | @@ -55,6 +55,11 @@ xinfer --w /home/Qwen3.6-35B-A3B --d 0,1 --ui-server python3 -m xinfer.server --m Qwen/Qwen3.6-27B-FP8 --kvcache-dtype turbo4 --ui-server ``` +**MTP** +```bash +xinfer --w /home/Qwen3.6-35B-A3B --d 0,1 --ui-server --mtp 2 +``` + > **Tip:** Open `http://IP:8001` for the built-in chat UI, or use `http://IP:8000/v1/` as your API `Base URL`. --- @@ -77,7 +82,7 @@ Add `--kvcache-dtype` to compress KV cache and extend context length: > Tested on **V100-32G**, **A100-40G**, **Hopper-80G** and **RTX 5090** -| Model | Format | Size | Decoding Speed | +| Model | Format | Size | Decoding Speed (without MTP) | |---|---|---|---| | Ministral-3-3B (**Multimodal**) | ISQ (BF16→Q4K) | 3B | **193.67** tokens/s | | Qwen3-VL-8B-Instruct (**Multimodal**) | Q8_0 | 8B | **112.51** tokens/s | @@ -164,6 +169,9 @@ xinfer --m Qwen/Qwen3.6-35B-A3B-FP8 --kvcache-dtype fp8 # 27B Dense + turbo4 xinfer --m Qwen/Qwen3.6-27B-FP8 --kvcache-dtype turbo4 +# 26B Gemma4 (local model, occupy more kvcache with --kv-fraction) +xinfer --w /data/gemma-4-26B-A4B-it --ui-server --port 9000 --kv-fraction 0.8 + # 30B MoE GGUF + turbo4 xinfer --m unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF \ --f Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf --kvcache-dtype turbo4 @@ -459,6 +467,7 @@ Constraint-based generation via llguidance — Lark grammars, regex, JSON Schema | `--frequency-penalty` | Penalize frequent tokens (−2 to 2) | | `--mcp-config` | MCP servers JSON config | | `--mcp-command` / `--mcp-args` | Single MCP server command + args | +| `--mtp`| Multi-token prediction, usage `--mtp 2` for two-token prediction per forward pass | --- @@ -504,6 +513,7 @@ Constraint-based generation via llguidance — Lark grammars, regex, JSON Schema * [x] **MXFP4/NVFP4 Model Support** * [x] **Support Turboquant (4-bit, 3-bit) KvCache** * [ ] TentorRT-LLM +* [x] **Multi-token Prediciton (MTP)** --- diff --git a/src/core/block_manager.rs b/src/core/block_manager.rs index 103b107a..1a4e6d65 100644 --- a/src/core/block_manager.rs +++ b/src/core/block_manager.rs @@ -145,6 +145,13 @@ impl BlockManager { } } + /// Allocate a single free block and return its ID, or None if no blocks available. + pub fn alloc_free_block(&mut self) -> Option { + let block_id = self.free_block_ids.pop_front()?; + self.allocate_block(block_id); + Some(block_id) + } + fn image_prefix_seed(images: &ImageData) -> u64 { let mut hasher = std::collections::hash_map::DefaultHasher::new(); images.raw.hash(&mut hasher); diff --git a/src/core/engine.rs b/src/core/engine.rs index 0008bf17..fa6c7054 100644 --- a/src/core/engine.rs +++ b/src/core/engine.rs @@ -833,6 +833,16 @@ impl LLMEngine { } } + // Pre-allocate blocks for MTP speculative positions before cloning sequences + if !is_prefill { + if let Some(mtp_tokens) = self.econfig.mtp_num_speculative_tokens { + if mtp_tokens > 0 { + self.scheduler + .pre_allocate_mtp_blocks(&scheduled_ids, mtp_tokens + 1); + } + } + } + let seqs = self.scheduler.get_sequences(&scheduled_ids); let owned_seqs: Vec = seqs.iter().map(|s| (*s).clone()).collect(); Ok(Some((scheduled_ids, is_prefill, owned_seqs))) @@ -897,6 +907,172 @@ impl LLMEngine { } } + /// Phase 2b: Run MTP speculative decode forward pass. + /// Returns Vec> where each inner vec contains all accepted tokens for that sequence. + pub fn run_forward_mtp( + runners: &Arc>, + owned_seqs: &[Sequence], + ) -> Result>> { + match &mut *runners.write() { + RunnerType::Thread(model_runner) => { + let seq_refs: Vec<&Sequence> = owned_seqs.iter().collect(); + model_runner.run_mtp_decode(Seqs::SeqRefs(&seq_refs)) + } + RunnerType::Process(ref mut runner_streams) => { + use crate::runner::{receive_local, send_local, MessageType}; + use interprocess::TryClone; + + let sequences = owned_seqs + .iter() + .map(|s| DecodeSequence::new(s)) + .collect::>(); + let request = MessageType::RunDecodeMTP(sequences); + + let cloned_streams: Vec = runner_streams + .iter_mut() + .map(|s| s.try_clone().expect("clone failed")) + .collect(); + + if let Some(mut stream) = cloned_streams.into_iter().next() { + send_local(&mut vec![stream.try_clone()?], &request, false)?; + let response = receive_local(&mut stream, false)?; + match response { + MessageType::RunResponseMTP(multi_tokens) => { + if multi_tokens.is_empty() { + candle_core::bail!("MTP runner returned empty response") + } + Ok(multi_tokens) + } + other => { + candle_core::bail!("Unexpected MTP response type: {:?}", other) + } + } + } else { + candle_core::bail!("No runner streams available for MTP") + } + } + } + } + + /// Phase 3b: MTP-aware finish step that handles multiple tokens per sequence. + pub fn finish_step_mtp( + &mut self, + scheduled_ids: Vec, + multi_output_ids: Vec>, + ) -> Result { + use std::time::{SystemTime, UNIX_EPOCH}; + + self.scheduler + .postprocess_multi(&scheduled_ids, &multi_output_ids); + + let mut count = 0; + for (i, &idx) in scheduled_ids.iter().enumerate() { + let sq = self.scheduler.get_running(idx); + if let Some(s) = sq { + let seq_id = s.id; + let tokens = &multi_output_ids[i]; + + if s.is_finished() { + let num_appended = + s.output_ids.len() - self.decode_length.get(&seq_id).copied().unwrap_or(0); + let tokens_to_stream = &tokens[..num_appended.min(tokens.len())]; + if let Some(sender) = self.stream_senders.get_mut(&seq_id) { + if let Some(request_type) = self.request_types.get(&seq_id) { + let prompt_start_time = s.created_time(); + let decode_finish_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() + as usize; + let decode_start_time = self + .decode_start_times + .get(&seq_id) + .copied() + .unwrap_or(decode_finish_time); + + if s.is_tool_call_end { + if *request_type == RequestType::Stream { + if let Some(decoder) = self.stream_decoders.get_mut(&seq_id) { + if let Some(tok) = decoder.step(s.last_token)? { + let _ = sender + .try_send(StreamItem::Token(tok, s.last_token)); + } + } + } else { + let _ = sender.try_send(StreamItem::TokenID(s.last_token)); + } + } + + if *request_type == RequestType::Stream { + for &tok in tokens_to_stream { + if let Some(decoder) = self.stream_decoders.get_mut(&seq_id) { + if let Some(text) = decoder.step(tok)? { + let _ = sender.try_send(StreamItem::Token(text, tok)); + } + } + } + let _ = sender.try_send(StreamItem::Done(( + prompt_start_time, + decode_start_time, + decode_finish_time, + s.output_ids.len(), + s.stop_sequence.clone(), + ))); + } else { + let _ = sender.try_send(StreamItem::Completion(( + prompt_start_time, + decode_start_time, + decode_finish_time, + s.output_ids.clone(), + s.stop_sequence.clone(), + ))); + } + } + } + self.decode_start_times.remove(&seq_id); + self.decode_length.remove(&seq_id); + self.active_requests.remove(&seq_id); + let _ = self.notify_runner_finished(seq_id); + count += 1; + } else { + if !self.decode_start_times.contains_key(&seq_id) { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() as usize; + self.decode_start_times.insert(seq_id, now); + } + if let Some(length) = self.decode_length.get_mut(&seq_id) { + *length = s.output_len(); + } else { + self.decode_length.insert(seq_id, s.output_len()); + } + if let Some(sender) = self.stream_senders.get_mut(&seq_id) { + if let Some(request_type) = self.request_types.get(&seq_id) { + for &tok in tokens { + if *request_type == RequestType::Stream { + if let Some(decoder) = self.stream_decoders.get_mut(&seq_id) { + if let Some(text) = decoder.step(tok)? { + let _ = sender.try_send(StreamItem::Token(text, tok)); + } + } + } else { + let _ = sender.try_send(StreamItem::TokenID(tok)); + } + } + } + } + count += 1; + } + } + } + if self.econfig.server_mode.unwrap_or(true) { + self.may_print_decoding_throughput(&scheduled_ids); + } + self.check_canceled(None); + Ok(count) + } + /// Phase 3: Postprocess forward pass results, deliver tokens, and do maintenance. pub fn finish_step( &mut self, @@ -1686,11 +1862,13 @@ impl LLMEngine { pub fn start_engine(engine: Arc>) { GLOBAL_RT.spawn(async move { let engine = engine.clone(); - let (is_pd_server, runners) = { + let (is_pd_server, runners, mtp_enabled) = { let guard = engine.read(); + let has_mtp = guard.econfig.mtp_num_speculative_tokens.is_some(); ( guard.is_pd_mode() && guard.is_pd_server(), guard.runners.clone(), + has_mtp, ) }; loop { @@ -1723,26 +1901,54 @@ impl LLMEngine { // Engine lock released -- server can accept new requests during forward pass if let Some((scheduled_ids, is_prefill, owned_seqs)) = prep { - // Phase 2: Forward pass (only runner lock, engine lock FREE) - match Self::run_forward(&runners, &owned_seqs, is_prefill) { - Ok(output_ids) => { - // Phase 3: Postprocess (engine lock held briefly) - let mut guard = engine.write(); - match guard.finish_step(scheduled_ids, is_prefill, output_ids) { - Ok(n) => task_processed = n, - Err(e) => { - crate::log_error!("[Engine Loop] Finish error: {:?}", e); - if !guard.cancel_all_with_reason(Some(e.to_string())) { - std::process::exit(1); + // Use MTP for decode steps when enabled (batch_size=1 only for now) + if mtp_enabled && !is_prefill && owned_seqs.len() == 1 { + match Self::run_forward_mtp(&runners, &owned_seqs) { + Ok(multi_output_ids) => { + let mut guard = engine.write(); + match guard.finish_step_mtp(scheduled_ids, multi_output_ids) { + Ok(n) => task_processed = n, + Err(e) => { + crate::log_error!( + "[Engine Loop] MTP Finish error: {:?}", + e + ); + if !guard.cancel_all_with_reason(Some(e.to_string())) { + std::process::exit(1); + } } } } + Err(e) => { + crate::log_error!("[Engine Loop] MTP Forward error: {:?}", e); + let mut guard = engine.write(); + if !guard.cancel_all_with_reason(Some(e.to_string())) { + std::process::exit(1); + } + } } - Err(e) => { - crate::log_error!("[Engine Loop] Forward error: {:?}", e); - let mut guard = engine.write(); - if !guard.cancel_all_with_reason(Some(e.to_string())) { - std::process::exit(1); + } else { + // Phase 2: Standard forward pass (only runner lock, engine lock FREE) + match Self::run_forward(&runners, &owned_seqs, is_prefill) { + Ok(output_ids) => { + // Phase 3: Postprocess (engine lock held briefly) + let mut guard = engine.write(); + match guard.finish_step(scheduled_ids, is_prefill, output_ids) { + Ok(n) => task_processed = n, + Err(e) => { + crate::log_error!("[Engine Loop] Finish error: {:?}", e); + if !guard.cancel_all_with_reason(Some(e.to_string())) { + std::process::exit(1); + } + } + } + } + Err(e) => { + crate::log_error!("[Engine Loop] Forward error: {:?}", e); + let mut guard = engine.write(); + if !guard.cancel_all_with_reason(Some(e.to_string())) { + std::process::exit(1); + } } } } diff --git a/src/core/mod.rs b/src/core/mod.rs index 95563506..ea38e202 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -1,5 +1,6 @@ pub mod block_manager; pub mod engine; +pub mod mtp; pub mod prefix_cache; pub mod runner; pub mod scheduler; diff --git a/src/core/mtp.rs b/src/core/mtp.rs new file mode 100644 index 00000000..51ce5aa4 --- /dev/null +++ b/src/core/mtp.rs @@ -0,0 +1,160 @@ +// src/core/mtp.rs +// Multi-Token Prediction (MTP) speculative decoding support. +// +// MTP uses lightweight prediction heads built into the model (e.g. Qwen3.5, DeepSeek-V3) +// to draft future tokens using the backbone's hidden states and KV cache. +// Accepted draft tokens are verified in a single target-model forward pass. + +use candle_core::{Result, Tensor, D}; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// Configuration for MTP speculative decoding at the engine level. +#[derive(Debug, Clone)] +pub struct MtpEngineConfig { + /// Number of speculative tokens to propose per step. + pub num_speculative_tokens: usize, +} + +impl MtpEngineConfig { + pub fn new(num_speculative_tokens: usize) -> Self { + Self { + num_speculative_tokens: num_speculative_tokens.max(1), + } + } +} + +/// Outcome of MTP verification for a single sequence. +#[derive(Debug, Clone)] +pub struct MtpVerifyResult { + /// All accepted tokens (draft tokens that matched the target model). + pub accepted_tokens: Vec, + /// The continuation token sampled from the first rejection point. + pub continuation_token: u32, + /// How many of the proposed drafts were accepted. + pub num_accepted: usize, + /// Total number proposed. + pub num_proposed: usize, +} + +/// Verify draft tokens against target model logits (greedy / argmax). +/// +/// Uses a single batched argmax over all rows + vectorized comparison on GPU, +/// then transfers results to CPU in one shot. +/// +/// `verify_logits`: shape [N+1, vocab_size] where N = len(draft_tokens). +/// - Position 0 predicts draft_tokens[0] +/// - Position i predicts draft_tokens[i] (for i < N) +/// - Position N provides the continuation token after last accepted draft +pub fn verify_draft_greedy( + verify_logits: &Tensor, + draft_tokens: &[u32], +) -> Result { + let num_positions = verify_logits.dim(0)?; + let num_proposed = draft_tokens.len(); + + if num_positions == 0 || num_proposed == 0 { + let first_token = if num_positions > 0 { + verify_logits + .get(0)? + .argmax(D::Minus1)? + .to_scalar::()? + } else { + 0 + }; + return Ok(MtpVerifyResult { + accepted_tokens: vec![], + continuation_token: first_token, + num_accepted: 0, + num_proposed, + }); + } + + // Keep verifier argmax aligned with the normal sampler path, which promotes + // logits to F32 before selecting tokens. + let verify_logits = verify_logits.to_dtype(candle_core::DType::F32)?; + let all_target_tokens = verify_logits.argmax(D::Minus1)?; + let target_vec: Vec = all_target_tokens.to_vec1()?; + + let compare_len = num_proposed.min(num_positions); + let mut num_accepted = 0; + for i in 0..compare_len { + if target_vec[i] == draft_tokens[i] { + num_accepted += 1; + } else { + break; + } + } + + let accepted_tokens = draft_tokens[..num_accepted].to_vec(); + let continuation_token = if num_accepted < num_positions { + target_vec[num_accepted] + } else { + target_vec[num_positions - 1] + }; + + Ok(MtpVerifyResult { + accepted_tokens, + continuation_token, + num_accepted, + num_proposed, + }) +} + +/// Global MTP statistics tracker. +pub static MTP_TOTAL_PROPOSED: AtomicUsize = AtomicUsize::new(0); +pub static MTP_TOTAL_ACCEPTED: AtomicUsize = AtomicUsize::new(0); +pub static MTP_TOTAL_STEPS: AtomicUsize = AtomicUsize::new(0); + +pub fn mtp_stats_update(proposed: usize, accepted: usize) { + MTP_TOTAL_PROPOSED.fetch_add(proposed, Ordering::Relaxed); + MTP_TOTAL_ACCEPTED.fetch_add(accepted, Ordering::Relaxed); + MTP_TOTAL_STEPS.fetch_add(1, Ordering::Relaxed); +} + +pub fn mtp_stats_acceptance_rate() -> f64 { + let proposed = MTP_TOTAL_PROPOSED.load(Ordering::Relaxed); + let accepted = MTP_TOTAL_ACCEPTED.load(Ordering::Relaxed); + if proposed == 0 { + 0.0 + } else { + accepted as f64 / proposed as f64 + } +} + +pub fn mtp_stats_avg_tokens_per_step() -> f64 { + let steps = MTP_TOTAL_STEPS.load(Ordering::Relaxed); + let accepted = MTP_TOTAL_ACCEPTED.load(Ordering::Relaxed); + if steps == 0 { + 1.0 + } else { + // Each step produces: 1 anchor + accepted drafts + 1 continuation + (accepted + 2 * steps) as f64 / steps as f64 + } +} + +pub fn mtp_stats_summary() -> String { + let proposed = MTP_TOTAL_PROPOSED.load(Ordering::Relaxed); + let accepted = MTP_TOTAL_ACCEPTED.load(Ordering::Relaxed); + let steps = MTP_TOTAL_STEPS.load(Ordering::Relaxed); + format!( + "MTP Stats: proposed={}, accepted={}, acceptance_rate={:.2}%, avg_tokens/step={:.2}", + proposed, + accepted, + if proposed > 0 { + accepted as f64 / proposed as f64 * 100.0 + } else { + 0.0 + }, + if steps > 0 { + (accepted + 2 * steps) as f64 / steps as f64 + } else { + 1.0 + }, + ) +} + +pub fn mtp_stats_reset() { + MTP_TOTAL_PROPOSED.store(0, Ordering::Relaxed); + MTP_TOTAL_ACCEPTED.store(0, Ordering::Relaxed); + MTP_TOTAL_STEPS.store(0, Ordering::Relaxed); +} diff --git a/src/core/runner.rs b/src/core/runner.rs index 2fcd5978..c58c71fe 100644 --- a/src/core/runner.rs +++ b/src/core/runner.rs @@ -4,6 +4,7 @@ use crate::models::gemma4::Gemma4ForCausalLM; use crate::models::layers::distributed::Comm; use crate::models::layers::linear::set_linear_is_prefill; use crate::models::layers::VarBuilderX; +use crate::models::qwen3_5_mtp::Qwen3_5MtpHead; use crate::server::EmbeddingStrategy; use crate::transfer::Transfer; #[cfg(all(feature = "cuda", feature = "graph"))] @@ -55,6 +56,7 @@ pub struct CachedSamplingParams { pub presence_penalty: Option, } +#[derive(Clone, Copy)] pub enum Seqs<'a> { SeqRefs(&'a [&'a Sequence]), DecodeVec(&'a Vec), @@ -117,7 +119,9 @@ pub struct ModelRunner { device: Device, config: EngineConfig, #[cfg(all(feature = "cuda", feature = "graph"))] - pub capturer: GraphCapturer>, + pub decode_capturer: GraphCapturer>, + #[cfg(all(feature = "cuda", feature = "graph"))] + pub mtp_capturer: Option>>, #[cfg(feature = "flashinfer")] flashinfer_kv_params: Option, logit_processor: LogitsProcessor, @@ -133,6 +137,16 @@ pub struct ModelRunner { /// Whether this runner is on the first rank (for logging) is_first_rank: bool, model_type: ModelType, + /// MTP head for speculative decoding (Qwen3.5 only for now) + mtp_head: Option>, + /// Number of speculative tokens to draft per step + mtp_num_speculative: usize, +} + +struct MtpSeqInfo { + id: usize, + len: usize, + block_table: Vec, } impl ModelRunner { @@ -148,6 +162,136 @@ impl ModelRunner { ) } + fn compute_slot_mappings( + &self, + seq_info: &MtpSeqInfo, + num_tokens: usize, + block_size: usize, + ctx: &str, + ) -> Result> { + let mut slots = Vec::with_capacity(num_tokens); + for i in 0..num_tokens { + let pos = seq_info.len + i; + let block_idx = pos / block_size; + let block_offset = pos % block_size; + if block_idx < seq_info.block_table.len() { + let physical_block = seq_info.block_table[block_idx] as i64; + slots.push(physical_block * block_size as i64 + block_offset as i64); + } else { + candle_core::bail!( + "MTP {} missing KV block: block_idx {} >= block_table.len() {}. \ + Blocks must be pre-allocated before MTP.", + ctx, + block_idx, + seq_info.block_table.len() + ); + } + } + Ok(slots) + } + + fn build_mtp_metadata( + &self, + seq_info: &MtpSeqInfo, + slot_mappings: &[i64], + q_len: usize, + ) -> Result { + let total_kv_len = (seq_info.len + q_len) as u32; + let mamba_slot_mapping = self.prepare_mamba_slot_mapping(&[seq_info.id], false)?; + + #[cfg(feature = "flashinfer")] + let flashinfer_metadata = if let Some(params) = self.flashinfer_kv_params { + let num_pages = seq_info.block_table.len(); + let indptr_host = vec![0u32, num_pages as u32]; + let indices_vec: Vec = seq_info.block_table.clone(); + let last_page_tokens = + total_kv_len as usize - (num_pages.saturating_sub(1)) * params.page_size; + let last_len_host = vec![last_page_tokens as u32]; + let kv_len_arr_host = vec![total_kv_len]; + let q_cu_seqlens_host = vec![0u32, q_len as u32]; + + #[cfg(all(feature = "cuda", feature = "graph"))] + let use_graph = self + .mtp_capturer + .as_ref() + .map_or(false, |c| c.is_mtp_captured(q_len)); + #[cfg(not(all(feature = "cuda", feature = "graph")))] + let use_graph = false; + + let prefill_plan_info = if use_graph { + None + } else { + Some(attention_rs::flashinfer::prefill_plan( + &self.device, + &q_cu_seqlens_host, + &indptr_host, + &kv_len_arr_host, + q_len as u32, + 1, + params.num_qo_heads, + params.num_kv_heads, + params.head_dim, + params.page_size, + params.out_dtype, + None, + Some(params.kv_dtype), + false, + )?) + }; + + Some(attention_rs::FlashInferMetadata { + indptr: Tensor::from_vec(indptr_host.clone(), (2,), &self.device)?, + indptr_host, + indices: Tensor::from_vec(indices_vec, (num_pages,), &self.device)?, + last_len: Tensor::from_vec(last_len_host.clone(), (1,), &self.device)?, + last_len_host: Some(last_len_host), + kv_len_arr_host: Some(kv_len_arr_host), + total_num_rows: Some(q_len as u32), + batch_indices: None, + positions: None, + use_cuda_graph: use_graph, + decode_plan_info: None, + prefill_plan_info, + mla_decode_plan_info: None, + mla_prefill_plan_info: None, + }) + } else { + None + }; + #[cfg(not(feature = "flashinfer"))] + let flashinfer_metadata = None; + + Ok(InputMetadata { + is_prefill: true, + is_mla: self.is_mla_model(), + sequence_ids: Some(vec![seq_info.id]), + mamba_slot_mapping, + slot_mapping: Tensor::from_vec(slot_mappings.to_vec(), (q_len,), &self.device)?, + context_lens: Some(Tensor::from_vec(vec![total_kv_len], (1,), &self.device)?), + block_tables: Some(Tensor::from_vec( + seq_info.block_table.clone(), + (1, seq_info.block_table.len()), + &self.device, + )?), + seqlens: None, + cu_seqlens_q: Some(Tensor::from_vec( + vec![0u32, q_len as u32], + (2,), + &self.device, + )?), + cu_seqlens_k: Some(Tensor::from_vec( + vec![0u32, total_kv_len], + (2,), + &self.device, + )?), + max_seqlen_q: q_len, + max_seqlen_k: seq_info.len + q_len, + max_context_len: seq_info.len + q_len, + flashinfer_metadata, + is_mtp_verify: true, + }) + } + fn prepare_mamba_slot_mapping( &self, sequence_ids: &[usize], @@ -449,6 +593,34 @@ impl ModelRunner { } ); + #[cfg(all(feature = "cuda", feature = "graph"))] + let mtp_wrapper = if econfig.mtp_num_speculative_tokens.unwrap_or(0) > 0 { + Some(crate::graph_wrapper!( + &model, + device, + { + Qwen3 => EmbedInputs, + Qwen3MoE => EmbedInputs, + Qwen3_5 => EmbedInputs, + Qwen3_5MoE => EmbedInputs, + LLaMa => EmbedInputs, + LLaMa4 => NoneArg, + Phi4 => EmbedInputs, + GLM4 => EmbedInputs, + GLM4MoE => EmbedInputs, + GLM4MoeLite => EmbedInputs, + DeepSeek => EmbedInputs, + Mistral3VL => NoneArg, + Gemma3 => NoneArg, + Gemma4 => EmbedInputs, + Qwen3VL => NoneArg, + MiniMax => EmbedInputs, + } + )) + } else { + None + }; + let allocator = if let Some(s) = stream { use crate::runner::{receive_local, send_local, MessageType}; use interprocess::TryClone; @@ -548,18 +720,27 @@ impl ModelRunner { ); } } + if is_hybrid_mamba_model && econfig.mtp_num_speculative_tokens.unwrap_or(0) > 0 { + // MTP verification mutates Qwen3.5 linear-attention state speculatively. + // Keep at least one snapshot per active sequence so rejected drafts can + // be rolled back before replaying only the accepted prefix. + mamba_prefix_capacity = mamba_prefix_capacity.max(mamba_cache_capacity.max(1)); + } match &model { Model::Qwen3_5(model) => { model.preallocate_mamba_cache(mamba_cache_capacity)?; model.set_mamba_prefix_cache_capacity(mamba_prefix_capacity); + model.preallocate_mtp_hidden_buffer(econfig.max_num_seqs.max(8))?; } Model::Qwen3_5MoE(model) => { model.preallocate_mamba_cache(mamba_cache_capacity)?; model.set_mamba_prefix_cache_capacity(mamba_prefix_capacity); + model.preallocate_mtp_hidden_buffer(econfig.max_num_seqs.max(8))?; } Model::Qwen3VL(model) => { model.preallocate_mamba_cache(mamba_cache_capacity)?; model.set_mamba_prefix_cache_capacity(mamba_prefix_capacity); + model.preallocate_mtp_hidden_buffer(econfig.max_num_seqs.max(8))?; } _ => {} } @@ -721,6 +902,61 @@ impl ModelRunner { } } + let (mtp_head, mtp_num_speculative) = if let Some(num_spec) = + econfig.mtp_num_speculative_tokens + { + let is_mtp_model_type = matches!( + model_type, + ModelType::Qwen3_5 | ModelType::Qwen3_5MoE | ModelType::Qwen3VL + ); + let has_mtp_config = config.mtp_num_hidden_layers.unwrap_or(0) > 0; + let has_mtp_weights = vb.pp("mtp").has_key("fc.weight") + || vb.pp("mtp").has_key("layers.0.mlp.gate_proj.weight") + || vb.pp("mtp").has_key("layers.0.mlp.gate.weight"); + + if is_mtp_model_type && (has_mtp_config || has_mtp_weights) && has_mtp_weights { + match crate::models::qwen3_5_mtp::Qwen3_5MtpHead::new( + vb, + comm.clone(), + config, + dtype, + is_rope_i, + &device, + ) { + Ok(head) => { + crate::log_info!( + "MTP head loaded: {} speculative tokens per step", + num_spec + ); + (Some(Arc::new(head)), num_spec) + } + Err(e) => { + crate::log_warn!("Failed to load MTP head: {}. MTP disabled.", e); + (None, 0) + } + } + } else if !is_mtp_model_type { + crate::log_info!( + "MTP requested but model type {:?} does not support MTP. MTP disabled.", + model_type + ); + (None, 0) + } else if !has_mtp_weights { + crate::log_info!( + "MTP requested but model weights do not contain MTP layers. MTP disabled." + ); + (None, 0) + } else { + crate::log_info!( + "MTP requested but model config has no MTP layers (mtp_num_hidden_layers={}). MTP disabled.", + config.mtp_num_hidden_layers.unwrap_or(0) + ); + (None, 0) + } + } else { + (None, 0) + }; + Ok(Self { model, gpu_kv_cache: Arc::new(Mutex::new(gpu_kv_cache)), @@ -729,7 +965,7 @@ impl ModelRunner { device, config: econfig.clone(), #[cfg(all(feature = "cuda", feature = "graph"))] - capturer: GraphCapturer::new( + decode_capturer: GraphCapturer::new( wrapper, graph_capture_max_num_seqs, econfig.max_model_len.unwrap_or(32768), @@ -739,6 +975,19 @@ impl ModelRunner { &flashinfer_kv_params, matches!(model_type, ModelType::GLM4MoeLite | ModelType::DeepSeek), ), + #[cfg(all(feature = "cuda", feature = "graph"))] + mtp_capturer: mtp_wrapper.map(|w| { + GraphCapturer::new( + w, + graph_capture_max_num_seqs, + econfig.max_model_len.unwrap_or(32768), + econfig.block_size, + config.hidden_size, + #[cfg(feature = "flashinfer")] + &flashinfer_kv_params, + matches!(model_type, ModelType::GLM4MoeLite | ModelType::DeepSeek), + ) + }), #[cfg(feature = "flashinfer")] flashinfer_kv_params, logit_processor: LogitsProcessor::new(seed, temperature, top_k, top_p), @@ -752,9 +1001,31 @@ impl ModelRunner { transfer, is_first_rank: comm.rank() == 0, model_type, + mtp_head, + mtp_num_speculative, }) } + /// Initialize MTP head for speculative decoding. + /// Should be called after model construction when MTP is enabled. + pub fn init_mtp( + &mut self, + mtp_head: Arc, + num_speculative: usize, + ) -> Result<()> { + self.mtp_head = Some(mtp_head); + self.mtp_num_speculative = num_speculative; + crate::log_info!( + "MTP initialized: {} speculative tokens per step", + num_speculative, + ); + Ok(()) + } + + pub fn has_mtp(&self) -> bool { + self.mtp_head.is_some() && self.mtp_num_speculative > 0 + } + pub fn get_kv_cache(&self) -> MutexGuard<'_, Vec<(Tensor, Tensor)>> { loop { if let Ok(v) = self.gpu_kv_cache.try_lock() { @@ -846,6 +1117,15 @@ impl ModelRunner { } } + fn mtp_rollback_mamba(&self, seq_id: usize, keep_tokens: usize) -> Result { + match &self.model { + Model::Qwen3_5(m) => m.mtp_rollback_mamba(seq_id, keep_tokens), + Model::Qwen3_5MoE(m) => m.mtp_rollback_mamba(seq_id, keep_tokens), + Model::Qwen3VL(m) => m.mtp_rollback_mamba(seq_id, keep_tokens), + _ => Ok(false), + } + } + #[allow(unused)] pub fn run(&self, seqs: Seqs, is_prefill: bool) -> Result> { #[cfg(feature = "nvtx")] @@ -877,33 +1157,33 @@ impl ModelRunner { let input_batch = input_ids.dim(0)?; let require_exact_graph = input_metadata.mamba_slot_mapping.is_some(); let can_replay = if require_exact_graph { - self.capturer.is_exact_captured(input_batch) + self.decode_capturer.is_exact_captured(input_batch) } else { - self.capturer.is_captured(input_batch) + self.decode_capturer.is_captured(input_batch) }; if !is_prefill && can_replay { let logits = match &self.model { Model::Qwen3_5(model) => { let _guard = model.lock_mamba_cache_for_graph(); - self.capturer + self.decode_capturer .replay(&input_ids, &positions, &input_metadata)? } Model::Qwen3_5MoE(model) => { let _guard = model.lock_mamba_cache_for_graph(); - self.capturer + self.decode_capturer .replay(&input_ids, &positions, &input_metadata)? } Model::Qwen3VL(model) => { if let Some(_guard) = model.lock_mamba_cache_for_graph() { - self.capturer + self.decode_capturer .replay(&input_ids, &positions, &input_metadata)? } else { - self.capturer + self.decode_capturer .replay(&input_ids, &positions, &input_metadata)? } } _ => self - .capturer + .decode_capturer .replay(&input_ids, &positions, &input_metadata)?, }; let output_ids = self.sample(&logits, seqs, is_prefill)?; @@ -1002,6 +1282,365 @@ impl ModelRunner { Ok(output_ids) } + /// MTP Step 1: single-token decode to get anchor token + hidden state. + /// Tries CUDA graph replay first (the graph's internal buffer for the + /// post-norm hidden state is accessible via take_last_hidden_for_mtp), + /// falling back to eager forward_with_hidden. + fn mtp_decode_step1(&self, seqs: Seqs, _seq_info: &MtpSeqInfo) -> Result<(u32, Tensor)> { + let (input_ids, positions, mut input_metadata) = match &seqs { + Seqs::SeqRefs(seqs_ref) => self.prepare_decode(*seqs_ref)?, + Seqs::DecodeVec(decode_seqs) => self.prepare_decode(decode_seqs.iter())?, + }; + + let _decode_guard = set_linear_is_prefill(false); + + // Try CUDA graph replay for the decode forward. The model's forward() + // stores hidden states in last_hidden_for_mtp during both capture and + // replay (the cached tensor shares GPU storage with the graph output, + // so it's updated in-place on replay). + #[cfg(all(feature = "cuda", feature = "graph"))] + { + let input_batch = input_ids.dim(0)?; + let require_exact_graph = input_metadata.mamba_slot_mapping.is_some(); + let can_replay = if require_exact_graph { + self.decode_capturer.is_exact_captured(input_batch) + } else { + self.decode_capturer.is_captured(input_batch) + }; + if can_replay { + let logits = match &self.model { + Model::Qwen3_5(model) => { + let _guard = model.lock_mamba_cache_for_graph(); + self.decode_capturer + .replay(&input_ids, &positions, &input_metadata)? + } + Model::Qwen3_5MoE(model) => { + let _guard = model.lock_mamba_cache_for_graph(); + self.decode_capturer + .replay(&input_ids, &positions, &input_metadata)? + } + Model::Qwen3VL(model) => { + if let Some(_guard) = model.lock_mamba_cache_for_graph() { + self.decode_capturer + .replay(&input_ids, &positions, &input_metadata)? + } else { + self.decode_capturer + .replay(&input_ids, &positions, &input_metadata)? + } + } + _ => self + .decode_capturer + .replay(&input_ids, &positions, &input_metadata)?, + }; + + let hidden_states = match &self.model { + Model::Qwen3_5(model) => model.take_last_hidden_for_mtp(), + Model::Qwen3_5MoE(model) => model.take_last_hidden_for_mtp(), + Model::Qwen3VL(model) => model.take_last_hidden_for_mtp(), + _ => None, + }; + + if let Some(hidden_states) = hidden_states { + let anchor_token = self.sample(&logits, seqs, false)?[0]; + let seq_hidden = if hidden_states.dims().len() == 2 && hidden_states.dim(0)? > 1 + { + hidden_states.get(hidden_states.dim(0)? - 1)? + } else if hidden_states.dims().len() == 2 { + hidden_states.get(0)? + } else { + hidden_states + }; + return Ok((anchor_token, seq_hidden)); + } + } + } + + // Fallback: eager forward_with_hidden (no graph available or hidden state extraction failed) + #[cfg(feature = "flashinfer")] + if let Some(fm) = input_metadata.flashinfer_metadata.as_mut() { + if input_metadata.is_mla { + if fm.mla_decode_plan_info.is_none() { + if let Some(params) = self.flashinfer_kv_params { + fm.mla_decode_plan_info = Some(attention_rs::mla::mla_decode_plan( + &self.device, + params.kv_dtype, + &fm.indptr_host, + input_ids.dim(0)?, + params.num_qo_heads, + params.page_size, + fm.use_cuda_graph, + )?); + } + } + } else if fm.decode_plan_info.is_none() { + if let Some(params) = self.flashinfer_kv_params { + fm.decode_plan_info = Some(attention_rs::flashinfer::decode_plan( + &self.device, + params.kv_dtype, + params.out_dtype, + &fm.indptr_host, + fm.last_len_host.as_deref(), + fm.kv_len_arr_host.as_deref(), + input_ids.dim(0)?, + params.num_qo_heads, + params.num_kv_heads, + params.head_dim, + params.page_size, + fm.use_cuda_graph, + )?); + } + } + } + + let kv_cache = self.get_kv_cache(); + let (logits, hidden_states) = match &self.model { + Model::Qwen3_5(model) => model.forward_with_hidden( + &input_ids, + &positions, + Some(&kv_cache), + &input_metadata, + false, + )?, + Model::Qwen3_5MoE(model) => model.forward_with_hidden( + &input_ids, + &positions, + Some(&kv_cache), + &input_metadata, + false, + )?, + Model::Qwen3VL(model) => model.forward_with_hidden( + &input_ids, + &positions, + Some(&kv_cache), + &input_metadata, + false, + )?, + _ => { + drop(kv_cache); + candle_core::bail!("MTP Step 1 requires Qwen3.5 model"); + } + }; + drop(kv_cache); + + let anchor_token = self.sample(&logits, seqs, false)?[0]; + + let seq_hidden = if hidden_states.dims().len() == 2 && hidden_states.dim(0)? > 1 { + hidden_states.get(hidden_states.dim(0)? - 1)? + } else if hidden_states.dims().len() == 2 { + hidden_states.get(0)? + } else { + hidden_states.clone() + }; + + Ok((anchor_token, seq_hidden)) + } + + /// Run MTP speculative decode for a batch of sequences. + /// Returns Vec> where each inner vec contains all accepted tokens for that sequence + /// (anchor + accepted drafts + bonus token). + /// + /// Optimized flow: + /// 1. Run main model decode via CUDA graph replay (when available) + extract hidden state + /// 2. Sample anchor token from logits + /// 3. MTP head drafts K tokens autoregressively (no KV cache) + /// 4. Verify: run main model on [anchor, draft_0, ..., draft_{K-1}] using native flash + /// 5. On partial rejection: roll back GDN state to the accepted token boundary + /// 6. Greedy-accept matching prefix; take bonus token at first mismatch + pub fn run_mtp_decode(&self, seqs: Seqs) -> Result>> { + let mtp_head = match &self.mtp_head { + Some(h) => h.clone(), + None => { + let output = self.run(seqs, false)?; + return Ok(output.into_iter().map(|t| vec![t]).collect()); + } + }; + + let (batch_size, seq_infos) = match &seqs { + Seqs::SeqRefs(s) => { + let infos: Vec = s + .iter() + .map(|seq| MtpSeqInfo { + id: seq.id, + len: seq.len(), + block_table: seq.block_table.clone(), + }) + .collect(); + (s.len(), infos) + } + Seqs::DecodeVec(d) => { + let infos: Vec = d + .iter() + .map(|ds| MtpSeqInfo { + id: ds.id, + len: ds.len, + block_table: ds.block_tables.clone(), + }) + .collect(); + (d.len(), infos) + } + }; + + if batch_size != 1 { + let output = self.run(seqs, false)?; + return Ok(output.into_iter().map(|t| vec![t]).collect()); + } + + let seq_info = &seq_infos[0]; + let num_draft = self.mtp_num_speculative; + + // Step 1: Main model decode for logits + hidden state. + let (anchor_token, seq_hidden) = self.mtp_decode_step1(seqs, seq_info)?; + + // Step 2: Draft K tokens using MTP head (GPU-resident, no per-step CPU sync) + let embed_weight = match &self.model { + Model::Qwen3_5(m) => m.embed_weight().clone(), + Model::Qwen3_5MoE(m) => m.embed_weight().clone(), + Model::Qwen3VL(m) => m + .embed_weight() + .expect("Qwen3VL MTP requires Qwen3.5 text backbone") + .clone(), + _ => unreachable!(), + }; + let lm_head_fn = |hidden: &Tensor| -> Result { + match &self.model { + Model::Qwen3_5(m) => m.forward_lm_head(hidden), + Model::Qwen3_5MoE(m) => m.forward_lm_head(hidden), + Model::Qwen3VL(m) => m.forward_lm_head(hidden), + _ => unreachable!(), + } + }; + + let base_position = seq_info.len.saturating_sub(1); + let anchor_token_tensor = Tensor::from_vec(vec![anchor_token], (1,), &self.device)?; + let (draft_tokens, _last_hidden) = mtp_head.draft_tokens_gpu( + &seq_hidden, + &anchor_token_tensor, + num_draft, + &embed_weight, + lm_head_fn, + base_position, + )?; + + if draft_tokens.is_empty() { + return Ok(vec![vec![anchor_token]]); + } + + // Step 3: Verify draft tokens via prefill-style forward on [anchor, draft_0..K-1]. + let mut verify_tokens = vec![anchor_token]; + verify_tokens.extend_from_slice(&draft_tokens); + let verify_len = verify_tokens.len(); + + let block_size = self.config.block_size; + let slot_mappings = + self.compute_slot_mappings(seq_info, verify_len, block_size, "verify")?; + + let verify_input_ids = Tensor::from_vec(verify_tokens, (verify_len,), &self.device)?; + let verify_positions_tensor = Tensor::from_vec( + (0..verify_len) + .map(|i| (seq_info.len + i) as i64) + .collect::>(), + (verify_len,), + &self.device, + )?; + + let verify_metadata = + self.build_mtp_metadata(seq_info, &slot_mappings[..verify_len], verify_len)?; + + let _prefill_guard = set_linear_is_prefill(true); + + #[cfg(all(feature = "cuda", feature = "graph"))] + let use_mtp_graph = self + .mtp_capturer + .as_ref() + .map_or(false, |c| c.is_mtp_captured(verify_len)); + #[cfg(not(all(feature = "cuda", feature = "graph")))] + let use_mtp_graph = false; + + let all_logits_result = if use_mtp_graph { + #[cfg(all(feature = "cuda", feature = "graph"))] + { + self.mtp_capturer.as_ref().unwrap().replay_mtp( + &verify_input_ids, + &verify_positions_tensor, + &verify_metadata, + ) + } + #[cfg(not(all(feature = "cuda", feature = "graph")))] + { + unreachable!() + } + } else { + let kv_cache = self.get_kv_cache(); + let res = match &self.model { + Model::Qwen3_5(model) => model.forward( + &verify_input_ids, + &verify_positions_tensor, + Some(&kv_cache), + &verify_metadata, + false, + ), + Model::Qwen3_5MoE(model) => model.forward( + &verify_input_ids, + &verify_positions_tensor, + Some(&kv_cache), + &verify_metadata, + false, + ), + Model::Qwen3VL(model) => model.forward( + &verify_input_ids, + &verify_positions_tensor, + Some(&kv_cache), + &verify_metadata, + None, + ), + _ => unreachable!(), + }; + drop(kv_cache); + res + }; + let all_logits = match all_logits_result { + Ok(logits) => logits, + Err(err) => { + return Err(err); + } + }; + + let verify_result = match crate::core::mtp::verify_draft_greedy(&all_logits, &draft_tokens) + { + Ok(result) => result, + Err(err) => { + return Err(err); + } + }; + + if verify_result.num_accepted < verify_result.num_proposed { + let commit_len = 1 + verify_result.num_accepted; + // KV cache does not need explicit rollback because the next decode writes + // the continuation token at the first rejected slot and stale later slots + // are outside the sequence length. + let restored = self.mtp_rollback_mamba(seq_info.id, commit_len)?; + if !restored { + candle_core::bail!( + "MTP failed to roll back mamba-state snapshot for seq {} to {} verified token(s)", + seq_info.id, + commit_len + ); + } + } + + let mut result_tokens = Vec::with_capacity(2 + verify_result.num_accepted); + result_tokens.push(anchor_token); + result_tokens.extend_from_slice(&verify_result.accepted_tokens); + result_tokens.push(verify_result.continuation_token); + + crate::core::mtp::mtp_stats_update(verify_result.num_proposed, verify_result.num_accepted); + if crate::core::mtp::MTP_TOTAL_STEPS.load(std::sync::atomic::Ordering::Relaxed) % 256 == 0 { + crate::log_info!("{}", crate::core::mtp::mtp_stats_summary()); + } + + Ok(vec![result_tokens]) + } + pub fn embed(&self, seqs: &[&Sequence], strategy: &EmbeddingStrategy) -> Result>> { let (input_ids, positions, input_metadata) = self.prepare_prefill(seqs)?; @@ -1284,6 +1923,7 @@ impl ModelRunner { params.out_dtype, None, Some(params.kv_dtype), + false, )?) } }; @@ -1330,6 +1970,7 @@ impl ModelRunner { max_context_len, seqlens: Some(cu_seqlens_q_vec[1..].to_vec()), flashinfer_metadata, + is_mtp_verify: false, }; Ok((input_ids, positions, input_metadata)) @@ -1379,9 +2020,9 @@ impl ModelRunner { _ => false, }; if require_exact_graph { - self.capturer.is_exact_captured(seq_refs.len()) + self.decode_capturer.is_exact_captured(seq_refs.len()) } else { - self.capturer.is_captured(seq_refs.len()) + self.decode_capturer.is_captured(seq_refs.len()) } }; #[cfg(not(all(feature = "cuda", feature = "graph")))] @@ -1479,6 +2120,7 @@ impl ModelRunner { max_context_len, seqlens: None, flashinfer_metadata, + is_mtp_verify: false, }; Ok((input_ids, positions, input_metadata)) @@ -1708,8 +2350,25 @@ impl ModelRunner { #[cfg(all(feature = "cuda", feature = "graph"))] pub fn warmup_capture(&mut self) -> Result<()> { - let kv_cache_lock = self.gpu_kv_cache.lock().unwrap(); // no custom method call on `self` - self.capturer.capture(&self.device, Some(&kv_cache_lock))?; + let kv_cache_lock = self.gpu_kv_cache.lock().unwrap(); + self.decode_capturer + .capture(&self.device, Some(&kv_cache_lock))?; + + if self.mtp_num_speculative > 0 { + // self.decode_capturer.model.sync()?; + if let Some(mtp_cap) = &mut self.mtp_capturer { + crate::log_info!( + "Capturing MTP verify graphs for up to {} draft tokens...", + self.mtp_num_speculative + ); + mtp_cap.capture_mtp( + &self.device, + Some(&kv_cache_lock), + self.mtp_num_speculative, + )?; + } + } + match &self.model { Model::Qwen3_5(model) => model.reset_mamba_cache()?, Model::Qwen3_5MoE(model) => model.reset_mamba_cache()?, diff --git a/src/core/scheduler.rs b/src/core/scheduler.rs index bd51f8fd..4761255e 100644 --- a/src/core/scheduler.rs +++ b/src/core/scheduler.rs @@ -642,6 +642,98 @@ impl Scheduler { } } + /// MTP-aware postprocess: accepts multiple tokens per sequence. + /// For each sequence in `ids`, `multi_output_ids[i]` contains all accepted tokens + /// (anchor + drafts + continuation). Each token is appended and checked for EOS/stop. + pub fn postprocess_multi(&mut self, ids: &[usize], multi_output_ids: &[Vec]) { + for (i, &idx) in ids.iter().enumerate() { + if idx >= self.running.len() { + continue; + } + let tokens = &multi_output_ids[i]; + for &token in tokens { + let _seq_id = self.running[idx].id; + + if self.is_pd_server() { + break; + } + + if self.running[idx].sampling_params.mcp_mode.is_some() { + let is_end = self.is_tool_call_end(token, idx); + if is_end { + let seq = &mut self.running[idx]; + seq.append_token(token); + seq.is_tool_call_end = true; + seq.status = SequenceStatus::Finished; + self.block_manager + .capture_mamba_prefix_state(seq, seq.len()); + self.block_manager.cache_sequence(seq); + self.block_manager.deallocate(seq); + break; + } + } + + let matched_stop_sequence_idx = + self.stop_sequence_match_index(token, &self.running[idx]); + let hit_stop_sequence = matched_stop_sequence_idx.is_some(); + let seq = &mut self.running[idx]; + + if hit_stop_sequence + || self.eos_token_id.contains(&token) + || seq.output_len() >= seq.sampling_params.max_tokens.unwrap_or(16384) + || seq.len() > self.cfg.max_num_batched_tokens + { + if hit_stop_sequence { + seq.hit_stop_sequence = true; + seq.stop_sequence = matched_stop_sequence_idx.and_then(|stop_idx| { + seq.sampling_params + .stop_sequences + .as_ref() + .and_then(|stops| stops.get(stop_idx)) + .cloned() + }); + } + seq.status = SequenceStatus::Finished; + self.block_manager + .capture_mamba_prefix_state(seq, seq.len()); + self.block_manager.cache_sequence(seq); + self.block_manager.deallocate(seq); + break; + } else { + seq.append_token(token); + if seq.len() % self.cfg.block_size == 1 && seq.len() > 1 { + let _ = self.block_manager.may_append(seq); + } + if seq.len() % self.cfg.block_size == 0 { + self.block_manager + .capture_mamba_prefix_state(seq, seq.len()); + } + } + } + } + } + + /// Pre-allocate KV cache blocks for MTP speculative tokens. + /// Called before MTP runs to ensure the verification forward has room + /// to write KV for speculative positions. + pub fn pre_allocate_mtp_blocks(&mut self, ids: &[usize], extra_tokens: usize) { + for &idx in ids { + if idx >= self.running.len() { + continue; + } + let seq = &mut self.running[idx]; + let needed_len = seq.len() + extra_tokens; + let needed_blocks = needed_len.div_ceil(self.cfg.block_size); + while seq.block_table.len() < needed_blocks { + if let Some(block_id) = self.block_manager.alloc_free_block() { + seq.block_table.push(block_id as u32); + } else { + break; + } + } + } + } + pub fn clear_finished(&mut self) { let is_pd_server = self.is_pd_server(); let mut finished_counts = Vec::new(); diff --git a/src/main.rs b/src/main.rs index 7ae7cc6c..b944208c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -218,7 +218,8 @@ async fn main() -> Result<()> { args.disable_reasoning, args.disable_cuda_graph, Some(args.prefill_chunk_size), - ); + ) + .with_mtp(args.mtp); let server_port = if server { let port = args diff --git a/src/models/layers/attention.rs b/src/models/layers/attention.rs index 701f0827..d5660b09 100644 --- a/src/models/layers/attention.rs +++ b/src/models/layers/attention.rs @@ -845,6 +845,89 @@ impl Attention { }; self.o_proj.forward(&y) } + + /// Optimized single-token attention without KV cache or paged attention. + /// For seq_len=1, self-attention is trivially: softmax([1x1]) * V = V. + /// We still compute Q/K/V projections and apply RoPE for correctness of + /// the output projection, but skip the attention kernel entirely. + pub fn forward_single_token_no_cache( + &self, + xs: &Tensor, + rotary_emb: &Arc, + positions: &Tensor, + ) -> Result { + let (seq_len, _) = xs.dims2()?; + + let (q_raw, k, v) = match &self.qkv_proj { + QkvProjection::Separate { + q_proj, + k_proj, + v_proj, + } => ( + q_proj.forward(xs)?, + k_proj.forward(xs)?, + v_proj.forward(xs)?, + ), + QkvProjection::Packed(qkv_proj) => { + let qkv = qkv_proj.forward(xs)?; + (qkv[0].clone(), qkv[1].clone(), qkv[2].clone()) + } + }; + + // Handle attn_output_gate (Qwen3.5 uses gated attention: Q proj outputs Q + gate) + let local_q_dim = self.num_heads * self.head_dim; + let (q_linear, gate) = if self.attn_output_gate { + let q_gate = q_raw.reshape((seq_len, self.num_heads, self.head_dim * 2))?; + let q = q_gate.narrow(2, 0, self.head_dim)?; + let gate = q_gate.narrow(2, self.head_dim, self.head_dim)?; + ( + q.reshape((seq_len, local_q_dim))?, + Some(gate.reshape((seq_len, local_q_dim))?), + ) + } else { + (q_raw, None) + }; + + let q = q_linear.reshape((seq_len, self.num_heads, self.head_dim))?; + let k = k.reshape((seq_len, self.num_kv_heads, self.head_dim))?; + let v = v.reshape((seq_len, self.num_kv_heads, self.head_dim))?; + + // Apply rotary embeddings (needed even for single token for positional encoding) + let (_q, _k) = match rotary_emb.apply_rotary_emb_qkv(&q, &k, positions)? { + Some((q_new, k_new)) => (q_new, k_new), + None => (q, k), + }; + + // For seq_len=1: softmax(Q*K^T / sqrt(d)) * V = V (single element softmax = 1.0) + // Output shape: (seq_len, num_heads * head_dim) after GQA expansion + let n_rep = self.num_heads / self.num_kv_heads; + let y = if n_rep > 1 { + v.unsqueeze(2)? + .expand((seq_len, self.num_kv_heads, n_rep, self.head_dim))? + .reshape((seq_len, self.num_heads * self.head_dim))? + } else { + v.reshape((seq_len, self.num_heads * self.head_dim))? + }; + + // Apply gated attention if needed + let y = if let Some(gate) = gate { + let gate = if gate.dtype() != y.dtype() { + gate.to_dtype(y.dtype())? + } else { + gate + }; + y.broadcast_mul(&candle_nn::ops::sigmoid(&gate)?)? + } else { + y + }; + + let y = if self.is_qvar_builder { + y + } else { + y.to_dtype(xs.dtype())? + }; + self.o_proj.forward(&y) + } } pub struct NaiveAttention { diff --git a/src/models/layers/deltanet.rs b/src/models/layers/deltanet.rs index 05000325..f46216b9 100644 --- a/src/models/layers/deltanet.rs +++ b/src/models/layers/deltanet.rs @@ -67,6 +67,8 @@ pub struct GatedDeltaNet { /// The model's native dtype (BF16/F16). Used for projection input and weight loading. /// Quantized projections (FP8/NVFP4/QLinear) handle dtype internally. model_dtype: DType, + conv_mtp_state: Tensor, + recurrent_mtp_state: Tensor, } impl GatedDeltaNet { @@ -709,7 +711,18 @@ impl GatedDeltaNet { ) .ok(); let scale = 1.0f64 / (head_k_dim as f64).sqrt(); - + let d_conv = key_dim * 2 + value_dim; + let max_verify_tokens = 16; + let conv_mtp_state = Tensor::zeros( + (max_verify_tokens, d_conv, conv_kernel_size - 1), + gdn_dtype, + &vb.device(), + )?; + let recurrent_mtp_state = Tensor::zeros( + (max_verify_tokens, num_v_heads, head_k_dim, head_v_dim), + DType::F32, + &vb.device(), + )?; Ok(Self { projection, out_proj, @@ -735,6 +748,8 @@ impl GatedDeltaNet { } else { dtype }, + conv_mtp_state, + recurrent_mtp_state, }) } @@ -755,9 +770,6 @@ impl GatedDeltaNet { let is_prefill = input_metadata.is_prefill; let (q, k, v, z, b, a) = self.project_inputs(xs)?; - // Upcast projection outputs to gdn_dtype for GDN core ops (conv1d, gating, recurrence). - // For GGUF/F16 mode (gdn_dtype=F32), this promotes BF16 outputs to F32. - // For standard BF16 models (gdn_dtype=BF16), this is a no-op. let (q, k, v, z, b, a) = if q.dtype() != self.gdn_dtype { ( q.to_dtype(self.gdn_dtype)?, @@ -779,13 +791,19 @@ impl GatedDeltaNet { .as_ref() .expect("cu_seqlens_q must be present in prefill!"); + let conv_snapshots = if input_metadata.is_mtp_verify { + Some(self.conv_mtp_state.narrow(0, 0, token_count)?) + } else { + None + }; let out = gdn::causal_conv1d_fwd( &mixed_qkv, &self.conv_weight, self.conv_bias.as_ref(), &mut conv_state, + conv_snapshots.as_ref(), Some(cu_seqlens), - true, // SiLU activation + true, )?; (out, Some(conv_state)) } else { @@ -834,6 +852,11 @@ impl GatedDeltaNet { .expect("cu_seqlens_q must be present in prefill!"); let global_state = mamba_cache.recurrent_state_mut(self.gdn_layer_idx); + let recurrent_snapshots = if input_metadata.is_mtp_verify { + Some(self.recurrent_mtp_state.narrow(0, 0, token_count)?) + } else { + None + }; if self.num_k_heads != self.num_v_heads { gdn::gated_delta_rule_recurrence_varlen_gqa( @@ -846,6 +869,7 @@ impl GatedDeltaNet { seq_slots, &cu_seqlens, self.scale as f32, + recurrent_snapshots.as_ref(), )? } else { let q_scaled = (&q * self.scale)?; @@ -858,6 +882,7 @@ impl GatedDeltaNet { global_state, seq_slots, &cu_seqlens, + recurrent_snapshots.as_ref(), )? } } else { @@ -905,4 +930,33 @@ impl GatedDeltaNet { Ok(out) } } + + /// Roll back this layer's GDN state to the position after `keep_tokens` tokens + /// were processed during MTP verification. Indexes into the per-token snapshot + /// buffers written by the prefill kernels and restores the slot state. + pub fn rollback_mtp_verify( + &self, + mamba_cache: &mut MambaCache, + seq_slots: &Tensor, + keep_tokens: usize, + ) -> Result<()> { + if keep_tokens == 0 { + return Ok(()); + } + let idx = keep_tokens - 1; + + let conv_snapshot = self.conv_mtp_state.narrow(0, idx, 1)?; + let conv_state_dtype = mamba_cache.conv_state(self.gdn_layer_idx).dtype(); + let conv_snapshot = if conv_snapshot.dtype() != conv_state_dtype { + conv_snapshot.to_dtype(conv_state_dtype)? + } else { + conv_snapshot + }; + mamba_cache.set_batch_conv_state(self.gdn_layer_idx, seq_slots, &conv_snapshot)?; + + let rec_snapshot = self.recurrent_mtp_state.narrow(0, idx, 1)?; + mamba_cache.set_batch_recurrent_state(self.gdn_layer_idx, seq_slots, &rec_snapshot)?; + + Ok(()) + } } diff --git a/src/models/mod.rs b/src/models/mod.rs index fb83ff36..4cc97320 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -13,5 +13,6 @@ pub mod phi4; pub mod qwen3; pub mod qwen3_5; pub mod qwen3_5_moe; +pub mod qwen3_5_mtp; pub mod qwen3_moe; pub mod qwen3_vl; diff --git a/src/models/qwen3_5.rs b/src/models/qwen3_5.rs index f92e0b58..b2b4bcd5 100644 --- a/src/models/qwen3_5.rs +++ b/src/models/qwen3_5.rs @@ -194,6 +194,10 @@ pub struct Qwen3_5ForCausalLM { dtype: DType, vocab_size: usize, is_qvar_builder: bool, + /// Pre-allocated hidden state buffer for MTP speculative decoding. + /// Allocated outside the CUDA graph pool so copy_ into it is graph-safe. + /// Shape: (max_graph_bs, hidden_size), allocated on first decode forward. + pub mtp_hidden_buffer: std::sync::Mutex>, } impl Qwen3_5ForCausalLM { @@ -480,6 +484,7 @@ impl Qwen3_5ForCausalLM { dtype, vocab_size, is_qvar_builder, + mtp_hidden_buffer: std::sync::Mutex::new(None), }) } @@ -492,6 +497,18 @@ impl Qwen3_5ForCausalLM { } } + pub fn embed_weight(&self) -> &Tensor { + self.embed_tokens.embeddings() + } + + /// Get the last hidden state for MTP from the pre-allocated buffer. + /// Returns row 0 of the buffer (MTP decode always uses batch_size=1). + pub fn take_last_hidden_for_mtp(&self) -> Option { + let guard = self.mtp_hidden_buffer.lock().ok()?; + let buf = guard.as_ref()?; + buf.get(0).ok() + } + fn forward_inner( &self, input_ids: &Tensor, @@ -565,12 +582,24 @@ impl Qwen3_5ForCausalLM { let xs = self.norm.forward(&xs)?; if return_hidden { xs.to_dtype(DType::F32) - } else if self.is_qvar_builder { - self.lm_head.forward(&xs) } else { - self.lm_head - .forward(&xs.to_dtype(self.dtype)?)? - .to_dtype(DType::F32) + // Copy hidden state into the pre-allocated MTP buffer (graph-safe). + // copy_ is a CUDA memcpy kernel that gets captured into the graph, + // so the buffer is updated in-place on every graph replay. + if let Ok(guard) = self.mtp_hidden_buffer.lock() { + if let Some(buf) = guard.as_ref() { + if xs.elem_count() <= buf.elem_count() { + let _ = buf.copy_(&xs, 0); + } + } + } + if self.is_qvar_builder { + self.lm_head.forward(&xs) + } else { + self.lm_head + .forward(&xs.to_dtype(self.dtype)?)? + .to_dtype(DType::F32) + } } } @@ -614,6 +643,36 @@ impl Qwen3_5ForCausalLM { ) } + /// Forward pass that returns both logits and hidden states. + /// Used by MTP speculative decoding to get backbone hidden states for the draft head. + pub fn forward_with_hidden( + &self, + input_ids: &Tensor, + positions: &Tensor, + kv_caches: Option<&Vec<(Tensor, Tensor)>>, + input_metadata: &InputMetadata, + embeded_inputs: bool, + ) -> Result<(Tensor, Tensor)> { + let hidden = self.forward_inner( + input_ids, + positions, + kv_caches, + input_metadata, + embeded_inputs, + &None, + &None, + true, + )?; + let logits = if self.is_qvar_builder { + self.lm_head.forward(&hidden)? + } else { + self.lm_head + .forward(&hidden.to_dtype(self.dtype)?)? + .to_dtype(DType::F32)? + }; + Ok((logits, hidden)) + } + pub fn forward_with_deepstack( &self, input_ids: &Tensor, @@ -636,6 +695,17 @@ impl Qwen3_5ForCausalLM { ) } + /// Apply lm_head to hidden states to get logits. Used by MTP drafting. + pub fn forward_lm_head(&self, hidden: &Tensor) -> Result { + if self.is_qvar_builder { + self.lm_head.forward(hidden) + } else { + self.lm_head + .forward(&hidden.to_dtype(self.dtype)?)? + .to_dtype(DType::F32) + } + } + pub fn get_vocab_size(&self) -> usize { self.vocab_size } @@ -664,6 +734,20 @@ impl Qwen3_5ForCausalLM { self.mamba_cache.write().reserve_capacity(max_num_seqs) } + /// Pre-allocate the MTP hidden state buffer outside of CUDA graph capture. + /// Must be called before warmup_capture so the buffer lives in regular GPU memory. + pub fn preallocate_mtp_hidden_buffer(&self, max_batch_size: usize) -> Result<()> { + let buf = Tensor::zeros( + (max_batch_size, self.config.hidden_size), + self.dtype, + &self.device, + )?; + if let Ok(mut guard) = self.mtp_hidden_buffer.lock() { + *guard = Some(buf); + } + Ok(()) + } + pub fn set_mamba_prefix_cache_capacity(&self, capacity: usize) { self.mamba_cache.write().set_prefix_cache_capacity(capacity); } @@ -694,6 +778,22 @@ impl Qwen3_5ForCausalLM { self.mamba_cache.write().restore_prefix_state(seq_id, hash) } + pub fn mtp_rollback_mamba(&self, seq_id: usize, keep_tokens: usize) -> Result { + let mut mamba_cache = self.mamba_cache.write(); + let slots = mamba_cache + .get_slots_for_sequences(&[seq_id])? + .into_iter() + .map(|s| s as i64) + .collect::>(); + let seq_slots = Tensor::from_vec(slots, (1,), &self.device)?; + for layer in &self.layers { + if let Qwen3_5AttnType::LinearAttention(gdn) = &layer.attn { + gdn.rollback_mtp_verify(&mut mamba_cache, &seq_slots, keep_tokens)?; + } + } + Ok(true) + } + pub fn reset_mamba_cache(&self) -> Result<()> { self.mamba_cache.write().reset_all() } diff --git a/src/models/qwen3_5_moe.rs b/src/models/qwen3_5_moe.rs index e6eed8b2..6be64e44 100644 --- a/src/models/qwen3_5_moe.rs +++ b/src/models/qwen3_5_moe.rs @@ -333,6 +333,8 @@ pub struct Qwen3_5MoEForCausalLM { dtype: DType, vocab_size: usize, is_qvar_builder: bool, + /// Pre-allocated hidden state buffer for MTP speculative decoding (graph-safe). + pub mtp_hidden_buffer: std::sync::Mutex>, } impl Qwen3_5MoEForCausalLM { @@ -606,6 +608,7 @@ impl Qwen3_5MoEForCausalLM { dtype, vocab_size, is_qvar_builder, + mtp_hidden_buffer: std::sync::Mutex::new(None), }) } @@ -618,6 +621,66 @@ impl Qwen3_5MoEForCausalLM { } } + pub fn embed_weight(&self) -> &Tensor { + self.embed_tokens.embeddings() + } + + pub fn take_last_hidden_for_mtp(&self) -> Option { + let guard = self.mtp_hidden_buffer.lock().ok()?; + let buf = guard.as_ref()?; + buf.get(0).ok() + } + + pub fn preallocate_mtp_hidden_buffer(&self, max_batch_size: usize) -> Result<()> { + let buf = Tensor::zeros( + (max_batch_size, self.config.hidden_size), + self.dtype, + &self.device, + )?; + if let Ok(mut guard) = self.mtp_hidden_buffer.lock() { + *guard = Some(buf); + } + Ok(()) + } + + pub fn forward_with_hidden( + &self, + input_ids: &Tensor, + positions: &Tensor, + kv_caches: Option<&Vec<(Tensor, Tensor)>>, + input_metadata: &InputMetadata, + embeded_inputs: bool, + ) -> Result<(Tensor, Tensor)> { + let hidden = self.forward_inner( + input_ids, + positions, + kv_caches, + input_metadata, + embeded_inputs, + &None, + &None, + true, + )?; + let logits = if self.is_qvar_builder { + self.lm_head.forward(&hidden)? + } else { + self.lm_head + .forward(&hidden.to_dtype(self.dtype)?)? + .to_dtype(DType::F32)? + }; + Ok((logits, hidden)) + } + + pub fn forward_lm_head(&self, hidden: &Tensor) -> Result { + if self.is_qvar_builder { + self.lm_head.forward(hidden) + } else { + self.lm_head + .forward(&hidden.to_dtype(self.dtype)?)? + .to_dtype(DType::F32) + } + } + fn forward_inner( &self, input_ids: &Tensor, @@ -689,12 +752,21 @@ impl Qwen3_5MoEForCausalLM { let xs = self.norm.forward(&xs)?; if return_hidden { xs.to_dtype(DType::F32) - } else if self.is_qvar_builder { - self.lm_head.forward(&xs) } else { - self.lm_head - .forward(&xs.to_dtype(self.dtype)?)? - .to_dtype(DType::F32) + if let Ok(guard) = self.mtp_hidden_buffer.lock() { + if let Some(buf) = guard.as_ref() { + if xs.elem_count() <= buf.elem_count() { + let _ = buf.copy_(&xs, 0); + } + } + } + if self.is_qvar_builder { + self.lm_head.forward(&xs) + } else { + self.lm_head + .forward(&xs.to_dtype(self.dtype)?)? + .to_dtype(DType::F32) + } } } @@ -818,6 +890,23 @@ impl Qwen3_5MoEForCausalLM { self.mamba_cache.write().restore_prefix_state(seq_id, hash) } + pub fn mtp_rollback_mamba(&self, seq_id: usize, keep_tokens: usize) -> Result { + let mut mamba_cache = self.mamba_cache.write(); + + let slots = mamba_cache + .get_slots_for_sequences(&[seq_id])? + .into_iter() + .map(|s| s as i64) + .collect::>(); + let seq_slots = Tensor::from_vec(slots, (1,), &self.device)?; + for layer in &self.layers { + if let Qwen3_5MoEAttnType::LinearAttention(gdn) = &layer.attn { + gdn.rollback_mtp_verify(&mut mamba_cache, &seq_slots, keep_tokens)?; + } + } + Ok(true) + } + pub fn reset_mamba_cache(&self) -> Result<()> { self.mamba_cache.write().reset_all() } diff --git a/src/models/qwen3_5_mtp.rs b/src/models/qwen3_5_mtp.rs new file mode 100644 index 00000000..1c21f0c0 --- /dev/null +++ b/src/models/qwen3_5_mtp.rs @@ -0,0 +1,487 @@ +// src/models/qwen3_5_mtp.rs +// Qwen3.5 MTP (Multi-Token Prediction) Head +// +// The MTP head is a lightweight transformer layer that predicts future tokens +// using the backbone model's hidden states and KV cache. +// +// Supports both dense and MoE variants: +// Dense: mtp.layers.0.mlp.{gate_proj,up_proj,down_proj} +// MoE: mtp.layers.0.mlp.{gate,experts.N.{gate_proj,up_proj,down_proj},shared_expert.*,shared_expert_gate} + +use crate::models::layers::attention::Attention; +use crate::models::layers::distributed::{Comm, ReplicatedLinear}; +use crate::models::layers::linear::LinearX as Linear; +use crate::models::layers::mlp::MLP; +use crate::models::layers::moe::{FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ}; +use crate::models::layers::others::{rms_norm, NormX}; +use crate::models::layers::rotary_emb::{ApplyRotaryEmbedding, ScalingRotaryEmbedding}; +use crate::models::layers::VarBuilderX; +use crate::utils::config::Config; +use candle_core::{DType, Device, Module, Result, Tensor, D}; +use std::rc::Rc; +use std::sync::Arc; + +enum MtpMlp { + Dense(MLP), + Moe { + fused_moe: MtpFusedMoe, + shared_gate: Option, + shared_expert: Option, + }, +} + +enum MtpFusedMoe { + BF16(FusedMoe), + FP8(FusedMoeFp8), + GGUF(FusedMoeGGUF), + ISQ(FusedMoeISQ), +} + +impl MtpFusedMoe { + fn forward(&self, xs: &Tensor, is_prefill: bool) -> Result { + match self { + Self::BF16(m) => m.forward(xs, is_prefill), + Self::FP8(m) => m.forward(xs, is_prefill), + Self::GGUF(m) => m.forward(xs, is_prefill), + Self::ISQ(m) => m.forward(xs, is_prefill), + } + } +} + +impl MtpMlp { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Dense(mlp) => mlp.forward(xs), + Self::Moe { + fused_moe, + shared_gate, + shared_expert, + } => { + let shared_output = match (shared_gate, shared_expert) { + (Some(sg), Some(se)) => { + let gate = candle_nn::ops::sigmoid(&sg.forward(xs)?)?; + let shared_out = se.forward(xs)?; + Some(gate.broadcast_mul(&shared_out)?) + } + _ => None, + }; + let moe_output = fused_moe.forward(xs, false)?; + if let Some(shared_output) = shared_output { + (moe_output + shared_output).map_err(Into::into) + } else { + Ok(moe_output) + } + } + } + } +} + +pub struct Qwen3_5MtpHead { + pre_fc_norm_hidden: NormX, + pre_fc_norm_embedding: NormX, + fc: ReplicatedLinear, + layer: Qwen3_5MtpDecoderLayer, + norm: NormX, + rotary_emb: Arc, + device: Device, + dtype: DType, +} + +struct Qwen3_5MtpDecoderLayer { + attn: Attention, + mlp: MtpMlp, + input_layernorm: NormX, + post_attention_layernorm: NormX, +} + +impl Qwen3_5MtpDecoderLayer { + /// Forward for MTP head - single token, no KV cache needed. + /// For seq_len=1, self-attention is trivially identity on the value, + /// so we compute: output = O_proj(V_proj(norm(x))) after RoPE on Q/K. + /// This avoids going through PagedAttention/FlashInfer backends entirely. + fn forward_single_token( + &self, + xs: &Tensor, + positions: &Tensor, + rotary_emb: &Arc, + ) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let rope: Arc = rotary_emb.clone(); + let attn_output = self + .attn + .forward_single_token_no_cache(&xs, &rope, positions)?; + let xs = (attn_output + residual)?; + let residual = &xs; + let xs = self.post_attention_layernorm.forward(&xs)?; + let mlp_output = self.mlp.forward(&xs)?; + residual + mlp_output + } +} + +impl Qwen3_5MtpHead { + pub fn new( + vb: &VarBuilderX, + comm: Rc, + config: &Config, + dtype: DType, + is_rope_i: bool, + device: &Device, + ) -> Result { + let hidden_size = config.hidden_size; + let is_qvar_builder = vb.is_qvar_builder(); + let prefix = "mtp."; + + let pre_fc_norm_hidden = rms_norm( + hidden_size, + config.rms_norm_eps, + vb.pp(&format!("{}pre_fc_norm_hidden", prefix)), + DType::F32, + !is_qvar_builder, + )?; + + let pre_fc_norm_embedding = rms_norm( + hidden_size, + config.rms_norm_eps, + vb.pp(&format!("{}pre_fc_norm_embedding", prefix)), + DType::F32, + !is_qvar_builder, + )?; + + let fc = ReplicatedLinear::load_no_bias( + hidden_size * 2, + hidden_size, + vb.pp(&format!("{}fc", prefix)), + &None, + &None, + dtype, + )?; + + let norm = rms_norm( + hidden_size, + config.rms_norm_eps, + vb.pp(&format!("{}norm", prefix)), + DType::F32, + !is_qvar_builder, + )?; + + let rotary_emb = Arc::new(ScalingRotaryEmbedding::new( + if is_qvar_builder || config.higher_precision_required() { + DType::F32 + } else { + dtype + }, + config, + device, + is_rope_i, + config.rope_theta, + )?); + + let layer_prefix = format!("{}layers.0", prefix); + let attn = Attention::new( + if is_qvar_builder { + vb.pp(&layer_prefix) + } else { + vb.pp(&format!("{}.self_attn", layer_prefix)) + }, + comm.clone(), + config, + None, + config.sliding_window, + dtype, + )?; + + let mlp_vb = if is_qvar_builder { + vb.pp(&layer_prefix) + } else { + vb.pp(&format!("{}.mlp", layer_prefix)) + }; + + let is_moe = config.moe_cfg.is_some() && mlp_vb.has_key("gate.weight"); + + let mlp = if is_moe { + let moe_cfg = config.moe_cfg.as_ref().unwrap(); + let fused_moe = if is_qvar_builder { + MtpFusedMoe::GGUF(FusedMoeGGUF::new( + config, + mlp_vb.clone(), + comm.clone(), + dtype, + )?) + } else if let Some(quant_config) = &config.quantization_config { + if quant_config.quant_method == "fp8" { + MtpFusedMoe::FP8(FusedMoeFp8::new( + config, + mlp_vb.clone(), + comm.clone(), + dtype, + quant_config, + )?) + } else { + MtpFusedMoe::BF16(FusedMoe::new(config, mlp_vb.clone(), comm.clone(), dtype)?) + } + } else if config.quant.is_some() { + MtpFusedMoe::ISQ(FusedMoeISQ::new( + config, + mlp_vb.clone(), + comm.clone(), + dtype, + )?) + } else { + MtpFusedMoe::BF16(FusedMoe::new(config, mlp_vb.clone(), comm.clone(), dtype)?) + }; + + let (shared_gate, shared_expert) = if let Some(intermediate_size) = + moe_cfg.shared_expert_intermediate_size + { + if intermediate_size > 0 { + let ws = match &mlp_vb.0 { + either::Either::Left(vb) => vb + .pp("shared_expert_gate") + .get((1, hidden_size), "weight")?, + either::Either::Right(vb) => { + let ws = vb.pp("ffn_gate_inp_shexp").get((hidden_size,), "weight")?; + ws.dequantize(&vb.device())?.reshape((1, hidden_size))? + } + } + .to_dtype( + if is_qvar_builder || config.quant.is_some() { + DType::F32 + } else { + dtype + }, + )?; + let shared_gate = Linear::new(ws, None, &None)?; + let shared_mlp = MLP::new( + if is_qvar_builder { + mlp_vb.clone() + } else { + mlp_vb.pp("shared_expert").clone() + }, + comm.clone(), + hidden_size, + intermediate_size, + &config.hidden_act, + &config.quantization_config, + &config.quant, + false, + dtype, + if is_qvar_builder { "_shexp" } else { "" }, + )?; + (Some(shared_gate), Some(shared_mlp)) + } else { + (None, None) + } + } else { + (None, None) + }; + + MtpMlp::Moe { + fused_moe, + shared_gate, + shared_expert, + } + } else { + MtpMlp::Dense(MLP::new( + mlp_vb, + comm.clone(), + hidden_size, + config.intermediate_size, + &config.hidden_act, + &config.quantization_config, + &config.quant, + false, + dtype, + "", + )?) + }; + + let input_layernorm = rms_norm( + hidden_size, + config.rms_norm_eps, + vb.pp(&format!("{}.input_layernorm", layer_prefix)), + DType::F32, + !is_qvar_builder, + )?; + + let post_attention_layernorm = rms_norm( + hidden_size, + config.rms_norm_eps, + vb.pp(&format!("{}.post_attention_layernorm", layer_prefix)), + DType::F32, + !is_qvar_builder, + )?; + + let layer = Qwen3_5MtpDecoderLayer { + attn, + mlp, + input_layernorm, + post_attention_layernorm, + }; + + Ok(Self { + pre_fc_norm_hidden, + pre_fc_norm_embedding, + fc, + layer, + norm, + rotary_emb, + device: device.clone(), + dtype, + }) + } + + /// Single MTP draft step. + /// + /// Given the backbone's last hidden states and the embedding of the current token, + /// produces the next hidden state for the MTP head. + /// The caller should apply lm_head to get logits. + /// + /// `backbone_hidden`: [batch, hidden_size] - last hidden from the backbone + /// `token_embedding`: [batch, hidden_size] - embedding of the last sampled/draft token + /// `positions`: position IDs for this step + /// `kv_cache`: MTP head's own KV cache (separate from backbone) + /// `input_metadata`: attention metadata for this step + pub fn forward_step( + &self, + backbone_hidden: &Tensor, + token_embedding: &Tensor, + positions: &Tensor, + ) -> Result { + let norm_hidden = self.pre_fc_norm_hidden.forward(backbone_hidden)?; + let norm_embed = self.pre_fc_norm_embedding.forward(token_embedding)?; + + // Concat order: [embedding, hidden] — matches vLLM/HuggingFace weight layout + // The FC weight's first half corresponds to embedding columns + let norm_embed = norm_embed.to_dtype(norm_hidden.dtype())?; + let fused = Tensor::cat(&[norm_embed, norm_hidden], D::Minus1)?; + let fused = fused.to_dtype(self.dtype)?; + let xs = self.fc.forward(&fused)?; + + // MTP head uses single-token attention without KV cache + let xs = self + .layer + .forward_single_token(&xs, positions, &self.rotary_emb)?; + + self.norm.forward(&xs) + } + + /// Draft K tokens with all operations on GPU (no per-step CPU round-trips). + /// + /// Uses GPU-resident argmax + gather-based embedding lookup to keep draft + /// tokens on device. Only transfers the final token list to CPU once. + /// + /// Returns (draft_token_ids, last_hidden_state). + pub fn draft_tokens_gpu( + &self, + initial_hidden: &Tensor, + anchor_token_tensor: &Tensor, + num_tokens: usize, + embed_weight: &Tensor, + lm_head_fn: impl Fn(&Tensor) -> Result, + positions_base: usize, + ) -> Result<(Vec, Tensor)> { + let mut gpu_draft_tokens: Vec = Vec::with_capacity(num_tokens); + let mut current_hidden = if initial_hidden.dims().len() == 1 { + initial_hidden.unsqueeze(0)? + } else { + initial_hidden.clone() + }; + let mut current_token_t = anchor_token_tensor.reshape((1,))?; + + for step in 0..num_tokens { + let token_embed = embed_weight.index_select(¤t_token_t, 0)?; + + let pos = (positions_base + step) as i64; + let positions = Tensor::from_vec(vec![pos], (1,), &self.device)?; + + let hidden_out = self.forward_step(¤t_hidden, &token_embed, &positions)?; + + let logits = lm_head_fn(&hidden_out.to_dtype(self.dtype)?)?; + let logits_last = if logits.dims().len() == 2 { + logits.get(logits.dim(0)? - 1)? + } else { + logits + }; + let next_token_t = logits_last.to_dtype(DType::F32)?.argmax(D::Minus1)?; + + gpu_draft_tokens.push(next_token_t.clone()); + current_hidden = if hidden_out.dims().len() == 2 { + hidden_out.get(hidden_out.dim(0)? - 1)?.unsqueeze(0)? + } else { + hidden_out + }; + current_token_t = next_token_t.reshape((1,))?; + } + + let draft_tokens: Vec = if gpu_draft_tokens.is_empty() { + vec![] + } else { + let stacked = Tensor::stack(&gpu_draft_tokens, 0)?; + stacked.to_vec1::()? + }; + + let final_hidden = current_hidden.squeeze(0)?; + Ok((draft_tokens, final_hidden)) + } + + /// Legacy draft method with CPU round-trips (kept for compatibility). + pub fn draft_tokens( + &self, + initial_hidden: &Tensor, + anchor_token: u32, + num_tokens: usize, + embed_fn: impl Fn(u32) -> Result, + lm_head_fn: impl Fn(&Tensor) -> Result, + positions_base: usize, + ) -> Result<(Vec, Tensor)> { + let mut draft_tokens = Vec::with_capacity(num_tokens); + let mut current_hidden = initial_hidden.clone(); + let mut current_token = anchor_token; + + for step in 0..num_tokens { + let token_embed = embed_fn(current_token)?; + let token_embed = match token_embed.dims().len() { + 1 => token_embed.unsqueeze(0)?, + _ => token_embed, + }; + + let current_hidden_2d = match current_hidden.dims().len() { + 1 => current_hidden.unsqueeze(0)?, + _ => current_hidden.clone(), + }; + + let pos = (positions_base + step) as i64; + let positions = Tensor::from_vec(vec![pos], (1,), &self.device)?; + + let hidden_out = self.forward_step(¤t_hidden_2d, &token_embed, &positions)?; + + let logits = lm_head_fn(&hidden_out.to_dtype(self.dtype)?)?; + let logits = logits.to_dtype(DType::F32)?; + let logits_last = if logits.dims().len() == 2 { + logits.get(logits.dim(0)? - 1)? + } else { + logits + }; + let next_token = logits_last.argmax(D::Minus1)?.to_scalar::()?; + + draft_tokens.push(next_token); + current_hidden = if hidden_out.dims().len() == 2 { + hidden_out.get(hidden_out.dim(0)? - 1)? + } else { + hidden_out + }; + current_token = next_token; + } + + Ok((draft_tokens, current_hidden)) + } + + pub fn device(&self) -> &Device { + &self.device + } + + pub fn dtype(&self) -> DType { + self.dtype + } +} diff --git a/src/models/qwen3_vl/mod.rs b/src/models/qwen3_vl/mod.rs index 35268917..059989a6 100644 --- a/src/models/qwen3_vl/mod.rs +++ b/src/models/qwen3_vl/mod.rs @@ -575,6 +575,14 @@ impl Qwen3VLForConditionalGeneration { } } + pub fn mtp_rollback_mamba(&self, seq_id: usize, keep_tokens: usize) -> Result { + match &self.text_model { + Qwen3TextModel::Dense35(m) => m.mtp_rollback_mamba(seq_id, keep_tokens), + Qwen3TextModel::MoE35(m) => m.mtp_rollback_mamba(seq_id, keep_tokens), + _ => Ok(false), + } + } + pub fn reset_mamba_cache(&self) -> Result<()> { match &self.text_model { Qwen3TextModel::Dense35(m) => m.reset_mamba_cache(), @@ -582,4 +590,79 @@ impl Qwen3VLForConditionalGeneration { _ => Ok(()), } } + + /// Forward pass that returns both logits and hidden states (for MTP drafting). + pub fn forward_with_hidden( + &self, + input_ids: &Tensor, + positions: &Tensor, + kv_caches: Option<&Vec<(Tensor, Tensor)>>, + input_metadata: &InputMetadata, + embeded_inputs: bool, + ) -> Result<(Tensor, Tensor)> { + match &self.text_model { + Qwen3TextModel::Dense35(m) => m.forward_with_hidden( + input_ids, + positions, + kv_caches, + input_metadata, + embeded_inputs, + ), + Qwen3TextModel::MoE35(m) => m.forward_with_hidden( + input_ids, + positions, + kv_caches, + input_metadata, + embeded_inputs, + ), + _ => { + candle_core::bail!("forward_with_hidden only supported for Qwen3.5 text models") + } + } + } + + /// Apply lm_head to hidden states (for MTP drafting). + pub fn forward_lm_head(&self, hidden: &Tensor) -> Result { + match &self.text_model { + Qwen3TextModel::Dense35(m) => m.forward_lm_head(hidden), + Qwen3TextModel::MoE35(m) => m.forward_lm_head(hidden), + _ => candle_core::bail!("forward_lm_head only supported for Qwen3.5 text models"), + } + } + + /// Get token embedding for a single token (for MTP drafting). + pub fn embed_forward(&self, input_ids: &Tensor) -> Result { + match &self.text_model { + Qwen3TextModel::Dense35(m) => m.embed_forward(input_ids), + Qwen3TextModel::MoE35(m) => m.embed_forward(input_ids), + Qwen3TextModel::Dense(m) => m.embed_forward(input_ids), + Qwen3TextModel::MoE(m) => m.embed_forward(input_ids), + } + } + + pub fn embed_weight(&self) -> Option<&candle_core::Tensor> { + match &self.text_model { + Qwen3TextModel::Dense35(m) => Some(m.embed_weight()), + Qwen3TextModel::MoE35(m) => Some(m.embed_weight()), + _ => None, + } + } + + /// Take the cached last hidden state for MTP + pub fn take_last_hidden_for_mtp(&self) -> Option { + match &self.text_model { + Qwen3TextModel::Dense35(m) => m.take_last_hidden_for_mtp(), + Qwen3TextModel::MoE35(m) => m.take_last_hidden_for_mtp(), + _ => None, + } + } + + /// Pre-allocate the MTP hidden state buffer + pub fn preallocate_mtp_hidden_buffer(&self, max_batch_size: usize) -> Result<()> { + match &self.text_model { + Qwen3TextModel::Dense35(m) => m.preallocate_mtp_hidden_buffer(max_batch_size), + Qwen3TextModel::MoE35(m) => m.preallocate_mtp_hidden_buffer(max_batch_size), + _ => Ok(()), + } + } } diff --git a/src/runner/mod.rs b/src/runner/mod.rs index c49886a2..98a109f5 100644 --- a/src/runner/mod.rs +++ b/src/runner/mod.rs @@ -181,9 +181,15 @@ pub enum MessageType { /// Sent by main process to request inference on sequences. RunDecode((Vec, bool)), + /// Sent by main process to request MTP speculative decode on sequences. + RunDecodeMTP(Vec), + /// Sent by runner in response to `Run` with generated token IDs. RunResponse(Vec), + /// Sent by runner in response to `RunDecodeMTP` with multiple tokens per sequence. + RunResponseMTP(Vec>), + /// Sent by main process to request embedding on sequences. RunEmbed((Vec, EmbeddingStrategy)), @@ -816,6 +822,26 @@ pub fn run_runner_process(args: Vec) -> anyhow::Result<()> { false, )?; } + Ok(MessageType::RunDecodeMTP(sequences)) => { + let outputs = runner.run_mtp_decode(Seqs::DecodeVec(&sequences)); + match outputs { + Ok(multi_tokens) => { + send_local( + &mut vec![stream.try_clone()?], + &MessageType::RunResponseMTP(multi_tokens), + false, + )?; + } + Err(e) => { + crate::log_error!("Runner MTP decode error: {:?}", e); + send_local( + &mut vec![stream.try_clone()?], + &MessageType::RunResponseMTP(vec![]), + false, + )?; + } + } + } Ok(MessageType::ClearBlocks(block_ids)) => { let ret = runner.clear_blocks(block_ids); if ret.is_err() { diff --git a/src/runner/runner.rs b/src/runner/runner.rs index 44ba4945..7f59d7e7 100644 --- a/src/runner/runner.rs +++ b/src/runner/runner.rs @@ -281,6 +281,26 @@ pub fn run_runner() -> anyhow::Result<()> { false, )?; } + Ok(MessageType::RunDecodeMTP(sequences)) => { + let outputs = runner.run_mtp_decode(Seqs::DecodeVec(&sequences)); + match outputs { + Ok(multi_tokens) => { + send_local( + &mut vec![stream.try_clone()?], + &MessageType::RunResponseMTP(multi_tokens), + false, + )?; + } + Err(e) => { + xinfer::log_error!("Runner MTP decode error: {:?}", e); + send_local( + &mut vec![stream.try_clone()?], + &MessageType::RunResponseMTP(vec![]), + false, + )?; + } + } + } Ok(MessageType::RunEmbed((sequences, strategy))) => { use xinfer::core::sequence::Sequence; let refs: Vec<&Sequence> = sequences.iter().collect(); diff --git a/src/server/mod.rs b/src/server/mod.rs index 7b66b5f1..7a8f7b8f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1007,6 +1007,12 @@ pub struct Args { /// Metal uses half of this value after rounding. #[arg(long, default_value_t = crate::utils::config::DEFAULT_PREFILL_CHUNK_SIZE)] pub prefill_chunk_size: usize, + + /// Enable MTP (Multi-Token Prediction) speculative decoding. + /// Specifies the number of speculative draft tokens per step (e.g. 3-7). + /// The model must have MTP heads (e.g. Qwen3.5, DeepSeek-V3). + #[arg(long, default_value = None)] + pub mtp: Option, } /// Result of executing tool calls via MCP diff --git a/src/utils/config.rs b/src/utils/config.rs index 6ae8749d..32d1d28e 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -303,6 +303,10 @@ pub struct Config { pub extra_config_json: Option, #[serde(default)] pub is_f16_mode: bool, + #[serde(default)] + pub mtp_num_hidden_layers: Option, + #[serde(default)] + pub mtp_use_dedicated_embeddings: Option, } impl Config { @@ -397,6 +401,10 @@ pub struct EngineConfig { pub disable_cuda_graph: bool, #[serde(default = "default_prefill_chunk_size")] pub prefill_chunk_size: usize, + /// MTP (Multi-Token Prediction) speculative decoding: number of draft tokens per step. + /// None means MTP is disabled. + #[serde(default)] + pub mtp_num_speculative_tokens: Option, } #[cfg(feature = "python")] @@ -483,6 +491,8 @@ pub struct EngineConfig { #[pyo3(get, set)] #[serde(default = "default_prefill_chunk_size")] pub prefill_chunk_size: usize, + #[serde(default)] + pub mtp_num_speculative_tokens: Option, } impl EngineConfig { @@ -587,8 +597,14 @@ impl EngineConfig { prefill_chunk_size: normalize_prefill_chunk_size( prefill_chunk_size.unwrap_or(DEFAULT_PREFILL_CHUNK_SIZE), ), + mtp_num_speculative_tokens: None, } } + + pub fn with_mtp(mut self, mtp_tokens: Option) -> Self { + self.mtp_num_speculative_tokens = mtp_tokens; + self + } } #[derive(Clone, Debug, serde::Deserialize)] diff --git a/src/utils/graph.rs b/src/utils/graph.rs index 79759e29..53d4a391 100644 --- a/src/utils/graph.rs +++ b/src/utils/graph.rs @@ -374,6 +374,24 @@ pub struct GraphCaptureVars { pub outputs: BTreeMap, } +pub struct MtpGraphCaptureVars { + pub input_ids: Tensor, + pub positions: Tensor, + pub mamba_slot_mapping: Tensor, + pub slot_mapping: Tensor, + pub context_lens: Tensor, + pub block_tables: Tensor, + pub cu_seqlens_q: Tensor, + pub cu_seqlens_k: Tensor, + #[cfg(feature = "flashinfer")] + pub flashinfer_indptr: Tensor, + #[cfg(feature = "flashinfer")] + pub flashinfer_indices: Tensor, + #[cfg(feature = "flashinfer")] + pub flashinfer_last_len: Tensor, + pub outputs: BTreeMap, +} + pub struct GraphCapturer { pub model: M, pub graph_bs: Vec, @@ -386,6 +404,7 @@ pub struct GraphCapturer { #[cfg(feature = "flashinfer")] pub flashinfer_kv_params: Option, pub is_mla: bool, + pub mtp_graph_vars: Option, } pub fn planned_graph_capture_batches(max_num_seqs: usize) -> Vec { @@ -463,6 +482,7 @@ impl GraphCapturer { #[cfg(feature = "flashinfer")] flashinfer_kv_params: flashinfer_kv_params.clone(), is_mla, + mtp_graph_vars: None, } } @@ -599,6 +619,7 @@ impl GraphCapturer { max_context_len: self.max_model_len, seqlens: None, flashinfer_metadata, + is_mtp_verify: false, }; let should_capture = @@ -793,6 +814,261 @@ impl GraphCapturer { candle_core::bail!("Graph is not captured!") } } + + pub fn capture_mtp( + &mut self, + device: &Device, + kv_caches: Option<&Vec<(Tensor, Tensor)>>, + mtp_num_speculative: usize, + ) -> Result<()> { + if mtp_num_speculative == 0 { + return Ok(()); + } + + self.device = Some(device.clone()); + let verify_len = mtp_num_speculative + 1; + let max_num_blocks = (self.max_model_len + self.block_size - 1) / self.block_size; + + let input_ids = Tensor::zeros((verify_len,), DType::U32, device)?; + let positions = Tensor::zeros((verify_len,), DType::I64, device)?; + let mamba_slot_mapping = Tensor::zeros((1,), DType::I64, device)?; + let slot_mapping = Tensor::zeros((verify_len,), DType::I64, device)?; + let context_lens = Tensor::zeros((1,), DType::U32, device)?; + let block_tables = Tensor::zeros((1, max_num_blocks), DType::U32, device)?; + let cu_seqlens_q = Tensor::zeros((2,), DType::U32, device)?; + let cu_seqlens_k = Tensor::zeros((2,), DType::U32, device)?; + + #[cfg(feature = "flashinfer")] + let flashinfer_indptr = Tensor::zeros((2,), DType::U32, device)?; + #[cfg(feature = "flashinfer")] + let flashinfer_indices = Tensor::zeros((max_num_blocks,), DType::U32, device)?; + #[cfg(feature = "flashinfer")] + let flashinfer_last_len = Tensor::zeros((1,), DType::U32, device)?; + + #[cfg(feature = "flashinfer")] + let use_flashinfer = self.flashinfer_kv_params.is_some(); + #[cfg(not(feature = "flashinfer"))] + let use_flashinfer = false; + + let capture_in_warmup = use_flashinfer; + + #[cfg(feature = "flashinfer")] + let flashinfer_metadata = if let Some(params) = self.flashinfer_kv_params { + let indptr_host = vec![0u32, max_num_blocks as u32]; + let kv_len_arr_host = vec![self.max_model_len as u32]; + let q_cu_seqlens_host = vec![0u32, verify_len as u32]; + + let prefill_plan_info = attention_rs::flashinfer::graph_prefill_plan( + device, + &q_cu_seqlens_host, + &indptr_host, + &kv_len_arr_host, + verify_len as u32, + 1, + params.num_qo_heads, + params.num_kv_heads, + params.head_dim, + params.page_size, + params.out_dtype, + None, + Some(params.kv_dtype), + )?; + + Some(attention_rs::FlashInferMetadata { + indptr: flashinfer_indptr.clone(), + indptr_host, + indices: flashinfer_indices.clone(), + last_len: flashinfer_last_len.clone(), + last_len_host: Some(vec![self.max_model_len as u32]), + kv_len_arr_host: Some(kv_len_arr_host), + total_num_rows: Some(verify_len as u32), + batch_indices: None, + positions: None, + use_cuda_graph: true, + decode_plan_info: None, + prefill_plan_info: Some(prefill_plan_info), + mla_decode_plan_info: None, + mla_prefill_plan_info: None, + }) + } else { + None + }; + #[cfg(not(feature = "flashinfer"))] + let flashinfer_metadata = None; + + let input_metadata = InputMetadata { + is_prefill: true, + is_mla: self.is_mla, + sequence_ids: Some(vec![0]), + mamba_slot_mapping: Some(mamba_slot_mapping.clone()), + slot_mapping: slot_mapping.clone(), + block_tables: Some(block_tables.clone()), + context_lens: Some(context_lens.clone()), + cu_seqlens_q: Some(cu_seqlens_q.clone()), + cu_seqlens_k: Some(cu_seqlens_k.clone()), + max_seqlen_q: verify_len, + max_seqlen_k: self.max_model_len, + max_context_len: self.max_model_len, + seqlens: None, + flashinfer_metadata, + is_mtp_verify: true, + }; + + let mut outputs = BTreeMap::::new(); + let _guard = candle_core::cuda_backend::cuda_param_cache_scope(true); + + for is_warmup in [true, false] { + if !is_warmup || capture_in_warmup { + self.model.start_capture(verify_len)?; + } + if is_warmup { + let _ = self.model.forward( + &input_ids, + &positions, + kv_caches, + &input_metadata, + false, + )?; + } else { + let out = self.model.forward( + &input_ids, + &positions, + kv_caches, + &input_metadata, + false, + )?; + outputs.insert(verify_len, out); + } + if !is_warmup || capture_in_warmup { + self.model.end_capture(!is_warmup)?; + } + } + + crate::log_warn!( + "Captured MTP verify graph len={} (flashinfer={})", + verify_len, + use_flashinfer + ); + + self.mtp_graph_vars = Some(MtpGraphCaptureVars { + input_ids, + positions, + mamba_slot_mapping, + slot_mapping, + context_lens, + block_tables, + cu_seqlens_q, + cu_seqlens_k, + #[cfg(feature = "flashinfer")] + flashinfer_indptr, + #[cfg(feature = "flashinfer")] + flashinfer_indices, + #[cfg(feature = "flashinfer")] + flashinfer_last_len, + outputs, + }); + Ok(()) + } + + pub fn is_mtp_captured(&self, verify_len: usize) -> bool { + self.mtp_graph_vars + .as_ref() + .map_or(false, |v| v.outputs.contains_key(&verify_len)) + } + + pub fn replay_mtp( + &self, + input_ids: &Tensor, + positions: &Tensor, + input_metadata: &InputMetadata, + ) -> Result { + let verify_len = input_ids.dim(0)?; + let max_num_blocks = (self.max_model_len + self.block_size - 1) / self.block_size; + + let mtp_vars = self + .mtp_graph_vars + .as_ref() + .ok_or_else(|| candle_core::Error::msg("MTP graphs not captured"))?; + + if !mtp_vars.outputs.contains_key(&verify_len) { + candle_core::bail!("MTP verify graph for len {} is not captured!", verify_len); + } + + mtp_vars.input_ids.zero_()?; + mtp_vars.input_ids.copy_(input_ids, 0)?; + mtp_vars.positions.zero_()?; + mtp_vars.positions.copy_(positions, 0)?; + + if let Some(ms_mapping) = input_metadata.mamba_slot_mapping.as_ref() { + mtp_vars.mamba_slot_mapping.zero_()?; + mtp_vars.mamba_slot_mapping.copy_(ms_mapping, 0)?; + } + + mtp_vars.slot_mapping.zero_()?; + mtp_vars + .slot_mapping + .copy_(&input_metadata.slot_mapping, 0)?; + + if let Some(c_lens) = input_metadata.context_lens.as_ref() { + mtp_vars.context_lens.zero_()?; + mtp_vars.context_lens.copy_(c_lens, 0)?; + } + + if let Some(b_tables) = input_metadata.block_tables.as_ref() { + let padded_table = b_tables + .pad_with_zeros(1, 0, max_num_blocks - b_tables.dim(1)?)? + .contiguous()?; + mtp_vars.block_tables.zero_()?; + mtp_vars.block_tables.copy_(&padded_table, 0)?; + } + + if let Some(cu_q) = input_metadata.cu_seqlens_q.as_ref() { + mtp_vars.cu_seqlens_q.copy_(cu_q, 0)?; + } + if let Some(cu_k) = input_metadata.cu_seqlens_k.as_ref() { + mtp_vars.cu_seqlens_k.copy_(cu_k, 0)?; + } + + #[cfg(feature = "flashinfer")] + if let Some(fm) = input_metadata.flashinfer_metadata.as_ref() { + mtp_vars.flashinfer_indptr.zero_()?; + mtp_vars.flashinfer_indptr.copy_(&fm.indptr, 0)?; + mtp_vars.flashinfer_indices.zero_()?; + mtp_vars.flashinfer_indices.copy_(&fm.indices, 0)?; + mtp_vars.flashinfer_last_len.zero_()?; + mtp_vars.flashinfer_last_len.copy_(&fm.last_len, 0)?; + + if let Some(params) = self.flashinfer_kv_params { + let dev = self + .device + .as_ref() + .ok_or_else(|| candle_core::Error::msg("graph device is missing"))?; + let kv_len_arr_host = fm.kv_len_arr_host.as_deref().ok_or_else(|| { + candle_core::Error::msg("mtp replay requires kv_len_arr_host") + })?; + let q_cu_seqlens_host = vec![0u32, verify_len as u32]; + let _ = attention_rs::flashinfer::graph_prefill_plan( + dev, + &q_cu_seqlens_host, + &fm.indptr_host, + kv_len_arr_host, + verify_len as u32, + 1, + params.num_qo_heads, + params.num_kv_heads, + params.head_dim, + params.page_size, + params.out_dtype, + None, + Some(params.kv_dtype), + )?; + } + } + + self.model.replay(verify_len)?; + + mtp_vars.outputs[&verify_len].contiguous() + } } unsafe impl Send for CudaGraph {} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index fd30bb1c..802a760d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -642,6 +642,8 @@ pub fn config_from_gguf( is_multi_model: None, extra_config_json, is_f16_mode: false, + mtp_num_hidden_layers: None, + mtp_use_dedicated_embeddings: None, }; if arch == "gemma4" || arch == "gemma3" {