Skip to content

✨ [llm][npu][quant] Add W8A8 MXFP8 quantization support for Qwen3 Dense on Ascend NPU#22352

Open
TallMessiWu wants to merge 24 commits into
sgl-project:mainfrom
TallMessiWu:junlin_qwen3_dense
Open

✨ [llm][npu][quant] Add W8A8 MXFP8 quantization support for Qwen3 Dense on Ascend NPU#22352
TallMessiWu wants to merge 24 commits into
sgl-project:mainfrom
TallMessiWu:junlin_qwen3_dense

Conversation

@TallMessiWu
Copy link
Copy Markdown
Contributor

@TallMessiWu TallMessiWu commented Apr 8, 2026

Summary

This PR adds W8A8 MXFP8 (Microscaling FP8) quantization support for Qwen3 / Qwen3.5 dense LLM models on Ascend NPU. It closes part of the NPU quantization gap tracked in issue #21584.

Prerequisites (both merged into main): #20922 (Diffusion MXFP8) | #22338 (Diffusion MXFP4). After the latest sync with upstream main, this PR no longer touches any multimodal_gen/... files — all diffusion changes are accepted from main.

Hardware requirement: Ascend A5 series or newer. npu_dynamic_mx_quant is not available on A2 / A3.

Two modes are supported:

Online quantization (--quantization mxfp8)

  • Reuses the existing Fp8Config path (triggered by --quantization mxfp8, use_mxfp8=True).
  • Adds an NPU bypass in Fp8Config.get_min_capability() (returns 0, skipping CUDA capability checks that are meaningless on NPU) — kept alongside the new upstream _is_musa branch.
  • Adds an NPU dispatch in Fp8Config.get_quant_method()NPUMXFP8LinearMethod (new class in srt/hardware_backend/npu/quantization/linear_method_npu.py).
  • At load time, FP16 / BF16 weights are quantized online to MXFP8 via npu_dynamic_mx_quant and pre-transposed to [in, out] layout. At inference, activations are quantized per-token and the matmul is executed by npu_quant_matmul with group_sizes=[1, 1, 32] (block_size = 32).

Offline quantization (msmodelslim pre-quantized weights)

  • Adds ModelSlimMXFP8Scheme (srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py) for loading weights pre-quantized by msmodelslim (float8_e4m3fn weights + uint8 scale in float8_e8m0fnu encoding).
  • Registered in ModelSlimConfig.get_linear_scheme() as ("W8A8_MXFP8", ModelSlimMXFP8Scheme) (table-driven dispatch introduced upstream); the scheme's __init__ accepts (and ignores) the standard quant_config / prefix kwargs to match the dispatch signature.

Bug fix

  • srt/layers/rotary_embedding/base.py: wrap fused_rope_qk_mqa import in try/except, falling back to None if the kernel is absent. Without this, a missing kernel in sgl_kernel_npu causes the whole module import to fail; ModelRegistry then silently skips the model and falls back to HF Transformers without quantization awareness, producing garbled output with FP8 weights interpreted as BF16.

Key NPU APIs used

API Purpose
torch_npu.npu_dynamic_mx_quant(x, dst_type=torch_npu.float8_e4m3fn) Dynamic MXFP8 quantization of activations / weights
torch_npu.npu_quant_matmul(..., group_sizes=[1, 1, 32]) MXFP8 quantized matmul (block_size = 32)
torch_npu.float8_e4m3fn / torch_npu.float8_e8m0fnu FP8 weight dtype / scale factor dtype

Files Changed

File Change
srt/hardware_backend/npu/quantization/linear_method_npu.py Add NPUMXFP8LinearMethod (online MXFP8 weight quantization + inference)
srt/layers/quantization/fp8.py NPU bypass in get_min_capability(); dispatch to NPUMXFP8LinearMethod on NPU + use_mxfp8=True path
srt/layers/quantization/modelslim/modelslim.py Register ("W8A8_MXFP8", ModelSlimMXFP8Scheme) in get_linear_scheme()
srt/layers/quantization/modelslim/schemes/__init__.py Export ModelSlimMXFP8Scheme; # isort: off block keeps the base-class import first to avoid circular dependency
srt/layers/quantization/modelslim/schemes/modelslim_mxfp8.py New — offline MXFP8 (ModelSlimMXFP8Scheme) for msmodelslim pre-quantized weights
srt/layers/rotary_embedding/base.py Guard fused_rope_qk_mqa import + None fallback

Final diff against upstream/main: 6 files, +245 / −2.

Implementation Notes

.contiguous() on transpose: online vs offline

The two paths differ intentionally:

  • Online (NPUMXFP8LinearMethod): calls .contiguous() after transpose. Safe because the quantized weight is freshly allocated by npu_dynamic_mx_quant; there are no pre-existing block-scale mappings tied to the original memory layout.
  • Offline (ModelSlimMXFP8Scheme): does not call .contiguous() after transpose, using .data assignment to preserve the non-contiguous transpose view. npu_quant_matmul reads strides correctly; calling .contiguous() would physically reorder the pre-quantized weight data and break the block-scale mapping, producing garbled output.

This pattern matches the vllm-ascend reference implementation.

Performance Comparison Report

Performance numbers are not yet available. This section will be filled in once benchmark runs on Ascend hardware are complete.

Metric Baseline Offline (ModelSlim MXFP8) Online (--quantization mxfp8)
E2E Latency TBD TBD TBD

Related Issues

