Skip to content

🚧 [llm][npu][quant] Add W4A8 MXFP quantization support for Qwen3 Dense on Ascend NPU#23650

Open
TallMessiWu wants to merge 28 commits into
sgl-project:mainfrom
TallMessiWu:junlin_qwen3_dense_w4a8
Open

🚧 [llm][npu][quant] Add W4A8 MXFP quantization support for Qwen3 Dense on Ascend NPU#23650
TallMessiWu wants to merge 28 commits into
sgl-project:mainfrom
TallMessiWu:junlin_qwen3_dense_w4a8

Conversation

@TallMessiWu
Copy link
Copy Markdown
Contributor

Summary

Dependency: This PR depends on #22352 (W8A8 MXFP8 for Qwen3 Dense on Ascend NPU) and should be merged after that PR lands, as it builds on the same NPU quantization infrastructure (_NPULinearMethodBase, ModelSlimConfig dispatch, etc.).

This PR adds W4A8 MXFP quantization support for Qwen3 dense LLM models on Ascend NPU. It continues the NPU quantization work tracked in issue #21584.

Hardware requirement: Ascend 950 (Atlas A5)

Two modes are supported:

Online quantization (--quantization mxfp4_w4a8_npu)

  • New NPUMxfp4Config (layers/quantization/npu_mxfp4.py) dispatches to NPUMXFP4W4A8LinearMethod.
  • At load time, FP16/BF16 weights are quantised online to dual-level MXFP4 via npu_dynamic_dual_level_mx_quant: produces float4_e2m1fn_x2 weights, L0 per-block scale (FP32), and L1 per-channel scale (FP8_E8M0). The weight is then cast to FRACTAL_NZ format for NPU matmul efficiency.
  • At inference, activations are quantised with the same dual-level API and the matmul is executed by npu_dual_level_quant_matmul.
  • Note: the actual matmul compute is W4A4 (both operands in FP4 with dual-level scales); "A8" refers to the FP8_E8M0 L1 scale format. There is no W4A8 mixed-precision public kernel in the current torch_npu API.

Offline quantization (msmodelslim pre-quantized weights, --quantization modelslim)

  • Adds ModelSlimMXFP4W4A8Scheme (modelslim/schemes/modelslim_mxfp4_w4a8.py) for the W4A8_MXFP scheme type.
  • The current msmodelslim checkpoint format for W4A8_MXFP stores weights as float8_e4m3fn (shape [out, in]) with a uint8 FP8_E8M0-biased scale (shape [out, in/32]). At load time, scale is reshaped to [out, in/64, 2] and both weight and scale are transposed (no .contiguous() — see Implementation Notes).
  • At inference, activations are dynamically quantised to FP8 via npu_dynamic_mx_quant and the matmul runs via npu_quant_matmul with group_sizes=[1,1,32] — identical to the MXFP8 offline path.
  • Also adds W4A8_DYNAMIC dispatch → ModelSlimW4A8Int8 (INT4 offline scheme for non-MXFP checkpoints).

Key NPU APIs used

API Purpose
torch_npu.npu_dynamic_dual_level_mx_quant(x) Dual-level MXFP4 quantisation of activations/weights (FP4 + L0 FP32 scale + L1 FP8_E8M0 scale)
torch_npu.npu_dual_level_quant_matmul(...) MXFP4 dual-level quantised matmul (online mode, Ascend 950 only)
torch_npu.npu_format_cast(w.view(torch.int8), 29) Convert FP4 weight to FRACTAL_NZ layout (required by dual-level matmul)
torch_npu.npu_dynamic_mx_quant(x, dst_type=float8_e4m3fn) Dynamic MXFP8 activation quantisation (offline mode)
torch_npu.npu_quant_matmul(..., group_sizes=[1,1,32]) MXFP8 quantised matmul (offline mode, block_size=32)

Files Changed

New files

File Change
srt/layers/quantization/npu_mxfp4.py NewNPUMxfp4Config for online W4A8 MXFP4 (--quantization mxfp4_w4a8_npu)
srt/layers/quantization/modelslim/schemes/modelslim_mxfp4_w4a8.py New — offline ModelSlimMXFP4W4A8Scheme for W4A8_MXFP msmodelslim checkpoints

Modified — online W4A8 NPU dispatch

File Change
srt/hardware_backend/npu/quantization/linear_method_npu.py Add NPUMXFP4W4A8LinearMethod (online dual-level MXFP4 weight quantisation + inference) and NPUW4A8DynamicLinearMethod (offline INT4 inference via npu_weight_quant_batchmatmul)
srt/layers/quantization/__init__.py Register NPUMxfp4Config under "mxfp4_w4a8_npu"
srt/server_args.py Add "mxfp4_w4a8_npu" and "mxfp4_w4a4_npu" to QUANTIZATION_CHOICES

