🚧 [llm][npu][quant] Add W4A4 MXFP4 quantization support for Qwen3 Dense on Ascend NPU#23795
Open
TallMessiWu wants to merge 35 commits into
Open
🚧 [llm][npu][quant] Add W4A4 MXFP4 quantization support for Qwen3 Dense on Ascend NPU#23795TallMessiWu wants to merge 35 commits into
TallMessiWu wants to merge 35 commits into
Conversation
…th B) Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU, supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8 checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant + npu_quant_matmul with group_sizes=[1,1,32]).
…n Ascend NPU Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2 and other diffusion models on Ascend NPU. Also adds explicit quantization field to diffusion ServerArgs so online quantization can be specified without pre-quantized weights.
- Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call - Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul) - Simplify output shape restoration logic Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4
按 reviewer 建议重构架构分层: - 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights) - 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用 - 改进架构分层,符合现有 NPU INT8 方法模式
Fix weight loading for msmodelslim pre-quantized MXFP8 weights: - Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors) - Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export) - Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2]
…weight processing.
- Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode) - Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format - Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape - Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv) - Improve flatten-to-2D comment to explain NPU kernel requirement
…rate PR Revert LLM-side MXFP8 changes to split into a separate PR. This branch now only contains Wan2.2 Diffusion MXFP8 changes. Reverted files: - fp8.py: removed MXFP8LinearAscendMethod class and NPU branch - mxfp8_method_npu.py: deleted (NPU MXFP8 linear method) - test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test) LLM MXFP8 code preserved on junlin_llm branch.
Conflict resolution: upstream refactored transformer loader class methods into standalone functions in transformer_load_utils.py. Preserved server_args.quantization priority (mxfp8/mxfp4/modelslim) in _resolve_quant_config.
…t_config Upstream refactored class methods to standalone functions in transformer_load_utils.py but dropped the server_args.quantization priority path. Re-add it so mxfp8/mxfp4/modelslim still work.
1. Add NPUMXFP8LinearMethod to linear_method_npu.py (online quant)\n2. Add NPU dispatch branch in Fp8Config.get_quant_method\n3. Fix get_min_capability to return 0 on NPU\n4. Add ModelSlimMXFP8Scheme for offline pre-quantized MXFP8 weights\n5. Register W8A8_MXFP8 scheme in modelslim dispatcher
…nit__.py Move ModelSlimLinearScheme import before ModelSlimMXFP8Scheme to resolve circular dependency. modelslim_mxfp8.py imports ModelSlimLinearScheme from the schemes package, which failed when __init__.py hadn't yet defined it. The fix ensures base class is available in module namespace before any subclass imports attempt to reference it. Issue: ImportError: cannot import name 'ModelSlimLinearScheme' from partially initialized module 'sglang.srt.layers.quantization.modelslim.schemes'
… LLM Implements NPUMXFP4W4A8LinearMethod for Ascend NPU online quantization of dense Qwen3/3.5 LLM models triggered via --quantization mxfp4_npu. Weight flow (process_weights_after_loading): BF16/FP16 → npu_dynamic_dual_level_mx_quant → FP4 (NZ format) l0_scale [in/512, out] (FP32) + l1_scale (FP8_E8M0) Inference flow (apply): activation → npu_dynamic_dual_level_mx_quant → FP4 → npu_dual_level_quant_matmul (W4A4 compute with dual-level scales) Config registered as 'mxfp4_npu' in QUANTIZATION_METHODS. Hardware: requires Ascend 950 (DualLevelQuantBatchMatmul not on A2/A3).
1. Add ModelSlimW4A8Int8 + NPUW4A8DynamicLinearMethod for W4A8_DYNAMIC dispatch 2. Add hardware warning + try/except guard in NPUMXFP4W4A8LinearMethod (Ascend 950 required) 3. Add MoE unquantized fallback warning in NPUMxfp4Config
Add `--quantization mxfp4w4a4_npu` support for Qwen3/3.5 dense models on Ascend NPU. Uses single-level npu_dynamic_mx_quant(float4_e2m1fn) for both weights and activations, analogous to the MXFP8 path but with FP4 dtype. - linear_method_npu.py: add NPUSingleLevelMXFP4LinearMethod (W4A4 single-level) - npu_mxfp4_w4a4.py: new NPUMxfp4W4A4Config registered as mxfp4w4a4_npu - __init__.py: register mxfp4w4a4_npu in BASE_QUANTIZATION_METHODS
Issue 2: delete erroneous W4A8_MXFP → ModelSlimMXFP8Scheme dispatch in _get_scheme_from_parts; unsupported quant_type now raises NotImplementedError as intended. Issue 4 (w4a4 side): add logger.warning before UnquantizedFusedMoEMethod fallback in NPUMxfp4W4A4Config.
ModelSlimMXFP4Scheme: loads msmodelslim pre-quantized FP4 weights (fp8 container), casts to float4_e2m1fn_x2, runs single-level MXFP4 matmul via npu_quant_matmul(group_sizes=[1,1,32]). Mirrors NPUSingleLevelMXFP4LinearMethod but offline.
1. Add W4A8_MXFP offline scheme (ModelSlimMXFP4W4A8Scheme) 2. Rename mxfp4_npu -> mxfp4_w4a8_npu (online W4A8 key) 3. Rename mxfp4w4a4_npu -> mxfp4_w4a4_npu (align naming style)
1. Use float4_e2m1fn_x2 (not float4_e2m1fn which doesn't exist) as dst_type 2. Add round_mode="round" to npu_dynamic_mx_quant calls 3. Pass x1_dtype/x2_dtype=float4_e2m1fn_x2 to npu_quant_matmul 4. Fix transpose: use .data= assignment without .contiguous() to preserve block-scale mapping
1. Replace float4_e2m1fn (non-existent) with float4_e2m1fn_x2 as dst_type 2. Add round_mode="round" to npu_dynamic_mx_quant 3. Pass x1_dtype/x2_dtype=float4_e2m1fn_x2 to npu_quant_matmul
npu_quant_matmul with float4_e2m1fn_x2 requires x2Scale (weight_scale) to be 3D. Reshape [out, in/32] -> [out, in/64, 2] then transpose to [in/64, out, 2]. Matches vllm-ascend w4a4_mxfp4 process_weights_after_loading pattern.
Contributor
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
2 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds W4A4 single-level MXFP4 quantization support for Qwen3 dense LLM models on Ascend NPU. It continues the NPU quantization work tracked in issue #21584.
Two modes are supported:
Online quantization (
--quantization mxfp4_w4a4_npu)NPUMxfp4W4A4Config(layers/quantization/npu_mxfp4_w4a4.py) dispatches toNPUSingleLevelMXFP4LinearMethod.npu_dynamic_mx_quant(dst_type=float4_e2m1fn_x2, round_mode="round"): produces packeduint8weights (shape[out, in//2]) and FP8_E8M0 per-block scales.process_weights_after_loading(no.contiguous()— see Implementation Notes).npu_quant_matmulwithx1_dtype=x2_dtype=float4_e2m1fn_x2andgroup_sizes=[1, 1, 32].Offline quantization (msmodelslim pre-quantized weights,
--quantization modelslim)ModelSlimMXFP4Scheme(modelslim/schemes/modelslim_mxfp4.py) for theW4A4_MXFP4scheme type.float8_e4m3fn(one FP4 value per byte, shape[out, in]) and scales asuint8E8M0 (shape[out, in/32]).float8_e4m3fntofloat4_e2m1fn_x2(packing 2 FP4 values per byte, shape[out, in//2]), then transposed to[in//2, out]; scales are reshaped[out, in/32]→[out, in/64, 2](3D, required bynpu_quant_matmul) then transposed to[in/64, out, 2].npu_dynamic_mx_quantand the matmul runs vianpu_quant_matmulwithgroup_sizes=[1,1,32].Key NPU APIs used
torch_npu.npu_dynamic_mx_quant(x, dst_type=float4_e2m1fn_x2, round_mode="round")torch_npu.npu_dtype_cast(weight, float4_e2m1fn_x2)float4_e2m1fn_x2format (offline)torch_npu.npu_quant_matmul(..., x1_dtype=float4_e2m1fn_x2, x2_dtype=float4_e2m1fn_x2, group_sizes=[1,1,32])Files Changed
New files
srt/layers/quantization/npu_mxfp4_w4a4.pyNPUMxfp4W4A4Configfor online W4A4 MXFP4 (--quantization mxfp4_w4a4_npu)srt/layers/quantization/modelslim/schemes/modelslim_mxfp4.pyModelSlimMXFP4SchemeforW4A4_MXFP4msmodelslim checkpointsModified
srt/hardware_backend/npu/quantization/linear_method_npu.pyNPUSingleLevelMXFP4LinearMethod(online single-level MXFP4 weight quantisation + inference)srt/layers/quantization/__init__.pyNPUMxfp4W4A4Configunder"mxfp4_w4a4_npu"srt/layers/quantization/modelslim/modelslim.pyW4A4_MXFP4branch →ModelSlimMXFP4Schemein_get_scheme_from_parts()srt/layers/quantization/modelslim/schemes/__init__.pyModelSlimMXFP4SchemeImplementation Notes
Offline checkpoint weight format
The msmodelslim
W4A4_MXFP4checkpoint stores weights asfloat8_e4m3fn(one FP4 value per byte, using the fp8 dtype as a container) rather than packeduint8. Inprocess_weights_after_loading,npu_dtype_cast(..., float4_e2m1fn_x2)re-packs them into 2-per-byte format, halving the last dimension ([out, in]→[out, in//2]).Weight scale must be 3D for
npu_quant_matmulwithfloat4_e2m1fn_x2Unlike the MXFP8 path where
x2Scaleis 2D,npu_quant_matmulwithx2_dtype=float4_e2m1fn_x2requiresx2Scaleto be 3D. The scale is reshaped from[out, in/32]to[out, in/64, 2](pairing consecutive E8M0 values) before transposing to[in/64, out, 2]..contiguous()is not called after transposeConsistent with the W8A8 and W4A8 PRs:
NPUSingleLevelMXFP4LinearMethod): transpose is applied via.data =in-place assignment without.contiguous().npu_quant_matmulreads strides directly; calling.contiguous()would physically reorder the quantized data and break block-scale mapping.ModelSlimMXFP4Scheme): same approach —.data =assignment preserves the non-contiguous view.Performance Comparison Report
Related Issues
Closes part of #21584 (MXFP8/MXFP4 support on Ascend NPU for Qwen3 Dense LLM).
Depends on #22352 (W8A8 MXFP8) and #23650 (W4A8 MXFP4) PR.