Closes part of #21584 (MXFP8 / MXFP4 support on Ascend NPU for Qwen3 dense LLM).

Builds on the NPU quantization infrastructure introduced by the now-merged #20922 (Diffusion MXFP8) and #22338 (Diffusion MXFP4).


CI States

Latest PR Test (Base): ❌ Run #26176245090
Latest PR Test (Extra): ❌ Run #26176244908

…th B)

Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU,
supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8
checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant +
npu_quant_matmul with group_sizes=[1,1,32]).
…n Ascend NPU

Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion
subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2
and other diffusion models on Ascend NPU. Also adds explicit
quantization field to diffusion ServerArgs so online quantization
can be specified without pre-quantized weights.
- Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call
- Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul)
- Simplify output shape restoration logic

Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4
按 reviewer 建议重构架构分层:
- 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights)
- 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用
- 改进架构分层,符合现有 NPU INT8 方法模式
Fix weight loading for msmodelslim pre-quantized MXFP8 weights:
- Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors)
- Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export)
- Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2]
- Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode)
- Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format
- Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape
- Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv)
- Improve flatten-to-2D comment to explain NPU kernel requirement
…rate PR

Revert LLM-side MXFP8 changes to split into a separate PR.
This branch now only contains Wan2.2 Diffusion MXFP8 changes.

Reverted files:
- fp8.py: removed MXFP8LinearAscendMethod class and NPU branch
- mxfp8_method_npu.py: deleted (NPU MXFP8 linear method)
- test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test)

LLM MXFP8 code preserved on junlin_llm branch.
Conflict resolution: upstream refactored transformer loader class methods into standalone functions in transformer_load_utils.py. Preserved server_args.quantization priority (mxfp8/mxfp4/modelslim) in _resolve_quant_config.
…t_config

Upstream refactored class methods to standalone functions in transformer_load_utils.py but dropped the server_args.quantization priority path. Re-add it so mxfp8/mxfp4/modelslim still work.
1. Add NPUMXFP8LinearMethod to linear_method_npu.py (online quant)\n2. Add NPU dispatch branch in Fp8Config.get_quant_method\n3. Fix get_min_capability to return 0 on NPU\n4. Add ModelSlimMXFP8Scheme for offline pre-quantized MXFP8 weights\n5. Register W8A8_MXFP8 scheme in modelslim dispatcher
…nit__.py

Move ModelSlimLinearScheme import before ModelSlimMXFP8Scheme to resolve circular
dependency. modelslim_mxfp8.py imports ModelSlimLinearScheme from the schemes
package, which failed when __init__.py hadn't yet defined it.

The fix ensures base class is available in module namespace before any subclass
imports attempt to reference it.

Issue: ImportError: cannot import name 'ModelSlimLinearScheme' from partially
initialized module 'sglang.srt.layers.quantization.modelslim.schemes'
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds MXFP8 quantization support for Diffusion and LLM models on Ascend NPU, implementing new schemes like ModelSlimMXFP8Scheme and NPUMXFP8LinearMethod, and introducing a --quantization CLI flag. It also refactors the wan_repack tool and improves modelslim configuration discovery. Feedback suggests optimizing the multimodal implementation by pre-transposing weights and scales during loading to reduce forward-pass overhead and using robust fallback definitions for the float8_e8m0fnu data type to ensure environment compatibility.

@ping1jing2 ping1jing2 self-assigned this Apr 9, 2026
@github-actions github-actions Bot added documentation Improvements or additions to documentation lora Multi-modal multi-modal language model deepseek speculative-decoding hicache Hierarchical Caching for SGLang labels May 20, 2026
TallMessiWu added a commit to TallMessiWu/sglang that referenced this pull request May 20, 2026
…#22352 conflicts

Resolved 10 conflicts:
- Diffusion side (7 files): used upstream/main (includes merged sgl-project#20922/sgl-project#22338)
- LLM side fp8.py: kept both NPU and MUSA capability bypasses
- LLM side modelslim.py: added W8A8_MXFP8 to upstream's table-driven scheme dispatch
- LLM side transformers.py: used upstream/main (MoE refactoring with proper prefix)
@TallMessiWu TallMessiWu changed the title 🚧 [llm][npu][quant] Add W8A8 MXFP8 quantization support for Qwen3 Dense on Ascend NPU ✨ [llm][npu][quant] Add W8A8 MXFP8 quantization support for Qwen3 Dense on Ascend NPU May 20, 2026
@TallMessiWu TallMessiWu force-pushed the junlin_qwen3_dense branch from e7c54c6 to 29c04bc Compare May 20, 2026 14:58
… prerequisites

- Prerequisites sgl-project#20922 (Diffusion MXFP8) and sgl-project#22338 (Diffusion MXFP4) merged
  upstream — accept their canonical versions and remove our duplicate
  diffusion modifications from this PR.
- Adapt offline MXFP8 dispatch to upstream's table-driven
  ModelSlimConfig.get_linear_scheme by registering W8A8_MXFP8 →
  ModelSlimMXFP8Scheme; add a no-op __init__ on the scheme so its
  signature matches the other entries.
- Keep NPU bypass in Fp8Config.get_min_capability alongside the new
  upstream _is_musa branch; revert an unrelated MoE _use_aiter style
  change that drifted in via earlier merges.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek diffusion SGLang Diffusion documentation Improvements or additions to documentation hicache Hierarchical Caching for SGLang lora Multi-modal multi-modal language model npu quant LLM Quantization speculative-decoding

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants