From c59f88979fe4f9b9eb6306703483eed5494aa2ec Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 1 Apr 2026 23:34:12 +0800 Subject: [PATCH 01/14] Support mxfp4 models --- Cargo.toml | 2 +- src/models/layers/linear.rs | 73 ++++++++++++++ src/models/layers/moe.rs | 185 ++++++++++++++++++++++++++++++++++++ src/models/qwen3_5_moe.rs | 13 ++- src/models/qwen3_moe.rs | 13 ++- src/utils/mod.rs | 5 +- 6 files changed, 284 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 41b9e892..48bf3a7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "7dcf5bd" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "38a4400" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" diff --git a/src/models/layers/linear.rs b/src/models/layers/linear.rs index 69846388..d9ebc010 100644 --- a/src/models/layers/linear.rs +++ b/src/models/layers/linear.rs @@ -432,6 +432,7 @@ pub enum LinearX { Linear(Linear), QLinear(QLinear), LnFp8(LnFp8), + LnMxfp4(LnMxfp4), } impl Module for LinearX { @@ -439,6 +440,7 @@ impl Module for LinearX { match self { Self::Linear(ln) => ln.forward(x), Self::QLinear(ln) => ln.forward(x), + Self::LnMxfp4(ln) => ln.forward(x), Self::LnFp8(ln) => ln.forward(x), } } @@ -452,6 +454,7 @@ impl LinearX { } Self::QLinear(ln) => ln.indexed_moe_forward(x, ids), Self::LnFp8(_) => panic!("LnFp8 does not support indexed_moe_forward yet"), + Self::LnMxfp4(_) => panic!("LnMxfp4 does not support indexed_moe_forward yet"), } } } @@ -478,6 +481,7 @@ impl LinearX { } Self::QLinear(ln) => ln.dequantize(), Self::LnFp8(_) => panic!("LnFp8 unable to be dequantized"), + Self::LnMxfp4(_) => panic!("LnMxfp4 unable to be dequantized"), } } } @@ -520,6 +524,11 @@ pub fn linear_x( } } + if cfg.quant_method == "mxfp4" { + let ln = LnMxfp4::load(in_dim, out_dim, vb.clone(), shards, true)?; + return Ok(LinearX::LnMxfp4(ln)); + } + let wna16 = WNA16::new( in_dim, out_dim, @@ -598,6 +607,11 @@ pub fn linear_no_bias_x( } } + if cfg.quant_method == "mxfp4" { + let ln = LnMxfp4::load(in_dim, out_dim, vb.clone(), shards, false)?; + return Ok(LinearX::LnMxfp4(ln)); + } + let wna16 = WNA16::new( in_dim, out_dim, @@ -1067,3 +1081,62 @@ impl Module for LnFp8 { } } } + +/// MXFP4 linear layer: packed FP4 E2M1 weights with E8M0 block scales. +#[derive(Debug, Clone)] +pub struct LnMxfp4 { + pub blocks: Tensor, + pub scales: Tensor, + pub bias: Option, +} + +impl LnMxfp4 { + pub fn load( + in_dim: usize, + out_dim: usize, + vb: VarBuilder, + shard: Shard, + load_bias: bool, + ) -> Result { + let blocks = vb.get_with_hints_dtype((out_dim, in_dim / 2), "blocks", shard, DType::U8)?; + let scales = vb.get_with_hints_dtype((out_dim, in_dim / 32), "scales", shard, DType::U8)?; + let bias = if load_bias { + Some(vb.get((out_dim,), "bias")?) + } else { + None + }; + Ok(Self { + blocks, + scales, + bias, + }) + } +} + +impl Module for LnMxfp4 { + fn forward(&self, x: &Tensor) -> Result { + let orig_dims = x.dims().to_vec(); + let x_2d = if orig_dims.len() > 2 { + let features = orig_dims[orig_dims.len() - 1]; + let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product(); + x.reshape((batch_size, features))? + } else { + x.clone() + }; + + let result = attention_rs::mxfp4_linear::mxfp4_matmul( + &x_2d, + &self.blocks, + &self.scales, + self.bias.as_ref(), + )?; + + if orig_dims.len() > 2 { + let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec(); + new_dims.push(result.dim(1)?); + result.reshape(new_dims) + } else { + Ok(result) + } + } +} diff --git a/src/models/layers/moe.rs b/src/models/layers/moe.rs index 30f92d00..4adadafc 100644 --- a/src/models/layers/moe.rs +++ b/src/models/layers/moe.rs @@ -6,6 +6,7 @@ use crate::utils::config::Config; use crate::utils::config::QuantConfig; use attention_rs::moe; use attention_rs::moe::moe_gemm_fp8; +use attention_rs::mxfp4_linear; use candle_core::Module; use candle_core::{ quantized::{GgmlDType, QTensor}, @@ -1262,3 +1263,187 @@ impl FusedMoeFp8 { Ok(ys.to_dtype(self.dtype)?) } } + +pub struct FusedMoeMxfp4 { + gate: Linear, + gate_up_blocks: Tensor, + gate_up_scales: Tensor, + down_blocks: Tensor, + down_scales: Tensor, + w_size_n: usize, + act: candle_nn::Activation, + norm_topk_prob: bool, + routed_scaling_factor: Option, + num_experts_per_tok: usize, + all_reduce: AllReduce, + world_size: usize, + dtype: DType, +} + +impl FusedMoeMxfp4 { + pub fn new(cfg: &Config, vb: VarBuilderX, comm: Rc, dtype: DType) -> Result { + let moe_cfg = cfg.moe_cfg.as_ref().expect("MoE config is not available!"); + let num_experts = moe_cfg.num_experts.unwrap(); + + let gate = linear_no_bias( + cfg.hidden_size, + num_experts, + vb.pp("gate"), + Shard::default(), + &None, + &None, + dtype, + )?; + + let experts_vb = vb.pp("experts"); + + let mut gate_blocks_vec = Vec::new(); + let mut gate_scales_vec = Vec::new(); + let mut up_blocks_vec = Vec::new(); + let mut up_scales_vec = Vec::new(); + let mut down_blocks_vec = Vec::new(); + let mut down_scales_vec = Vec::new(); + + match &experts_vb.0 { + Either::Left(vb) => { + for i in 0..num_experts { + let expert_vb = vb.pp(i.to_string()); + + let gate_b = expert_vb.pp("gate_proj").get_with_hints_dtype( + (moe_cfg.moe_intermediate_size, cfg.hidden_size / 2), + "blocks", + shard(0, comm.rank(), comm.world_size()), + DType::U8, + )?; + let gate_s = expert_vb.pp("gate_proj").get_with_hints_dtype( + (moe_cfg.moe_intermediate_size, cfg.hidden_size / 32), + "scales", + shard(0, comm.rank(), comm.world_size()), + DType::U8, + )?; + + let up_b = expert_vb.pp("up_proj").get_with_hints_dtype( + (moe_cfg.moe_intermediate_size, cfg.hidden_size / 2), + "blocks", + shard(0, comm.rank(), comm.world_size()), + DType::U8, + )?; + let up_s = expert_vb.pp("up_proj").get_with_hints_dtype( + (moe_cfg.moe_intermediate_size, cfg.hidden_size / 32), + "scales", + shard(0, comm.rank(), comm.world_size()), + DType::U8, + )?; + + let down_b = expert_vb.pp("down_proj").get_with_hints_dtype( + (cfg.hidden_size, moe_cfg.moe_intermediate_size / 2), + "blocks", + shard(1, comm.rank(), comm.world_size()), + DType::U8, + )?; + let down_s = expert_vb.pp("down_proj").get_with_hints_dtype( + (cfg.hidden_size, moe_cfg.moe_intermediate_size / 32), + "scales", + shard(1, comm.rank(), comm.world_size()), + DType::U8, + )?; + + gate_blocks_vec.push(gate_b); + gate_scales_vec.push(gate_s); + up_blocks_vec.push(up_b); + up_scales_vec.push(up_s); + down_blocks_vec.push(down_b); + down_scales_vec.push(down_s); + } + } + _ => candle_core::bail!("FusedMoeMxfp4: GGUF loading not supported for MXFP4"), + } + + let gate_blocks = Tensor::stack(&gate_blocks_vec, 0)?; + let gate_scales = Tensor::stack(&gate_scales_vec, 0)?; + let up_blocks = Tensor::stack(&up_blocks_vec, 0)?; + let up_scales = Tensor::stack(&up_scales_vec, 0)?; + + let gate_up_blocks = Tensor::cat(&[&gate_blocks, &up_blocks], 1)?; + let gate_up_scales = Tensor::cat(&[&gate_scales, &up_scales], 1)?; + let w_size_n = gate_up_blocks.dim(1)? / 2; + + let down_blocks = Tensor::stack(&down_blocks_vec, 0)?; + let down_scales = Tensor::stack(&down_scales_vec, 0)?; + + Ok(Self { + gate, + gate_up_blocks, + gate_up_scales, + down_blocks, + down_scales, + w_size_n, + act: candle_nn::Activation::Silu, + norm_topk_prob: moe_cfg.norm_topk_prob, + routed_scaling_factor: moe_cfg.routed_scaling_factor, + num_experts_per_tok: moe_cfg.num_experts_per_tok, + all_reduce: AllReduce::new(comm.clone()), + world_size: comm.world_size(), + dtype, + }) + } + + pub fn forward(&self, xs: &Tensor, _is_prefill: bool) -> Result { + let (num_tokens, hidden_dim) = xs.dims2()?; + let router_logits = self.gate.forward(xs)?; + + let (mut topk_weights, topk_ids) = attention_rs::topk::topk_softmax( + &router_logits.to_dtype(DType::F32)?, + self.num_experts_per_tok, + )?; + + if self.norm_topk_prob { + topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?; + } + + if let Some(routed_scaling_factor) = self.routed_scaling_factor { + topk_weights = (topk_weights * routed_scaling_factor)?; + } + + let xs = if xs.dtype() == DType::F32 { + xs.to_dtype(self.dtype)? + } else { + xs.clone() + }; + + let gate_up = mxfp4_linear::mxfp4_moe_gemm( + &xs, + &self.gate_up_blocks, + &self.gate_up_scales, + None, + &topk_ids, + )?; + + let gate = gate_up + .narrow(candle_core::D::Minus1, 0, self.w_size_n)? + .contiguous()?; + let up = gate_up + .narrow(candle_core::D::Minus1, self.w_size_n, self.w_size_n)? + .contiguous()?; + let down_inputs = (up * gate.apply(&self.act)?)?; + + let down_inputs_2d = down_inputs.reshape((num_tokens * self.num_experts_per_tok, ()))?; + + let down = mxfp4_linear::mxfp4_moe_gemm( + &down_inputs_2d, + &self.down_blocks, + &self.down_scales, + None, + &topk_ids, + )?; + + let mut ys = (down * topk_weights.unsqueeze(D::Minus1)?)? + .reshape((num_tokens, self.num_experts_per_tok, hidden_dim))? + .sum(1)?; + + if self.world_size > 1 { + ys = self.all_reduce.apply(&ys)?; + } + Ok(ys.to_dtype(self.dtype)?) + } +} diff --git a/src/models/qwen3_5_moe.rs b/src/models/qwen3_5_moe.rs index 94ba618e..572a469a 100644 --- a/src/models/qwen3_5_moe.rs +++ b/src/models/qwen3_5_moe.rs @@ -6,7 +6,7 @@ use crate::models::layers::distributed::{Comm, ReplicatedLinear}; use crate::models::layers::linear::LinearX as Linear; use crate::models::layers::mask::get_attention_causal_mask; use crate::models::layers::mlp::MLP; -use crate::models::layers::moe::{FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ}; +use crate::models::layers::moe::{FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ, FusedMoeMxfp4}; use crate::models::layers::others::{embedding, rms_norm, NormX}; use crate::models::layers::rotary_emb::{ApplyRotaryEmbedding, ScalingRotaryEmbedding}; use crate::models::layers::VarBuilderX; @@ -32,6 +32,7 @@ enum MoeOrMlp { FusedMoeGGUF(FusedMoeGGUF), FusedMoeISQ(FusedMoeISQ), FusedMoeFp8(FusedMoeFp8), + FusedMoeMxfp4(FusedMoeMxfp4), } impl MoeOrMlp { @@ -41,6 +42,7 @@ impl MoeOrMlp { Self::FusedMoeGGUF(m) => m.forward(xs, is_prefill), Self::FusedMoeISQ(m) => m.forward(xs, is_prefill), Self::FusedMoeFp8(m) => m.forward(xs, is_prefill), + Self::FusedMoeMxfp4(m) => m.forward(xs, is_prefill), } } } @@ -126,8 +128,15 @@ impl Qwen3_5MoEDecoderLayer { dtype, quant_config, )?) + } else if quant_config.quant_method == "mxfp4" { + MoeOrMlp::FusedMoeMxfp4(FusedMoeMxfp4::new( + config, + vb.pp("mlp").clone(), + comm.clone(), + dtype, + )?) } else { - panic!("Unsupported quant method for MoE (use unquantized, gguf or fp8)!"); + panic!("Unsupported quant method for MoE (use unquantized, gguf, fp8 or mxfp4)!"); } } else if config.quant.is_some() { MoeOrMlp::FusedMoeISQ(FusedMoeISQ::new( diff --git a/src/models/qwen3_moe.rs b/src/models/qwen3_moe.rs index 52e30187..3d59ff4d 100644 --- a/src/models/qwen3_moe.rs +++ b/src/models/qwen3_moe.rs @@ -4,7 +4,7 @@ use crate::models::layers::distributed::{Comm, ReplicatedLinear}; use crate::models::layers::linear::LinearX as Linear; use crate::models::layers::mask::get_attention_causal_mask; use crate::models::layers::mlp::MLP; -use crate::models::layers::moe::{FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ}; +use crate::models::layers::moe::{FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ, FusedMoeMxfp4}; use crate::models::layers::others::{embedding, rms_norm, NormX}; use crate::models::layers::rotary_emb::{ApplyRotaryEmbedding, ScalingRotaryEmbedding}; use crate::models::layers::VarBuilderX; @@ -25,6 +25,7 @@ enum MoeOrMlp { FusedMoeGGUF(FusedMoeGGUF), FusedMoeISQ(FusedMoeISQ), FusedMoeFp8(FusedMoeFp8), + FusedMoeMxfp4(FusedMoeMxfp4), Mlp(MLP), } @@ -36,6 +37,7 @@ impl MoeOrMlp { Self::FusedMoeGGUF(m) => m.forward(xs, is_prefill), Self::FusedMoeISQ(m) => m.forward(xs, is_prefill), Self::FusedMoeFp8(m) => m.forward(xs, is_prefill), + Self::FusedMoeMxfp4(m) => m.forward(xs, is_prefill), } } } @@ -98,8 +100,15 @@ impl Qwen3DecoderLayer { dtype, quant_config, )?) + } else if quant_config.quant_method == "mxfp4" { + MoeOrMlp::FusedMoeMxfp4(FusedMoeMxfp4::new( + config, + vb.pp("mlp").clone(), + comm.clone(), + dtype, + )?) } else { - panic!("This feature is under developement (use unquantized, gguf or isq to gguf instead)!"); + panic!("This feature is under developement (use unquantized, gguf, fp8 or mxfp4 instead)!"); } } else if config.quant.is_some() { MoeOrMlp::FusedMoeISQ(FusedMoeISQ::new( diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 8ab2c47c..c722e372 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -996,8 +996,9 @@ pub fn init_config_tokenizer( assert!( qcfg.quant_method == "gptq" || qcfg.quant_method == "awq" - || qcfg.quant_method == "fp8", - "Invalid quantization format! Only `gptq`, `awq` and `fp8` supported" + || qcfg.quant_method == "fp8" + || qcfg.quant_method == "mxfp4", + "Invalid quantization format! Only `gptq`, `awq`, `fp8` and `mxfp4` supported" ); if qcfg.quant_method == "gptq" || qcfg.quant_method == "awq" { assert!( From 36da24a365734718d737160d77c67052f7df8fac Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Thu, 2 Apr 2026 09:37:24 +0000 Subject: [PATCH 02/14] Working on Hopper --- Cargo.toml | 2 +- src/models/layers/linear.rs | 23 ++++++++++-- src/models/layers/moe.rs | 62 +++++++++++++++++++++++--------- src/utils/config.rs | 70 +++++++++++++++++++++++++++++++++++++ src/utils/mod.rs | 18 ++++++---- 5 files changed, 148 insertions(+), 27 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 48bf3a7d..095ac17d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "38a4400" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "be94d47" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" diff --git a/src/models/layers/linear.rs b/src/models/layers/linear.rs index d9ebc010..7de8aa9e 100644 --- a/src/models/layers/linear.rs +++ b/src/models/layers/linear.rs @@ -3,6 +3,7 @@ use super::wna16::WNA16; use crate::models::layers::VarBuilderX; use crate::utils::config::QuantConfig; use crate::utils::should_skip_fp8_for_module; +use crate::utils::should_skip_quant_for_module; use candle_core::quantized; use candle_core::quantized::GgmlDType; use candle_core::Module; @@ -525,6 +526,10 @@ pub fn linear_x( } if cfg.quant_method == "mxfp4" { + if should_skip_quant_for_module(&module_path, cfg) { + let ln = linear(in_dim, out_dim, vb.clone(), shards, dtype)?; + return Ok(LinearX::Linear(ln)); + } let ln = LnMxfp4::load(in_dim, out_dim, vb.clone(), shards, true)?; return Ok(LinearX::LnMxfp4(ln)); } @@ -608,6 +613,10 @@ pub fn linear_no_bias_x( } if cfg.quant_method == "mxfp4" { + if should_skip_quant_for_module(&module_path, cfg) { + let ln = linear_no_bias(in_dim, out_dim, vb.clone(), shards, dtype)?; + return Ok(LinearX::Linear(ln)); + } let ln = LnMxfp4::load(in_dim, out_dim, vb.clone(), shards, false)?; return Ok(LinearX::LnMxfp4(ln)); } @@ -1098,9 +1107,17 @@ impl LnMxfp4 { shard: Shard, load_bias: bool, ) -> Result { - let blocks = vb.get_with_hints_dtype((out_dim, in_dim / 2), "blocks", shard, DType::U8)?; - let scales = vb.get_with_hints_dtype((out_dim, in_dim / 32), "scales", shard, DType::U8)?; - let bias = if load_bias { + let blocks = if vb.contains_tensor("weight_packed") { + vb.get_with_hints_dtype((out_dim, in_dim / 2), "weight_packed", shard, DType::U8)? + } else { + vb.get_with_hints_dtype((out_dim, in_dim / 2), "blocks", shard, DType::U8)? + }; + let scales = if vb.contains_tensor("weight_scale") { + vb.get_with_hints_dtype((out_dim, in_dim / 32), "weight_scale", shard, DType::U8)? + } else { + vb.get_with_hints_dtype((out_dim, in_dim / 32), "scales", shard, DType::U8)? + }; + let bias = if load_bias && vb.contains_tensor("bias") { Some(vb.get((out_dim,), "bias")?) } else { None diff --git a/src/models/layers/moe.rs b/src/models/layers/moe.rs index 4adadafc..14ad31cf 100644 --- a/src/models/layers/moe.rs +++ b/src/models/layers/moe.rs @@ -1281,6 +1281,22 @@ pub struct FusedMoeMxfp4 { } impl FusedMoeMxfp4 { + fn mxfp4_tensor_name_packed(vb: &candle_nn::var_builder::ShardedVarBuilder) -> &'static str { + if vb.contains_tensor("weight_packed") { + "weight_packed" + } else { + "blocks" + } + } + + fn mxfp4_tensor_name_scale(vb: &candle_nn::var_builder::ShardedVarBuilder) -> &'static str { + if vb.contains_tensor("weight_scale") { + "weight_scale" + } else { + "scales" + } + } + pub fn new(cfg: &Config, vb: VarBuilderX, comm: Rc, dtype: DType) -> Result { let moe_cfg = cfg.moe_cfg.as_ref().expect("MoE config is not available!"); let num_experts = moe_cfg.num_experts.unwrap(); @@ -1290,7 +1306,7 @@ impl FusedMoeMxfp4 { num_experts, vb.pp("gate"), Shard::default(), - &None, + &cfg.quantization_config, &None, dtype, )?; @@ -1309,41 +1325,53 @@ impl FusedMoeMxfp4 { for i in 0..num_experts { let expert_vb = vb.pp(i.to_string()); - let gate_b = expert_vb.pp("gate_proj").get_with_hints_dtype( + let gate_proj_vb = expert_vb.pp("gate_proj"); + let packed_name = Self::mxfp4_tensor_name_packed(&gate_proj_vb); + let scale_name = Self::mxfp4_tensor_name_scale(&gate_proj_vb); + + let gate_b = gate_proj_vb.get_with_hints_dtype( (moe_cfg.moe_intermediate_size, cfg.hidden_size / 2), - "blocks", + packed_name, shard(0, comm.rank(), comm.world_size()), DType::U8, )?; - let gate_s = expert_vb.pp("gate_proj").get_with_hints_dtype( + let gate_s = gate_proj_vb.get_with_hints_dtype( (moe_cfg.moe_intermediate_size, cfg.hidden_size / 32), - "scales", + scale_name, shard(0, comm.rank(), comm.world_size()), DType::U8, )?; - let up_b = expert_vb.pp("up_proj").get_with_hints_dtype( + let up_proj_vb = expert_vb.pp("up_proj"); + let packed_name = Self::mxfp4_tensor_name_packed(&up_proj_vb); + let scale_name = Self::mxfp4_tensor_name_scale(&up_proj_vb); + + let up_b = up_proj_vb.get_with_hints_dtype( (moe_cfg.moe_intermediate_size, cfg.hidden_size / 2), - "blocks", + packed_name, shard(0, comm.rank(), comm.world_size()), DType::U8, )?; - let up_s = expert_vb.pp("up_proj").get_with_hints_dtype( + let up_s = up_proj_vb.get_with_hints_dtype( (moe_cfg.moe_intermediate_size, cfg.hidden_size / 32), - "scales", + scale_name, shard(0, comm.rank(), comm.world_size()), DType::U8, )?; - let down_b = expert_vb.pp("down_proj").get_with_hints_dtype( + let down_proj_vb = expert_vb.pp("down_proj"); + let packed_name = Self::mxfp4_tensor_name_packed(&down_proj_vb); + let scale_name = Self::mxfp4_tensor_name_scale(&down_proj_vb); + + let down_b = down_proj_vb.get_with_hints_dtype( (cfg.hidden_size, moe_cfg.moe_intermediate_size / 2), - "blocks", + packed_name, shard(1, comm.rank(), comm.world_size()), DType::U8, )?; - let down_s = expert_vb.pp("down_proj").get_with_hints_dtype( + let down_s = down_proj_vb.get_with_hints_dtype( (cfg.hidden_size, moe_cfg.moe_intermediate_size / 32), - "scales", + scale_name, shard(1, comm.rank(), comm.world_size()), DType::U8, )?; @@ -1427,17 +1455,17 @@ impl FusedMoeMxfp4 { .contiguous()?; let down_inputs = (up * gate.apply(&self.act)?)?; - let down_inputs_2d = down_inputs.reshape((num_tokens * self.num_experts_per_tok, ()))?; - let down = mxfp4_linear::mxfp4_moe_gemm( - &down_inputs_2d, + &down_inputs, &self.down_blocks, &self.down_scales, None, &topk_ids, )?; - let mut ys = (down * topk_weights.unsqueeze(D::Minus1)?)? + let topk_weights = topk_weights.to_dtype(down.dtype())?; + let mut ys = down + .broadcast_mul(&topk_weights.unsqueeze(D::Minus1)?)? .reshape((num_tokens, self.num_experts_per_tok, hidden_dim))? .sum(1)?; diff --git a/src/utils/config.rs b/src/utils/config.rs index 342627f1..ae5b6a81 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -697,9 +697,77 @@ pub struct QuantConfig { pub desc_act: Option, pub checkpoint_format: Option, pub fmt: Option, + #[serde(default)] + pub format: Option, pub weight_block_size: Option>, #[serde(default, alias = "ignore")] pub modules_to_not_convert: Vec, + #[serde(default)] + pub config_groups: Option, +} + +impl QuantConfig { + /// Normalizes a compressed-tensors config into a flat quant_method. + /// If `quant_method == "compressed-tensors"` and the `format` field (or a + /// `config_groups` entry) indicates `mxfp4-pack-quantized`, rewrites + /// `quant_method` to `"mxfp4"` and extracts `group_size` / `ignore` list. + pub fn normalize_compressed_tensors(&mut self) { + if self.quant_method != "compressed-tensors" { + return; + } + + let is_mxfp4 = self + .format + .as_deref() + .map(|f| f.contains("mxfp4")) + .unwrap_or(false) + || self.detect_mxfp4_from_config_groups(); + + if is_mxfp4 { + self.quant_method = "mxfp4".to_string(); + self.extract_compressed_tensors_params(); + } + } + + fn detect_mxfp4_from_config_groups(&self) -> bool { + let groups = match &self.config_groups { + Some(v) => v, + None => return false, + }; + if let Some(obj) = groups.as_object() { + for (_key, group) in obj { + if let Some(fmt) = group.get("format").and_then(|v| v.as_str()) { + if fmt.contains("mxfp4") { + return true; + } + } + } + } + false + } + + fn extract_compressed_tensors_params(&mut self) { + let groups = match &self.config_groups { + Some(v) => v.clone(), + None => return, + }; + if let Some(obj) = groups.as_object() { + for (_key, group) in obj { + if let Some(weights) = group.get("weights") { + if self.group_size == 0 { + if let Some(gs) = weights.get("group_size").and_then(|v| v.as_i64()) { + self.group_size = gs as i32; + } + } + if self.bits == 0 { + if let Some(nb) = weights.get("num_bits").and_then(|v| v.as_u64()) { + self.bits = nb as usize; + } + } + } + } + } + } } impl fmt::Debug for QuantConfig { @@ -712,7 +780,9 @@ impl fmt::Debug for QuantConfig { .field("desc_act", &self.desc_act) .field("checkpoint_format", &self.checkpoint_format) .field("fmt", &self.fmt) + .field("format", &self.format) .field("weight_block_size", &self.weight_block_size) + .field("modules_to_not_convert", &self.modules_to_not_convert) .finish() } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index c722e372..36e9cd74 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -78,6 +78,10 @@ pub fn should_skip_fp8_for_module(module_path: &str, cfg: &QuantConfig) -> bool .any(|item| module_path_matches_not_convert(module_path, item)) } +pub fn should_skip_quant_for_module(module_path: &str, cfg: &QuantConfig) -> bool { + should_skip_fp8_for_module(module_path, cfg) +} + pub fn hub_load_local_safetensors(path: &String, json_file: &str) -> Result> { crate::log_info!("{:}", Path::new(path).join(json_file).display()); let jsfile = std::fs::File::open(Path::new(path).join(json_file))?; @@ -698,10 +702,10 @@ fn merge_multimodal_top_level_config( ) -> Result<()> { if let Some(qcfg) = raw_root.get("quantization_config") { if !qcfg.is_null() { - config.quantization_config = Some( - serde_json::from_value::(qcfg.clone()) - .map_err(candle_core::Error::wrap)?, - ); + let mut parsed = serde_json::from_value::(qcfg.clone()) + .map_err(candle_core::Error::wrap)?; + parsed.normalize_compressed_tensors(); + config.quantization_config = Some(parsed); } } @@ -992,13 +996,15 @@ pub fn init_config_tokenizer( } } - if let Some(qcfg) = &config.quantization_config { + if let Some(qcfg) = &mut config.quantization_config { + qcfg.normalize_compressed_tensors(); assert!( qcfg.quant_method == "gptq" || qcfg.quant_method == "awq" || qcfg.quant_method == "fp8" || qcfg.quant_method == "mxfp4", - "Invalid quantization format! Only `gptq`, `awq`, `fp8` and `mxfp4` supported" + "Invalid quantization format! Only `gptq`, `awq`, `fp8` and `mxfp4` supported, got `{}`", + qcfg.quant_method ); if qcfg.quant_method == "gptq" || qcfg.quant_method == "awq" { assert!( From f14cc948486db73401268225b631fd1a8f75b148 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Thu, 2 Apr 2026 10:43:01 +0000 Subject: [PATCH 03/14] Fail fast if server port not available --- src/main.rs | 11 +++++++++-- src/utils/mod.rs | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index 4601e45b..69e36416 100644 --- a/src/main.rs +++ b/src/main.rs @@ -214,11 +214,18 @@ async fn main() -> Result<()> { args.yarn_scaling_factor, ); - let engine = LLMEngine::new(&econfig, dtype)?; - if args.server || args.ui_server || args.pd_server { + let server_port = if args.server || args.ui_server || args.pd_server { let port = args .port .unwrap_or(if args.pd_server { 7000 } else { 8000 }); + vllm_rs::utils::ensure_port_free("0.0.0.0", port as u16); + Some(port) + } else { + None + }; + + let engine = LLMEngine::new(&econfig, dtype)?; + if let Some(port) = server_port { run_server(engine.clone(), econfig.clone(), port, args.ui_server).await?; return Ok(()); } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 36e9cd74..ce99d197 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1862,3 +1862,20 @@ mod tests { assert!(!is_rope_i); } } + +/// Fail fast if `host:port` is already bound, before spending minutes loading +/// model weights. Prints a user-friendly error and exits the process when the +/// port is occupied. +pub fn ensure_port_free(host: &str, port: u16) { + let addr = format!("{host}:{port}"); + match std::net::TcpListener::bind(&addr) { + Ok(_listener) => { /* port is free; drop the listener immediately */ } + Err(e) => { + eprintln!( + "\n❌ Port {port} is already in use ({e}).\n \ + Free the port or choose a different one with --port .\n" + ); + std::process::exit(1); + } + } +} From 545d7f7c38b48c4e63c864168b4a8537249f2159 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Thu, 2 Apr 2026 11:44:44 +0000 Subject: [PATCH 04/14] Optimize decoding speed --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 095ac17d..0154d540 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "be94d47" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "0beb7b3" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" From a91761ee573ec499f4146c46bf706217ff5316bc Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Thu, 2 Apr 2026 23:37:42 +0800 Subject: [PATCH 05/14] Fix build on V100 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0154d540..7ea9a3e8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "0beb7b3" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "4920116" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" From c08b4f406f9ae012d357beb1858227fccbc4f16b Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 3 Apr 2026 06:34:48 +0000 Subject: [PATCH 06/14] Support NVFP4 & fix mxfp4 model loading --- Cargo.toml | 2 +- src/models/layers/deltanet.rs | 32 +- src/models/layers/linear.rs | 107 +++++++ src/models/layers/moe.rs | 241 +++++++++++++++ src/models/qwen3_5_moe.rs | 15 +- src/models/qwen3_moe.rs | 15 +- src/models/qwen3_vl/mod.rs | 5 +- src/utils/config.rs | 541 +++++++++++++++++++++++++++++++++- src/utils/mod.rs | 36 +-- 9 files changed, 940 insertions(+), 54 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7ea9a3e8..a780f2a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.4.3", rev = "4920116" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.5.0", rev = "f21d557" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" diff --git a/src/models/layers/deltanet.rs b/src/models/layers/deltanet.rs index 3404b85f..8a74859f 100644 --- a/src/models/layers/deltanet.rs +++ b/src/models/layers/deltanet.rs @@ -75,9 +75,9 @@ impl GatedDeltaNet { comm: Rc, config: &Config, dtype: DType, - is_fp8_quant: bool, + is_quantized: bool, ) -> Result { - let (quantization_config, quant) = if is_fp8_quant { + let (quantization_config, quant) = if is_quantized { (config.quantization_config.clone(), config.quant.clone()) } else { (None, None) @@ -110,8 +110,8 @@ impl GatedDeltaNet { false, vb.pp("in_proj_ba"), comm.clone(), - &None, - &None, + &quantization_config, + &quant, dtype, ); @@ -171,8 +171,8 @@ impl GatedDeltaNet { false, vb.pp(projection_key_map["in_proj_b"]), comm.clone(), - &None, - &None, + &quantization_config, + &quant, dtype, ) }; @@ -193,8 +193,8 @@ impl GatedDeltaNet { false, vb.pp(projection_key_map["in_proj_a"]), comm.clone(), - &None, - &None, + &quantization_config, + &quant, dtype, ) }; @@ -242,9 +242,9 @@ impl GatedDeltaNet { }); } Err(err) => { - if is_fp8_quant { + if is_quantized { candle_core::bail!( - "Unable to load TP-safe FP8 Qwen3.5 split in_proj_qkv: {}", + "Unable to load TP-safe quantized Qwen3.5 split in_proj_qkv: {}", err ); } @@ -435,11 +435,7 @@ impl GatedDeltaNet { ); } - let is_fp8_quant = if let Some(cfg) = config.quantization_config.as_ref() { - cfg.quant_method == "fp8" - } else { - false - }; + let is_quantized = config.quantization_config.is_some(); let kernel_dtype = if vb.is_qvar_builder() { DType::F32 } else { @@ -507,7 +503,7 @@ impl GatedDeltaNet { comm.clone(), config, dtype, - is_fp8_quant, + is_quantized, )?; // Conv1D weights are stored global; slice rank-local q/k/v channel blocks. @@ -584,12 +580,12 @@ impl GatedDeltaNet { hidden_size, vb.pp(gdn_key_map["out_proj"]), comm.clone(), - if is_fp8_quant { + if is_quantized { &config.quantization_config } else { &None }, - if is_fp8_quant { &config.quant } else { &None }, + if is_quantized { &config.quant } else { &None }, dtype, )? }; diff --git a/src/models/layers/linear.rs b/src/models/layers/linear.rs index 7de8aa9e..5e3f8447 100644 --- a/src/models/layers/linear.rs +++ b/src/models/layers/linear.rs @@ -434,6 +434,7 @@ pub enum LinearX { QLinear(QLinear), LnFp8(LnFp8), LnMxfp4(LnMxfp4), + LnNvfp4(LnNvfp4), } impl Module for LinearX { @@ -442,6 +443,7 @@ impl Module for LinearX { Self::Linear(ln) => ln.forward(x), Self::QLinear(ln) => ln.forward(x), Self::LnMxfp4(ln) => ln.forward(x), + Self::LnNvfp4(ln) => ln.forward(x), Self::LnFp8(ln) => ln.forward(x), } } @@ -456,6 +458,7 @@ impl LinearX { Self::QLinear(ln) => ln.indexed_moe_forward(x, ids), Self::LnFp8(_) => panic!("LnFp8 does not support indexed_moe_forward yet"), Self::LnMxfp4(_) => panic!("LnMxfp4 does not support indexed_moe_forward yet"), + Self::LnNvfp4(_) => panic!("LnNvfp4 does not support indexed_moe_forward yet"), } } } @@ -483,6 +486,7 @@ impl LinearX { Self::QLinear(ln) => ln.dequantize(), Self::LnFp8(_) => panic!("LnFp8 unable to be dequantized"), Self::LnMxfp4(_) => panic!("LnMxfp4 unable to be dequantized"), + Self::LnNvfp4(_) => panic!("LnNvfp4 unable to be dequantized"), } } } @@ -534,6 +538,15 @@ pub fn linear_x( return Ok(LinearX::LnMxfp4(ln)); } + if cfg.quant_method == "nvfp4" { + if should_skip_quant_for_module(&module_path, cfg) { + let ln = linear(in_dim, out_dim, vb.clone(), shards, dtype)?; + return Ok(LinearX::Linear(ln)); + } + let ln = LnNvfp4::load(in_dim, out_dim, vb.clone(), shards, true)?; + return Ok(LinearX::LnNvfp4(ln)); + } + let wna16 = WNA16::new( in_dim, out_dim, @@ -621,6 +634,15 @@ pub fn linear_no_bias_x( return Ok(LinearX::LnMxfp4(ln)); } + if cfg.quant_method == "nvfp4" { + if should_skip_quant_for_module(&module_path, cfg) { + let ln = linear_no_bias(in_dim, out_dim, vb.clone(), shards, dtype)?; + return Ok(LinearX::Linear(ln)); + } + let ln = LnNvfp4::load(in_dim, out_dim, vb.clone(), shards, false)?; + return Ok(LinearX::LnNvfp4(ln)); + } + let wna16 = WNA16::new( in_dim, out_dim, @@ -1157,3 +1179,88 @@ impl Module for LnMxfp4 { } } } + +/// NVFP4 linear layer: packed FP4 E2M1 weights with FP8 E4M3 block scales + F32 global scale. +#[derive(Debug, Clone)] +pub struct LnNvfp4 { + pub blocks: Tensor, + pub scales: Tensor, + pub global_scale: f32, + pub bias: Option, +} + +impl LnNvfp4 { + pub fn load( + in_dim: usize, + out_dim: usize, + vb: VarBuilder, + shard: Shard, + load_bias: bool, + ) -> Result { + let blocks = if vb.contains_tensor("weight_packed") { + vb.get_with_hints_dtype((out_dim, in_dim / 2), "weight_packed", shard, DType::U8)? + } else if vb.contains_tensor("weight") { + vb.get_with_hints_dtype((out_dim, in_dim / 2), "weight", shard, DType::U8)? + } else { + vb.get_with_hints_dtype((out_dim, in_dim / 2), "blocks", shard, DType::U8)? + }; + + let scale_dim = in_dim / 16; + let scales = if vb.contains_tensor("weight_scale") { + vb.get_with_hints_dtype((out_dim, scale_dim), "weight_scale", shard, DType::U8)? + } else { + vb.get_with_hints_dtype((out_dim, scale_dim), "scales", shard, DType::U8)? + }; + + let global_scale = if vb.contains_tensor("weight_scale_2") { + let t = match vb.get_with_hints_dtype((1,), "weight_scale_2", shard, DType::F32) { + Ok(t) => t, + Err(_) => vb.get_with_hints_dtype((), "weight_scale_2", shard, DType::F32)?, + }; + t.flatten_all()?.to_vec1::()?[0] + } else { + 1.0f32 + }; + + let bias = if load_bias && vb.contains_tensor("bias") { + Some(vb.get((out_dim,), "bias")?) + } else { + None + }; + Ok(Self { + blocks, + scales, + global_scale, + bias, + }) + } +} + +impl Module for LnNvfp4 { + fn forward(&self, x: &Tensor) -> Result { + let orig_dims = x.dims().to_vec(); + let x_2d = if orig_dims.len() > 2 { + let features = orig_dims[orig_dims.len() - 1]; + let batch_size: usize = orig_dims[..orig_dims.len() - 1].iter().product(); + x.reshape((batch_size, features))? + } else { + x.clone() + }; + + let result = attention_rs::nvfp4_linear::nvfp4_matmul( + &x_2d, + &self.blocks, + &self.scales, + self.global_scale, + self.bias.as_ref(), + )?; + + if orig_dims.len() > 2 { + let mut new_dims = orig_dims[..orig_dims.len() - 1].to_vec(); + new_dims.push(result.dim(1)?); + result.reshape(new_dims) + } else { + Ok(result) + } + } +} diff --git a/src/models/layers/moe.rs b/src/models/layers/moe.rs index 14ad31cf..ea74c572 100644 --- a/src/models/layers/moe.rs +++ b/src/models/layers/moe.rs @@ -7,6 +7,7 @@ use crate::utils::config::QuantConfig; use attention_rs::moe; use attention_rs::moe::moe_gemm_fp8; use attention_rs::mxfp4_linear; +use attention_rs::nvfp4_linear; use candle_core::Module; use candle_core::{ quantized::{GgmlDType, QTensor}, @@ -1475,3 +1476,243 @@ impl FusedMoeMxfp4 { Ok(ys.to_dtype(self.dtype)?) } } +pub struct FusedMoeNvfp4 { + gate: Linear, + gate_up_blocks: Tensor, + gate_up_scales: Tensor, + gate_up_global_scales: Tensor, + down_blocks: Tensor, + down_scales: Tensor, + down_global_scales: Tensor, + w_size_n: usize, + act: candle_nn::Activation, + norm_topk_prob: bool, + routed_scaling_factor: Option, + num_experts_per_tok: usize, + all_reduce: AllReduce, + world_size: usize, + dtype: DType, +} + +impl FusedMoeNvfp4 { + fn tensor_name_packed(vb: &candle_nn::var_builder::ShardedVarBuilder) -> &'static str { + if vb.contains_tensor("weight_packed") { + "weight_packed" + } else if vb.contains_tensor("weight") { + "weight" + } else { + "blocks" + } + } + + fn tensor_name_scale(vb: &candle_nn::var_builder::ShardedVarBuilder) -> &'static str { + if vb.contains_tensor("weight_scale") { + "weight_scale" + } else { + "scales" + } + } + + fn load_global_scale(vb: &candle_nn::var_builder::ShardedVarBuilder, shard: Shard) -> f32 { + if vb.contains_tensor("weight_scale_2") { + vb.get_with_hints_dtype((1,), "weight_scale_2", shard, DType::F32) + .or_else(|_| vb.get_with_hints_dtype((), "weight_scale_2", shard, DType::F32)) + .and_then(|t| t.flatten_all()?.to_vec1::().map(|v| v[0])) + .unwrap_or(1.0) + } else { + 1.0 + } + } + + pub fn new(cfg: &Config, vb: VarBuilderX, comm: Rc, dtype: DType) -> Result { + let moe_cfg = cfg.moe_cfg.as_ref().expect("MoE config is not available!"); + let num_experts = moe_cfg.num_experts.unwrap(); + + let gate = linear_no_bias( + cfg.hidden_size, + num_experts, + vb.pp("gate"), + Shard::default(), + &cfg.quantization_config, + &None, + dtype, + )?; + + let experts_vb = vb.pp("experts"); + + let mut gate_blocks_vec = Vec::new(); + let mut gate_scales_vec = Vec::new(); + let mut gate_gscales_vec: Vec = Vec::new(); + let mut up_blocks_vec = Vec::new(); + let mut up_scales_vec = Vec::new(); + let mut up_gscales_vec: Vec = Vec::new(); + let mut down_blocks_vec = Vec::new(); + let mut down_scales_vec = Vec::new(); + let mut down_gscales_vec: Vec = Vec::new(); + + match &experts_vb.0 { + Either::Left(vb) => { + for i in 0..num_experts { + let expert_vb = vb.pp(i.to_string()); + + let gate_proj_vb = expert_vb.pp("gate_proj"); + let packed_name = Self::tensor_name_packed(&gate_proj_vb); + let scale_name = Self::tensor_name_scale(&gate_proj_vb); + let sh0 = shard(0, comm.rank(), comm.world_size()); + + gate_blocks_vec.push(gate_proj_vb.get_with_hints_dtype( + (moe_cfg.moe_intermediate_size, cfg.hidden_size / 2), + packed_name, + sh0, + DType::U8, + )?); + gate_scales_vec.push(gate_proj_vb.get_with_hints_dtype( + (moe_cfg.moe_intermediate_size, cfg.hidden_size / 16), + scale_name, + sh0, + DType::U8, + )?); + gate_gscales_vec.push(Self::load_global_scale(&gate_proj_vb, sh0)); + + let up_proj_vb = expert_vb.pp("up_proj"); + let packed_name = Self::tensor_name_packed(&up_proj_vb); + let scale_name = Self::tensor_name_scale(&up_proj_vb); + + up_blocks_vec.push(up_proj_vb.get_with_hints_dtype( + (moe_cfg.moe_intermediate_size, cfg.hidden_size / 2), + packed_name, + sh0, + DType::U8, + )?); + up_scales_vec.push(up_proj_vb.get_with_hints_dtype( + (moe_cfg.moe_intermediate_size, cfg.hidden_size / 16), + scale_name, + sh0, + DType::U8, + )?); + up_gscales_vec.push(Self::load_global_scale(&up_proj_vb, sh0)); + + let down_proj_vb = expert_vb.pp("down_proj"); + let packed_name = Self::tensor_name_packed(&down_proj_vb); + let scale_name = Self::tensor_name_scale(&down_proj_vb); + let sh1 = shard(1, comm.rank(), comm.world_size()); + + down_blocks_vec.push(down_proj_vb.get_with_hints_dtype( + (cfg.hidden_size, moe_cfg.moe_intermediate_size / 2), + packed_name, + sh1, + DType::U8, + )?); + down_scales_vec.push(down_proj_vb.get_with_hints_dtype( + (cfg.hidden_size, moe_cfg.moe_intermediate_size / 16), + scale_name, + sh1, + DType::U8, + )?); + down_gscales_vec.push(Self::load_global_scale(&down_proj_vb, sh1)); + } + } + _ => candle_core::bail!("FusedMoeNvfp4: GGUF loading not supported for NVFP4"), + } + + let gate_blocks = Tensor::stack(&gate_blocks_vec, 0)?; + let gate_scales = Tensor::stack(&gate_scales_vec, 0)?; + let up_blocks = Tensor::stack(&up_blocks_vec, 0)?; + let up_scales = Tensor::stack(&up_scales_vec, 0)?; + + let gate_up_blocks = Tensor::cat(&[&gate_blocks, &up_blocks], 1)?; + let gate_up_scales = Tensor::cat(&[&gate_scales, &up_scales], 1)?; + let w_size_n = gate_up_blocks.dim(1)? / 2; + + let dev = gate_up_blocks.device(); + let gate_up_gscales: Vec = gate_gscales_vec + .iter() + .zip(up_gscales_vec.iter()) + .map(|(g, u)| (g + u) / 2.0) + .collect(); + let gate_up_global_scales = Tensor::from_vec(gate_up_gscales, (num_experts,), dev)?; + + let down_blocks = Tensor::stack(&down_blocks_vec, 0)?; + let down_scales = Tensor::stack(&down_scales_vec, 0)?; + let down_global_scales = Tensor::from_vec(down_gscales_vec, (num_experts,), dev)?; + + Ok(Self { + gate, + gate_up_blocks, + gate_up_scales, + gate_up_global_scales, + down_blocks, + down_scales, + down_global_scales, + w_size_n, + act: candle_nn::Activation::Silu, + norm_topk_prob: moe_cfg.norm_topk_prob, + routed_scaling_factor: moe_cfg.routed_scaling_factor, + num_experts_per_tok: moe_cfg.num_experts_per_tok, + all_reduce: AllReduce::new(comm.clone()), + world_size: comm.world_size(), + dtype, + }) + } + + pub fn forward(&self, xs: &Tensor, _is_prefill: bool) -> Result { + let (num_tokens, hidden_dim) = xs.dims2()?; + let router_logits = self.gate.forward(xs)?; + + let (mut topk_weights, topk_ids) = attention_rs::topk::topk_softmax( + &router_logits.to_dtype(DType::F32)?, + self.num_experts_per_tok, + )?; + + if self.norm_topk_prob { + topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?; + } + + if let Some(routed_scaling_factor) = self.routed_scaling_factor { + topk_weights = (topk_weights * routed_scaling_factor)?; + } + + let xs = if xs.dtype() == DType::F32 { + xs.to_dtype(self.dtype)? + } else { + xs.clone() + }; + + let gate_up = nvfp4_linear::nvfp4_moe_gemm( + &xs, + &self.gate_up_blocks, + &self.gate_up_scales, + &self.gate_up_global_scales, + None, + &topk_ids, + )?; + + let gate = gate_up + .narrow(candle_core::D::Minus1, 0, self.w_size_n)? + .contiguous()?; + let up = gate_up + .narrow(candle_core::D::Minus1, self.w_size_n, self.w_size_n)? + .contiguous()?; + let down_inputs = (up * gate.apply(&self.act)?)?; + + let down = nvfp4_linear::nvfp4_moe_gemm( + &down_inputs, + &self.down_blocks, + &self.down_scales, + &self.down_global_scales, + None, + &topk_ids, + )?; + + let topk_weights = topk_weights.to_dtype(down.dtype())?; + let mut ys = down + .broadcast_mul(&topk_weights.unsqueeze(D::Minus1)?)? + .reshape((num_tokens, self.num_experts_per_tok, hidden_dim))? + .sum(1)?; + + if self.world_size > 1 { + ys = self.all_reduce.apply(&ys)?; + } + Ok(ys.to_dtype(self.dtype)?) + } +} diff --git a/src/models/qwen3_5_moe.rs b/src/models/qwen3_5_moe.rs index 572a469a..8154c8b6 100644 --- a/src/models/qwen3_5_moe.rs +++ b/src/models/qwen3_5_moe.rs @@ -6,7 +6,9 @@ use crate::models::layers::distributed::{Comm, ReplicatedLinear}; use crate::models::layers::linear::LinearX as Linear; use crate::models::layers::mask::get_attention_causal_mask; use crate::models::layers::mlp::MLP; -use crate::models::layers::moe::{FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ, FusedMoeMxfp4}; +use crate::models::layers::moe::{ + FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ, FusedMoeMxfp4, FusedMoeNvfp4, +}; use crate::models::layers::others::{embedding, rms_norm, NormX}; use crate::models::layers::rotary_emb::{ApplyRotaryEmbedding, ScalingRotaryEmbedding}; use crate::models::layers::VarBuilderX; @@ -33,6 +35,7 @@ enum MoeOrMlp { FusedMoeISQ(FusedMoeISQ), FusedMoeFp8(FusedMoeFp8), FusedMoeMxfp4(FusedMoeMxfp4), + FusedMoeNvfp4(FusedMoeNvfp4), } impl MoeOrMlp { @@ -43,6 +46,7 @@ impl MoeOrMlp { Self::FusedMoeISQ(m) => m.forward(xs, is_prefill), Self::FusedMoeFp8(m) => m.forward(xs, is_prefill), Self::FusedMoeMxfp4(m) => m.forward(xs, is_prefill), + Self::FusedMoeNvfp4(m) => m.forward(xs, is_prefill), } } } @@ -135,8 +139,15 @@ impl Qwen3_5MoEDecoderLayer { comm.clone(), dtype, )?) + } else if quant_config.quant_method == "nvfp4" { + MoeOrMlp::FusedMoeNvfp4(FusedMoeNvfp4::new( + config, + vb.pp("mlp").clone(), + comm.clone(), + dtype, + )?) } else { - panic!("Unsupported quant method for MoE (use unquantized, gguf, fp8 or mxfp4)!"); + panic!("Unsupported quant method for MoE (use unquantized, gguf, fp8, mxfp4 or nvfp4)!"); } } else if config.quant.is_some() { MoeOrMlp::FusedMoeISQ(FusedMoeISQ::new( diff --git a/src/models/qwen3_moe.rs b/src/models/qwen3_moe.rs index 3d59ff4d..b855c9ad 100644 --- a/src/models/qwen3_moe.rs +++ b/src/models/qwen3_moe.rs @@ -4,7 +4,9 @@ use crate::models::layers::distributed::{Comm, ReplicatedLinear}; use crate::models::layers::linear::LinearX as Linear; use crate::models::layers::mask::get_attention_causal_mask; use crate::models::layers::mlp::MLP; -use crate::models::layers::moe::{FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ, FusedMoeMxfp4}; +use crate::models::layers::moe::{ + FusedMoe, FusedMoeFp8, FusedMoeGGUF, FusedMoeISQ, FusedMoeMxfp4, FusedMoeNvfp4, +}; use crate::models::layers::others::{embedding, rms_norm, NormX}; use crate::models::layers::rotary_emb::{ApplyRotaryEmbedding, ScalingRotaryEmbedding}; use crate::models::layers::VarBuilderX; @@ -26,6 +28,7 @@ enum MoeOrMlp { FusedMoeISQ(FusedMoeISQ), FusedMoeFp8(FusedMoeFp8), FusedMoeMxfp4(FusedMoeMxfp4), + FusedMoeNvfp4(FusedMoeNvfp4), Mlp(MLP), } @@ -38,6 +41,7 @@ impl MoeOrMlp { Self::FusedMoeISQ(m) => m.forward(xs, is_prefill), Self::FusedMoeFp8(m) => m.forward(xs, is_prefill), Self::FusedMoeMxfp4(m) => m.forward(xs, is_prefill), + Self::FusedMoeNvfp4(m) => m.forward(xs, is_prefill), } } } @@ -107,8 +111,15 @@ impl Qwen3DecoderLayer { comm.clone(), dtype, )?) + } else if quant_config.quant_method == "nvfp4" { + MoeOrMlp::FusedMoeNvfp4(FusedMoeNvfp4::new( + config, + vb.pp("mlp").clone(), + comm.clone(), + dtype, + )?) } else { - panic!("This feature is under developement (use unquantized, gguf, fp8 or mxfp4 instead)!"); + panic!("This feature is under developement (use unquantized, gguf, fp8, mxfp4 or nvfp4 instead)!"); } } else if config.quant.is_some() { MoeOrMlp::FusedMoeISQ(FusedMoeISQ::new( diff --git a/src/models/qwen3_vl/mod.rs b/src/models/qwen3_vl/mod.rs index 91bac98f..ac7a5689 100644 --- a/src/models/qwen3_vl/mod.rs +++ b/src/models/qwen3_vl/mod.rs @@ -69,8 +69,9 @@ impl Qwen3VLForConditionalGeneration { let mut vision_end_token_id = 0; if let Some(cfg) = try_parse_multimodal_extra_config(config)? { - if cfg.quantization_config.is_some() { - config_text.quantization_config = cfg.quantization_config.clone(); + if let Some(mut qcfg) = cfg.quantization_config.clone() { + qcfg.normalize_compressed_tensors(); + config_text.quantization_config = Some(qcfg); } spatial_merge_size = cfg.vision_config.spatial_merge_size; diff --git a/src/utils/config.rs b/src/utils/config.rs index ae5b6a81..8af2aa29 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -686,6 +686,34 @@ pub struct GenerationConfig { pub eos_token_id: Option, } +/// Match a module path against an ignore pattern. +/// Supports three pattern types: +/// - `re:` prefix → regex match +/// - Contains `*` → glob match (converted to regex: `*` becomes `.*`) +/// - Otherwise → literal suffix matching +pub fn match_ignore_pattern(module_path: &str, pattern: &str) -> bool { + if let Some(re_pat) = pattern.strip_prefix("re:") { + if let Ok(re) = regex::Regex::new(re_pat) { + return re.is_match(module_path); + } + return false; + } + if pattern.contains('*') { + let re_pat = format!("^{}$", regex::escape(pattern).replace(r"\*", ".*")); + if let Ok(re) = regex::Regex::new(&re_pat) { + return re.is_match(module_path); + } + return false; + } + let module_path = module_path.trim_end_matches(".weight"); + let item = pattern.trim_end_matches(".weight"); + module_path == item + || module_path.ends_with(item) + || module_path.ends_with(&format!(".{item}")) + || item.ends_with(module_path) + || item.ends_with(&format!(".{module_path}")) +} + #[derive(Serialize, Deserialize, PartialEq, Clone)] pub struct QuantConfig { pub quant_method: String, @@ -704,14 +732,46 @@ pub struct QuantConfig { pub modules_to_not_convert: Vec, #[serde(default)] pub config_groups: Option, + #[serde(default)] + pub quant_algo: Option, } impl QuantConfig { - /// Normalizes a compressed-tensors config into a flat quant_method. - /// If `quant_method == "compressed-tensors"` and the `format` field (or a - /// `config_groups` entry) indicates `mxfp4-pack-quantized`, rewrites - /// `quant_method` to `"mxfp4"` and extracts `group_size` / `ignore` list. + /// Normalizes a quantization config into a canonical quant_method string. + /// + /// Handles two families: + /// 1. `compressed-tensors` with `format` containing `mxfp4` → `"mxfp4"` + /// 2. `modelopt` with `quant_algo` == `NVFP4` → `"nvfp4"` + /// + /// Also extracts group_size / bits from config_groups when present. pub fn normalize_compressed_tensors(&mut self) { + if self.quant_method == "modelopt" { + if let Some(algo) = &self.quant_algo { + if algo.eq_ignore_ascii_case("NVFP4") || algo.eq_ignore_ascii_case("FP4") { + self.quant_method = "nvfp4".to_string(); + self.extract_compressed_tensors_params(); + if self.group_size == 0 { + self.group_size = 16; + } + if self.bits == 0 { + self.bits = 4; + } + return; + } + } + if self.detect_nvfp4_from_config_groups() { + self.quant_method = "nvfp4".to_string(); + self.extract_compressed_tensors_params(); + if self.group_size == 0 { + self.group_size = 16; + } + if self.bits == 0 { + self.bits = 4; + } + return; + } + } + if self.quant_method != "compressed-tensors" { return; } @@ -729,6 +789,36 @@ impl QuantConfig { } } + fn detect_nvfp4_from_config_groups(&self) -> bool { + let groups = match &self.config_groups { + Some(v) => v, + None => return false, + }; + if let Some(obj) = groups.as_object() { + for (_key, group) in obj { + if let Some(weights) = group.get("weights") { + if let Some(num_bits) = weights.get("num_bits").and_then(|v| v.as_u64()) { + if num_bits == 4 { + let is_float = weights + .get("type") + .and_then(|v| v.as_str()) + .map(|t| t == "float") + .unwrap_or(false); + let gs = weights + .get("group_size") + .and_then(|v| v.as_u64()) + .unwrap_or(0); + if is_float && gs == 16 { + return true; + } + } + } + } + } + } + false + } + fn detect_mxfp4_from_config_groups(&self) -> bool { let groups = match &self.config_groups { Some(v) => v, @@ -741,6 +831,13 @@ impl QuantConfig { return true; } } + if let Some(weights) = group.get("weights") { + if let Some(fmt) = weights.get("format").and_then(|v| v.as_str()) { + if fmt.contains("mxfp4") { + return true; + } + } + } } } false @@ -768,6 +865,18 @@ impl QuantConfig { } } } + + /// Check if a module path should be skipped for this quantization config. + /// Supports literal paths and `re:` prefixed regex patterns in + /// `modules_to_not_convert` / `ignore`. + pub fn should_skip_module(&self, module_path: &str) -> bool { + if module_path.is_empty() || self.modules_to_not_convert.is_empty() { + return false; + } + self.modules_to_not_convert + .iter() + .any(|item| match_ignore_pattern(module_path, item)) + } } impl fmt::Debug for QuantConfig { @@ -786,3 +895,427 @@ impl fmt::Debug for QuantConfig { .finish() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_match_ignore_literal_exact() { + assert!(match_ignore_pattern("lm_head", "lm_head")); + assert!(match_ignore_pattern( + "model.layers.0.self_attn.q_proj", + "model.layers.0.self_attn.q_proj" + )); + } + + #[test] + fn test_match_ignore_literal_suffix() { + assert!(match_ignore_pattern( + "model.language_model.layers.0.linear_attn.out_proj", + "model.language_model.layers.0.linear_attn.out_proj" + )); + assert!(match_ignore_pattern("model.lm_head.weight", "lm_head")); + } + + #[test] + fn test_match_ignore_regex() { + assert!(match_ignore_pattern( + "model.layers.5.self_attn.q_proj", + "re:.*self_attn.*" + )); + assert!(match_ignore_pattern( + "model.layers.10.linear_attn.in_proj_qkv", + "re:.*linear_attn.*" + )); + assert!(match_ignore_pattern( + "model.layers.3.mlp.gate", + "re:.*.mlp.gate$" + )); + assert!(!match_ignore_pattern( + "model.layers.3.mlp.gate_proj", + "re:.*.mlp.gate$" + )); + assert!(match_ignore_pattern( + "model.visual.blocks.0.attn.qkv", + "re:.*visual.*" + )); + assert!(match_ignore_pattern("mtp.fc", "re:.*mtp.*")); + assert!(match_ignore_pattern( + "model.embed_tokens", + "re:.*embed_tokens.*" + )); + } + + #[test] + fn test_match_ignore_regex_no_false_positive() { + assert!(!match_ignore_pattern( + "model.layers.5.mlp.up_proj", + "re:.*self_attn.*" + )); + assert!(!match_ignore_pattern( + "model.layers.5.mlp.up_proj", + "re:.*linear_attn.*" + )); + assert!(!match_ignore_pattern( + "model.layers.5.mlp.up_proj", + "re:.*lm_head.*" + )); + } + + #[test] + fn test_should_skip_module() { + let cfg = QuantConfig { + quant_method: "mxfp4".to_string(), + bits: 4, + group_size: 32, + sym: None, + desc_act: None, + checkpoint_format: None, + fmt: None, + format: Some("mxfp4-pack-quantized".to_string()), + weight_block_size: None, + modules_to_not_convert: vec![ + "re:.*self_attn.*".to_string(), + "re:.*linear_attn.*".to_string(), + "re:.*.mlp.gate$".to_string(), + "re:.*lm_head.*".to_string(), + "re:.*embed_tokens.*".to_string(), + "re:.*visual.*".to_string(), + "re:.*mtp.*".to_string(), + ], + config_groups: None, + quant_algo: None, + }; + assert!(cfg.should_skip_module("model.layers.0.self_attn.q_proj")); + assert!(cfg.should_skip_module("model.layers.5.linear_attn.out_proj")); + assert!(cfg.should_skip_module("model.layers.3.mlp.gate")); + assert!(!cfg.should_skip_module("model.layers.3.mlp.gate_proj")); + assert!(!cfg.should_skip_module("model.layers.3.mlp.up_proj")); + assert!(!cfg.should_skip_module("model.layers.3.mlp.down_proj")); + assert!(cfg.should_skip_module("lm_head")); + assert!(cfg.should_skip_module("model.visual.blocks.0.attn.qkv")); + assert!(cfg.should_skip_module("mtp.fc")); + } + + #[test] + fn test_normalize_compressed_tensors_regex_ignore() { + let json = r#"{ + "quant_method": "compressed-tensors", + "format": "mxfp4-pack-quantized", + "config_groups": { + "group_0": { + "format": "mxfp4-pack-quantized", + "weights": {"num_bits": 4, "group_size": 32, "strategy": "group", "symmetric": true} + } + }, + "ignore": [ + "re:.*self_attn.*", + "re:.*linear_attn.*", + "re:.*.mlp.gate$", + "re:.*lm_head.*", + "re:.*embed_tokens.*", + "re:.*visual.*", + "re:.*mtp.*" + ] + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "mxfp4"); + assert_eq!(cfg.group_size, 32); + assert_eq!(cfg.bits, 4); + assert_eq!(cfg.modules_to_not_convert.len(), 7); + assert!(cfg.should_skip_module("model.layers.5.self_attn.q_proj")); + assert!(!cfg.should_skip_module("model.layers.5.mlp.up_proj")); + } + + #[test] + fn test_normalize_compressed_tensors_literal_ignore() { + let json = r#"{ + "quant_method": "compressed-tensors", + "format": "mxfp4-pack-quantized", + "config_groups": { + "group_0": { + "weights": {"num_bits": 4, "group_size": 32} + } + }, + "ignore": [ + "model.layers.0.linear_attn.out_proj", + "model.layers.0.linear_attn.in_proj_qkv", + "lm_head" + ] + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "mxfp4"); + assert!(cfg.should_skip_module("model.layers.0.linear_attn.out_proj")); + assert!(cfg.should_skip_module("lm_head")); + assert!(!cfg.should_skip_module("model.layers.1.mlp.up_proj")); + } + + #[test] + fn test_normalize_format_in_config_groups_only() { + let json = r#"{ + "quant_method": "compressed-tensors", + "config_groups": { + "group_0": { + "format": "mxfp4-pack-quantized", + "weights": {"num_bits": 4, "group_size": 32, "type": "float", "strategy": "group", "symmetric": true} + } + }, + "ignore": ["lm_head"] + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "mxfp4"); + assert_eq!(cfg.group_size, 32); + } + + #[test] + fn test_olka_4b_config() { + let json = r#"{ + "quant_method": "compressed-tensors", + "format": "mxfp4-pack-quantized", + "config_groups": { + "group_0": { + "targets": ["Linear"], + "weights": { + "num_bits": 4, + "type": "float", + "strategy": "group", + "group_size": 32, + "symmetric": true + } + } + }, + "ignore": [ + "model.language_model.embed_tokens", + "model.language_model.layers.0.input_layernorm", + "model.language_model.layers.0.linear_attn.conv1d", + "model.language_model.layers.0.linear_attn.in_proj_a", + "model.language_model.norm", + "model.visual.blocks.0.attn.proj", + "mtp.fc" + ] + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "mxfp4"); + assert_eq!(cfg.group_size, 32); + assert_eq!(cfg.bits, 4); + assert!(cfg.should_skip_module("model.language_model.embed_tokens")); + assert!(cfg.should_skip_module("model.language_model.layers.0.linear_attn.conv1d")); + assert!(!cfg.should_skip_module("model.language_model.layers.0.mlp.up_proj")); + } + + #[test] + fn test_kaitchup_27b_config() { + let json = r#"{ + "config_groups": { + "group_0": { + "format": "mxfp4-pack-quantized", + "input_activations": null, + "output_activations": null, + "targets": ["Linear"], + "weights": { + "actorder": null, + "block_structure": null, + "dynamic": false, + "group_size": 32, + "num_bits": 4, + "observer": "memoryless_minmax", + "observer_kwargs": {}, + "scale_dtype": "torch.uint8", + "strategy": "group", + "symmetric": true, + "type": "float", + "zp_dtype": null + } + } + }, + "format": "mxfp4-pack-quantized", + "global_compression_ratio": null, + "ignore": [ + "model.visual.blocks.0.attn.qkv", + "model.language_model.layers.0.linear_attn.out_proj", + "lm_head" + ], + "kv_cache_scheme": null, + "quant_method": "compressed-tensors", + "quantization_status": "compressed", + "sparsity_config": {}, + "transform_config": {}, + "version": "0.13.1.dev53+gd96634b" + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "mxfp4"); + assert_eq!(cfg.group_size, 32); + assert_eq!(cfg.bits, 4); + assert!(cfg.should_skip_module("model.language_model.layers.0.linear_attn.out_proj")); + assert!(cfg.should_skip_module("lm_head")); + } + + #[test] + fn test_122b_regex_config() { + let json = r#"{ + "quant_method": "compressed-tensors", + "format": "mxfp4-pack-quantized", + "quantization_status": "compressed", + "config_groups": { + "group_0": { + "format": "mxfp4-pack-quantized", + "weights": { + "num_bits": 4, + "type": "float", + "strategy": "group", + "group_size": 32, + "symmetric": true, + "scale_dtype": "torch.uint8", + "dynamic": false, + "actorder": null, + "block_structure": null, + "observer": "minmax", + "observer_kwargs": {}, + "zp_dtype": null + }, + "targets": ["Linear"], + "input_activations": null, + "output_activations": null + } + }, + "ignore": [ + "re:.*self_attn.*", + "re:.*linear_attn.*", + "re:.*.mlp.gate$", + "re:.*shared_expert_gate.*", + "re:.*lm_head.*", + "re:.*embed_tokens.*", + "re:.*visual.*", + "re:.*mtp.*" + ], + "kv_cache_scheme": null, + "sparsity_config": {}, + "transform_config": {} + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "mxfp4"); + assert_eq!(cfg.group_size, 32); + assert_eq!(cfg.bits, 4); + assert!(cfg.should_skip_module("model.layers.5.self_attn.q_proj")); + assert!(cfg.should_skip_module("model.layers.5.linear_attn.out_proj")); + assert!(cfg.should_skip_module("model.layers.3.mlp.gate")); + assert!(!cfg.should_skip_module("model.layers.3.mlp.gate_proj")); + assert!(cfg.should_skip_module("model.layers.3.shared_expert_gate")); + assert!(cfg.should_skip_module("lm_head")); + assert!(cfg.should_skip_module("model.embed_tokens")); + assert!(cfg.should_skip_module("model.visual.blocks.0.attn.qkv")); + assert!(cfg.should_skip_module("mtp.fc")); + assert!(!cfg.should_skip_module("model.layers.3.mlp.up_proj")); + assert!(!cfg.should_skip_module("model.layers.3.mlp.down_proj")); + } + + #[test] + fn test_2imi9_9b_config() { + let json = r#"{ + "config_groups": { + "group_0": { + "format": "mxfp4-pack-quantized", + "input_activations": null, + "output_activations": null, + "targets": ["Linear"], + "weights": { + "actorder": null, + "block_structure": null, + "dynamic": false, + "group_size": 32, + "num_bits": 4, + "observer": "memoryless_minmax", + "observer_kwargs": {}, + "scale_dtype": "torch.uint8", + "strategy": "group", + "symmetric": true, + "type": "float", + "zp_dtype": null + } + } + }, + "format": "mxfp4-pack-quantized", + "global_compression_ratio": null, + "ignore": [ + "model.layers.0.linear_attn.out_proj", + "model.layers.0.linear_attn.in_proj_qkv", + "lm_head" + ], + "kv_cache_scheme": null, + "quant_method": "compressed-tensors", + "quantization_status": "compressed", + "sparsity_config": {}, + "transform_config": {}, + "version": "0.14.1.a20260310" + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "mxfp4"); + assert_eq!(cfg.group_size, 32); + assert_eq!(cfg.bits, 4); + assert!(cfg.should_skip_module("model.layers.0.linear_attn.out_proj")); + assert!(cfg.should_skip_module("lm_head")); + assert!(!cfg.should_skip_module("model.layers.1.mlp.up_proj")); + } + + #[test] + fn test_nvfp4_axionml_4b_config() { + let json = r#"{ + "quant_method": "modelopt", + "quant_algo": "NVFP4", + "config_groups": { + "group_0": { + "input_activations": {"dynamic": false, "num_bits": 4, "type": "float", "group_size": 16}, + "weights": {"dynamic": false, "num_bits": 4, "type": "float", "group_size": 16}, + "targets": ["Linear"] + } + }, + "ignore": [ + "lm_head", + "model.language_model.layers.0.linear_attn.conv1d", + "model.visual*", + "mtp.layers.0*" + ] + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "nvfp4"); + assert_eq!(cfg.group_size, 16); + assert_eq!(cfg.bits, 4); + assert!(cfg.should_skip_module("lm_head")); + assert!(cfg.should_skip_module("model.visual.encoder.layers.0.self_attn")); + assert!(cfg.should_skip_module("mtp.layers.0.mlp.gate_proj")); + assert!(!cfg.should_skip_module("model.language_model.layers.1.mlp.up_proj")); + } + + #[test] + fn test_nvfp4_glob_wildcards() { + let json = r#"{ + "quant_method": "modelopt", + "quant_algo": "NVFP4", + "ignore": [ + "lm_head", + "*.mlp.shared_expert.*", + "model.layers.0.self_attn*", + "model.layers.92*" + ] + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "nvfp4"); + assert!(cfg.should_skip_module("lm_head")); + assert!(cfg.should_skip_module("model.layers.5.mlp.shared_expert.gate_proj")); + assert!(cfg.should_skip_module("model.layers.0.self_attn.q_proj")); + assert!(cfg.should_skip_module("model.layers.0.self_attn.k_proj")); + assert!(cfg.should_skip_module("model.layers.92.self_attn.q_proj")); + assert!(!cfg.should_skip_module("model.layers.1.self_attn.q_proj")); + assert!(!cfg.should_skip_module("model.layers.5.mlp.gate_proj")); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index ce99d197..a851b421 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -60,26 +60,15 @@ macro_rules! serde_default { } pub fn module_path_matches_not_convert(module_path: &str, item: &str) -> bool { - let module_path = module_path.trim_end_matches(".weight"); - let item = item.trim_end_matches(".weight"); - module_path == item - || module_path.ends_with(item) - || module_path.ends_with(&format!(".{item}")) - || item.ends_with(module_path) - || item.ends_with(&format!(".{module_path}")) + crate::utils::config::match_ignore_pattern(module_path, item) } pub fn should_skip_fp8_for_module(module_path: &str, cfg: &QuantConfig) -> bool { - if module_path.is_empty() || cfg.modules_to_not_convert.is_empty() { - return false; - } - cfg.modules_to_not_convert - .iter() - .any(|item| module_path_matches_not_convert(module_path, item)) + cfg.should_skip_module(module_path) } pub fn should_skip_quant_for_module(module_path: &str, cfg: &QuantConfig) -> bool { - should_skip_fp8_for_module(module_path, cfg) + cfg.should_skip_module(module_path) } pub fn hub_load_local_safetensors(path: &String, json_file: &str) -> Result> { @@ -1002,8 +991,9 @@ pub fn init_config_tokenizer( qcfg.quant_method == "gptq" || qcfg.quant_method == "awq" || qcfg.quant_method == "fp8" - || qcfg.quant_method == "mxfp4", - "Invalid quantization format! Only `gptq`, `awq`, `fp8` and `mxfp4` supported, got `{}`", + || qcfg.quant_method == "mxfp4" + || qcfg.quant_method == "nvfp4", + "Invalid quantization format! Only `gptq`, `awq`, `fp8`, `mxfp4` and `nvfp4` supported, got `{}`", qcfg.quant_method ); if qcfg.quant_method == "gptq" || qcfg.quant_method == "awq" { @@ -1475,22 +1465,18 @@ pub fn get_arch_rope( ModelType::Qwen3, "<|im_start|>user\n {} <|im_end|>".to_string(), ), - "Qwen3_5ForCausalLM" | "qwen35" => ( - ModelType::Qwen3_5, + "qwen2moe" | "Qwen2MoeForCausalLM" | "qwen3moe" | "Qwen3MoeForCausalLM" => ( + ModelType::Qwen3MoE, "<|im_start|>user\n {} <|im_end|>".to_string(), ), - "Qwen3NextForCausalLM" => ( - ModelType::Qwen3_5MoE, + "Qwen3_5ForCausalLM" | "qwen35" => ( + ModelType::Qwen3_5, "<|im_start|>user\n {} <|im_end|>".to_string(), ), - "Qwen3_5MoeForCausalLM" | "qwen35moe" => ( + "Qwen3_5MoeForCausalLM" | "Qwen3NextForCausalLM" | "qwen35moe" => ( ModelType::Qwen3_5MoE, "<|im_start|>user\n {} <|im_end|>".to_string(), ), - "qwen2moe" | "Qwen2MoeForCausalLM" | "qwen3moe" | "Qwen3MoeForCausalLM" => ( - ModelType::Qwen3MoE, - "<|im_start|>user\n {} <|im_end|>".to_string(), - ), "Qwen3VLForConditionalGeneration" | "Qwen3VLMoeForConditionalGeneration" | "Qwen3_5ForConditionalGeneration" From 25af74e3fffdd20749f99c153e5d2917c4096ddc Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 3 Apr 2026 10:17:38 +0000 Subject: [PATCH 07/14] Update ReadMe --- ReadMe-CN.md | 40 +++++++++++++++++++++++++++++++--------- ReadMe.md | 43 +++++++++++++++++++++++++++++++++---------- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/ReadMe-CN.md b/ReadMe-CN.md index 72ec00bf..68f892f1 100644 --- a/ReadMe-CN.md +++ b/ReadMe-CN.md @@ -40,7 +40,7 @@ | **Qwen3-30B-A3B** | Q4_K_M | **30B (MoE)**| **97.16** tokens/s | | **Qwen3.5-27B** | Q4_K_M | **27B (Dense)**| **45.20** tokens/s | | **Qwen3.5-27B** | FP8 | **27B (Dense)**| **42** tokens/s (**Hopper**) | -| **Qwen3.5-35B-A3B** | Q3_K_M | **35B (MoE)**| **106** tokens/s (**Hopper**) | +| **Qwen3.5-35B-A3B** | Q3_K_M/MXFP4 | **35B (MoE)**| **95-106** tokens/s (**Hopper**) | > vLLM.rs 在 **Metal (Apple Silicon, M4)** 上的性能 @@ -74,7 +74,7 @@ * ✅ Qwen3-VL (Dense, 多模态) * ✅ MiroThinker-v1.5 (30B, 235B) -支持 **Safetensor** (包含GPTQ, AWQ, FP8-blockwise 量化格式) 和 **GGUF** 格式。 +支持 **Safetensor** (包含GPTQ, AWQ, MXFP4, NVFP4, FP8-blockwise 量化格式) 和 **GGUF** 格式。 所有模型均支持硬件FP8 KvCache加速(需SM90+及关闭`flashinfer` 或 `flashattn` 特性)。 @@ -162,6 +162,7 @@ python3 -m pip install vllm_rs
FP8模型 +_FP8-Blockwise格式:_ ```bash # CUDA (MoE, Dense) sm90+ 设备需打开`cutlass`特性以支持FP8硬件加速 vllm-rs --m Qwen/Qwen3.5-27B-FP8 --ui-server --prefix-cache @@ -169,6 +170,15 @@ vllm-rs --m Qwen/Qwen3.5-27B-FP8 --ui-server --prefix-cache vllm-rs --m Qwen/Qwen3-4B-Instruct-2507-FP8 --ui-server --prefix-cache ``` +_MXFP4 格式:_ +```bash +python3 -m vllm_rs.server --m olka-fi/Qwen3.5-4B-MXFP4 --ui-server --prefix-cache +``` + +_NVFP4 格式:_ +```bash +python3 -m vllm_rs.server --m AxionML/Qwen3.5-9B-NVFP4 --ui-server --prefix-cache +```
@@ -284,13 +294,24 @@ cargo install --features metal
- FP8模型 + FP8/FP4模型 + _FP8格式:_ ```bash vllm-rs --d 0,1 --w /path/Qwen3-Coder-30B-A3B-Instruct-FP8/ --ui-server --prefix-cache # Or Qwen3-Next 80B vllm-rs --m Qwen/Qwen3-Coder-Next-FP8 --ui-server --d 0,1 --prefix-cache ``` + + _MXFP4格式:_ + ```bash + vllm-rs --m olka-fi/Qwen3.5-4B-MXFP4 --ui-server --prefix-cache + ``` + + _NVFP4格式:_ + ```bash + vllm-rs --m AxionML/Qwen3.5-9B-NVFP4 --ui-server --prefix-cache + ```
@@ -492,14 +513,15 @@ pip install target/wheels/vllm_rs-*-cp38-abi3-*.whl --force-reinstall * [x] PD(Prefill/Decode)分离(CUDA) * [x] PD(Prefill/Decode)分离(Metal) * [x] 内置 ChatGPT风格 Web 网页服务 -* [x] **Embedding API** -* [x] **Tokenize/Detokenize API** -* [x] **MCP集成与工具调用** -* [x] **公共前缀缓存** -* [x] **Claude/Anthropic API 兼容服务器** -* [x] **支持CUDA 13** +* [x] Embedding API +* [x] Tokenize/Detokenize API +* [x] MCP集成与工具调用 +* [x] 公共前缀缓存 +* [x] Claude/Anthropic API 兼容服务器 +* [x] 支持CUDA 13 * [x] **支持FlashInfer后端** * [x] **支持DeepGEMM后端 (Hopper)** +* [x] **MXFP4/NVFP4模型支持** * [ ] TentorRT-LLM 后端 ## 📚 参考项目 diff --git a/ReadMe.md b/ReadMe.md index 0d3f6fb7..3a960e7a 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -40,7 +40,7 @@ A blazing-fast ⚡, lightweight **Rust** 🦀 implementation of vLLM. | **Qwen3-30B-A3B** | Q4_K_M | **30B (MoE)**| **97.16** tokens/s | | **Qwen3.5-27B** | Q4_K_M | **27B (Dense)**| **45.20** tokens/s | | **Qwen3.5-27B** | FP8 | **27B (Dense)**| **42** tokens/s (**Hopper**) | -| **Qwen3.5-35B-A3B** | Q3_K_M | **35B (MoE)**| **106** tokens/s (**Hopper**) | +| **Qwen3.5-35B-A3B** | Q3_K_M/MXFP4 | **35B (MoE)**| **95-106** tokens/s (**Hopper**) | > **Metal (Apple Silicon, M4)**
@@ -75,7 +75,7 @@ See [**Full Performance Benchmarks →**](docs/performance.md) * ✅ Qwen3-VL (Dense, Multimodal model) * ✅ MiroThinker-v1.5 (30B, 235B) -Supports both **Safetensor** (including GPTQ, AWQ and FP8-blockwise formats) and **GGUF** formats. +Supports both **Safetensor** (including GPTQ, AWQ, MXFP4, NVFP4, and FP8-blockwise formats) and **GGUF** formats. All models support hardware FP8 KV-cache acceleration (requires SM90+ and disable `flashinfer` or `flashattn`). @@ -163,10 +163,22 @@ python3 -m vllm_rs.server --w /path/Qwen3.5-35B-A3B --isq q4k --d 0 --ui-server
- FP8 Model + FP8/FP4 Model +_FP8-Blockwise format:_ ```bash python3 -m vllm_rs.server --m Qwen/Qwen3.5-27B-FP8 --ui-server --prefix-cache +``` + +_MXFP4 format:_ + +```bash +python3 -m vllm_rs.server --m olka-fi/Qwen3.5-4B-MXFP4 --ui-server --prefix-cache +``` + +_NVFP4 format:_ +```bash +python3 -m vllm_rs.server --m AxionML/Qwen3.5-9B-NVFP4 --ui-server --prefix-cache ```
@@ -300,8 +312,9 @@ Use `--i` to enable interactive mode 🤖, `--ui-server` or `--server` to enable
- FP8 Model + FP8/FP4 Model +_FP8-Blockwise format:_ ```bash # CUDA (MoE, Dense), be sure to enable `cutlass` feature on sm90+ vllm-rs --m Qwen/Qwen3.5-27B-FP8 --ui-server --prefix-cache @@ -311,6 +324,15 @@ vllm-rs --m Qwen/Qwen3-Coder-Next-FP8 --ui-server --d 0,1 --prefix-cache vllm-rs --m Qwen/Qwen3.5-4B-FP8 --ui-server --prefix-cache ``` +_MXFP4 format:_ +```bash +vllm-rs --m olka-fi/Qwen3.5-4B-MXFP4 --ui-server --prefix-cache +``` + +_NVFP4 format:_ +```bash +vllm-rs --m AxionML/Qwen3.5-9B-NVFP4 --ui-server --prefix-cache +```
@@ -558,14 +580,15 @@ pip install target/wheels/vllm_rs-*-cp38-abi3-*.whl --force-reinstall * [x] Prefill-decode Disaggregation (CUDA) * [x] Prefill-decode Disaggregation (Metal) * [x] Built-in ChatGPT-like Web Server -* [x] **Embedding API** -* [x] **Tokenize/Detokenize API** -* [x] **MCP Integration & Tool Calling** -* [x] **Prefix Caching** -* [x] **Claude/Anthropic-compatible API Server** -* [x] **Support CUDA 13** +* [x] Embedding API +* [x] Tokenize/Detokenize API +* [x] MCP Integration & Tool Calling +* [x] Prefix Caching +* [x] Claude/Anthropic-compatible API Server +* [x] Support CUDA 13 * [x] **Support FlashInfer backend** * [x] **Support DeepGEMM backend (Hopper)** +* [x] **MXFP4/NVFP4 Model Support** * [ ] TentorRT-LLM --- From affebdbda2a28118c683ef525c9a30e9a453d374 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 3 Apr 2026 10:49:44 +0000 Subject: [PATCH 08/14] Compatible with more meta format --- src/utils/config.rs | 164 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 155 insertions(+), 9 deletions(-) diff --git a/src/utils/config.rs b/src/utils/config.rs index 8af2aa29..0102d791 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -716,6 +716,7 @@ pub fn match_ignore_pattern(module_path: &str, pattern: &str) -> bool { #[derive(Serialize, Deserialize, PartialEq, Clone)] pub struct QuantConfig { + #[serde(default)] pub quant_method: String, #[serde(default)] pub bits: usize, @@ -734,17 +735,46 @@ pub struct QuantConfig { pub config_groups: Option, #[serde(default)] pub quant_algo: Option, + #[serde(default)] + pub mode: Option, } impl QuantConfig { /// Normalizes a quantization config into a canonical quant_method string. /// - /// Handles two families: - /// 1. `compressed-tensors` with `format` containing `mxfp4` → `"mxfp4"` - /// 2. `modelopt` with `quant_algo` == `NVFP4` → `"nvfp4"` + /// Handles the following families: + /// 1. MLX-style: `mode` == `"nvfp4"` or `"mxfp4"` (quant_method may be empty) + /// 2. `modelopt` with `quant_algo` == `NVFP4` / `FP4` + /// 3. `compressed-tensors` with `format` containing `nvfp4` or `mxfp4` + /// 4. `compressed-tensors` detected from `config_groups` content /// /// Also extracts group_size / bits from config_groups when present. pub fn normalize_compressed_tensors(&mut self) { + // MLX-style: {"mode": "nvfp4", "bits": 4, "group_size": 16} + if let Some(mode) = &self.mode { + if mode.eq_ignore_ascii_case("nvfp4") { + self.quant_method = "nvfp4".to_string(); + if self.group_size == 0 { + self.group_size = 16; + } + if self.bits == 0 { + self.bits = 4; + } + return; + } + if mode.eq_ignore_ascii_case("mxfp4") { + self.quant_method = "mxfp4".to_string(); + if self.group_size == 0 { + self.group_size = 32; + } + if self.bits == 0 { + self.bits = 4; + } + return; + } + } + + // modelopt: {"quant_method": "modelopt", "quant_algo": "NVFP4"} if self.quant_method == "modelopt" { if let Some(algo) = &self.quant_algo { if algo.eq_ignore_ascii_case("NVFP4") || algo.eq_ignore_ascii_case("FP4") { @@ -776,12 +806,24 @@ impl QuantConfig { return; } - let is_mxfp4 = self - .format - .as_deref() - .map(|f| f.contains("mxfp4")) - .unwrap_or(false) - || self.detect_mxfp4_from_config_groups(); + // compressed-tensors: check format string for nvfp4 or mxfp4 + let format_str = self.format.as_deref().unwrap_or(""); + + let is_nvfp4 = format_str.contains("nvfp4") || self.detect_nvfp4_from_config_groups(); + + if is_nvfp4 { + self.quant_method = "nvfp4".to_string(); + self.extract_compressed_tensors_params(); + if self.group_size == 0 { + self.group_size = 16; + } + if self.bits == 0 { + self.bits = 4; + } + return; + } + + let is_mxfp4 = format_str.contains("mxfp4") || self.detect_mxfp4_from_config_groups(); if is_mxfp4 { self.quant_method = "mxfp4".to_string(); @@ -796,7 +838,20 @@ impl QuantConfig { }; if let Some(obj) = groups.as_object() { for (_key, group) in obj { + // Check group-level format (e.g. "nvfp4-pack-quantized") + if let Some(fmt) = group.get("format").and_then(|v| v.as_str()) { + if fmt.contains("nvfp4") { + return true; + } + } if let Some(weights) = group.get("weights") { + // Check weights-level format + if let Some(fmt) = weights.get("format").and_then(|v| v.as_str()) { + if fmt.contains("nvfp4") { + return true; + } + } + // Detect by parameters: 4-bit float with group_size=16 if let Some(num_bits) = weights.get("num_bits").and_then(|v| v.as_u64()) { if num_bits == 4 { let is_float = weights @@ -986,6 +1041,7 @@ mod tests { ], config_groups: None, quant_algo: None, + mode: None, }; assert!(cfg.should_skip_module("model.layers.0.self_attn.q_proj")); assert!(cfg.should_skip_module("model.layers.5.linear_attn.out_proj")); @@ -1318,4 +1374,94 @@ mod tests { assert!(!cfg.should_skip_module("model.layers.1.self_attn.q_proj")); assert!(!cfg.should_skip_module("model.layers.5.mlp.gate_proj")); } + + #[test] + fn test_nvfp4_mlx_community_config() { + // mlx-community/Qwen3.5-0.8B-nvfp4 style: mode field, no quant_method + let json = r#"{ + "group_size": 16, + "bits": 4, + "mode": "nvfp4" + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + assert_eq!(cfg.quant_method, ""); // default empty before normalization + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "nvfp4"); + assert_eq!(cfg.bits, 4); + assert_eq!(cfg.group_size, 16); + } + + #[test] + fn test_nvfp4_compressed_tensors_format() { + // RedHatAI/Qwen3.5-122B-A10B-NVFP4 style: compressed-tensors + nvfp4-pack-quantized + let json = r#"{ + "quant_method": "compressed-tensors", + "format": "nvfp4-pack-quantized", + "config_groups": { + "group_0": { + "format": "nvfp4-pack-quantized", + "targets": ["Linear"], + "weights": { + "num_bits": 4, + "type": "float", + "group_size": 16, + "strategy": "tensor_group", + "symmetric": true, + "dynamic": false, + "scale_dtype": "torch.float8_e4m3fn" + }, + "input_activations": { + "num_bits": 4, + "type": "float", + "group_size": 16, + "dynamic": "local", + "scale_dtype": "torch.float8_e4m3fn" + } + } + }, + "ignore": [ + "lm_head", + "model.visual.blocks.0.attn.qkv", + "model.language_model.layers.0.linear_attn.out_proj", + "model.language_model.layers.0.mlp.gate", + "model.language_model.layers.0.mlp.shared_expert_gate" + ] + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "nvfp4"); + assert_eq!(cfg.bits, 4); + assert_eq!(cfg.group_size, 16); + assert!(cfg.should_skip_module("lm_head")); + assert!(cfg.should_skip_module("model.visual.blocks.0.attn.qkv")); + assert!(cfg.should_skip_module("model.language_model.layers.0.linear_attn.out_proj")); + assert!(cfg.should_skip_module("model.language_model.layers.0.mlp.gate")); + assert!(cfg.should_skip_module("model.language_model.layers.0.mlp.shared_expert_gate")); + assert!(!cfg.should_skip_module("model.language_model.layers.0.mlp.gate_proj")); + assert!(!cfg.should_skip_module("model.language_model.layers.0.mlp.down_proj")); + } + + #[test] + fn test_nvfp4_compressed_tensors_detect_from_groups() { + // compressed-tensors without top-level format, detected from config_groups + let json = r#"{ + "quant_method": "compressed-tensors", + "config_groups": { + "group_0": { + "format": "nvfp4-pack-quantized", + "targets": ["Linear"], + "weights": { + "num_bits": 4, + "type": "float", + "group_size": 16 + } + } + } + }"#; + let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); + cfg.normalize_compressed_tensors(); + assert_eq!(cfg.quant_method, "nvfp4"); + assert_eq!(cfg.bits, 4); + assert_eq!(cfg.group_size, 16); + } } From f6416300093449111f9431f457f68f6e975bc76c Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 3 Apr 2026 11:16:06 +0000 Subject: [PATCH 09/14] Explicit error for unsupported MLX format --- src/utils/config.rs | 49 ++++++++++++++------------------------------- src/utils/mod.rs | 11 ++++++++++ 2 files changed, 26 insertions(+), 34 deletions(-) diff --git a/src/utils/config.rs b/src/utils/config.rs index 0102d791..bf66af6a 100644 --- a/src/utils/config.rs +++ b/src/utils/config.rs @@ -743,37 +743,15 @@ impl QuantConfig { /// Normalizes a quantization config into a canonical quant_method string. /// /// Handles the following families: - /// 1. MLX-style: `mode` == `"nvfp4"` or `"mxfp4"` (quant_method may be empty) - /// 2. `modelopt` with `quant_algo` == `NVFP4` / `FP4` - /// 3. `compressed-tensors` with `format` containing `nvfp4` or `mxfp4` - /// 4. `compressed-tensors` detected from `config_groups` content + /// 1. `modelopt` with `quant_algo` == `NVFP4` / `FP4` + /// 2. `compressed-tensors` with `format` containing `nvfp4` or `mxfp4` + /// 3. `compressed-tensors` detected from `config_groups` content + /// + /// MLX-style quantization (`"mode": "nvfp4"/"mxfp4"`) uses an incompatible + /// packing format (U32 weights, integer scales) and is NOT supported. /// /// Also extracts group_size / bits from config_groups when present. pub fn normalize_compressed_tensors(&mut self) { - // MLX-style: {"mode": "nvfp4", "bits": 4, "group_size": 16} - if let Some(mode) = &self.mode { - if mode.eq_ignore_ascii_case("nvfp4") { - self.quant_method = "nvfp4".to_string(); - if self.group_size == 0 { - self.group_size = 16; - } - if self.bits == 0 { - self.bits = 4; - } - return; - } - if mode.eq_ignore_ascii_case("mxfp4") { - self.quant_method = "mxfp4".to_string(); - if self.group_size == 0 { - self.group_size = 32; - } - if self.bits == 0 { - self.bits = 4; - } - return; - } - } - // modelopt: {"quant_method": "modelopt", "quant_algo": "NVFP4"} if self.quant_method == "modelopt" { if let Some(algo) = &self.quant_algo { @@ -1376,19 +1354,22 @@ mod tests { } #[test] - fn test_nvfp4_mlx_community_config() { - // mlx-community/Qwen3.5-0.8B-nvfp4 style: mode field, no quant_method + fn test_mlx_nvfp4_not_normalized() { + // MLX-community models use an incompatible quantization format: + // U32-packed weights + integer U8 scales (NOT FP8 E4M3 block scales). + // These must NOT be normalized to our "nvfp4" quant_method. let json = r#"{ "group_size": 16, "bits": 4, "mode": "nvfp4" }"#; let mut cfg: QuantConfig = serde_json::from_str(json).unwrap(); - assert_eq!(cfg.quant_method, ""); // default empty before normalization + assert_eq!(cfg.quant_method, ""); cfg.normalize_compressed_tensors(); - assert_eq!(cfg.quant_method, "nvfp4"); - assert_eq!(cfg.bits, 4); - assert_eq!(cfg.group_size, 16); + assert_eq!( + cfg.quant_method, "", + "MLX mode=nvfp4 must not normalize to nvfp4" + ); } #[test] diff --git a/src/utils/mod.rs b/src/utils/mod.rs index a851b421..b12cd2bd 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -987,6 +987,17 @@ pub fn init_config_tokenizer( if let Some(qcfg) = &mut config.quantization_config { qcfg.normalize_compressed_tensors(); + if let Some(mode) = &qcfg.mode { + if mode.eq_ignore_ascii_case("nvfp4") || mode.eq_ignore_ascii_case("mxfp4") { + panic!( + "MLX-quantized models (mode=\"{}\") are not supported. \ + MLX uses an incompatible packing format (U32 weights with integer scales). \ + Please use a modelopt or compressed-tensors quantized model instead \ + (e.g. AxionML/Qwen3.5-*-NVFP4 or nvidia/*-NVFP4).", + mode + ); + } + } assert!( qcfg.quant_method == "gptq" || qcfg.quant_method == "awq" From ff5910f40b8df595cc906d74409a1491b83cdf70 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Sat, 4 Apr 2026 00:12:36 +0800 Subject: [PATCH 10/14] Fix compressed-tensors nvfp4 precision issue --- src/models/layers/linear.rs | 15 ++++++++++++++- src/models/layers/moe.rs | 24 ++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/src/models/layers/linear.rs b/src/models/layers/linear.rs index 5e3f8447..7777b827 100644 --- a/src/models/layers/linear.rs +++ b/src/models/layers/linear.rs @@ -1212,7 +1212,20 @@ impl LnNvfp4 { vb.get_with_hints_dtype((out_dim, scale_dim), "scales", shard, DType::U8)? }; - let global_scale = if vb.contains_tensor("weight_scale_2") { + let global_scale = if vb.contains_tensor("weight_global_scale") { + // compressed-tensors format: weight_global_scale is a divisor, invert it + let t = match vb.get_with_hints_dtype((1,), "weight_global_scale", shard, DType::F32) { + Ok(t) => t, + Err(_) => vb.get_with_hints_dtype((), "weight_global_scale", shard, DType::F32)?, + }; + let raw = t.flatten_all()?.to_vec1::()?[0]; + if raw != 0.0 { + 1.0 / raw + } else { + 1.0 + } + } else if vb.contains_tensor("weight_scale_2") { + // modelopt format: weight_scale_2 is the direct multiplier let t = match vb.get_with_hints_dtype((1,), "weight_scale_2", shard, DType::F32) { Ok(t) => t, Err(_) => vb.get_with_hints_dtype((), "weight_scale_2", shard, DType::F32)?, diff --git a/src/models/layers/moe.rs b/src/models/layers/moe.rs index ea74c572..f5686b5e 100644 --- a/src/models/layers/moe.rs +++ b/src/models/layers/moe.rs @@ -1514,7 +1514,20 @@ impl FusedMoeNvfp4 { } fn load_global_scale(vb: &candle_nn::var_builder::ShardedVarBuilder, shard: Shard) -> f32 { - if vb.contains_tensor("weight_scale_2") { + if vb.contains_tensor("weight_global_scale") { + // compressed-tensors format: weight_global_scale is a divisor, invert it + let raw = vb + .get_with_hints_dtype((1,), "weight_global_scale", shard, DType::F32) + .or_else(|_| vb.get_with_hints_dtype((), "weight_global_scale", shard, DType::F32)) + .and_then(|t| t.flatten_all()?.to_vec1::().map(|v| v[0])) + .unwrap_or(1.0); + if raw != 0.0 { + 1.0 / raw + } else { + 1.0 + } + } else if vb.contains_tensor("weight_scale_2") { + // modelopt format: weight_scale_2 is the direct multiplier vb.get_with_hints_dtype((1,), "weight_scale_2", shard, DType::F32) .or_else(|_| vb.get_with_hints_dtype((), "weight_scale_2", shard, DType::F32)) .and_then(|t| t.flatten_all()?.to_vec1::().map(|v| v[0])) @@ -1628,7 +1641,14 @@ impl FusedMoeNvfp4 { let gate_up_gscales: Vec = gate_gscales_vec .iter() .zip(up_gscales_vec.iter()) - .map(|(g, u)| (g + u) / 2.0) + .map(|(g, u)| { + if (g - u).abs() > f32::EPSILON { + crate::log_warn!( + "NVFP4 MoE: gate/up global scales differ ({g} vs {u}), using gate scale" + ); + } + *g + }) .collect(); let gate_up_global_scales = Tensor::from_vec(gate_up_gscales, (num_experts,), dev)?; From 57d65ac045134fb5a3e3b4f4844f875e98a79ec7 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Sat, 4 Apr 2026 01:24:05 +0800 Subject: [PATCH 11/14] Fix weight global scale sharding --- src/models/layers/linear.rs | 12 ++++++++---- src/models/layers/moe.rs | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/models/layers/linear.rs b/src/models/layers/linear.rs index 7777b827..edbcb44f 100644 --- a/src/models/layers/linear.rs +++ b/src/models/layers/linear.rs @@ -1212,11 +1212,15 @@ impl LnNvfp4 { vb.get_with_hints_dtype((out_dim, scale_dim), "scales", shard, DType::U8)? }; + let no_shard = Shard::default(); let global_scale = if vb.contains_tensor("weight_global_scale") { // compressed-tensors format: weight_global_scale is a divisor, invert it - let t = match vb.get_with_hints_dtype((1,), "weight_global_scale", shard, DType::F32) { + let t = match vb.get_with_hints_dtype((1,), "weight_global_scale", no_shard, DType::F32) + { Ok(t) => t, - Err(_) => vb.get_with_hints_dtype((), "weight_global_scale", shard, DType::F32)?, + Err(_) => { + vb.get_with_hints_dtype((), "weight_global_scale", no_shard, DType::F32)? + } }; let raw = t.flatten_all()?.to_vec1::()?[0]; if raw != 0.0 { @@ -1226,9 +1230,9 @@ impl LnNvfp4 { } } else if vb.contains_tensor("weight_scale_2") { // modelopt format: weight_scale_2 is the direct multiplier - let t = match vb.get_with_hints_dtype((1,), "weight_scale_2", shard, DType::F32) { + let t = match vb.get_with_hints_dtype((1,), "weight_scale_2", no_shard, DType::F32) { Ok(t) => t, - Err(_) => vb.get_with_hints_dtype((), "weight_scale_2", shard, DType::F32)?, + Err(_) => vb.get_with_hints_dtype((), "weight_scale_2", no_shard, DType::F32)?, }; t.flatten_all()?.to_vec1::()?[0] } else { diff --git a/src/models/layers/moe.rs b/src/models/layers/moe.rs index f5686b5e..d3ac1c9d 100644 --- a/src/models/layers/moe.rs +++ b/src/models/layers/moe.rs @@ -1513,12 +1513,15 @@ impl FusedMoeNvfp4 { } } - fn load_global_scale(vb: &candle_nn::var_builder::ShardedVarBuilder, shard: Shard) -> f32 { + fn load_global_scale(vb: &candle_nn::var_builder::ShardedVarBuilder) -> f32 { + let no_shard = Shard::default(); if vb.contains_tensor("weight_global_scale") { // compressed-tensors format: weight_global_scale is a divisor, invert it let raw = vb - .get_with_hints_dtype((1,), "weight_global_scale", shard, DType::F32) - .or_else(|_| vb.get_with_hints_dtype((), "weight_global_scale", shard, DType::F32)) + .get_with_hints_dtype((1,), "weight_global_scale", no_shard, DType::F32) + .or_else(|_| { + vb.get_with_hints_dtype((), "weight_global_scale", no_shard, DType::F32) + }) .and_then(|t| t.flatten_all()?.to_vec1::().map(|v| v[0])) .unwrap_or(1.0); if raw != 0.0 { @@ -1528,8 +1531,8 @@ impl FusedMoeNvfp4 { } } else if vb.contains_tensor("weight_scale_2") { // modelopt format: weight_scale_2 is the direct multiplier - vb.get_with_hints_dtype((1,), "weight_scale_2", shard, DType::F32) - .or_else(|_| vb.get_with_hints_dtype((), "weight_scale_2", shard, DType::F32)) + vb.get_with_hints_dtype((1,), "weight_scale_2", no_shard, DType::F32) + .or_else(|_| vb.get_with_hints_dtype((), "weight_scale_2", no_shard, DType::F32)) .and_then(|t| t.flatten_all()?.to_vec1::().map(|v| v[0])) .unwrap_or(1.0) } else { @@ -1585,7 +1588,7 @@ impl FusedMoeNvfp4 { sh0, DType::U8, )?); - gate_gscales_vec.push(Self::load_global_scale(&gate_proj_vb, sh0)); + gate_gscales_vec.push(Self::load_global_scale(&gate_proj_vb)); let up_proj_vb = expert_vb.pp("up_proj"); let packed_name = Self::tensor_name_packed(&up_proj_vb); @@ -1603,7 +1606,7 @@ impl FusedMoeNvfp4 { sh0, DType::U8, )?); - up_gscales_vec.push(Self::load_global_scale(&up_proj_vb, sh0)); + up_gscales_vec.push(Self::load_global_scale(&up_proj_vb)); let down_proj_vb = expert_vb.pp("down_proj"); let packed_name = Self::tensor_name_packed(&down_proj_vb); @@ -1622,7 +1625,7 @@ impl FusedMoeNvfp4 { sh1, DType::U8, )?); - down_gscales_vec.push(Self::load_global_scale(&down_proj_vb, sh1)); + down_gscales_vec.push(Self::load_global_scale(&down_proj_vb)); } } _ => candle_core::bail!("FusedMoeNvfp4: GGUF loading not supported for NVFP4"), From 50664bbaa8e8cced5839c6cd2ea66292bbbec486 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Sun, 5 Apr 2026 10:53:36 +0000 Subject: [PATCH 12/14] Fix models & add test-model skill --- .cursor/skills/test-model/SKILL.md | 297 +++++++++++++++++++++++++++++ Cargo.toml | 2 +- ReadMe-CN.md | 1 + ReadMe.md | 1 + docs/test_model.md | 104 ++++++++++ src/core/runner.rs | 36 +++- src/models/layers/deltanet.rs | 115 ++++++++--- src/utils/graph.rs | 5 +- 8 files changed, 527 insertions(+), 34 deletions(-) create mode 100644 .cursor/skills/test-model/SKILL.md create mode 100644 docs/test_model.md diff --git a/.cursor/skills/test-model/SKILL.md b/.cursor/skills/test-model/SKILL.md new file mode 100644 index 00000000..f7125edf --- /dev/null +++ b/.cursor/skills/test-model/SKILL.md @@ -0,0 +1,297 @@ +--- +name: test-model +description: >- + Test LLM models served by vllm.rs for correctness, output quality, and + performance. Use when the user asks to test, benchmark, validate, or verify + models — either from a local folder path or HuggingFace model IDs. Supports + all vllm.rs-compatible formats: BF16, FP8, MXFP4, NVFP4, GGUF, GPTQ, AWQ, + ISQ, Dense, MoE, and Multimodal architectures. +--- + +# Test Model — Validate and Benchmark LLM Models on vllm.rs + +## Phase 0: Gather Model List + +Collect the models to test. The user provides **one or both** of: + +| Input | Format | Example | +|-------|--------|---------| +| **Local folder** | Absolute path to a directory containing model weights | `/data/models` or `/data/Qwen3.5-27B-FP8` | +| **HuggingFace IDs** | Comma-separated model IDs | `AxionML/Qwen3.5-2B-NVFP4, Qwen/Qwen3-4B` | + +### Detecting models in a local folder + +If the user provides a **parent directory** (not a single model), scan it to find testable models: + +```bash +# List subdirectories that look like model folders +for d in /data/*/; do + if [ -f "$d/config.json" ] || ls "$d"/*.gguf 2>/dev/null | head -1 >/dev/null; then + echo "$d" + fi +done +``` + +For each candidate directory, determine the model type by reading `config.json`: + +```python +import json, os, sys, glob + +def detect_model(path): + """Detect model type and quantization from a local directory.""" + config_path = os.path.join(path, "config.json") + gguf_files = glob.glob(os.path.join(path, "*.gguf")) + + info = {"path": path, "name": os.path.basename(path.rstrip("/"))} + + if gguf_files: + info["format"] = "gguf" + info["gguf_file"] = os.path.basename(gguf_files[0]) + return info + + if not os.path.exists(config_path): + return None + + cfg = json.load(open(config_path)) + arch = (cfg.get("architectures") or ["Unknown"])[0] + + supported = [ + "LlamaForCausalLM", "MistralForCausalLM", "Ministral3ForConditionalGeneration", + "Qwen2ForCausalLM", "Qwen3ForCausalLM", "Qwen3MoeForCausalLM", + "Qwen3_5ForCausalLM", "Qwen3_5MoeForCausalLM", + "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration", + "Qwen3NextForCausalLM", + "Qwen3VLForConditionalGeneration", + "Gemma3ForConditionalGeneration", "Gemma3ForCausalLM", + "Gemma4ForCausalLM", "Gemma4ForConditionalGeneration", + "Phi3ForCausalLM", "Phi4ForCausalLM", + "Glm4ForCausalLM", "Glm4MoeForCausalLM", + ] + if arch not in supported: + info["skip"] = f"Unsupported architecture: {arch}" + return info + + info["arch"] = arch + info["format"] = "safetensors" + + qcfg = cfg.get("quantization_config", {}) + qm = qcfg.get("quant_method", "") + if qm in ("fp8", "modelopt", "compressed-tensors"): + algo = qcfg.get("quant_algo", "") + fmt = qcfg.get("format", "") + if algo and ("nvfp4" in algo.lower() or "fp4" in algo.lower()): + info["quant"] = "nvfp4" + elif "nvfp4" in fmt.lower(): + info["quant"] = "nvfp4" + elif "mxfp4" in fmt.lower(): + info["quant"] = "mxfp4" + elif qm == "fp8": + info["quant"] = "fp8" + else: + info["quant"] = qm + elif qm in ("gptq", "awq"): + info["quant"] = qm + elif qm == "mxfp4": + info["quant"] = "mxfp4" + else: + info["quant"] = "bf16" + + return info +``` + +Present the detected models to the user as a table and confirm before proceeding. + +--- + +## Phase 1: Estimate GPU Requirements and Detect Hardware + +### Detect available GPUs + +```bash +nvidia-smi --query-gpu=index,name,memory.total,memory.free --format=csv,noheader,nounits +``` + +Parse the output to get `gpu_id`, `name`, `total_mb`, `free_mb` for each GPU. + +### Estimate model memory + +Use these rough heuristics for memory estimation (single-GPU, including KV cache overhead): + +| Format | Estimate (GB) | +|--------|---------------| +| BF16 / FP16 | `params_B * 2.2` | +| FP8 | `params_B * 1.2` | +| MXFP4 / NVFP4 | `params_B * 0.8` | +| GGUF Q4_K_M | `params_B * 0.7` | +| GGUF Q3_K_M | `params_B * 0.55` | +| GGUF Q2_K | `params_B * 0.45` | +| MoE (A3B active) | Use active params for compute, total params for weight memory | + +Extract parameter count from the model name when possible (e.g. `Qwen3.5-27B` → 27B). +For MoE models with `A3B` in the name, the weight memory uses total params but fits better than dense. + +### GPU assignment rules + +1. If a model fits in one GPU's free memory, use `--d ` with the GPU that has the most free memory. +2. If a model needs 2 GPUs, use `--d ,` with the two GPUs with the most free memory. +3. If a model exceeds all available GPU memory, report it as skipped and move to the next. +4. For models explicitly specified as multi-GPU by the user, respect that. + +--- + +## Phase 2: Build the Project + +Build using `run.sh` which compiles both the main binary and the runner: + +```bash +cd +./run.sh --features cuda,nccl,flashinfer,cutlass,graph --release +``` + +Verify the build succeeds (exit code 0). The `Error: Must provide model_id or weight_path` message after build is expected — it means the binary compiled correctly. + +If the build fails, check and fix compilation errors before proceeding. + +--- + +## Phase 3: Create the Test Script + +Create `test_model.py` in the project root with the following capabilities: + +- Accept `--port` to specify the API server port +- Accept `--wait` for server readiness timeout +- Test both `thinking=false` and `thinking=true` modes +- Send a prompt with **at least 1024 input tokens** and request **at least 2048 output tokens** +- Measure end-to-end throughput (completion_tokens / total_time) +- Check output quality: detect excessive 3-gram repetition, too-short responses +- Report prompt tokens, completion tokens, time, throughput, and quality verdict +- Print a summary table at the end + +The prompt should be a substantive multi-topic question (algorithms, data structures, etc.) padded with context tokens to reach the 1k+ input requirement. Use `max_tokens: 2048` and `temperature: 0.7`. Set request timeout to 300s. + +For thinking mode, add `"extra_body": {"thinking": true}` to the payload. + +Quality checks: +- Response must be at least 100 characters +- 3-gram repetition: flag if any trigram appears more than `max(10, 5% of total trigrams)` times + +--- + +## Phase 4: Test Each Model + +For each model, execute this sequence: + +### Step 1: Kill previous instances + +```bash +pkill -9 -f 'vllm-rs|runner' 2>/dev/null +sleep 3 +``` + +Always wait 3 seconds after killing to ensure GPU memory is released. + +### Step 2: Start the server + +Build the server command based on model type: + +| Model source | Command pattern | +|-------------|-----------------| +| Local safetensors | `./target/release/vllm-rs --w --prefix-cache --ui-server --d --port 7000` | +| Local GGUF | `./target/release/vllm-rs --w --f --prefix-cache --ui-server --d --port 7000` | +| HuggingFace ID | `./target/release/vllm-rs --m --prefix-cache --ui-server --d --port 7000` | + +Run the server in the background with `RUST_BACKTRACE=1` for debugging. + +### Step 3: Wait for server readiness + +Poll `GET /v1/models` every 2-3 seconds until it returns HTTP 200, with a timeout of: +- Small models (< 10B): 120s +- Medium models (10-40B): 300s +- Large models (> 40B) or HF downloads: 600s + +### Step 4: Run the test script + +```bash +python3 test_model.py --port 7000 +``` + +### Step 5: Handle failures + +If the server fails to start or the test script returns errors: + +1. **Check server logs** for panics or errors +2. **Common issues and fixes**: + +| Error | Likely cause | Fix | +|-------|-------------|-----| +| `MLX-quantized models` panic | Incompatible NVFP4 packing | Skip model; use modelopt/compressed-tensors variant | +| `Unable to load ... projection weights` | DeltaNet weights not detected as quantized | Check `is_weight_quantized` in `deltanet.rs` | +| `CUDA out of memory` | Model too large for GPU | Try with more GPUs or skip | +| Server starts but API times out | Model too slow on prefill | Increase test timeout to 600s | +| `failed to fill whole buffer` | Runner process crashed | Check runner logs, enable `RUST_BACKTRACE=full` | + +3. **Debug with unwrap**: If the model crashes during inference, temporarily change `guard.step()` to `guard.step().unwrap()` in `src/core/engine.rs` to get a full stack trace. **Revert after debugging.** + +4. If a model cannot be fixed, record the failure reason and continue to the next model. + +--- + +## Phase 5: Summarize Results + +After all models are tested, produce a summary table: + +``` +## Test Results + +| # | Model | Format | GPUs | thinking=false | thinking=true | Quality | +|---|-------|--------|------|----------------|---------------|---------| +| 1 | Qwen3.5-27B-FP8 | FP8 | 1 | 1342 in / 2048 out, 42.2 tok/s | 1342 in / 2048 out, 42.2 tok/s | OK | +| 2 | ... | ... | ... | ... | ... | ... | + +### Notes +- Model X: SKIPPED — reason +- Model Y: FAILED — error description +``` + +Include for each model: +- Model name and quantization format +- Number of GPUs used +- Input/output token counts and throughput for both thinking modes +- Quality verdict (OK / ISSUES / FAILED / SKIPPED) + +--- + +## Quick Reference + +### Key files + +| File | Purpose | +|------|---------| +| `test_model.py` | OpenAI API test script (created by this skill) | +| `src/core/engine.rs` | Engine loop; `guard.step()` for debug | +| `src/models/layers/deltanet.rs` | DeltaNet layer; quantization detection | +| `src/models/layers/linear.rs` | Linear layer loaders (FP8, MXFP4, NVFP4) | +| `run.sh` | Build script (compiles both vllm-rs and runner) | + +### Build features + +| Feature set | When to use | +|-------------|-------------| +| `cuda,nccl,flashinfer,cutlass,graph` | SM80+ (Ampere/Ada/Hopper), recommended | +| `cuda,nccl,flashattn,cutlass,graph` | Alternative to flashinfer | +| `cuda,nccl,graph` | V100 (SM70), no flash attention | +| `metal` | macOS Apple Silicon | + +### Server flags + +| Flag | Purpose | +|------|---------| +| `--w ` | Local model weight directory | +| `--f ` | GGUF filename within the weight directory | +| `--m ` | HuggingFace model ID (auto-downloads) | +| `--d ` | GPU device IDs (e.g. `0` or `0,1`) | +| `--port ` | API server port | +| `--prefix-cache` | Enable automatic prefix caching | +| `--ui-server` | Enable built-in ChatGPT-like web UI | +| `--isq ` | In-situ quantization (q2k, q3k, q4k, q5k, q6k, q8_0) | +| `--fp8-kvcache` | Enable FP8 KV cache (no flashinfer/flashattn) | diff --git a/Cargo.toml b/Cargo.toml index a780f2a6..d52a97dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.5.0", rev = "f21d557" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.5.0", rev = "0c1e716" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" diff --git a/ReadMe-CN.md b/ReadMe-CN.md index 68f892f1..392ccc87 100644 --- a/ReadMe-CN.md +++ b/ReadMe-CN.md @@ -95,6 +95,7 @@ - [Rust库](docs/rust_crate.md) - [Tokenize/Detokenize](docs/tokenize.md) - [性能测试](docs/performance.md) +- [模型测试 (AI辅助)](docs/test_model.md) ## 📘 使用方法(Python) ### 📦 使用 pip 安装 diff --git a/ReadMe.md b/ReadMe.md index 3a960e7a..f0d1acc6 100644 --- a/ReadMe.md +++ b/ReadMe.md @@ -95,6 +95,7 @@ All models support hardware FP8 KV-cache acceleration (requires SM90+ and disabl - [Rust crate](docs/rust_crate.md) - [Tokenize/Detokenize](docs/tokenize.md) - [Performance Benchmarks](docs/performance.md) +- [Model Testing (AI-Assisted)](docs/test_model.md) ## 📘 Usage in Python diff --git a/docs/test_model.md b/docs/test_model.md new file mode 100644 index 00000000..4fdc2881 --- /dev/null +++ b/docs/test_model.md @@ -0,0 +1,104 @@ +# Model Testing (AI-Assisted) + +vLLM.rs ships with a **Cursor Agent Skill** that automates the process of testing LLM models for correctness, output quality, and inference performance. It handles GPU detection, model loading, API testing, and result summarization. + +## Prerequisites + +- [Cursor IDE](https://cursor.sh/) with Agent mode enabled (for other Agents, mention the skill file manually) +- The vLLM.rs repository cloned locally +- One or more NVIDIA GPUs with sufficient memory for the target models +- Python 3 with the requests library installed + +## How It Works + +The skill lives at .cursor/skills/test-model/SKILL.md and is **automatically activated** when you ask the agent to test, benchmark, or validate models. It walks the agent through five phases: + +| Phase | What happens | +|-------|-------------| +| **0 - Gather models** | Collects model list from a local folder scan or user-provided HuggingFace IDs. Auto-detects supported architectures and quantization formats. | +| **1 - Estimate resources** | Queries nvidia-smi for available GPUs and free memory. Estimates per-model memory requirements and assigns GPUs accordingly. | +| **2 - Build** | Compiles the project with run.sh --features cuda,nccl,flashinfer,cutlass,graph --release (builds both vllm-rs and the runner binary). | +| **3 - Create test script** | Writes test_model.py, an OpenAI-compatible API test that sends 1k+ input tokens, requests 1k+ output tokens, and measures throughput and quality. | +| **4 - Test each model** | Iteratively starts the server, waits for readiness, runs the test script (with and without reasoning/thinking), records results, kills the server, and moves to the next model. | +| **5 - Summarize** | Produces a markdown table with per-model results: format, GPUs, throughput, and quality verdict. | + +## Quick Start + +Open the project in Cursor and ask the agent: + +``` +Test all models in /data/ +``` + +The agent will scan the directory, identify compatible models, and test each one automatically. + +You can also specify individual models: + +``` +Test models AxionML/Qwen3.5-2B-NVFP4, Qwen/Qwen3-4B +``` + +Or mix local paths and HuggingFace IDs: + +``` +Test /data/Qwen3.5-27B-FP8 and AxionML/Qwen3.5-9B-NVFP4 +``` + +## What Gets Tested + +For each model, the skill tests: + +| Aspect | Details | +|--------|---------| +| **Loading** | Server starts and model is accessible via /v1/models | +| **Inference (no thinking)** | 1k+ input tokens, 2k output tokens, thinking=false | +| **Inference (with thinking)** | Same prompt with thinking=true / reasoning enabled | +| **Output quality** | Coherence check, repetition detection (3-gram analysis) | +| **Performance** | End-to-end throughput in tokens/second | + +## Supported Model Formats + +| Format | Detection method | +|--------|-----------------| +| **BF16 / FP16** (safetensors) | No quantization_config in config.json | +| **FP8** (blockwise) | quant_method: fp8 | +| **MXFP4** | quant_method: mxfp4 or format contains mxfp4 | +| **NVFP4** (modelopt) | quant_method: modelopt with quant_algo: NVFP4 | +| **NVFP4** (compressed-tensors) | format contains nvfp4 | +| **GGUF** | .gguf file present in directory | +| **GPTQ / AWQ** | quant_method: gptq or awq | + +## GPU Assignment + +The agent automatically assigns GPUs based on available memory: + +- **Single GPU**: Models that fit in one GPU's free memory +- **Multi-GPU**: Large models that require 2+ GPUs (uses --d 0,1 etc.) +- **Skip**: Models that exceed all available GPU memory + +## Debugging Failures + +If a model fails to load or produce output, the skill guides the agent through: + +1. Checking server logs for panics or weight-loading errors +2. Enabling RUST_BACKTRACE=1 for stack traces +3. Temporarily using guard.step().unwrap() in engine.rs for crash debugging +4. Identifying common issues (MLX format incompatibility, quantization detection, OOM) + +## Example Output (Decoding performance) + +| # | Model | Format | GPUs | thinking=false | thinking=true | Quality | +|---|-------|--------|------|----------------|---------------|---------| +| 1 | Qwen3-4B-FP8 | FP8 | 1 | 1329 in / 2048 out, 168 tok/s | 1329 in / 2048 out, 169 tok/s | OK | +| 2 | Qwen3.5-27B | BF16 | 1 | 1342 in / 2048 out, 28 tok/s | 1342 in / 2048 out, 27 tok/s | OK | +| 3 | Qwen3.5-35B-A3B-GGUF | Q3_K_M | 1 | 1342 in / 2048 out, 95 tok/s | 1342 in / 2048 out, 98 tok/s | OK | + +## File Reference + +| File | Role | +|------|------| +| .cursor/skills/test-model/SKILL.md | The skill definition (read by the AI agent) | +| test_model.py | OpenAI API test script (created by the skill) | +| run.sh | Build script for both vllm-rs and runner binaries | +| src/core/engine.rs | Engine loop, guard.step() location for debug | +| src/models/layers/deltanet.rs | DeltaNet layer with per-weight quantization detection | diff --git a/src/core/runner.rs b/src/core/runner.rs index e3accdd8..113428cb 100644 --- a/src/core/runner.rs +++ b/src/core/runner.rs @@ -1103,6 +1103,32 @@ impl ModelRunner { Tensor::from_vec(batch_indices_vec, (batch_indices_len,), &self.device)?; let positions = Tensor::from_vec(positions_vec, (positions_len,), &self.device)?; + #[cfg(feature = "flashinfer")] + let cu_seqlens_q_host_u32: Vec = + cu_seqlens_q_vec.iter().map(|&x| x as u32).collect(); + + #[cfg(feature = "flashinfer")] + let prefill_plan_info = if let Some(params) = self.flashinfer_kv_params { + Some(attention_rs::flashinfer::prefill_plan( + &self.device, + &cu_seqlens_q_host_u32, + &indptr_host, + &kv_len_arr_host, + *cu_seqlens_q_vec.last().unwrap() as u32, + last_len_host.len(), + params.num_qo_heads, + params.num_kv_heads, + params.head_dim, + params.page_size, + params.out_dtype, + None, + )?) + } else { + None + }; + #[cfg(not(feature = "flashinfer"))] + let prefill_plan_info = None; + Some(FlashInferMetadata { indptr, indptr_host, @@ -1110,12 +1136,14 @@ impl ModelRunner { last_len, last_len_host: Some(last_len_host), kv_len_arr_host: Some(kv_len_arr_host), - cu_seqlens_q_host: Some(cu_seqlens_q_vec.iter().map(|&x| x as u32).collect()), total_num_rows: Some(*cu_seqlens_q_vec.last().unwrap() as u32), batch_indices: Some(batch_indices), positions: Some(positions), use_cuda_graph: false, decode_plan_info: None, + prefill_plan_info, + mla_decode_plan_info: None, + mla_prefill_plan_info: None, }) } else { None @@ -1127,6 +1155,7 @@ impl ModelRunner { let input_metadata = InputMetadata { is_prefill: true, + is_mla: false, sequence_ids, mamba_slot_mapping, slot_mapping, @@ -1250,12 +1279,14 @@ impl ModelRunner { last_len, last_len_host: Some(last_len_host), kv_len_arr_host: Some(kv_len_arr_host), - cu_seqlens_q_host: None, total_num_rows: None, batch_indices: None, positions: None, use_cuda_graph, decode_plan_info: None, + prefill_plan_info: None, + mla_decode_plan_info: None, + mla_prefill_plan_info: None, }) } else { None @@ -1271,6 +1302,7 @@ impl ModelRunner { let input_metadata = InputMetadata { is_prefill: false, + is_mla: false, sequence_ids, mamba_slot_mapping, slot_mapping, diff --git a/src/models/layers/deltanet.rs b/src/models/layers/deltanet.rs index 8a74859f..c2329fcd 100644 --- a/src/models/layers/deltanet.rs +++ b/src/models/layers/deltanet.rs @@ -64,6 +64,42 @@ pub struct GatedDeltaNet { } impl GatedDeltaNet { + /// Check if a weight at the given VarBuilder path actually carries quantized data. + /// Returns false when the weight is stored in its original dtype (BF16/F16/F32) + /// even though the model-level quantization config is set. + fn is_weight_quantized(vb: &VarBuilderX, quant_method: &str) -> bool { + if vb.is_qvar_builder() { + return false; + } + match quant_method { + "fp8" => vb.has_key("weight_scale") || vb.has_key("weight_scale_inv"), + "mxfp4" => vb.has_key("weight_packed") || vb.has_key("blocks"), + "nvfp4" => { + let has_packed = vb.has_key("weight_packed") || vb.has_key("blocks"); + let has_scale = vb.has_key("weight_scale") || vb.has_key("scales"); + let has_modelopt = vb.has_key("weight_scale_2") || vb.has_key("input_scale"); + (has_packed && has_scale) || (has_modelopt && has_scale) + } + _ => true, + } + } + + /// Resolve effective quantization config for a specific weight. + /// If the weight is not actually quantized, returns (None, None) so + /// the loader falls back to the standard unquantized path. + fn resolve_quant_for_weight( + vb: &VarBuilderX, + quantization_config: &Option, + quant: &Option, + ) -> (Option, Option) { + if let Some(cfg) = quantization_config { + if Self::is_weight_quantized(vb, &cfg.quant_method) { + return (quantization_config.clone(), quant.clone()); + } + } + (None, None) + } + fn load_projection( vb: &VarBuilderX, hidden_size: usize, @@ -93,30 +129,35 @@ impl GatedDeltaNet { // Qwen3Next format: fused qkvz + fused ba let projection_size_qkvz = key_dim_global * 2 + value_dim_global * 2; let projection_size_ba = num_v_heads_global * 2; + + let vb_qkvz = vb.pp("in_proj_qkvz"); + let (qc_qkvz, q_qkvz) = + Self::resolve_quant_for_weight(&vb_qkvz, &quantization_config, &quant); let fused_qkvz = TensorParallelColumnLinear::load_with_hints( hidden_size, projection_size_qkvz, false, - vb.pp("in_proj_qkvz"), + vb_qkvz, comm.clone(), - &quantization_config, - &quant, + &qc_qkvz, + &q_qkvz, dtype, ); + let vb_ba = vb.pp("in_proj_ba"); + let (qc_ba, q_ba) = Self::resolve_quant_for_weight(&vb_ba, &quantization_config, &quant); let fused_ba = TensorParallelColumnLinear::load_with_hints( hidden_size, projection_size_ba, false, - vb.pp("in_proj_ba"), + vb_ba, comm.clone(), - &quantization_config, - &quant, + &qc_ba, + &q_ba, dtype, ); if let (Ok(in_proj_qkvz), Ok(in_proj_ba)) = (fused_qkvz, fused_ba) { - // Qwen3 Next projection return Ok(GdnProjection::FusedQkvzBa { in_proj_qkvz, in_proj_ba, @@ -142,14 +183,16 @@ impl GatedDeltaNet { }, ) } else { + let vb_z = vb.pp(projection_key_map["in_proj_z"]); + let (qc_z, q_z) = Self::resolve_quant_for_weight(&vb_z, &quantization_config, &quant); TensorParallelColumnLinear::load_with_hints( hidden_size, value_dim_global, false, - vb.pp(projection_key_map["in_proj_z"]), + vb_z, comm.clone(), - &quantization_config, - &quant, + &qc_z, + &q_z, dtype, ) }; @@ -165,14 +208,16 @@ impl GatedDeltaNet { |w| undo_tiled_v_heads_first_dim(&w, num_k_heads_global, num_v_heads_global, 1), ) } else { + let vb_b = vb.pp(projection_key_map["in_proj_b"]); + let (qc_b, q_b) = Self::resolve_quant_for_weight(&vb_b, &quantization_config, &quant); TensorParallelColumnLinear::load_with_hints( hidden_size, num_v_heads_global, false, - vb.pp(projection_key_map["in_proj_b"]), + vb_b, comm.clone(), - &quantization_config, - &quant, + &qc_b, + &q_b, dtype, ) }; @@ -187,14 +232,16 @@ impl GatedDeltaNet { |w| undo_tiled_v_heads_first_dim(&w, num_k_heads_global, num_v_heads_global, 1), ) } else { + let vb_a = vb.pp(projection_key_map["in_proj_a"]); + let (qc_a, q_a) = Self::resolve_quant_for_weight(&vb_a, &quantization_config, &quant); TensorParallelColumnLinear::load_with_hints( hidden_size, num_v_heads_global, false, - vb.pp(projection_key_map["in_proj_a"]), + vb_a, comm.clone(), - &quantization_config, - &quant, + &qc_a, + &q_a, dtype, ) }; @@ -218,16 +265,19 @@ impl GatedDeltaNet { dtype, ) } else { + let vb_qkv = vb.pp(projection_key_map["in_proj_qkv"]); + let (qc_qkv, q_qkv) = + Self::resolve_quant_for_weight(&vb_qkv, &quantization_config, &quant); MergedParallelColumnLinear::load_merged_chunks( hidden_size, key_dim_global * 2 + value_dim_global, 0, vec![key_dim_global, key_dim_global, value_dim_global], None, - vb.pp(projection_key_map["in_proj_qkv"]), + vb_qkv, comm.clone(), - &quantization_config, - &quant, + &qc_qkv, + &q_qkv, dtype, ) }; @@ -242,7 +292,7 @@ impl GatedDeltaNet { }); } Err(err) => { - if is_quantized { + if is_quantized && !vb.is_qvar_builder() { candle_core::bail!( "Unable to load TP-safe quantized Qwen3.5 split in_proj_qkv: {}", err @@ -273,14 +323,17 @@ impl GatedDeltaNet { }, ) } else { + let vb_qkv = vb.pp(projection_key_map["in_proj_qkv"]); + let (qc_qkv, q_qkv) = + Self::resolve_quant_for_weight(&vb_qkv, &quantization_config, &quant); TensorParallelColumnLinear::load_with_hints( hidden_size, key_dim_global * 2 + value_dim_global, false, - vb.pp(projection_key_map["in_proj_qkv"]), + vb_qkv, comm.clone(), - &quantization_config, - &quant, + &qc_qkv, + &q_qkv, dtype, ) }; @@ -575,17 +628,19 @@ impl GatedDeltaNet { }, )? } else { + let vb_out = vb.pp(gdn_key_map["out_proj"]); + let (qc_out, q_out) = if is_quantized { + Self::resolve_quant_for_weight(&vb_out, &config.quantization_config, &config.quant) + } else { + (None, None) + }; TensorParallelRowLinear::load_with_hints( value_dim_global, hidden_size, - vb.pp(gdn_key_map["out_proj"]), + vb_out, comm.clone(), - if is_quantized { - &config.quantization_config - } else { - &None - }, - if is_quantized { &config.quant } else { &None }, + &qc_out, + &q_out, dtype, )? }; diff --git a/src/utils/graph.rs b/src/utils/graph.rs index 764578b7..53e1d3a0 100644 --- a/src/utils/graph.rs +++ b/src/utils/graph.rs @@ -499,12 +499,14 @@ impl GraphCapturer { last_len: flashinfer_last_len.narrow(0, 0, bs)?, last_len_host: Some(last_len_host[..bs].to_vec()), kv_len_arr_host, - cu_seqlens_q_host: None, total_num_rows: None, batch_indices: None, positions: None, use_cuda_graph: true, decode_plan_info, + prefill_plan_info: None, + mla_decode_plan_info: None, + mla_prefill_plan_info: None, }) }; #[cfg(not(feature = "flashinfer"))] @@ -512,6 +514,7 @@ impl GraphCapturer { let input_metadata = InputMetadata { is_prefill: false, + is_mla: false, sequence_ids: None, mamba_slot_mapping: Some(mamba_slot_mapping.narrow(0, 0, bs)?), slot_mapping: slot_mapping.narrow(0, 0, bs)?, From 89143798db687eedfe6a590f83b0765aaa928e6e Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Sun, 5 Apr 2026 12:20:54 +0000 Subject: [PATCH 13/14] Fix precision issue for mxfp4 models. --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index d52a97dd..fda4f241 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,7 +45,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.5.0", rev = "0c1e716" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.5.0", rev = "db97c3e" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31" From 474e912dfdaf0c7890077a56253baee64d20ee6a Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Mon, 6 Apr 2026 14:27:21 +0000 Subject: [PATCH 14/14] Update dependency --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fda4f241..bc7b4ae5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "vllm-rs" -version = "0.9.16" +version = "0.10.0" edition = "2021" default-run = "vllm-rs" description = "A minimal, high-performance large language model (LLM) inference engine implementing vLLM in Rust." @@ -45,7 +45,7 @@ ahash = "0.8.11" reedline = "0.40.0" pyo3 = { version = "0.25.1", features = ["extension-module", "abi3-py38"], optional = true } parking_lot = "0.12.4" -attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.5.0", rev = "db97c3e" } +attention-rs = { git = "https://github.com/guoqingbao/attention.rs.git", version="0.5.0", rev = "935891b" } once_cell = "1.21.3" tqdm = "0.8.0" futures = "0.3.31"