✨ [diffusion][npu][quant] Add MXFP4 quantization support for Wan2.2 Diffusion on Ascend NPU#22338
Conversation
…th B) Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU, supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8 checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant + npu_quant_matmul with group_sizes=[1,1,32]).
…n Ascend NPU Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2 and other diffusion models on Ascend NPU. Also adds explicit quantization field to diffusion ServerArgs so online quantization can be specified without pre-quantized weights.
- Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call - Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul) - Simplify output shape restoration logic Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4
按 reviewer 建议重构架构分层: - 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights) - 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用 - 改进架构分层,符合现有 NPU INT8 方法模式
Fix weight loading for msmodelslim pre-quantized MXFP8 weights: - Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors) - Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export) - Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2]
…weight processing.
- Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode) - Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format - Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape - Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv) - Improve flatten-to-2D comment to explain NPU kernel requirement
…rate PR Revert LLM-side MXFP8 changes to split into a separate PR. This branch now only contains Wan2.2 Diffusion MXFP8 changes. Reverted files: - fp8.py: removed MXFP8LinearAscendMethod class and NPU branch - mxfp8_method_npu.py: deleted (NPU MXFP8 linear method) - test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test) LLM MXFP8 code preserved on junlin_llm branch.
…ding - Add ModelSlimMXFP4Scheme for loading msmodelslim pre-quantized MXFP4 weights - Support dual-level quantization via npu_dual_level_quant_matmul - Register W4A4_MXFP4 quant type in modelslim.py dispatcher - Handle FP4 packed weight casting and scale transformations Weights: float8_e4m3fn (FP4 packed) [out, in/2] Scales: uint8 (e8m0+127) [out, in/32] + bfloat16 dual [out, in/64]
1. Dispatch W4A4_MXFP4_DUALSCALE type to ModelSlimMXFP4Scheme in modelslim.py\n2. Add .linear. key stripping in wan_repack RENAME_DICT for MXFP4 checkpoints\n3. Support multi-shard safetensors loading in load_sharded_safetensors
- Add W4A4_MXFP4_DUALSCALE type to modelslim scheme dispatcher - Support .linear. key stripping in wan_repack for MXFP4 msmodelslim exports - Support multi-shard safetensors loading in repack tool - Fix modelslim quantization config loading from component directory - Add detailed error messages for unsupported quantization schemes
… flag is explicit When --quantization modelslim is explicitly passed, the loader must load the per-layer quant_model_description.json from the transformer directory rather than creating an empty config. This ensures ModelSlimConfig receives the quantization type mappings required for proper scheme dispatch.
…r msmodelslim export
- weight: [out, in] float8_e4m3fn (not [out, in/2])
- weight_dual_scale: [out, in/512, 1] float32 (not [out, in/64] bfloat16)
L1 scale groups 16 L0 blocks = 512 elements
- Fix create_weights allocation and process_weights_after_loading transforms
to match actual checkpoint tensor formats from msmodelslim
Bring in MXFP4 offline (ModelSlim) loading support including dual-scale weight format, smooth quant mul_scale, and npu_format_cast fix.
Merge latest upstream changes and migrate modelslim/quantization explicit-flag support to refactored transformer_load_utils.py.
There was a problem hiding this comment.
Code Review
This pull request introduces support for MXFP4 and MXFP8 quantization on Ascend NPUs, including both offline schemes for pre-quantized weights and experimental online quantization methods. Key changes include the addition of ModelSlimMXFP4Scheme and ModelSlimMXFP8Scheme, updates to the model loader to support an explicit --quantization flag, and significant enhancements to the wan_repack.py tool for Wan2.2 models. Review feedback focuses on performance optimizations in the MXFP4 forward pass to avoid GPU-to-CPU synchronization and improving the robustness of weight shuffling logic in the FP8 implementation by using standard PyTorch in-place update patterns.
Resolve conflicts across 6 diffusion quant files: 1. quantization/__init__.py: keep mxfp4 registration; add upstream modelopt/modelopt_fp8 to literal 2. modelslim.py: keep W4A4_MXFP4 dispatch + verbose error; dedupe W8A8_MXFP8 branch 3. modelslim_mxfp8_scheme.py: adopt upstream platform-gated torch_npu import 4. mxfp8_npu.py: adopt upstream platform-gated torch_npu import 5. transformer_load_utils.py: keep modelslim special case + safetensors-metadata fallback 6. tools/wan_repack.py: keep .linear./.div. rename rules and sharded safetensors loader Skipping pre-commit (--no-verify): check-no-docs-changes hook blocks docs/ changes, but those are legitimately introduced by 1394 upstream commits, not local edits.
|
/tag-and-rerun-ci |
…ners 1. Guard `import torch_npu` with `if _is_npu:` in mxfp4_npu.py and modelslim_mxfp4_scheme.py -- fixes ModuleNotFoundError on all GPU/AMD/MUSA CI runners\n2. Precompute layer.use_mul_scale flag in process_weights_after_loading to avoid GPU-to-CPU sync on every forward pass\n3. Use torch.no_grad() + copy_() instead of .data= for weight shuffle in fp8.py elif _use_aiter: block
Performance Comparison Report1. High-level Summary
2. Stage Breakdown
Metadata
|
Upstream main introduced Mxfp4Config (ROCm/aiter, MI350+) registered as `--quantization mxfp4` in `mxfp4.py`, colliding with this branch's NPU MXFP4Config that previously used the same key. Resolution: - Rename NPU diffusion MXFP4 key to `mxfp4_npu` (consistent with LLM-side `--quantization mxfp4_npu` convention) - Register both `mxfp4` (ROCm) and `mxfp4_npu` (Ascend) in the quantization registry; deduplicate the dict entry - Update server_args help text and transformer_load_utils comment to list both options and their hardware targets Note: --no-verify used because upstream main contains legacy `docs/` changes that this repo's check-no-docs-changes hook rejects; those changes are inherited from upstream and not introduced by this merge.
Disambiguate from upstream ROCm Mxfp4Config (mxfp4.py) which differs only by letter case. NPU prefix aligns with LLM-side npu_mxfp4 naming convention.
Performance Comparison Report1. High-level Summary
2. Stage Breakdown
Metadata
|
| "Note: MXFP4 requires ROCm and MI350+ (gfx95x)." | ||
| "Options: 'fp8', 'mxfp8', 'mxfp4', 'mxfp4_npu', 'modelslim'. " | ||
| "Note: 'mxfp4' targets ROCm + MI350+ (gfx95x); " | ||
| "'mxfp4_npu' / 'mxfp8' target Ascend NPU (A5 series for mxfp4_npu)." |
There was a problem hiding this comment.
Hi! Why are new quantization entities like mxfp8 or mxpf4_npu being created, shouldn't it be related to modelslim and handled in modelslim_config?
There was a problem hiding this comment.
Hi! Great question. The distinction is between online quantization and offline (pre-quantized) loading.
-
modelslimis the entry point for offline pre-quantized checkpoints produced by Huawei's msmodelslim tool. It loads already-quantized weights (FP8/INT8/INT4) and dispatches to the right scheme based onquant_model_description.json. It does no quantization itself. -
mxfp8andmxfp4_npuare online quantization configs: they start from FP16/BF16 weights and perform real-time quantization insideprocess_weights_after_loading. This is a fundamentally different weight-loading flow.
Merging them into modelslim would conflate two separate paradigms ("adapt pre-quantized weights" vs. "quantize at runtime"), and modelslim would need to handle cases that aren't about ModelSlim checkpoints at all. This split also mirrors the broader SGLang pattern: fp8 for online FP8, compressed-tensors/quark/modelopt for their respective offline toolchain checkpoints.
There was a problem hiding this comment.
Got it! Thank you for your answer!
CI Failure AnalysisI went through all failing jobs in the latest run. None of the failures are related to this PR's changes. This PR only adds MXFP4 quantization support for Wan2.2 on Ascend NPU; all failing tests run on NVIDIA, AMD, or MUSA hardware with Failing Jobs SummaryNVIDIA (run 26008786983)
These are non-deterministic image quality metrics with tight fixed thresholds — flaky by nature, same tests pass in other partitions (e.g., partition 0 and 2 passed). AMD (run 26008786876)
MOVA-360p error (MI325, gfx942): This is a HIP JIT compilation bug in the MOVA model kernel, unrelated to this PR. MUSA (run 26008786907)
ConclusionAll failures are pre-existing flaky tests or infrastructure issues:
|
Resolve fp8.py conflict in Fp8MoEMethod aiter pre-shuffle block: 1. Adopt upstream's t = shuffle_weight(...); copy_(t); del t pattern (memory-peak optimization, drops .contiguous()) 2. Drop local torch.no_grad() wrapper -- w13/w2 weights are requires_grad=False, so copy_() is autograd-safe without it Note: committed with --no-verify because the check-no-docs-changes hook is not merge-aware and flags upstream-owned legacy docs/ edits.
|
I merged it as AMD CIs are unrelated to this PR and MUSA CIs started 3.5h ago but still pend |
Summary
This PR adds MXFP4 (Microscaling FP4, dual-level) quantization support for Wan2.2 diffusion models on Ascend NPU. It is a follow-up to #20922 (MXFP8 support).
Hardware requirement: Ascend A5 series or newer.
npu_dynamic_dual_level_mx_quantandnpu_dual_level_quant_matmulare not available on A2/A3.Two modes are supported:
Online quantization (
--quantization mxfp4_npu)NPUMXFP4Config+NPUMXFP4DiffusionLinearMethod(multimodal_gen/runtime/layers/quantization/mxfp4_npu.py) for the diffusion subsystem.npu_dynamic_dual_level_mx_quant; at inference, activations are quantized per-token and the matmul is executed bynpu_dual_level_quant_matmulwith dual-level block scales (L0 block size = 512, L1 block size = 32).Offline quantization (msmodelslim pre-quantized weights)
ModelSlimMXFP4Scheme(multimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.py) for loading weights pre-quantized by msmodelslim.weight:[out, in]—float8_e4m3fncontainer for FP4 data (converted tofloat4_e2m1fn_x2+ FRACTAL_NZ at load time)weight_scale:[out, in/32]—uint8L1 block scales (e8m0 + 127 bias), reshaped to[out, in/64, 2]weight_dual_scale:[out, in/512, 1]—float32L0 coarse scales, transposed to[in/512, out]mul_scale:[in]—float32smooth-quant activation scale fromNonFusionSmoothQuantWrapper; must be applied to activations before quantization to preserve numerical alignment with the offline-calibrated weights. Defaults to ones (no-op) if absent.Key NPU APIs used
torch_npu.npu_dynamic_dual_level_mx_quant(x, smooth_scale=None)(quant, l0_scale, l1_scale)torch_npu.npu_dual_level_quant_matmul(x1, x2, x1l0, x2l0, x1l1, x2l1, ...)torch_npu.npu_dtype_cast(weight, torch_npu.float4_e2m1fn_x2)float4_e2m1fn_x2dtypetorch_npu.npu_format_cast(w.view(torch.int8), 29, customize_dtype=torch.int8)npu_dual_level_quant_matmulFiles Changed
New files
multimodal_gen/runtime/layers/quantization/mxfp4_npu.pyNPUMXFP4Config+NPUMXFP4DiffusionLinearMethod) for Wan2.2 diffusionmultimodal_gen/runtime/layers/quantization/modelslim_mxfp4_scheme.pyModelSlimMXFP4Scheme) for msmodelslim pre-quantized weightsModified — MXFP4 registration & dispatch
multimodal_gen/runtime/layers/quantization/__init__.pyNPUMXFP4Configunder key"mxfp4_npu"; add"mxfp4_npu"toQuantizationMethodsliteral (coexists with upstream ROCm"mxfp4")multimodal_gen/runtime/layers/quantization/modelslim.pyW4A4_MXFP4/W4A4_MXFP4_DUALSCALEbranch →ModelSlimMXFP4Schemein_get_scheme_from_parts(); improveNotImplementedErrormessage to include layer name and quant typeModified — supporting infrastructure
multimodal_gen/runtime/loader/transformer_load_utils.py_resolve_quant_configpriority:modelslimflag now loads the per-layer quant description file; add safetensors-metadata fallback when only--transformer-weights-pathis suppliedmultimodal_gen/runtime/server_args.py--quantizationhelp text: listmxfp4_npualongsidemxfp4, document hardware targets (ROCm MI350+ vs Ascend A5)multimodal_gen/tools/wan_repack.py.linear.→.and.div.→.so msmodelslim-wrapped Linear /NonFusionSmoothQuantWrapperkeys match SGLang model parameters; allow loading multi-shard safetensorssrt/layers/quantization/fp8.py.data =weight assignment withtorch.no_grad() + copy_()in the AMD_use_aiterblock-quant MoE path (per Gemini reviewer suggestion; preservesParameteridentity)Implementation Notes
Dual-Level Scale Layout
MXFP4 uses a two-level block-scale hierarchy:
weight_scale[out, in/64, 2](uint8)weight_dual_scale[in/512, out](float32)The msmodelslim export uses
[out, in/32]forweight_scaleand[out, in/512, 1]forweight_dual_scale.process_weights_after_loadingreshapes and transposes these to match whatnpu_dual_level_quant_matmulexpects, following the MindIE-SDW4A4MXFP4DualQuantLinearreference.Smooth-Quant
mul_scalemsmodelslim wraps quantized layers in
NonFusionSmoothQuantWrapper, which exports a per-channel activation scalemul_scale(shape[in]). The activation must be multiplied by this scale before dual-level quantization to stay aligned with the offline-calibrated weights. Omitting this step causes mosaic / corrupted output.mul_scaleis loaded as aBasevLLMParameterwithmissing_param_init = "ones"so that models exported without smooth-quant (or repacked without the.div.key rename) degrade gracefully to a no-op rather than crashing.To avoid a GPU→CPU sync on every forward pass,
process_weights_after_loadingprecomputes alayer.use_mul_scaleboolean by checkingtorch.all(mul_scale == 1.0)once at load time (per Gemini reviewer suggestion).FRACTAL_NZ Requirement
npu_dual_level_quant_matmulrequires the weight tensor (x2) to be in FRACTAL_NZ memory format (format 29). The conversion is:This matches the
_init_dynamic_quant_paramstep in MindIE-SD'sW4A4MXFP4DualQuantLinear.Performance Comparison Report
High-level Summary
Stage Breakdown
The previous report on the pre-rename head (commit
7c6f431) reported the same trend: E2E -14.5% online / -11.6% offline. Run-to-run variance is well within ±0.5% on E2E.Related Issues / PRs
CI States
Latest PR Test (Base): ✅ Run #26070807413⚠️ Not enabled -- add
Latest PR Test (Extra):
run-ci-extralabel to opt in.