Skip to content

[Diffusion][Quantization] SVDQuant W4A4 (Nunchaku) for Z-Image-Turbo#3830

Open
ultism wants to merge 7 commits into
vllm-project:mainfrom
ultism:svdquant-converter
Open

[Diffusion][Quantization] SVDQuant W4A4 (Nunchaku) for Z-Image-Turbo#3830
ultism wants to merge 7 commits into
vllm-project:mainfrom
ultism:svdquant-converter

Conversation

@ultism
Copy link
Copy Markdown

@ultism ultism commented May 23, 2026

[Diffusion][Quantization] SVDQuant W4A4 (Nunchaku) for Z-Image-Turbo

Summary

Integrate SVDQuant W4A4 NVFP4 (the Nunchaku family) as an offline-quantized backend for diffusion transformers, validated on Z-Image-Turbo on RTX 5090 (consumer Blackwell, SM_120).

Headline (1024×1024, 20 steps, seed=42, batch=1):

  • 2.24× speedup vs BF16 (11.07s → 4.94s)
  • -29% peak VRAM (24.26 → 17.14 GiB), -34% weights (20.87 → 13.74 GiB)

Why this is a new PR (not duplicate of #1986)

The closing comment on #1986 explained the three things that changed materially:

  1. On-disk format pivoted to canonical row-major NVFP4. Offline converter (vllm_omni/quantization/tools/convert_nunchaku_to_svdquant.py) emits a backend-agnostic checkpoint; the nunchaku PTX-MMA tile fragment layout is now produced at load time inside the backend, not baked into the on-disk weights. No more runtime MergedColumnParallelLinear output-half swap.
  2. Config + LinearMethod + backend live in vllm-omni, not vllm. The earlier plan was a split (config in vllm, glue in vllm-omni), but per review on the in-vllm PR (vllm-project/vllm#43471, now closed) the consensus was that SVDQuant — being a Python wrapper around nunchaku's W4A4 CUDA kernels with a consumer-GPU-only envelope — fits vllm-omni's "diffusion-side caller" pattern alongside DiffusionInt8Config / DiffusionMXFP4Config / DiffusionMXFP8Config rather than vllm proper. This PR is now the single home for SVDQuant in the vllm ecosystem.
  3. Reviewer concerns from [Feature] Integrate Nunchaku SVDQuant W4A4 for diffusion models #1986 are addressed declaratively, not papered over (declarative quantization_config in transformer/config.json, strict missing-weight validation preserved via ComponentQuantizationConfig returning UnquantizedLinearMethod for non-quantized components, model-side key remapping reduced to one trailing-dot fix in stacked_params_mapping).

What's in this PR

svdquant-converter branch, 7 commits:

commit scope
ea7325bf Config + LinearMethod + nunchaku backend infrastructure (migrated from closed vllm-project/vllm#43471). New files: vllm_omni/quantization/svdquant_config.py (DiffusionSVDQuantConfig + LinearMethod, backend-agnostic), svdquant_dispatch.py (hardware gate + select_backend), svdquant_nunchaku.py (nunchaku capability detection + lazy importlib wrappers + prepare_weights / apply). Promotes tools/svdquant_nvfp4_layout.py from re-export shim to real implementation. Factory registers "svdquant". 12 tests in tests/diffusion/quantization/test_svdquant_config.py. Converter also drops nunchaku's unused smooth_factor_orig suffix at group time (nunchaku itself marks it (Unused) in nunchaku/models/linear.py:54).
3d7ff30a Pre-commit (ruff format) cleanup
762caa48 Bench: use nvidia-smi for GPU memory (spawn-mode worker isolation) + zero-division guard
1aa18b06 Bench: --baseline-model + --quantization auto sentinel, for offline-quantized checkpoints where BF16 and quant trees are separate paths
605b342a Converter: modules_to_not_convert: ["lm_head"] so Qwen3 text-encoder's top-level LM head doesn't fall through to SVDQuant
a8e637d3 Bench: emit Memory Profiling table (PR #1470 layout)
edad5f35 Offline converter (convert_nunchaku_to_svdquant.py, 549 lines) + per-component config wiring + Z-Image transformer trailing-dot key fix + example label resolution

The 4 [Benchmark] commits are bench scaffolding. Happy to split them into a separate [Benchmark] PR if reviewers prefer (the nvidia-smi memory fix in particular is a latent bug that hits every offline-quant PR going through spawn-mode workers).

Quantized checkpoints

Both produced by the included converter from the original nunchaku-published merged safetensors:

Test Plan

# Bench command — produces all 3 tables + 16 side-by-side PNGs
python benchmarks/diffusion/quantization_quality.py \
  --baseline-model Tongyi-MAI/Z-Image-Turbo \
  --model ultranationalism/nunchaku-z-image-turbo-svdq \
  --task t2i \
  --quantization auto \
  --prompts \
    "a close-up portrait of an elderly fisherman with weathered skin and a thick gray beard, soft natural light" \
    "an aerial view of a coral reef with crystal clear turquoise water" \
    "extreme close-up of a dewdrop on a red rose petal, morning sunlight" \
    "a bustling night market in Tokyo with neon signs, rain-slicked streets, and crowds with umbrellas" \
    "a vintage bookstore storefront with the sign CLASSICS AND RARE EDITIONS in elegant gold lettering" \
    "a campfire in a dark forest with sparks rising into a starry sky" \
    "a ballet dancer in mid-leap on an empty theater stage, dramatic spotlight from above" \
    "a cup of coffee on a wooden table, morning light" \
  --height 1024 --width 1024 \
  --num-inference-steps 20 \
  --seed 42 \
  --lpips-net alex \
  --output-dir ./svdquant_bench_output

Test Result

  • GPU: NVIDIA RTX 5090 D (Blackwell SM_120, 32 GiB) — primary bench target
  • Backend dispatched: nunchaku (the only in-tree backend in this PR; see Roadmap below for SM_100/103 + Ascend plans)
  • Stack: vLLM + this PR's svdquant-converter branch + nunchaku 1.2.1+cu12.8torch2.11 + PyTorch 2.11.0+cu128

Summary

Config Avg Time Speedup Memory (GiB) Mem Reduction Mean LPIPS
BF16 baseline 11.07s 1.00× 24.26 (ref)
SVDQuant W4A4 NVFP4 (nunchaku backend) 4.94s 2.24× 17.14 -29% 0.2324

Memory Profiling

First-prompt snapshot at 1024×1024, 20 steps, TP=1. Memory read via nvidia-smi --query-gpu=memory.used (vllm-omni spawn-mode workers have their own CUDA contexts; the bench driver process sees 0 GiB allocated).

Config Weights Activations Peak Total Reduction
BF16, TP=1 20.87 GiB 3.39 GiB 24.26 GiB
SVDQuant, TP=1 13.74 GiB 3.40 GiB 17.14 GiB -29%

Activations are unchanged (text-encoder activations + diffusion latents are not quantized). The 7.13 GiB weights delta is the entire SVDQuant win.

Spot-checked on a smaller box: RTX 5060 Ti (SM_120, 16 GiB) with --enable-cpu-offload runs 512×512 / 8 steps in 5.9 s, peak VRAM 8.5 GiB. Confirms backend dispatch + nunchaku weight-prep + apply path work under model CPU-offload, not just dense GPU residency.

Per-Prompt LPIPS (alex backbone)

Prompt LPIPS
a close-up portrait of an elderly fisherman with weathered skin... 0.276
an aerial view of a coral reef with crystal clear turquoise water 0.210
extreme close-up of a dewdrop on a red rose petal, morning sunlight 0.394
a bustling night market in Tokyo with neon signs, rain-slicked streets... 0.312
a vintage bookstore storefront with the sign CLASSICS AND RARE EDITIONS... 0.161
a campfire in a dark forest with sparks rising into a starry sky 0.192
a ballet dancer in mid-leap on an empty theater stage, dramatic spotlight... 0.160
a cup of coffee on a wooden table, morning light 0.155
mean 0.232

Quality trade-off (honest framing)

W4A4 is significantly more aggressive than the W8A8 baselines that other offline-quant PRs (ModelOpt FP8 #2913, MXFP8 #3140) report on. LPIPS scores reflect that:

These higher LPIPS values are inherent to W4A4 with SVD low-rank correction (rank=128) — the algorithm trades quality for compression more aggressively than W8A8. The PR's value proposition is 2.24× speedup + 34% weights compression, not strict pixel parity with BF16. Users who need higher fidelity should stay on the BF16 or int8/FP8 paths.

Visual Gallery

Z-Image-Turbo — BF16 vs SVDQuant W4A4 NVFP4 (8 prompts, same seed)
# Prompt BF16 SVDQuant
0 elderly fisherman portrait
1 aerial coral reef
2 dewdrop macro (worst LPIPS)
3 Tokyo night market
4 vintage bookstore (text rendering)
5 campfire in dark forest
6 ballet dancer mid-leap
7 coffee on wooden table (best LPIPS / #1986 reference)

Roadmap (forward-looking; not in this PR)

The dispatch architecture in svdquant_dispatch.py is built so new backends drop in as siblings — each backend is a single module exposing three functions (supports(cap, precision) -> bool, prepare_weights(layer, precision), apply(layer, x, bias)), and select_backend() returns the first one that claims the active platform. Adding a new backend requires zero changes to DiffusionSVDQuantLinearMethod or the on-disk format. The on-disk canonical row-major NVFP4 (or INT4-nibble) layout is the explicit cross-backend contract — one checkpoint serves all of them.

Status Backend Hardware Module Notes
✅ Shipped (this PR) nunchaku Consumer NVIDIA: SM_75 Turing, SM_80/86/89 Ampere/Ada, SM_120 consumer Blackwell svdquant_nunchaku.py PTX-MMA fragment layout; prepare_weights repacks row-major → fragment at load time. Hopper SM_90 deliberately excluded (no validated kernel family).
🛠️ Planned flashinfer Datacenter Blackwell: SM_100 (B200), SM_103 (GB300) svdquant_flashinfer.py (TBD) Native CuTe DSL W4A4 kernel landing in FlashInfer so SGLang and vllm-omni share the same primitive. Consumes the on-disk canonical row-major NVFP4 directly — no second checkpoint needed.
🛠️ Planned npu (Ascend 910x) Huawei Ascend 910 / 910B / 910C via torch_npu svdquant_npu.py (TBD) Mirrors the existing Diffusion{Int8,MXFP4,MXFP8}Config NPU path: capability detection via current_omni_platform.is_npu(), kernel call via torch_npu.npu_*-family ops. Requires Huawei's W4A4 SVDQuant-equivalent primitive to ship in their CANN release; the on-disk format and dispatcher entry are ready ahead of time.

Each future backend reuses the converter, the LinearMethod, the factory registration, the test suite, and the hardware gate. Only the GEMM/quantize call sites are new.

Other follow-ups (not in this PR)

  • Per-component quantization config (RFC #1044 OmniDiffusionConfig.quantization_targets): the current modules_to_not_convert: ["lm_head"] skip is a substring escape hatch. The cleaner long-term fix mirrors PR fix: do not apply FP8 quant config to vision/audio encoders for pre-quantized checkpoints #2702 (Qwen3-Omni encoder fix): pass quant_config=None explicitly to non-quantized subcomponents at pipeline construction in pipeline_z_image.py. That requires plumbing quant_config through create_transformers_model and is orthogonal to this PR.
  • Bench scaffolding extraction: if reviewers prefer, the 4 [Benchmark] commits in this PR can move to a separate fix-only PR.

Closes / Refs


AI assistance: this PR's commits and PR description were produced with Claude Code assistance. Every change was reviewed and validated end-to-end on RTX 5090 + RTX 5060 Ti by the human submitter before push.

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.
Credits must be used to enable repository wide code reviews.

ultism and others added 5 commits May 23, 2026 17:14
…ponent glue

Pivot from the closed PR vllm-project#1986 design (runtime nunchaku-format glue) to
the post-RFC architecture where:
  - vLLM upstream (vllm-project/vllm) hosts the SVDQuant quantization
    config, linear method, dispatcher, and native SM_100/103 CuTe DSL
    kernel.
  - vllm-omni hosts only diffusion-specific glue and a one-time offline
    converter that emits canonical row-major NVFP4 checkpoints.

Components:

  - `vllm_omni/quantization/tools/convert_nunchaku_to_svdquant.py`:
    Standalone converter that ingests nunchaku-published merged
    safetensors and emits a vLLM-loadable diffusers pipeline tree in
    canonical row-major + FP4 nibble pack. This is the layout the
    SM_100 native CuTe kernel consumes directly; for the nunchaku
    backend on consumer GPUs (SM_75-89 / SM_120), vLLM repacks to the
    PTX-MMA tile layout at load time in
    `SVDQuantLinearMethod.process_weights_after_loading`. The on-disk
    format is backend-agnostic.

  - `vllm_omni/quantization/tools/svdquant_nvfp4_layout.py`: Thin
    re-export shim. The actual layout adapters live in
    `vllm/model_executor/layers/quantization/utils/svdquant_nvfp4_layout.py`
    in vLLM proper; this file keeps the import surface stable for
    downstream code that referenced the original vllm-omni location.

  - `vllm_omni/quantization/component_config.py`: per-component
    quantization config wiring so per-pipeline-component (transformer,
    text_encoder, vae, etc.) quant config can be declared declaratively
    in `transformer/config.json["quantization_config"]` rather than
    runtime monkey-patching. Addresses the "blanket strict-validation
    disable" review concern from vllm-project#1986.

  - `vllm_omni/diffusion/models/z_image/z_image_transformer.py`: trailing
    -dot fix in `stacked_params_mapping` (replaces the per-model
    diffusers->vLLM key remapping from the closed PR; reduced to one
    line under the canonical row-major design).

  - `examples/offline_inference/text_to_image/text_to_image.py`: smarter
    quantization label resolution that mirrors
    `OmniDiffusionConfig._propagate_quantization_from_tf_config` so the
    startup banner reflects on-disk per-component quant config rather
    than printing "None (BF16)" for an already-quantized checkpoint.

Canonical checkpoint produced by this converter:
  - HuggingFace: https://huggingface.co/ultranationalism/nunchaku-z-image-turbo-svdq
  - ModelScope: https://www.modelscope.cn/models/ultranationalism/Z-Image-Turbo-SVDQuant-NVFP4

Test plan (pending validation on consumer Blackwell SM_120):
  - E2E quantized Z-Image-Turbo inference on RTX 5090
  - BF16 vs SVDQuant LPIPS quality benchmark per PR vllm-project#1470 template

Refs: vllm-project#1986 (closed), RFC vllm-project/vllm#37908

AI assistance: this commit was produced with Claude Code assistance.

Signed-off-by: ultranationalism <www913363043@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: ultism <www913363043@gmail.com>
Per PR vllm-project#1470's review template, "量化展示" requires two tables:
the Summary table (already emitted) and a Memory Profiling table
that breaks Peak into Weights + Activations.

Capture the extra weights/activations numbers in `_generate_image`
and `_generate_video` by snapshotting `memory_allocated()` right
before each `generate()` call (= weights + persistent buffers
already on device) and subtracting it from the post-generate
`max_memory_allocated()` to get the activations delta.

Surface the values in `run_benchmark` as a third markdown table
("### Memory Profiling") with the columns PR vllm-project#1470 used:
Weights / Activations / Peak / Total Reduction, broken down by
TP size (from `args.tensor_parallel_size`).

First-prompt snapshot is canonical, matching the existing Peak
column's "use first prompt's memory" convention.

Signed-off-by: ultranationalism <www913363043@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: ultism <www913363043@gmail.com>
Validated on RTX 5090 (consumer Blackwell, SM_120): the converter's
emitted `transformer/config.json["quantization_config"]` is read by
vllm-omni as a *global* config (`OmniDiffusionConfig.quantization_config`,
see `data.py:_propagate_quantization_from_tf_config`) and applied to
every component's nn.Linear, not just the Z-Image DiT. The
`{"model": None}` prefix rule masks `Qwen3ForCausalLM.model.layers.*`
but does NOT cover `Qwen3ForCausalLM.lm_head`, which lives at the
top level — so it fell through to the SVDQuant default and tripped a
tied-weight data_ptr error on the first text-encoder forward.

Add `"lm_head"` to the SVDQuant default's `modules_to_not_convert`.
This uses vLLM's standard HF-convention substring skip (see
`SVDQuantConfig.get_quant_method` → `is_layer_skipped(..., skip_with_substr=True)`),
matches every layer whose prefix contains `lm_head`, and is shorter
than enumerating additional component-prefix dict keys.

Long-term: per-component quant config (RFC vllm-project#1044's
`OmniDiffusionConfig.quantization_targets`) would let `text_encoder/config.json`
declare its own `quantization_config=None` directly. Until then this
substring skip is the workable escape hatch.

Signed-off-by: ultranationalism <www913363043@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: ultism <www913363043@gmail.com>
…lity.py

Adds two pieces so the bench can run BF16 baseline vs offline-quantized
checkpoints (SVDQuant, MXFP8 offline, ModelOpt FP4, etc.) where the
quantized variant ships its own pipeline tree, distinct from the BF16
reference:

  - `--baseline-model PATH`: overrides the model path used for the
    BF16 baseline run. Defaults to `--model` so existing online-quant
    flows (fp8 / int8 / bitsandbytes) keep working unchanged.

  - `--quantization auto`: sentinel meaning "do not pass
    `quantization_config` to Omni; honor the on-disk
    `transformer/config.json[\"quantization_config\"]`". Needed because
    offline-quantized checkpoints bake the method + per-layer skip
    list (`modules_to_not_convert`) into the config, and overriding
    them at the CLI would defeat the purpose.

Example (Z-Image SVDQuant NVFP4 W4A4 vs BF16 baseline):

  python benchmarks/diffusion/quantization_quality.py \\
      --baseline-model Tongyi-MAI/Z-Image-Turbo \\
      --model ultranationalism/Z-Image-Turbo-SVDQuant-NVFP4 \\
      --task t2i --quantization auto \\
      --prompts "a cup of coffee on a wooden table" \\
      --height 1024 --width 1024 \\
      --num-inference-steps 20 --seed 42

Signed-off-by: ultranationalism <www913363043@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: ultism <www913363043@gmail.com>
`torch.cuda.memory_allocated()` / `max_memory_allocated()` from the
bench driver process always reports 0 in the spawn-mode vllm-omni
worker layout: the model lives in a child process with its own CUDA
context, and the driver allocates nothing. Switching to nvidia-smi
(via `_gpu_memory_gib`) reads the device-wide `memory.used`, which on
a single-GPU benchmark equals the worker's footprint.

Also:
  - Remove the now-meaningless `reset_peak_memory_stats()` call.
  - Guard the summary table's `(bl_mem - qt_mem) / bl_mem` against
    zero (returns 0% reduction with a 0.00 GiB display when nvidia-smi
    is unavailable). Previously a `ZeroDivisionError` killed the run
    after all images were already saved — so the LPIPS scoring and
    markdown emission were lost too.
  - Flush stdout per prompt + print elapsed/memory after each generate,
    so a long-running tmux/`tee` log shows progress in real time.

Signed-off-by: ultranationalism <www913363043@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: ultism <www913363043@gmail.com>
@ultism ultism force-pushed the svdquant-converter branch from f781bb1 to 762caa4 Compare May 23, 2026 09:15
Copy link
Copy Markdown
Collaborator

@hsliuustc0106 hsliuustc0106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLOCKING:

  • Correctness — Pre-commit check is failing. Please run pre-commit run --all-files locally and fix any issues before proceeding.

Pre-commit (ruff-format) reformatted four files modified in this PR:

  * vllm_omni/quantization/tools/convert_nunchaku_to_svdquant.py:
    `json.dumps({...})` expanded to multi-line dict form; argparse `help=`
    continuation strings re-indented; blank line before `def main`.
  * vllm_omni/quantization/tools/svdquant_nvfp4_layout.py: blank line
    after module docstring.
  * benchmarks/diffusion/quantization_quality.py: minor wrap.
  * examples/offline_inference/text_to_image/text_to_image.py: minor wrap.

Also drops an unused `import sys` flagged by ruff-check in the converter.

No behavior change.

Signed-off-by: ultism <www913363043@gmail.com>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can refer to vllm_omni/quantization/int8_config.py about how to implement a vllm-independent quantization method.

imported from this module; keep the import surface stable.
"""

from vllm.model_executor.layers.quantization.utils.svdquant_nvfp4_layout import ( # noqa: F401
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just put this at vllm-omni instead of upstream vllm.

…d into vllm-omni

Migrate the SVDQuant W4A4 (Nunchaku family) quantization plumbing from
the vllm in-tree proposal (vllm-project/vllm#43471, now closed) into
vllm-omni. Per reviewer feedback on that PR (Isotr0py): "the linear
method here looks pretty much like a special vLLM-omni quant method
because it uses some custom 3rd party operator" — the caller side is
willing to take on wrapper and dispatch work, so SVDQuant lands here
next to the other Diffusion*Config siblings (Int8, MXFP4/8, GGUF, INC).

Structure (mirrors existing per-config files, but split for FlashInfer
forward-compat):

  svdquant_config.py        DiffusionSVDQuantConfig + LinearMethod;
                            backend-agnostic. `_backend` is selected at
                            __init__ via select_backend(); apply() and
                            process_weights_after_loading delegate to it.

  svdquant_dispatch.py      select_backend(precision) -> module,
                            assert_svdquant_supported() hardware gate.
                            Only this file knows the SM-to-backend
                            mapping. To add FlashInfer for SM_100/103
                            later: drop a new svdquant_flashinfer.py
                            exposing (supports, prepare_weights, apply),
                            and prepend it in _candidate_backends().

  svdquant_nunchaku.py      Nunchaku backend: has_nunchaku() capability
                            detection, lazy importlib wrappers around
                            svdq_gemm_w4a4 / svdq_quantize_w4a4_act_fuse_lora
                            (PyPI 'nunchaku' is a different project; the
                            install hint points to the GitHub releases),
                            plus prepare_weights() that repacks
                            canonical row-major NVFP4 into the
                            PTX-MMA fragment layout the kernel expects.

  tools/svdquant_nvfp4_layout.py
                            Bit-preserving fragment ↔ row-major helpers
                            for qweight / wscales. Previously a shim
                            re-exporting from vllm; now the real impl.

Factory registers a new "svdquant" entry in _OVERRIDES. The converter
is updated to (1) import the layout helpers from the new vllm-omni
local path, and (2) drop nunchaku's `smooth_factor_orig` suffix at
group time — upstream nunchaku itself marks it "Unused"
(nunchaku/models/linear.py:54), it's never consumed in either int4
or nvfp4 path, and keeping it triggers a load-time KeyError because
the LinearMethod does not register a destination parameter.

Verified end-to-end on RTX 5060 Ti (SM_120, 16 GiB) with
--enable-cpu-offload: 512x512 / 8 steps generates an image in 5.9 s,
peak VRAM 8.5 GiB. select_backend() correctly picks
vllm_omni.quantization.svdquant_nunchaku. 12/12 tests in
tests/diffusion/quantization/test_svdquant_config.py pass (registry,
factory routing, hardware gate for Hopper / datacenter Blackwell /
pre-Blackwell NVFP4, create_weights parameter layout, skip-list).

Signed-off-by: ultism <www913363043@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
Signed-off-by: ultism <www913363043@gmail.com>
@ultism ultism requested a review from yenuo26 as a code owner May 24, 2026 06:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants