Skip to content

[Feature] Integrate Nunchaku SVDQuant W4A4 for diffusion models#1986

Open
ultism wants to merge 1 commit into
vllm-project:mainfrom
ultism:svdquant-pr
Open

[Feature] Integrate Nunchaku SVDQuant W4A4 for diffusion models#1986
ultism wants to merge 1 commit into
vllm-project:mainfrom
ultism:svdquant-pr

Conversation

@ultism
Copy link
Copy Markdown

@ultism ultism commented Mar 18, 2026

Summary

  • Integrate Nunchaku as a quantization backend for diffusion transformers, enabling W4A4 inference with SVD low-rank correction.
  • Verified on Z-Image-Turbo: ~2.2x speedup over BF16 on RTX 5090 with comparable image quality.

Motivation

SVDQuant (W4A4) provides significant inference speedup and reduced memory footprint for DiT models. Nunchaku's PTX-optimized kernels are community-proven (FLUX, Qwen-Image) and lightweight enough to integrate as an optional backend.

The main blocker during integration was weight key mapping: Nunchaku checkpoints use diffusers-style naming while vLLM models use different conventions, and the naming is not standardized across models (Z-Image: w13/w2, Flux: linear_in/linear_out, QwenImage: no remap needed). This mapping must currently be hardcoded per-model in load_weights, which is the primary effort when adding new model support.

Additionally, Nunchaku's weight format is highly optimized (tiled/interleaved MMA layout via PTX assembly), so the glue code (weight packing, activation swap, shape calculations) is tightly coupled to Nunchaku's internal layout. This means weight-level manipulation (e.g. row-swapping for SwiGLU convention) is not possible — we handle this via runtime output swap instead.

Changes

  • NunchakuConfig / NunchakuLinearMethod (svdq_nunchaku.py): vLLM quantization plugin with W4A4 GEMM + SVD low-rank correction. Quantizes QKV, MergedColumnParallel, and RowParallel layers; leaves ReplicatedLinear (adaLN, embedders) unquantized.
  • Gated-activation output swap: Nunchaku checkpoints (from diffusers) store merged gate+up weights in diffusers order [linear ; activation], while vLLM's SiluAndMul expects [activation ; linear]. Applied automatically at runtime in NunchakuLinearMethod.apply() for all MergedColumnParallelLinear layers.
  • DiffusionNunchakuConfig (nunchaku.py): Per-model weight key mapping table for translating diffusers-style naming to vLLM conventions.
  • Z-Image support: key remapping (net.0.projw13, net.2w2) in load_weights, fixed stacked_params_mapping substring collision (.w1 falsely matching .w13).
  • Example script text_to_image_quant.py.

Quantized Model

Quality Comparison (RTX 5090, seed=42, Z-Image-Turbo 1024x1024)

BF16 (13.4s) Nunchaku W4A4 nvfp4 (6.0s, 2.2x faster)
bf16 quant

Follow-up Plans

  • Auto-infer rank/precision from safetensors file metadata (currently must be specified manually via --rank / --precision). Nunchaku checkpoints embed quantization_config (including rank, group_size, method) and model_class in safetensors metadata — the same mechanism Nunchaku's own from_pretrained uses. This would eliminate the need for users to specify these parameters manually.
  • Auto key mapping: derive weight name mapping from Nunchaku model metadata on meta device, eliminating per-model hardcoding
  • CI/CD tests: unit tests for weight loading, key remapping, and E2E inference

Test Plan

  • E2E quantized Z-Image-Turbo inference (RTX 5090, RTX 5060 Ti)
  • BF16 vs quantized visual quality comparison (same seed, same GPU)
  • CPU offload compatibility verified

Closes #507

@ultism ultism requested a review from hsliuustc0106 as a code owner March 18, 2026 17:52
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a2cefeb602

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +419 to +423
set_weight_attrs(qweight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": default_weight_loader,
})
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Shard Nunchaku tensors before calling default_weight_loader

When tensor_parallel_size > 1, create_weights() allocates qweight (and the analogous wscales/proj_* tensors below) at partition-local shapes, but ZImageTransformer2DModel.load_weights() later calls each parameter's weight_loader directly and this loader is still default_weight_loader or a thin wrapper around it. default_weight_loader only copies exact-size tensors, so a full Nunchaku checkpoint cannot be loaded into rank-local storage; the new TP path will fail as soon as users enable tensor parallelism.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Acknowledged. This is a known limitation — TP support for Nunchaku quantized layers is not implemented yet. Current diffusion models don't use TP with quantization, so this is deferred to a follow-up PR when TP + quantization is needed.

Comment on lines +447 to +451
torch.empty(
input_size_per_partition,
rank,
dtype=lora_dtype,
),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Pad or validate low-rank dimensions before creating proj_down

pack_lowrank_weight() pads low-rank weights to a multiple of 16, but proj_down/proj_up are still allocated with the unpadded rank. For any checkpoint whose rank is not already divisible by 16, the packed tensor becomes larger than the destination parameter and default_weight_loader will raise during load. Since the config and example CLI accept arbitrary --rank values, non-16 ranks are currently impossible even though the packing path is explicitly trying to support them.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not an issue in practice. Nunchaku only produces checkpoints with rank as a multiple of 16 (32, 64, 128). This is a constraint from the upstream Nunchaku library itself — their CUDA kernels require 16-aligned rank. The pad logic is a defensive measure but will never be triggered with real checkpoints.

Comment on lines +332 to +333
has_quant = self.od_config and getattr(self.od_config, "quantization_config", None) is not None
if loaded_weights is not None and not has_quant:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep strict load checks for required quantized tensors

ZImageTransformer2DModel.load_weights() now skips any checkpoint entry it cannot map, and this new not has_quant guard disables the only global completeness check for every quantized model. That means a mismatched Nunchaku checkpoint, unsupported model, or missing required tensor like qweight will now "load" without an error and leave the layer using whatever placeholder/uninitialized parameter was created. It would be safer to exempt only the truly optional extras instead of turning off strictness wholesale.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional and required. Nunchaku checkpoints inherently have a non-uniform parameter structure: wtscale exists only for 102 out of 136 quantized layers (absent from QKV layers), and wcscales exists only for 34 layers (QKV only). Our vLLM integration creates these parameters uniformly for all quantized layers and fills missing ones with defaults (1.0) during process_weights_after_loading.

Nunchaku's own from_pretrained (in nunchaku/models/transformers/transformer_zimage.py) uses load_state_dict(strict=True) successfully because its model class defines parameters to exactly match the checkpoint layout. Replicating that exact parameter structure in vLLM would require coupling the model definition to Nunchaku's per-layer-type conventions, which would be far more complex and fragile than the current approach.

The strict check is only disabled for quantized models; non-quantized models retain full strict checking.

…ion models

Integrate Nunchaku library as a quantization backend for diffusion
transformers, enabling W4A4 inference with SVD low-rank correction.

Key changes:

- NunchakuConfig / NunchakuLinearMethod: vLLM quantization plugin that
  quantizes QKV, MergedColumnParallel, and RowParallel layers using
  Nunchaku's PTX-optimized W4A4/W4A16 CUDA kernels with SVD low-rank
  correction. Leaves ReplicatedLinear (adaLN, embedders) unquantized.

- Gated-activation output swap in NunchakuLinearMethod.apply(): Nunchaku
  checkpoints (quantized from diffusers) store merged gate+up weights in
  diffusers order [linear ; activation], while vLLM's SiluAndMul expects
  [activation ; linear]. The qweight uses a tiled MMA layout that prevents
  row-swap, so we swap the output halves at runtime instead. Applied
  automatically for all MergedColumnParallelLinear layers.

- DiffusionNunchakuConfig with per-model weight key mapping table for
  translating diffusers-style naming to vLLM model conventions.

- Z-Image support: key remapping (net.0.proj -> w13, net.2 -> w2) in
  load_weights, fixed stacked_params_mapping substring collision.

- Example script text_to_image_quant.py for quantized inference.

Verified on RTX 5090: ~2.2x speedup over BF16 with comparable image quality.

Closes vllm-project#507

Signed-off-by: ultranationalism <www913363043@gmail.com>
@lishunyang12
Copy link
Copy Markdown
Collaborator

I think the image quality presevation is pretty good. Can you have more image comparision between W4A4 and BF16 and use LPIPs to quantify the difference? You can refer to #1470

@lishunyang12
Copy link
Copy Markdown
Collaborator

lishunyang12 commented Mar 19, 2026

This method is dedicated for diffusion models so i think it may not be realistic to implement on vllm side and omni reuses. But, my concern is that there are some invasive changes on model pipeline files and made some customized linear work. @ZJY0516 @Isotr0py PTAL.

@ultism
Copy link
Copy Markdown
Author

ultism commented Mar 19, 2026

I think the image quality presevation is pretty good. Can you have more image comparision between W4A4 and BF16 and use LPIPs to quantify the difference? You can refer to #1470

Thanks for the thoughtful review. One thing I want to clarify before we commit to the upstream path: I'm not sure it would reduce the maintenance burden as much as expected.

The glue code here — key mapping, weight layout handling — isn't a consequence of where the code lives architecturally. It's an inherent property of Nunchaku: the library uses a PTX-tiled MMA weight format and diffusion-specific checkpoint naming that isn't standardized and doesn't map cleanly to generic abstractions. These would need to exist in upstream vLLM just as much as here.

I'm genuinely happy to pursue the upstream path if that's the direction. But I think the honest framing is: supporting Nunchaku comes with a fixed complexity floor regardless of where it's integrated. The question is whether the 2x+ speedup on diffusion workloads is worth carrying that cost. If the team decides it's not, I understand — but I'd rather that be an explicit decision than one deferred indefinitely on the assumption that the problem will get easier.

@ultism
Copy link
Copy Markdown
Author

ultism commented Mar 19, 2026

This method is didicated for diffusion models so i think it may not be realistic to implement on vllm side and omni reuse. But, my concern is that there are some invasive changes on model pipeline files and made some customized linear work. @ZJY0516 @Isotr0py PTAL.

After working through this PR, I want to add some more context on why "upstream first" doesn't actually simplify things — and why Nunchaku support may not be viable at all in any framework.

The real problem is that there are only two paths to SVDQuant integration, and both lead to the same wall:

Option A: Implement SVDQuant independently, bypass Nunchaku's weight format.
This breaks compatibility with all existing quantized checkpoints — every SVDQuant checkpoint currently available was produced by Nunchaku's toolchain and uses its tiled MMA layout. There's no community-standard alternative format. Without checkpoint compatibility, there are no users.

Option B: Add weight remapping and layout translation to bridge Nunchaku checkpoints.
The glue code remains exactly as complex. And critically: after remapping, the weights still can't be fed into Nunchaku's kernel, because those operators are built around the tile-based layout. You'd need to reimplement the W4A4 kernels from scratch anyway — at which point you've paid the full cost of Option A on top of the remapping work.

Both paths converge on the same conclusion: Nunchaku's kernel is unusable outside its own layout, so anyone wanting SVDQuant support in a general framework has to write their own W4A4 CUDA kernels for diffusion transformers. That's not a quantization backend integration anymore — it's building a new quantization stack.

@lishunyang12
Copy link
Copy Markdown
Collaborator

lishunyang12 commented Mar 19, 2026

This method is didicated for diffusion models so i think it may not be realistic to implement on vllm side and omni reuse. But, my concern is that there are some invasive changes on model pipeline files and made some customized linear work. @ZJY0516 @Isotr0py PTAL.

After working through this PR, I want to add some more context on why "upstream first" doesn't actually simplify things — and why Nunchaku support may not be viable at all in any framework.

The real problem is that there are only two paths to SVDQuant integration, and both lead to the same wall:

Option A: Implement SVDQuant independently, bypass Nunchaku's weight format. This breaks compatibility with all existing quantized checkpoints — every SVDQuant checkpoint currently available was produced by Nunchaku's toolchain and uses its tiled MMA layout. There's no community-standard alternative format. Without checkpoint compatibility, there are no users.

Option B: Add weight remapping and layout translation to bridge Nunchaku checkpoints. The glue code remains exactly as complex. And critically: after remapping, the weights still can't be fed into Nunchaku's kernel, because those operators are built around the tile-based layout. You'd need to reimplement the W4A4 kernels from scratch anyway — at which point you've paid the full cost of Option A on top of the remapping work.

Both paths converge on the same conclusion: Nunchaku's kernel is unusable outside its own layout, so anyone wanting SVDQuant support in a general framework has to write their own W4A4 CUDA kernels for diffusion transformers. That's not a quantization backend integration anymore — it's building a new quantization stack.

Thanks for your insight :) I think that is quite a lot of work if we need to write our own W4A4 CUDA kernels. Can you please help check how many parts we can resuse from vLLM upstream and what parts are unavoidable to be customized on our side?

@ultism
Copy link
Copy Markdown
Author

ultism commented Mar 19, 2026

Thanks for the question. Here's a breakdown of what can be reused from vLLM upstream and what must stay on the omni side.

Architecture-wise, Nunchaku's integration pattern is analogous to how vLLM upstream already integrates external quantization kernel libraries like Marlin (for GPTQ/AWQ) or DeepGEMM (for FP8). The overall structure — QuantizationConfig + LinearMethodBase with create_weights() / apply() / process_weights_after_loading() — is fully reused. The custom parts are algorithm-specific: SVDQuant's low-rank correction requires 6–9 parameters per layer (vs. Marlin's 4), and uses Nunchaku's own CUDA kernels (svdq_gemm_w4a4_cuda, etc.). This is the same kind of external kernel dependency that Marlin and DeepGEMM already represent in upstream vLLM. Nunchaku is scoped entirely under diffusion/, touching only 2–3 core files.

Here's how the components break down:

Component Where it should live Reasoning
Key mapping (diffusers → vLLM naming) vLLM-Omni (here) Diffusion-specific logic. Diffusers checkpoint naming conventions don't exist in the LLM world, so this naturally belongs on the omni side. Currently hardcoded per-model; a future improvement could auto-derive the mapping from Nunchaku's checkpoint metadata (by matching param shapes + suffixes between the Nunchaku model and vLLM model on meta device). This necessarily import nunchaku — but if you don't have nunchaku installed, you can't run SVDQuant anyway, so the dependency is already mandatory.
Quantization kernel calls (GEMM, activation quantization) Could live in vLLM upstream These follow the standard LinearMethodBase.apply() pattern, same as Marlin/DeepGEMM. Could be upstreamed as a quantization backend if desired. This achieves the decoupling and lightweight-invocation goal.

The one component that cannot be cleanly decoupled is the SwiGLU activation order reversal.

The root cause: Nunchaku checkpoints are quantized from diffusers models, where gated activations use hidden, gate = proj(x).chunk(2) (linear-first). vLLM uses silu(x[:d]) * x[d:] (activation-first). The weight rows need to be swapped — but the quantized weights are stored in Nunchaku's tiled MMA layout, so you can't simply reorder the rows.

There are three options, and none is perfect:

  1. Current approach: implicit assumption in apply() — Swap the output halves at runtime when the layer is MergedColumnParallelLinear. This assumes all such layers are gated FFNs. Minimal pipeline invasion, but creates a hidden functional coupling. This is what the current PR does.

  2. Explicit pipeline modification — Modify the model code to handle the activation order difference directly. This is what Nunchaku's own repository does (it explicitly patches the model's forward pass). More correct, but invasive to pipeline files — exactly the concern raised.

  3. Weight re-layout at load time — Rearrange both the tiled NvFP4 quantized matrix (qweight) and the tiled BF16 scale/projection matrices simultaneously during weight loading to match vLLM's convention. This would eliminate the runtime assumption, but the code to reverse-engineer and re-tile Nunchaku's MMA-packed layout is fragile and tightly coupled to Nunchaku's internal packing format — any upstream format change would silently break it.

My recommendation is to keep option 1 (the current approach) for now. The implicit assumption holds for all current diffusion models (Z-Image, Flux, HunyuanImage, etc.) and is documented in the code. If a future model breaks this assumption, we can handle it per-model at that point.

Copy link
Copy Markdown
Collaborator

@lishunyang12 lishunyang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments:

  1. Dead code: transform_weight (quantization/base.py, quantization/nunchaku.py) — defined but nothing in the loading pipeline calls it. Actual key remapping is hardcoded in z_image_transformer.py:load_weights. Either wire it into the loader or remove it.

  2. Strict weight validation disabled for ALL quantized models (diffusers_loader.py) — a genuinely missing weight in a quantized checkpoint would be silently ignored. Consider having the quant config declare expected extra/missing params instead of blanket-disabling.

  3. Nunchaku-specific key remapping in model code (z_image_transformer.py) — couples the model to a specific quantization backend. Should live in the quant config or at least be conditioned on whether Nunchaku is active.

@ultism
Copy link
Copy Markdown
Author

ultism commented Mar 23, 2026

Submitted an RFC to vLLM upstream to host the core quantization backend (QuantizationConfig + LinearMethodBase): vllm-project/vllm#37908

If accepted, this PR would be simplified to only the diffusion-specific glue code (key mapping, SwiGLU activation order handling, model integration).

# create extra parameters (e.g. wtscale, wcscales) that don't have
# corresponding entries in the checkpoint.
has_quant = self.od_config and getattr(self.od_config, "quantization_config", None) is not None
if loaded_weights is not None and not has_quant:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it would be better to keep the checks strict where possible and not loosen it for every quant method

# Each entry: "source_key_fragment": ("target_key_fragment", swap_swiglu)
# - swap_swiglu=True: swap the two halves of the merged gate+up weight
# to account for SwiGLU activation order difference.
_MODEL_KEY_MAPPING: dict[str, dict[str, tuple[str, bool]]] = {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having this here and stored on the DiTs is confusing, it would be best to avoid duplicating this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC]: Integrate Nunchaku to Support SVDQuant (W4A4) for Diffusion Models

3 participants