✨ [llm][npu][quant] Add W8A8 MXFP8 quantization support for Qwen3 Dense on Ascend NPU#22352
Open
TallMessiWu wants to merge 24 commits into
Open
✨ [llm][npu][quant] Add W8A8 MXFP8 quantization support for Qwen3 Dense on Ascend NPU#22352TallMessiWu wants to merge 24 commits into
TallMessiWu wants to merge 24 commits into
Conversation
…th B) Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU, supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8 checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant + npu_quant_matmul with group_sizes=[1,1,32]).
…n Ascend NPU Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2 and other diffusion models on Ascend NPU. Also adds explicit quantization field to diffusion ServerArgs so online quantization can be specified without pre-quantized weights.
- Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call - Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul) - Simplify output shape restoration logic Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4
按 reviewer 建议重构架构分层: - 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights) - 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用 - 改进架构分层,符合现有 NPU INT8 方法模式
Fix weight loading for msmodelslim pre-quantized MXFP8 weights: - Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors) - Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export) - Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2]
…weight processing.
- Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode) - Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format - Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape - Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv) - Improve flatten-to-2D comment to explain NPU kernel requirement
…rate PR Revert LLM-side MXFP8 changes to split into a separate PR. This branch now only contains Wan2.2 Diffusion MXFP8 changes. Reverted files: - fp8.py: removed MXFP8LinearAscendMethod class and NPU branch - mxfp8_method_npu.py: deleted (NPU MXFP8 linear method) - test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test) LLM MXFP8 code preserved on junlin_llm branch.
Conflict resolution: upstream refactored transformer loader class methods into standalone functions in transformer_load_utils.py. Preserved server_args.quantization priority (mxfp8/mxfp4/modelslim) in _resolve_quant_config.
…t_config Upstream refactored class methods to standalone functions in transformer_load_utils.py but dropped the server_args.quantization priority path. Re-add it so mxfp8/mxfp4/modelslim still work.
1. Add NPUMXFP8LinearMethod to linear_method_npu.py (online quant)\n2. Add NPU dispatch branch in Fp8Config.get_quant_method\n3. Fix get_min_capability to return 0 on NPU\n4. Add ModelSlimMXFP8Scheme for offline pre-quantized MXFP8 weights\n5. Register W8A8_MXFP8 scheme in modelslim dispatcher
…nit__.py Move ModelSlimLinearScheme import before ModelSlimMXFP8Scheme to resolve circular dependency. modelslim_mxfp8.py imports ModelSlimLinearScheme from the schemes package, which failed when __init__.py hadn't yet defined it. The fix ensures base class is available in module namespace before any subclass imports attempt to reference it. Issue: ImportError: cannot import name 'ModelSlimLinearScheme' from partially initialized module 'sglang.srt.layers.quantization.modelslim.schemes'
Contributor
There was a problem hiding this comment.
Code Review
This pull request adds MXFP8 quantization support for Diffusion and LLM models on Ascend NPU, implementing new schemes like ModelSlimMXFP8Scheme and NPUMXFP8LinearMethod, and introducing a --quantization CLI flag. It also refactors the wan_repack tool and improves modelslim configuration discovery. Feedback suggests optimizing the multimodal implementation by pre-transposing weights and scales during loading to reduce forward-pass overhead and using robust fallback definitions for the float8_e8m0fnu data type to ensure environment compatibility.
2 tasks
TallMessiWu
added a commit
to TallMessiWu/sglang
that referenced
this pull request
May 20, 2026
…#22352 conflicts Resolved 10 conflicts: - Diffusion side (7 files): used upstream/main (includes merged sgl-project#20922/sgl-project#22338) - LLM side fp8.py: kept both NPU and MUSA capability bypasses - LLM side modelslim.py: added W8A8_MXFP8 to upstream's table-driven scheme dispatch - LLM side transformers.py: used upstream/main (MoE refactoring with proper prefix)
e7c54c6 to
29c04bc
Compare
… prerequisites - Prerequisites sgl-project#20922 (Diffusion MXFP8) and sgl-project#22338 (Diffusion MXFP4) merged upstream — accept their canonical versions and remove our duplicate diffusion modifications from this PR. - Adapt offline MXFP8 dispatch to upstream's table-driven ModelSlimConfig.get_linear_scheme by registering W8A8_MXFP8 → ModelSlimMXFP8Scheme; add a no-op __init__ on the scheme so its signature matches the other entries. - Keep NPU bypass in Fp8Config.get_min_capability alongside the new upstream _is_musa branch; revert an unrelated MoE _use_aiter style change that drifted in via earlier merges.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds W8A8 MXFP8 (Microscaling FP8) quantization support for Qwen3 / Qwen3.5 dense LLM models on Ascend NPU. It closes part of the NPU quantization gap tracked in issue #21584.
Prerequisites (both merged into
main): #20922 (Diffusion MXFP8) | #22338 (Diffusion MXFP4). After the latest sync with upstreammain, this PR no longer touches anymultimodal_gen/...files — all diffusion changes are accepted frommain.Hardware requirement: Ascend A5 series or newer.
npu_dynamic_mx_quantis not available on A2 / A3.Two modes are supported:
Online quantization (
--quantization mxfp8)Fp8Configpath (triggered by--quantization mxfp8,use_mxfp8=True).Fp8Config.get_min_capability()(returns 0, skipping CUDA capability checks that are meaningless on NPU) — kept alongside the new upstream_is_musabranch.Fp8Config.get_quant_method()→NPUMXFP8LinearMethod(new class insrt/hardware_backend/npu/quantization/linear_method_npu.py).npu_dynamic_mx_quantand pre-transposed to[in, out]layout. At inference, activations are quantized per-token and the matmul is executed bynpu_quant_matmulwithgroup_sizes=[1, 1, 32](block_size = 32).Offline quantization (msmodelslim pre-quantized weights)
ModelSlimMXFP8Scheme(srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py) for loading weights pre-quantized by msmodelslim (float8_e4m3fnweights +uint8scale infloat8_e8m0fnuencoding).ModelSlimConfig.get_linear_scheme()as("W8A8_MXFP8", ModelSlimMXFP8Scheme)(table-driven dispatch introduced upstream); the scheme's__init__accepts (and ignores) the standardquant_config/prefixkwargs to match the dispatch signature.Bug fix
srt/layers/rotary_embedding/base.py: wrapfused_rope_qk_mqaimport intry/except, falling back toNoneif the kernel is absent. Without this, a missing kernel insgl_kernel_npucauses the whole module import to fail;ModelRegistrythen silently skips the model and falls back to HF Transformers without quantization awareness, producing garbled output with FP8 weights interpreted as BF16.Key NPU APIs used
torch_npu.npu_dynamic_mx_quant(x, dst_type=torch_npu.float8_e4m3fn)torch_npu.npu_quant_matmul(..., group_sizes=[1, 1, 32])torch_npu.float8_e4m3fn/torch_npu.float8_e8m0fnuFiles Changed
srt/hardware_backend/npu/quantization/linear_method_npu.pyNPUMXFP8LinearMethod(online MXFP8 weight quantization + inference)srt/layers/quantization/fp8.pyget_min_capability(); dispatch toNPUMXFP8LinearMethodon NPU +use_mxfp8=Truepathsrt/layers/quantization/modelslim/modelslim.py("W8A8_MXFP8", ModelSlimMXFP8Scheme)inget_linear_scheme()srt/layers/quantization/modelslim/schemes/__init__.pyModelSlimMXFP8Scheme;# isort: offblock keeps the base-class import first to avoid circular dependencysrt/layers/quantization/modelslim/schemes/modelslim_mxfp8.pyModelSlimMXFP8Scheme) for msmodelslim pre-quantized weightssrt/layers/rotary_embedding/base.pyfused_rope_qk_mqaimport + None fallbackFinal diff against
upstream/main: 6 files, +245 / −2.Implementation Notes
.contiguous()on transpose: online vs offlineThe two paths differ intentionally:
NPUMXFP8LinearMethod): calls.contiguous()after transpose. Safe because the quantized weight is freshly allocated bynpu_dynamic_mx_quant; there are no pre-existing block-scale mappings tied to the original memory layout.ModelSlimMXFP8Scheme): does not call.contiguous()after transpose, using.dataassignment to preserve the non-contiguous transpose view.npu_quant_matmulreads strides correctly; calling.contiguous()would physically reorder the pre-quantized weight data and break the block-scale mapping, producing garbled output.This pattern matches the
vllm-ascendreference implementation.Performance Comparison Report
--quantization mxfp8)Related Issues
Closes part of #21584 (MXFP8 / MXFP4 support on Ascend NPU for Qwen3 dense LLM).
Builds on the NPU quantization infrastructure introduced by the now-merged #20922 (Diffusion MXFP8) and #22338 (Diffusion MXFP4).
CI States
Latest PR Test (Base): ❌ Run #26176245090
Latest PR Test (Extra): ❌ Run #26176244908