Skip to content

🚧 [llm][npu][quant] Add W4A4 MXFP4 quantization support for Qwen3 Dense on Ascend NPU#23795

Open
TallMessiWu wants to merge 35 commits into
sgl-project:mainfrom
TallMessiWu:junlin_qwen3_dense_w4a4
Open

🚧 [llm][npu][quant] Add W4A4 MXFP4 quantization support for Qwen3 Dense on Ascend NPU#23795
TallMessiWu wants to merge 35 commits into
sgl-project:mainfrom
TallMessiWu:junlin_qwen3_dense_w4a4

Conversation

@TallMessiWu
Copy link
Copy Markdown
Contributor

Summary

Dependency: This PR depends on #22352 (W8A8 MXFP8) and #23650 (W4A8 MXFP4) PR, and should be merged after both land. It builds on the same NPU quantization infrastructure (_NPULinearMethodBase, ModelSlimConfig dispatch, etc.).

This PR adds W4A4 single-level MXFP4 quantization support for Qwen3 dense LLM models on Ascend NPU. It continues the NPU quantization work tracked in issue #21584.

Two modes are supported:

Online quantization (--quantization mxfp4_w4a4_npu)

  • New NPUMxfp4W4A4Config (layers/quantization/npu_mxfp4_w4a4.py) dispatches to NPUSingleLevelMXFP4LinearMethod.
  • At load time, FP16/BF16 weights are quantised online to single-level MXFP4 via npu_dynamic_mx_quant(dst_type=float4_e2m1fn_x2, round_mode="round"): produces packed uint8 weights (shape [out, in//2]) and FP8_E8M0 per-block scales.
  • Weights and scales are pre-transposed in process_weights_after_loading (no .contiguous() — see Implementation Notes).
  • At inference, activations are dynamically quantised with the same single-level API and the matmul is executed by npu_quant_matmul with x1_dtype=x2_dtype=float4_e2m1fn_x2 and group_sizes=[1, 1, 32].

Offline quantization (msmodelslim pre-quantized weights, --quantization modelslim)

  • Adds ModelSlimMXFP4Scheme (modelslim/schemes/modelslim_mxfp4.py) for the W4A4_MXFP4 scheme type.
  • The msmodelslim checkpoint stores weights as float8_e4m3fn (one FP4 value per byte, shape [out, in]) and scales as uint8 E8M0 (shape [out, in/32]).
  • At load time: weights are cast from float8_e4m3fn to float4_e2m1fn_x2 (packing 2 FP4 values per byte, shape [out, in//2]), then transposed to [in//2, out]; scales are reshaped [out, in/32][out, in/64, 2] (3D, required by npu_quant_matmul) then transposed to [in/64, out, 2].
  • At inference, activations are dynamically quantised to MXFP4 via npu_dynamic_mx_quant and the matmul runs via npu_quant_matmul with group_sizes=[1,1,32].

Key NPU APIs used

API Purpose
torch_npu.npu_dynamic_mx_quant(x, dst_type=float4_e2m1fn_x2, round_mode="round") Single-level MXFP4 quantisation of weights (online) and activations (inference)
torch_npu.npu_dtype_cast(weight, float4_e2m1fn_x2) Cast fp8-container FP4 weights to packed float4_e2m1fn_x2 format (offline)
torch_npu.npu_quant_matmul(..., x1_dtype=float4_e2m1fn_x2, x2_dtype=float4_e2m1fn_x2, group_sizes=[1,1,32]) Single-level MXFP4 quantised matmul

Files Changed

New files

File Change
srt/layers/quantization/npu_mxfp4_w4a4.py NewNPUMxfp4W4A4Config for online W4A4 MXFP4 (--quantization mxfp4_w4a4_npu)
srt/layers/quantization/modelslim/schemes/modelslim_mxfp4.py New — offline ModelSlimMXFP4Scheme for W4A4_MXFP4 msmodelslim checkpoints

Modified

File Change
srt/hardware_backend/npu/quantization/linear_method_npu.py Add NPUSingleLevelMXFP4LinearMethod (online single-level MXFP4 weight quantisation + inference)
srt/layers/quantization/__init__.py Register NPUMxfp4W4A4Config under "mxfp4_w4a4_npu"
srt/layers/quantization/modelslim/modelslim.py Add W4A4_MXFP4 branch → ModelSlimMXFP4Scheme in _get_scheme_from_parts()
srt/layers/quantization/modelslim/schemes/__init__.py Export ModelSlimMXFP4Scheme

Implementation Notes

Offline checkpoint weight format

The msmodelslim W4A4_MXFP4 checkpoint stores weights as float8_e4m3fn (one FP4 value per byte, using the fp8 dtype as a container) rather than packed uint8. In process_weights_after_loading, npu_dtype_cast(..., float4_e2m1fn_x2) re-packs them into 2-per-byte format, halving the last dimension ([out, in][out, in//2]).

Weight scale must be 3D for npu_quant_matmul with float4_e2m1fn_x2

Unlike the MXFP8 path where x2Scale is 2D, npu_quant_matmul with x2_dtype=float4_e2m1fn_x2 requires x2Scale to be 3D. The scale is reshaped from [out, in/32] to [out, in/64, 2] (pairing consecutive E8M0 values) before transposing to [in/64, out, 2].

.contiguous() is not called after transpose

Consistent with the W8A8 and W4A8 PRs:

  • Online (NPUSingleLevelMXFP4LinearMethod): transpose is applied via .data = in-place assignment without .contiguous(). npu_quant_matmul reads strides directly; calling .contiguous() would physically reorder the quantized data and break block-scale mapping.
  • Offline (ModelSlimMXFP4Scheme): same approach — .data = assignment preserves the non-contiguous view.

Performance Comparison Report

TODO: Performance numbers are not yet available. This section will be filled in once benchmark runs on Ascend hardware are complete.

Metric Baseline (BF16) Offline (ModelSlim W4A4_MXFP4) Online (--quantization mxfp4_w4a4_npu)
E2E Latency TBD TBD TBD
Memory (NPU) TBD TBD TBD

Related Issues

Closes part of #21584 (MXFP8/MXFP4 support on Ascend NPU for Qwen3 Dense LLM).

Depends on #22352 (W8A8 MXFP8) and #23650 (W4A8 MXFP4) PR.

…th B)

Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU,
supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8
checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant +
npu_quant_matmul with group_sizes=[1,1,32]).
…n Ascend NPU

Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion
subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2
and other diffusion models on Ascend NPU. Also adds explicit
quantization field to diffusion ServerArgs so online quantization
can be specified without pre-quantized weights.
- Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call
- Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul)
- Simplify output shape restoration logic

Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4
按 reviewer 建议重构架构分层:
- 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights)
- 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用
- 改进架构分层,符合现有 NPU INT8 方法模式
Fix weight loading for msmodelslim pre-quantized MXFP8 weights:
- Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors)
- Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export)
- Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2]
- Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode)
- Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format
- Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape
- Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv)
- Improve flatten-to-2D comment to explain NPU kernel requirement
…rate PR

Revert LLM-side MXFP8 changes to split into a separate PR.
This branch now only contains Wan2.2 Diffusion MXFP8 changes.

Reverted files:
- fp8.py: removed MXFP8LinearAscendMethod class and NPU branch
- mxfp8_method_npu.py: deleted (NPU MXFP8 linear method)
- test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test)

LLM MXFP8 code preserved on junlin_llm branch.
Conflict resolution: upstream refactored transformer loader class methods into standalone functions in transformer_load_utils.py. Preserved server_args.quantization priority (mxfp8/mxfp4/modelslim) in _resolve_quant_config.
…t_config

Upstream refactored class methods to standalone functions in transformer_load_utils.py but dropped the server_args.quantization priority path. Re-add it so mxfp8/mxfp4/modelslim still work.
1. Add NPUMXFP8LinearMethod to linear_method_npu.py (online quant)\n2. Add NPU dispatch branch in Fp8Config.get_quant_method\n3. Fix get_min_capability to return 0 on NPU\n4. Add ModelSlimMXFP8Scheme for offline pre-quantized MXFP8 weights\n5. Register W8A8_MXFP8 scheme in modelslim dispatcher
…nit__.py

Move ModelSlimLinearScheme import before ModelSlimMXFP8Scheme to resolve circular
dependency. modelslim_mxfp8.py imports ModelSlimLinearScheme from the schemes
package, which failed when __init__.py hadn't yet defined it.

The fix ensures base class is available in module namespace before any subclass
imports attempt to reference it.

Issue: ImportError: cannot import name 'ModelSlimLinearScheme' from partially
initialized module 'sglang.srt.layers.quantization.modelslim.schemes'
… LLM

Implements NPUMXFP4W4A8LinearMethod for Ascend NPU online quantization

of dense Qwen3/3.5 LLM models triggered via --quantization mxfp4_npu.

Weight flow (process_weights_after_loading):

BF16/FP16 → npu_dynamic_dual_level_mx_quant → FP4 (NZ format)

l0_scale [in/512, out] (FP32) + l1_scale (FP8_E8M0)
Inference flow (apply):

activation → npu_dynamic_dual_level_mx_quant → FP4

→ npu_dual_level_quant_matmul (W4A4 compute with dual-level scales)

Config registered as 'mxfp4_npu' in QUANTIZATION_METHODS.

Hardware: requires Ascend 950 (DualLevelQuantBatchMatmul not on A2/A3).
1. Add ModelSlimW4A8Int8 + NPUW4A8DynamicLinearMethod for W4A8_DYNAMIC dispatch
2. Add hardware warning + try/except guard in NPUMXFP4W4A8LinearMethod (Ascend 950 required)
3. Add MoE unquantized fallback warning in NPUMxfp4Config
Add `--quantization mxfp4w4a4_npu` support for Qwen3/3.5 dense models on
Ascend NPU. Uses single-level npu_dynamic_mx_quant(float4_e2m1fn) for both
weights and activations, analogous to the MXFP8 path but with FP4 dtype.

- linear_method_npu.py: add NPUSingleLevelMXFP4LinearMethod (W4A4 single-level)
- npu_mxfp4_w4a4.py: new NPUMxfp4W4A4Config registered as mxfp4w4a4_npu
- __init__.py: register mxfp4w4a4_npu in BASE_QUANTIZATION_METHODS
Issue 2: delete erroneous W4A8_MXFP → ModelSlimMXFP8Scheme dispatch in _get_scheme_from_parts; unsupported quant_type now raises NotImplementedError as intended.
Issue 4 (w4a4 side): add logger.warning before UnquantizedFusedMoEMethod fallback in NPUMxfp4W4A4Config.
ModelSlimMXFP4Scheme: loads msmodelslim pre-quantized FP4 weights (fp8 container), casts to float4_e2m1fn_x2, runs single-level MXFP4 matmul via npu_quant_matmul(group_sizes=[1,1,32]). Mirrors NPUSingleLevelMXFP4LinearMethod but offline.
1. Add W4A8_MXFP offline scheme (ModelSlimMXFP4W4A8Scheme)
2. Rename mxfp4_npu -> mxfp4_w4a8_npu (online W4A8 key)
3. Rename mxfp4w4a4_npu -> mxfp4_w4a4_npu (align naming style)
1. Use float4_e2m1fn_x2 (not float4_e2m1fn which doesn't exist) as dst_type
2. Add round_mode="round" to npu_dynamic_mx_quant calls
3. Pass x1_dtype/x2_dtype=float4_e2m1fn_x2 to npu_quant_matmul
4. Fix transpose: use .data= assignment without .contiguous() to preserve block-scale mapping
1. Replace float4_e2m1fn (non-existent) with float4_e2m1fn_x2 as dst_type
2. Add round_mode="round" to npu_dynamic_mx_quant
3. Pass x1_dtype/x2_dtype=float4_e2m1fn_x2 to npu_quant_matmul
npu_quant_matmul with float4_e2m1fn_x2 requires x2Scale (weight_scale) to
be 3D. Reshape [out, in/32] -> [out, in/64, 2] then transpose to [in/64, out, 2].
Matches vllm-ascend w4a4_mxfp4 process_weights_after_loading pattern.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion npu quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant