🚧 [llm][npu][quant] Add W4A8 MXFP quantization support for Qwen3 Dense on Ascend NPU#23650
🚧 [llm][npu][quant] Add W4A8 MXFP quantization support for Qwen3 Dense on Ascend NPU#23650TallMessiWu wants to merge 28 commits into
Conversation
…th B) Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU, supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8 checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant + npu_quant_matmul with group_sizes=[1,1,32]).
…n Ascend NPU Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2 and other diffusion models on Ascend NPU. Also adds explicit quantization field to diffusion ServerArgs so online quantization can be specified without pre-quantized weights.
- Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call - Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul) - Simplify output shape restoration logic Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4
按 reviewer 建议重构架构分层: - 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights) - 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用 - 改进架构分层,符合现有 NPU INT8 方法模式
Fix weight loading for msmodelslim pre-quantized MXFP8 weights: - Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors) - Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export) - Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2]
…weight processing.
- Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode) - Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format - Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape - Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv) - Improve flatten-to-2D comment to explain NPU kernel requirement
…rate PR Revert LLM-side MXFP8 changes to split into a separate PR. This branch now only contains Wan2.2 Diffusion MXFP8 changes. Reverted files: - fp8.py: removed MXFP8LinearAscendMethod class and NPU branch - mxfp8_method_npu.py: deleted (NPU MXFP8 linear method) - test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test) LLM MXFP8 code preserved on junlin_llm branch.
Conflict resolution: upstream refactored transformer loader class methods into standalone functions in transformer_load_utils.py. Preserved server_args.quantization priority (mxfp8/mxfp4/modelslim) in _resolve_quant_config.
…t_config Upstream refactored class methods to standalone functions in transformer_load_utils.py but dropped the server_args.quantization priority path. Re-add it so mxfp8/mxfp4/modelslim still work.
1. Add NPUMXFP8LinearMethod to linear_method_npu.py (online quant)\n2. Add NPU dispatch branch in Fp8Config.get_quant_method\n3. Fix get_min_capability to return 0 on NPU\n4. Add ModelSlimMXFP8Scheme for offline pre-quantized MXFP8 weights\n5. Register W8A8_MXFP8 scheme in modelslim dispatcher
…nit__.py Move ModelSlimLinearScheme import before ModelSlimMXFP8Scheme to resolve circular dependency. modelslim_mxfp8.py imports ModelSlimLinearScheme from the schemes package, which failed when __init__.py hadn't yet defined it. The fix ensures base class is available in module namespace before any subclass imports attempt to reference it. Issue: ImportError: cannot import name 'ModelSlimLinearScheme' from partially initialized module 'sglang.srt.layers.quantization.modelslim.schemes'
… LLM Implements NPUMXFP4W4A8LinearMethod for Ascend NPU online quantization of dense Qwen3/3.5 LLM models triggered via --quantization mxfp4_npu. Weight flow (process_weights_after_loading): BF16/FP16 → npu_dynamic_dual_level_mx_quant → FP4 (NZ format) l0_scale [in/512, out] (FP32) + l1_scale (FP8_E8M0) Inference flow (apply): activation → npu_dynamic_dual_level_mx_quant → FP4 → npu_dual_level_quant_matmul (W4A4 compute with dual-level scales) Config registered as 'mxfp4_npu' in QUANTIZATION_METHODS. Hardware: requires Ascend 950 (DualLevelQuantBatchMatmul not on A2/A3).
1. Add ModelSlimW4A8Int8 + NPUW4A8DynamicLinearMethod for W4A8_DYNAMIC dispatch 2. Add hardware warning + try/except guard in NPUMXFP4W4A8LinearMethod (Ascend 950 required) 3. Add MoE unquantized fallback warning in NPUMxfp4Config
There was a problem hiding this comment.
Code Review
This pull request introduces support for MXFP8 and MXFP4 quantization on Ascend NPU, covering both diffusion models and LLMs. It adds specialized quantization configurations, linear methods utilizing NPU-specific kernels, and ModelSlim schemes for pre-quantized weight inference. Additionally, the wan_repack.py tool is refactored to support Wan2.2, and a new --quantization CLI argument allows for explicit quantization overrides. Feedback focuses on optimizing inference performance by pre-transposing weights and scales during model loading and ensuring environment robustness with fallback mechanisms for NPU-specific dtypes.
| ) | ||
| from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme | ||
|
|
||
| MXFP8_BLOCK_SIZE = 32 |
There was a problem hiding this comment.
For robustness across different versions of torch_npu and torch, it is better to use a fallback mechanism for the float8_e8m0fnu dtype, similar to the implementation in the SRT backend.
MXFP8_BLOCK_SIZE = 32
_FLOAT8_E8M0FNU_DTYPE = getattr(
torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)
)| def process_weights_after_loading(self, layer: torch.nn.Module): | ||
| # weight is already float8_e4m3fn, no cast needed | ||
| weight = layer.weight.data | ||
| layer.weight = torch.nn.Parameter(weight, requires_grad=False) | ||
|
|
||
| # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2] | ||
| weight_scale = layer.weight_scale.data | ||
| weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2) | ||
| layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) |
There was a problem hiding this comment.
Pre-transposing the weight and scale tensors during model loading avoids the overhead of performing transposes on every forward pass. This optimization is already present in the SRT implementation of this scheme.
def process_weights_after_loading(self, layer: torch.nn.Module):
# weight is already float8_e4m3fn, no cast needed
weight = layer.weight.data
# Pre-transpose weight and scale to [in, out] for npu_quant_matmul.
# Use .data assignment without .contiguous() to preserve the transpose
# view strides — npu_quant_matmul reads strides correctly.
layer.weight = torch.nn.Parameter(weight.transpose(0, 1), requires_grad=False)
# Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2]
weight_scale = layer.weight_scale.data
weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2)
layer.weight_scale = torch.nn.Parameter(weight_scale.transpose(0, 1), requires_grad=False)| output = torch_npu.npu_quant_matmul( | ||
| qx, | ||
| layer.weight.transpose(0, 1), | ||
| layer.weight_scale.transpose(0, 1), | ||
| scale_dtype=torch_npu.float8_e8m0fnu, | ||
| pertoken_scale=input_scale, | ||
| pertoken_scale_dtype=torch_npu.float8_e8m0fnu, | ||
| bias=bias.to(torch.float32) if bias is not None else None, | ||
| output_dtype=original_dtype, | ||
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | ||
| ) |
There was a problem hiding this comment.
Use the pre-transposed weights and the robust dtype fallback in the matmul call.
| output = torch_npu.npu_quant_matmul( | |
| qx, | |
| layer.weight.transpose(0, 1), | |
| layer.weight_scale.transpose(0, 1), | |
| scale_dtype=torch_npu.float8_e8m0fnu, | |
| pertoken_scale=input_scale, | |
| pertoken_scale_dtype=torch_npu.float8_e8m0fnu, | |
| bias=bias.to(torch.float32) if bias is not None else None, | |
| output_dtype=original_dtype, | |
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | |
| ) | |
| # MXFP8 matmul (weight & scale already transposed at load time) | |
| output = torch_npu.npu_quant_matmul( | |
| qx, | |
| layer.weight, | |
| layer.weight_scale, | |
| scale_dtype=_FLOAT8_E8M0FNU_DTYPE, | |
| pertoken_scale=input_scale, | |
| pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, | |
| bias=bias.to(torch.float32) if bias is not None else None, | |
| output_dtype=original_dtype, | |
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | |
| ) |
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| MXFP8_BLOCK_SIZE = 32 |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
|
|
||
| weight_fp = layer.weight.data | ||
| if weight_fp.dtype not in (torch.float16, torch.bfloat16): | ||
| weight_fp = weight_fp.to(torch.bfloat16) | ||
|
|
||
| # Move weight to NPU if needed. We intentionally use a conditional | ||
| # move rather than an assert because `dit_cpu_offload` defaults to | ||
| # True in ServerArgs, which causes fsdp_load to move every parameter | ||
| # back to CPU after loading (even when the target device is NPU). | ||
| # npu_dynamic_mx_quant requires an NPU tensor, so we must transfer | ||
| # here. The quantized fp8 weights produced below will remain on NPU | ||
| # for inference; if the model still needs to be offloaded after | ||
| # quantization (e.g. very large model on a small NPU), a higher-level | ||
| # offload pass can move them back afterwards. | ||
| if not weight_fp.is_npu: | ||
| weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") | ||
|
|
||
| # Online MXFP8 quantisation of weights (block_size=32) | ||
| qw, w_scale = torch_npu.npu_dynamic_mx_quant( | ||
| weight_fp, dst_type=torch_npu.float8_e4m3fn | ||
| ) | ||
| layer.weight = Parameter(qw, requires_grad=False) | ||
| layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) |
There was a problem hiding this comment.
Pre-transpose the weights and scales during model loading to improve inference performance. Since this is online quantization, using .contiguous() after transpose is safe and recommended for NPU matmul efficiency.
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| weight_fp = layer.weight.data | |
| if weight_fp.dtype not in (torch.float16, torch.bfloat16): | |
| weight_fp = weight_fp.to(torch.bfloat16) | |
| # Move weight to NPU if needed. We intentionally use a conditional | |
| # move rather than an assert because `dit_cpu_offload` defaults to | |
| # True in ServerArgs, which causes fsdp_load to move every parameter | |
| # back to CPU after loading (even when the target device is NPU). | |
| # npu_dynamic_mx_quant requires an NPU tensor, so we must transfer | |
| # here. The quantized fp8 weights produced below will remain on NPU | |
| # for inference; if the model still needs to be offloaded after | |
| # quantization (e.g. very large model on a small NPU), a higher-level | |
| # offload pass can move them back afterwards. | |
| if not weight_fp.is_npu: | |
| weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") | |
| # Online MXFP8 quantisation of weights (block_size=32) | |
| qw, w_scale = torch_npu.npu_dynamic_mx_quant( | |
| weight_fp, dst_type=torch_npu.float8_e4m3fn | |
| ) | |
| layer.weight = Parameter(qw, requires_grad=False) | |
| layer.weight_scale_inv = Parameter(w_scale, requires_grad=False) | |
| def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | |
| weight_fp = layer.weight.data | |
| if weight_fp.dtype not in (torch.float16, torch.bfloat16): | |
| weight_fp = weight_fp.to(torch.bfloat16) | |
| # Move weight to NPU if needed. | |
| if not weight_fp.is_npu: | |
| weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}") | |
| # Online MXFP8 quantisation of weights (block_size=32) | |
| qw, w_scale = torch_npu.npu_dynamic_mx_quant( | |
| weight_fp, dst_type=torch_npu.float8_e4m3fn | |
| ) | |
| # Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose) | |
| layer.weight = Parameter(qw.transpose(0, 1).contiguous(), requires_grad=False) | |
| layer.weight_scale_inv = Parameter(w_scale.transpose(0, 1).contiguous(), requires_grad=False) |
| output = torch_npu.npu_quant_matmul( | ||
| qx, | ||
| layer.weight.transpose(0, 1), | ||
| layer.weight_scale_inv.transpose(0, 1), | ||
| scale_dtype=torch_npu.float8_e8m0fnu, | ||
| pertoken_scale=input_scale, | ||
| pertoken_scale_dtype=torch_npu.float8_e8m0fnu, | ||
| bias=bias.to(torch.float32) if bias is not None else None, | ||
| output_dtype=original_dtype, | ||
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | ||
| ) |
There was a problem hiding this comment.
Update the matmul call to use the pre-transposed parameters and the robust dtype fallback.
| output = torch_npu.npu_quant_matmul( | |
| qx, | |
| layer.weight.transpose(0, 1), | |
| layer.weight_scale_inv.transpose(0, 1), | |
| scale_dtype=torch_npu.float8_e8m0fnu, | |
| pertoken_scale=input_scale, | |
| pertoken_scale_dtype=torch_npu.float8_e8m0fnu, | |
| bias=bias.to(torch.float32) if bias is not None else None, | |
| output_dtype=original_dtype, | |
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | |
| ) | |
| # MXFP8 matmul (weight & scale already transposed at load time) | |
| output = torch_npu.npu_quant_matmul( | |
| qx, | |
| layer.weight, | |
| layer.weight_scale_inv, | |
| scale_dtype=_FLOAT8_E8M0FNU_DTYPE, | |
| pertoken_scale=input_scale, | |
| pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE, | |
| bias=bias.to(torch.float32) if bias is not None else None, | |
| output_dtype=original_dtype, | |
| group_sizes=[1, 1, MXFP8_BLOCK_SIZE], | |
| ) |
Summary
This PR adds W4A8 MXFP quantization support for Qwen3 dense LLM models on Ascend NPU. It continues the NPU quantization work tracked in issue #21584.
Hardware requirement: Ascend 950 (Atlas A5)
Two modes are supported:
Online quantization (
--quantization mxfp4_w4a8_npu)NPUMxfp4Config(layers/quantization/npu_mxfp4.py) dispatches toNPUMXFP4W4A8LinearMethod.npu_dynamic_dual_level_mx_quant: producesfloat4_e2m1fn_x2weights, L0 per-block scale (FP32), and L1 per-channel scale (FP8_E8M0). The weight is then cast to FRACTAL_NZ format for NPU matmul efficiency.npu_dual_level_quant_matmul.Offline quantization (msmodelslim pre-quantized weights,
--quantization modelslim)ModelSlimMXFP4W4A8Scheme(modelslim/schemes/modelslim_mxfp4_w4a8.py) for theW4A8_MXFPscheme type.W4A8_MXFPstores weights asfloat8_e4m3fn(shape[out, in]) with auint8FP8_E8M0-biased scale (shape[out, in/32]). At load time, scale is reshaped to[out, in/64, 2]and both weight and scale are transposed (no.contiguous()— see Implementation Notes).npu_dynamic_mx_quantand the matmul runs vianpu_quant_matmulwithgroup_sizes=[1,1,32]— identical to the MXFP8 offline path.W4A8_DYNAMICdispatch →ModelSlimW4A8Int8(INT4 offline scheme for non-MXFP checkpoints).Key NPU APIs used
torch_npu.npu_dynamic_dual_level_mx_quant(x)torch_npu.npu_dual_level_quant_matmul(...)torch_npu.npu_format_cast(w.view(torch.int8), 29)torch_npu.npu_dynamic_mx_quant(x, dst_type=float8_e4m3fn)torch_npu.npu_quant_matmul(..., group_sizes=[1,1,32])Files Changed
New files
srt/layers/quantization/npu_mxfp4.pyNPUMxfp4Configfor online W4A8 MXFP4 (--quantization mxfp4_w4a8_npu)srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.pyModelSlimMXFP4W4A8SchemeforW4A8_MXFPmsmodelslim checkpointsModified — online W4A8 NPU dispatch
srt/hardware_backend/npu/quantization/linear_method_npu.pyNPUMXFP4W4A8LinearMethod(online dual-level MXFP4 weight quantisation + inference) andNPUW4A8DynamicLinearMethod(offline INT4 inference vianpu_weight_quant_batchmatmul)srt/layers/quantization/__init__.pyNPUMxfp4Configunder"mxfp4_w4a8_npu"srt/server_args.py"mxfp4_w4a8_npu"and"mxfp4_w4a4_npu"toQUANTIZATION_CHOICESModified — offline W4A8 registration & dispatch
srt/layers/quantization/modelslim/modelslim.pyW4A8_MXFPbranch →ModelSlimMXFP4W4A8SchemeandW4A8_DYNAMICbranch →ModelSlimW4A8Int8in_get_scheme_from_parts()srt/layers/quantization/modelslim/schemes/__init__.pyModelSlimMXFP4W4A8SchemeandModelSlimW4A8Int8Implementation Notes
W4A8_MXFP offline: weight format matches MXFP8
The current msmodelslim
W4A8_MXFPcheckpoint stores weights asfloat8_e4m3fn(identical layout toW8A8_MXFP8), not as packed FP4 uint8. This corresponds to an older msmodelslim export version. Consequently,ModelSlimMXFP4W4A8Schemeis structurally identical toModelSlimMXFP8Scheme— the distinction lies in the quantisation process, not the inference path. A future checkpoint format change (packed FP4) would require a separate scheme.Online W4A8: dual-level scale layout requirements
npu_dual_level_quant_matmulrequires:npu_format_cast(w.view(torch.int8), 29)since only int-dtype tensors are accepted.[in/512, out]: loaded as[out, in/512, 1], squeezed and transposed with.contiguous(). The.contiguous()here is safe — the scale is freshly allocated, unlike the offline non-contiguous transpose pattern..contiguous()asymmetry between online and offline pathsConsistent with the W8A8 PR (#22352):
NPUMXFP4W4A8LinearMethod): uses.contiguous()after transpose — safe because weights/scales are freshly allocated fromnpu_dynamic_dual_level_mx_quant.ModelSlimMXFP4W4A8Scheme): does not call.contiguous(), using.dataassignment to preserve the non-contiguous transpose view.npu_quant_matmulreads strides correctly;.contiguous()would physically reorder pre-quantized data and break block-scale mapping.Performance Comparison Report
Related Issues
Closes part of #21584 (MXFP8/MXFP4 support on Ascend NPU for Qwen3 Dense LLM).
Depends on #22352 (W8A8 MXFP8 for Qwen3 Dense on Ascend NPU).