✨ [diffusion][npu][quant] Add MXFP8 quantization support for Wan2.2 Diffusion on Ascend NPU#20922
Conversation
…th B) Add NPUMXFP8LinearMethod that enables --quantization mxfp8 on Ascend NPU, supporting both online (FP16/BF16 → MXFP8) and offline (serialized FP8 checkpoint) quantization via torch_npu APIs (npu_dynamic_mx_quant + npu_quant_matmul with group_sizes=[1,1,32]).
…n Ascend NPU Add MXFP8Config and NPUMXFP8DiffusionLinearMethod for the diffusion subsystem (multimodal_gen), enabling --quantization mxfp8 for Wan2.2 and other diffusion models on Ascend NPU. Also adds explicit quantization field to diffusion ServerArgs so online quantization can be specified without pre-quantized weights.
- Ensure weight tensor is on NPU device before npu_dynamic_mx_quant call - Flatten input x to 2D before quantization so input_scale is 3D (required by npu_quant_matmul) - Simplify output shape restoration logic Fixes: dimension of x1Scale(pertoken_scale) should be 3 but was 4
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for online MXFP8 quantization on Ascend NPUs, significantly enhancing the efficiency and performance of both Large Language Models (LLMs) and Wan2.2 Diffusion models within the SGLang framework. By integrating NPU-specific quantization methods and updating the model loading mechanisms, it allows users to leverage the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for MXFP8 quantization on Ascend NPU for both LLM serving and Diffusion models. It introduces new quantization methods (NPUMXFP8LinearMethod and NPUMXFP8DiffusionLinearMethod) that leverage torch_npu for online and offline quantization. The changes also include updates to server arguments and model loading logic to enable this feature, along with a new test suite. The implementation appears solid, with separate, clean implementations for the LLM and diffusion model paths. I have one suggestion to improve robustness in the LLM quantization path by ensuring the weight tensor is on the NPU device before quantization, mirroring the practice in the diffusion path.
按 reviewer 建议重构架构分层: - 在 fp8.py 新增 MXFP8LinearAscendMethod,负责权重定义(__init__、create_weights) - 简化 mxfp8_method_npu.py 中的 NPUMXFP8LinearMethod,只保留权重处理和 kernel 调用 - 改进架构分层,符合现有 NPU INT8 方法模式
Fix weight loading for msmodelslim pre-quantized MXFP8 weights: - Change weight dtype from int8 to float8_e4m3fn (actual storage format in safetensors) - Fix weight_scale shape from [out, in/32*2] to [out, in/32] (actual msmodelslim export) - Update process_weights_after_loading to reshape weight_scale [out, in/32] -> [out, -1, 2]
|
/rerun-failed-ci |
CI failure analysisI went through every failed/cancelled job. None of the failures are caused by changes in this PR. Breakdown by root cause: 1. NPU runner infra issue
2. Pre-existing NPU test failure (exists on
|
| Category | Jobs | Cause |
|---|---|---|
| NPU infra | 2 | Rust mirror network failure on self-hosted runner |
| Pre-existing NPU test | 2 | wan2_2_t2v_14b_w8a8_8npu quant_type unmapped (also fails on unrelated PRs) |
| AMD environment | 3 | NCCL init / VRAM cleanup / 1.7 ms latency overshoot |
| Pre-existing flaky GPU | 4 | Wan2.x / LTX / mova / fsdp consistency or performance (also fails on unrelated PRs) |
| Cascading fast-fail | 20 | Auto-skipped due to the above |
No action required on this PR's code. Re-running CI once the runner infra recovers should clear the cascading failures; the NPU wan2_2_t2v_14b_w8a8_8npu and the GPU consistency drifts need a separate fix on main.
Restore docs_new/index.mdx to upstream/main state. This file was modified by an automated github-actions[bot] commit (762f21f) that ran the LMSYS blog sync workflow against the fork's junlin branch, unrelated to the diffusion MXFP8 changes in this PR.
|
I merged it as we already analysis all CIs and these failed tests are unrelated to our change. please let me know if there are some other issues |
|
it won't work until #24540 is merged |
…iffusion on Ascend NPU (sgl-project#20922) Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
…iffusion on Ascend NPU (sgl-project#20922) Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
…iffusion on Ascend NPU (sgl-project#20922) Co-authored-by: ronnie_zheng <zl19940307@163.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
…#22352 conflicts Resolved 10 conflicts: - Diffusion side (7 files): used upstream/main (includes merged sgl-project#20922/sgl-project#22338) - LLM side fp8.py: kept both NPU and MUSA capability bypasses - LLM side modelslim.py: added W8A8_MXFP8 to upstream's table-driven scheme dispatch - LLM side transformers.py: used upstream/main (MoE refactoring with proper prefix)
… prerequisites - Prerequisites sgl-project#20922 (Diffusion MXFP8) and sgl-project#22338 (Diffusion MXFP4) merged upstream — accept their canonical versions and remove our duplicate diffusion modifications from this PR. - Adapt offline MXFP8 dispatch to upstream's table-driven ModelSlimConfig.get_linear_scheme by registering W8A8_MXFP8 → ModelSlimMXFP8Scheme; add a no-op __init__ on the scheme so its signature matches the other entries. - Keep NPU bypass in Fp8Config.get_min_capability alongside the new upstream _is_musa branch; revert an unrelated MoE _use_aiter style change that drifted in via earlier merges.
Summary
This PR adds MXFP8 (Microscaling FP8) quantization support for Wan2.2 diffusion models on Ascend NPU. It closes part of the NPU MXFP8 gap tracked in issue #14424.
Hardware requirement: Ascend A5 series or newer.
npu_dynamic_mx_quantis not available on A2/A3.Two modes are supported:
Online quantization (
--quantization mxfp8)MXFP8Config+NPUMXFP8DiffusionLinearMethod(multimodal_gen/runtime/layers/quantization/mxfp8_npu.py) for the diffusion subsystem.npu_dynamic_mx_quant; at inference, activations are quantized per-token and the matmul is executed bynpu_quant_matmulwithgroup_sizes=[1,1,32](block_size=32).Offline quantization (msmodelslim pre-quantized weights)
ModelSlimMXFP8Scheme(multimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.py) for loading weights pre-quantized by msmodelslim (float8_e4m3fnweights +uint8scale infloat8_e8m0fnuencoding).wan_repack.py refactor
multimodal_gen/tools/wan_repack.pyinto a one-step repack tool: copies the original HF Diffusers model, converts msmodelslim quant weights (renaming keys to Diffusers format), and restoresconfig.json— replacing a multi-step manual workflow. Fixes multiple bugs in the original script (glob patterns passed as literal paths, unconditionalquant_configkey update causingKeyError). SupportsWan2.2-TI2V-5B(single transformer) andWan2.2-T2V-A14B/Wan2.2-I2V-A14B(Cascade dual-transformer).Key NPU APIs used
torch_npu.npu_dynamic_mx_quant(x, dst_type=torch_npu.float8_e4m3fn)torch_npu.npu_quant_matmul(..., group_sizes=[1,1,32])torch_npu.float8_e4m3fn/torch_npu.float8_e8m0fnuFiles Changed
New files
multimodal_gen/runtime/layers/quantization/mxfp8_npu.pyMXFP8Config+NPUMXFP8DiffusionLinearMethod) for Wan2.2 diffusionmultimodal_gen/runtime/layers/quantization/modelslim_mxfp8_scheme.pyModelSlimMXFP8Scheme) for msmodelslim pre-quantized weightsModified — MXFP8 registration & dispatch
multimodal_gen/runtime/layers/quantization/__init__.pyMXFP8Config; add"mxfp8"toQuantizationMethodsliteralmultimodal_gen/runtime/layers/quantization/modelslim.pyW8A8_MXFP8branch →ModelSlimMXFP8Schemein_get_scheme_from_parts()Modified — CLI & loader support
multimodal_gen/runtime/server_args.py--quantizationCLI arg (explicit method override, e.g.--quantization mxfp8)multimodal_gen/runtime/loader/component_loaders/transformer_loader.py--quantizationflag; takes priority over auto-detection from config.json / metadatamultimodal_gen/runtime/loader/fsdp_load.pyweight_scaleto FSDP unused-key list (prevents crash on offline MXFP8 weight load)multimodal_gen/runtime/utils/quantization_utils.pyquant_model_description*.json(supports repacked filenames)Modified — tooling
multimodal_gen/tools/wan_repack.pyModified — minor refactor (srt)
srt/layers/quantization/fp8.pyapply_fp8_marlin_linear→torch.ops.sglang.apply_fp8_marlin_linear); restructureFp8MoEMethod.process_weights_after_loading()to move weight shuffle inside scale-processing blockwan_repack.py: Design Details
Bug Fixes
The original script contained four bugs that made it entirely non-functional:
load_sharded_safetensors()pathlib.Path(dir, "*model*.safetensors")passed directly toload_file()pathlib.Path(dir, "*.safetensors")creates a literal path with*in the filename — not a glob.load_file()does not expand globs, so every run raisesFileNotFoundError.convert_transformer()open(pathlib.Path(dir, "*quant_model_description*.json"))get_transformer_config()elsebranch — unknownmodel_typecausesNameError: name 'RENAME_DICT' is not definedif model_type == "Wan-T2V-14B", then referenced unconditionally outside.convert_transformer()update_dict_(original_quant_config, key, new_key)called unconditionally for every keydict.pop()on a missing key raisesKeyError.Fix for bugs 1 & 2: replaced glob-as-literal-path with
directory.glob(pattern), which returns a proper file list. Added existence and uniqueness checks with descriptive error messages.Fix for bug 3: added
else: raise ValueError(...)and extended support toWan2.2-I2V-A14BandWan2.2-TI2V-5B.Fix for bug 4: added
if key in quant_configguard before updating the quant description dict.One-Step Repack Workflow
Before — users had to run these steps manually:
After — single command:
python wan_repack.py \ --model-type Wan2.2-TI2V-5B \ --original-model-path Wan2.2-TI2V-5B-Diffusers \ --quant-path Wan2.2-TI2V-5B-quantized \ --output-path Wan2.2-TI2V-5B-Diffusers-MXFP8Internally, the new
repack()orchestrator runs three steps:shutil.copytree(original, output, ignore=transformer_dirs)— copies the full model (VAE, text encoder, scheduler, etc.) to the output path, skipping transformer dirs.convert_transformer()for each transformer dir — converts quantized weights todiffusion_pytorch_model.safetensorsand renames keys to HF Diffusers format.shutil.copy2(original/transformer/config.json, output/transformer/config.json)— restores the architecture config that was excluded in step 1.For Cascade models (
Wan2.2-T2V-A14B,Wan2.2-I2V-A14B), steps 2–3 repeat for bothtransformer/(sourced fromquant_path/high_noise_model/) andtransformer_2/(sourced fromquant_path/low_noise_model/). The cascade vs. single-model dispatch is driven byCASCADE_MODEL_TYPES.Summary
pathlib.Path(dir, "*.safetensors")— broken,FileNotFoundErroron every rundir.glob("*.safetensors")with existence and uniqueness checksget_transformer_config()"Wan-T2V-14B"; crashes withNameErrorfor any other typeWan2.2-T2V-A14B,Wan2.2-I2V-A14B,Wan2.2-TI2V-5B; raisesValueErroron unsupported typequant_configkey updatedict.pop()—KeyErrorfor non-quantized layersif key in quant_configguardrepack()call handles all steps--input-path,--output-pathonly--model-type,--original-model-path,--quant-path,--output-pathPerformance Comparison Report
1. Scripts
2. High-level Summary
3. Stage Breakdown
Metadata
ef874c0a1c92bf29a35e7f2e7efaf2bdaed748faef874c0a1c92bf29a35e7f2e7efaf2bdaed748faef874c0a1c92bf29a35e7f2e7efaf2bdaed748faRelated Issues
Closes part of #14424 (MXFP8/MXFP4 support on Ascend NPU for SGLang).