Modified — offline W4A8 registration & dispatch

File Change
srt/layers/quantization/modelslim/modelslim.py Add W4A8_MXFP branch → ModelSlimMXFP4W4A8Scheme and W4A8_DYNAMIC branch → ModelSlimW4A8Int8 in _get_scheme_from_parts()
srt/layers/quantization/modelslim/schemes/__init__.py Register ModelSlimMXFP4W4A8Scheme and ModelSlimW4A8Int8

Implementation Notes

W4A8_MXFP offline: weight format matches MXFP8

The current msmodelslim W4A8_MXFP checkpoint stores weights as float8_e4m3fn (identical layout to W8A8_MXFP8), not as packed FP4 uint8. This corresponds to an older msmodelslim export version. Consequently, ModelSlimMXFP4W4A8Scheme is structurally identical to ModelSlimMXFP8Scheme — the distinction lies in the quantisation process, not the inference path. A future checkpoint format change (packed FP4) would require a separate scheme.

Online W4A8: dual-level scale layout requirements

npu_dual_level_quant_matmul requires:

  • Weight in FRACTAL_NZ format (format=29), cast via npu_format_cast(w.view(torch.int8), 29) since only int-dtype tensors are accepted.
  • L0 weight scale shape [in/512, out]: loaded as [out, in/512, 1], squeezed and transposed with .contiguous(). The .contiguous() here is safe — the scale is freshly allocated, unlike the offline non-contiguous transpose pattern.

.contiguous() asymmetry between online and offline paths

Consistent with the W8A8 PR (#22352):

  • Online (NPUMXFP4W4A8LinearMethod): uses .contiguous() after transpose — safe because weights/scales are freshly allocated from npu_dynamic_dual_level_mx_quant.
  • Offline (ModelSlimMXFP4W4A8Scheme): does not call .contiguous(), using .data assignment to preserve the non-contiguous transpose view. npu_quant_matmul reads strides correctly; .contiguous() would physically reorder pre-quantized data and break block-scale mapping.

Performance Comparison Report

TODO: Performance numbers are not yet available. This section will be filled in once benchmark runs on Ascend hardware are complete.

Metric Baseline (BF16) Offline (ModelSlim W4A8_MXFP) Online (--quantization mxfp4_w4a8_npu)
E2E Latency TBD TBD TBD
Memory (GPU) TBD TBD TBD

Related Issues

Closes part of #21584 (MXFP8/MXFP4 support on Ascend NPU for Qwen3 Dense LLM).

Depends on #22352 (W8A8 MXFP8 for Qwen3 Dense on Ascend NPU).

…th B)

Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU,
supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8
checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant +
npu_quant_matmul with group_sizes=[1,1,32]).
…n Ascend NPU

Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion
subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2
and other diffusion models on Ascend NPU. Also adds explicit
quantization field to diffusion ServerArgs so online quantization
can be specified without pre-quantized weights.
- Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call
- Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul)
- Simplify output shape restoration logic

Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4
按 reviewer 建议重构架构分层:
- 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights)
- 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用
- 改进架构分层,符合现有 NPU INT8 方法模式
Fix weight loading for msmodelslim pre-quantized MXFP8 weights:
- Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors)
- Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export)
- Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2]
- Remove unused __init__ (no quant_config/prefix needed, MXFP8 has only one mode)
- Fix weight dtype: float8_e4m3fn (not int8) to match msmodelslim checkpoint format
- Fix weight_scale shape: [out, in/32] (not in/32*2) to match actual tensor shape
- Add comment explaining weight_scale name must match checkpoint key (not weight_scale_inv)
- Improve flatten-to-2D comment to explain NPU kernel requirement
…rate PR

Revert LLM-side MXFP8 changes to split into a separate PR.
This branch now only contains Wan2.2 Diffusion MXFP8 changes.

Reverted files:
- fp8.py: removed MXFP8LinearAscendMethod class and NPU branch
- mxfp8_method_npu.py: deleted (NPU MXFP8 linear method)
- test_ascend_mxfp8_quantization.py: deleted (LLM MXFP8 test)

LLM MXFP8 code preserved on junlin_llm branch.
Conflict resolution: upstream refactored transformer loader class methods into standalone functions in transformer_load_utils.py. Preserved server_args.quantization priority (mxfp8/mxfp4/modelslim) in _resolve_quant_config.
…t_config

Upstream refactored class methods to standalone functions in transformer_load_utils.py but dropped the server_args.quantization priority path. Re-add it so mxfp8/mxfp4/modelslim still work.
1. Add NPUMXFP8LinearMethod to linear_method_npu.py (online quant)\n2. Add NPU dispatch branch in Fp8Config.get_quant_method\n3. Fix get_min_capability to return 0 on NPU\n4. Add ModelSlimMXFP8Scheme for offline pre-quantized MXFP8 weights\n5. Register W8A8_MXFP8 scheme in modelslim dispatcher
…nit__.py

Move ModelSlimLinearScheme import before ModelSlimMXFP8Scheme to resolve circular
dependency. modelslim_mxfp8.py imports ModelSlimLinearScheme from the schemes
package, which failed when __init__.py hadn't yet defined it.

The fix ensures base class is available in module namespace before any subclass
imports attempt to reference it.

Issue: ImportError: cannot import name 'ModelSlimLinearScheme' from partially
initialized module 'sglang.srt.layers.quantization.modelslim.schemes'
… LLM

Implements NPUMXFP4W4A8LinearMethod for Ascend NPU online quantization

of dense Qwen3/3.5 LLM models triggered via --quantization mxfp4_npu.

Weight flow (process_weights_after_loading):

BF16/FP16 → npu_dynamic_dual_level_mx_quant → FP4 (NZ format)

l0_scale [in/512, out] (FP32) + l1_scale (FP8_E8M0)
Inference flow (apply):

activation → npu_dynamic_dual_level_mx_quant → FP4

→ npu_dual_level_quant_matmul (W4A4 compute with dual-level scales)

Config registered as 'mxfp4_npu' in QUANTIZATION_METHODS.

Hardware: requires Ascend 950 (DualLevelQuantBatchMatmul not on A2/A3).
1. Add ModelSlimW4A8Int8 + NPUW4A8DynamicLinearMethod for W4A8_DYNAMIC dispatch
2. Add hardware warning + try/except guard in NPUMXFP4W4A8LinearMethod (Ascend 950 required)
3. Add MoE unquantized fallback warning in NPUMxfp4Config
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for MXFP8 and MXFP4 quantization on Ascend NPU, covering both diffusion models and LLMs. It adds specialized quantization configurations, linear methods utilizing NPU-specific kernels, and ModelSlim schemes for pre-quantized weight inference. Additionally, the wan_repack.py tool is refactored to support Wan2.2, and a new --quantization CLI argument allows for explicit quantization overrides. Feedback focuses on optimizing inference performance by pre-transposing weights and scales during model loading and ensuring environment robustness with fallback mechanisms for NPU-specific dtypes.

)
from sglang.srt.layers.quantization.modelslim.schemes import ModelSlimLinearScheme

MXFP8_BLOCK_SIZE = 32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For robustness across different versions of torch_npu and torch, it is better to use a fallback mechanism for the float8_e8m0fnu dtype, similar to the implementation in the SRT backend.

MXFP8_BLOCK_SIZE = 32
_FLOAT8_E8M0FNU_DTYPE = getattr(
    torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)
)

Comment on lines +65 to +73
def process_weights_after_loading(self, layer: torch.nn.Module):
# weight is already float8_e4m3fn, no cast needed
weight = layer.weight.data
layer.weight = torch.nn.Parameter(weight, requires_grad=False)

# Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2]
weight_scale = layer.weight_scale.data
weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pre-transposing the weight and scale tensors during model loading avoids the overhead of performing transposes on every forward pass. This optimization is already present in the SRT implementation of this scheme.

    def process_weights_after_loading(self, layer: torch.nn.Module):
        # weight is already float8_e4m3fn, no cast needed
        weight = layer.weight.data
        # Pre-transpose weight and scale to [in, out] for npu_quant_matmul.
        # Use .data assignment without .contiguous() to preserve the transpose
        # view strides — npu_quant_matmul reads strides correctly.
        layer.weight = torch.nn.Parameter(weight.transpose(0, 1), requires_grad=False)

        # Reshape weight_scale: [out, in/32] -> [out, in/32//2, 2]
        weight_scale = layer.weight_scale.data
        weight_scale = weight_scale.reshape(weight_scale.shape[0], -1, 2)
        layer.weight_scale = torch.nn.Parameter(weight_scale.transpose(0, 1), requires_grad=False)

Comment on lines +102 to +112
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use the pre-transposed weights and the robust dtype fallback in the matmul call.

Suggested change
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)
# MXFP8 matmul (weight & scale already transposed at load time)
output = torch_npu.npu_quant_matmul(
qx,
layer.weight,
layer.weight_scale,
scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
pertoken_scale=input_scale,
pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)


logger = init_logger(__name__)

MXFP8_BLOCK_SIZE = 32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Add a fallback for the float8_e8m0fnu dtype to ensure compatibility across different environment versions.

MXFP8_BLOCK_SIZE = 32
_FLOAT8_E8M0FNU_DTYPE = getattr(
    torch_npu, "float8_e8m0fnu", getattr(torch, "float8_e8m0fnu", None)
)

Comment on lines +108 to +131
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

weight_fp = layer.weight.data
if weight_fp.dtype not in (torch.float16, torch.bfloat16):
weight_fp = weight_fp.to(torch.bfloat16)

# Move weight to NPU if needed. We intentionally use a conditional
# move rather than an assert because `dit_cpu_offload` defaults to
# True in ServerArgs, which causes fsdp_load to move every parameter
# back to CPU after loading (even when the target device is NPU).
# npu_dynamic_mx_quant requires an NPU tensor, so we must transfer
# here. The quantized fp8 weights produced below will remain on NPU
# for inference; if the model still needs to be offloaded after
# quantization (e.g. very large model on a small NPU), a higher-level
# offload pass can move them back afterwards.
if not weight_fp.is_npu:
weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}")

# Online MXFP8 quantisation of weights (block_size=32)
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
layer.weight = Parameter(qw, requires_grad=False)
layer.weight_scale_inv = Parameter(w_scale, requires_grad=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Pre-transpose the weights and scales during model loading to improve inference performance. Since this is online quantization, using .contiguous() after transpose is safe and recommended for NPU matmul efficiency.

Suggested change
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_fp = layer.weight.data
if weight_fp.dtype not in (torch.float16, torch.bfloat16):
weight_fp = weight_fp.to(torch.bfloat16)
# Move weight to NPU if needed. We intentionally use a conditional
# move rather than an assert because `dit_cpu_offload` defaults to
# True in ServerArgs, which causes fsdp_load to move every parameter
# back to CPU after loading (even when the target device is NPU).
# npu_dynamic_mx_quant requires an NPU tensor, so we must transfer
# here. The quantized fp8 weights produced below will remain on NPU
# for inference; if the model still needs to be offloaded after
# quantization (e.g. very large model on a small NPU), a higher-level
# offload pass can move them back afterwards.
if not weight_fp.is_npu:
weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}")
# Online MXFP8 quantisation of weights (block_size=32)
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
layer.weight = Parameter(qw, requires_grad=False)
layer.weight_scale_inv = Parameter(w_scale, requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_fp = layer.weight.data
if weight_fp.dtype not in (torch.float16, torch.bfloat16):
weight_fp = weight_fp.to(torch.bfloat16)
# Move weight to NPU if needed.
if not weight_fp.is_npu:
weight_fp = weight_fp.to(f"npu:{torch.npu.current_device()}")
# Online MXFP8 quantisation of weights (block_size=32)
qw, w_scale = torch_npu.npu_dynamic_mx_quant(
weight_fp, dst_type=torch_npu.float8_e4m3fn
)
# Pre-transpose to [in, out] for npu_quant_matmul (avoid per-call transpose)
layer.weight = Parameter(qw.transpose(0, 1).contiguous(), requires_grad=False)
layer.weight_scale_inv = Parameter(w_scale.transpose(0, 1).contiguous(), requires_grad=False)

Comment on lines +154 to +164
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale_inv.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the matmul call to use the pre-transposed parameters and the robust dtype fallback.

Suggested change
output = torch_npu.npu_quant_matmul(
qx,
layer.weight.transpose(0, 1),
layer.weight_scale_inv.transpose(0, 1),
scale_dtype=torch_npu.float8_e8m0fnu,
pertoken_scale=input_scale,
pertoken_scale_dtype=torch_npu.float8_e8m0fnu,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)
# MXFP8 matmul (weight & scale already transposed at load time)
output = torch_npu.npu_quant_matmul(
qx,
layer.weight,
layer.weight_scale_inv,
scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
pertoken_scale=input_scale,
pertoken_scale_dtype=_FLOAT8_E8M0FNU_DTYPE,
bias=bias.to(torch.float32) if bias is not None else None,
output_dtype=original_dtype,
group_sizes=[1, 1, MXFP8_BLOCK_SIZE],
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

diffusion SGLang Diffusion npu quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant