[Feature] Integrate Nunchaku SVDQuant W4A4 for diffusion models#1986
[Feature] Integrate Nunchaku SVDQuant W4A4 for diffusion models#1986ultism wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a2cefeb602
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| set_weight_attrs(qweight, { | ||
| "input_dim": 1, | ||
| "output_dim": 0, | ||
| "weight_loader": default_weight_loader, | ||
| }) |
There was a problem hiding this comment.
Shard Nunchaku tensors before calling default_weight_loader
When tensor_parallel_size > 1, create_weights() allocates qweight (and the analogous wscales/proj_* tensors below) at partition-local shapes, but ZImageTransformer2DModel.load_weights() later calls each parameter's weight_loader directly and this loader is still default_weight_loader or a thin wrapper around it. default_weight_loader only copies exact-size tensors, so a full Nunchaku checkpoint cannot be loaded into rank-local storage; the new TP path will fail as soon as users enable tensor parallelism.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Acknowledged. This is a known limitation — TP support for Nunchaku quantized layers is not implemented yet. Current diffusion models don't use TP with quantization, so this is deferred to a follow-up PR when TP + quantization is needed.
| torch.empty( | ||
| input_size_per_partition, | ||
| rank, | ||
| dtype=lora_dtype, | ||
| ), |
There was a problem hiding this comment.
Pad or validate low-rank dimensions before creating proj_down
pack_lowrank_weight() pads low-rank weights to a multiple of 16, but proj_down/proj_up are still allocated with the unpadded rank. For any checkpoint whose rank is not already divisible by 16, the packed tensor becomes larger than the destination parameter and default_weight_loader will raise during load. Since the config and example CLI accept arbitrary --rank values, non-16 ranks are currently impossible even though the packing path is explicitly trying to support them.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Not an issue in practice. Nunchaku only produces checkpoints with rank as a multiple of 16 (32, 64, 128). This is a constraint from the upstream Nunchaku library itself — their CUDA kernels require 16-aligned rank. The pad logic is a defensive measure but will never be triggered with real checkpoints.
| has_quant = self.od_config and getattr(self.od_config, "quantization_config", None) is not None | ||
| if loaded_weights is not None and not has_quant: |
There was a problem hiding this comment.
Keep strict load checks for required quantized tensors
ZImageTransformer2DModel.load_weights() now skips any checkpoint entry it cannot map, and this new not has_quant guard disables the only global completeness check for every quantized model. That means a mismatched Nunchaku checkpoint, unsupported model, or missing required tensor like qweight will now "load" without an error and leave the layer using whatever placeholder/uninitialized parameter was created. It would be safer to exempt only the truly optional extras instead of turning off strictness wholesale.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
This is intentional and required. Nunchaku checkpoints inherently have a non-uniform parameter structure: wtscale exists only for 102 out of 136 quantized layers (absent from QKV layers), and wcscales exists only for 34 layers (QKV only). Our vLLM integration creates these parameters uniformly for all quantized layers and fills missing ones with defaults (1.0) during process_weights_after_loading.
Nunchaku's own from_pretrained (in nunchaku/models/transformers/transformer_zimage.py) uses load_state_dict(strict=True) successfully because its model class defines parameters to exactly match the checkpoint layout. Replicating that exact parameter structure in vLLM would require coupling the model definition to Nunchaku's per-layer-type conventions, which would be far more complex and fragile than the current approach.
The strict check is only disabled for quantized models; non-quantized models retain full strict checking.
…ion models Integrate Nunchaku library as a quantization backend for diffusion transformers, enabling W4A4 inference with SVD low-rank correction. Key changes: - NunchakuConfig / NunchakuLinearMethod: vLLM quantization plugin that quantizes QKV, MergedColumnParallel, and RowParallel layers using Nunchaku's PTX-optimized W4A4/W4A16 CUDA kernels with SVD low-rank correction. Leaves ReplicatedLinear (adaLN, embedders) unquantized. - Gated-activation output swap in NunchakuLinearMethod.apply(): Nunchaku checkpoints (quantized from diffusers) store merged gate+up weights in diffusers order [linear ; activation], while vLLM's SiluAndMul expects [activation ; linear]. The qweight uses a tiled MMA layout that prevents row-swap, so we swap the output halves at runtime instead. Applied automatically for all MergedColumnParallelLinear layers. - DiffusionNunchakuConfig with per-model weight key mapping table for translating diffusers-style naming to vLLM model conventions. - Z-Image support: key remapping (net.0.proj -> w13, net.2 -> w2) in load_weights, fixed stacked_params_mapping substring collision. - Example script text_to_image_quant.py for quantized inference. Verified on RTX 5090: ~2.2x speedup over BF16 with comparable image quality. Closes vllm-project#507 Signed-off-by: ultranationalism <www913363043@gmail.com>
|
I think the image quality presevation is pretty good. Can you have more image comparision between W4A4 and BF16 and use LPIPs to quantify the difference? You can refer to #1470 |
Thanks for the thoughtful review. One thing I want to clarify before we commit to the upstream path: I'm not sure it would reduce the maintenance burden as much as expected. The glue code here — key mapping, weight layout handling — isn't a consequence of where the code lives architecturally. It's an inherent property of Nunchaku: the library uses a PTX-tiled MMA weight format and diffusion-specific checkpoint naming that isn't standardized and doesn't map cleanly to generic abstractions. These would need to exist in upstream vLLM just as much as here. I'm genuinely happy to pursue the upstream path if that's the direction. But I think the honest framing is: supporting Nunchaku comes with a fixed complexity floor regardless of where it's integrated. The question is whether the 2x+ speedup on diffusion workloads is worth carrying that cost. If the team decides it's not, I understand — but I'd rather that be an explicit decision than one deferred indefinitely on the assumption that the problem will get easier. |
After working through this PR, I want to add some more context on why "upstream first" doesn't actually simplify things — and why Nunchaku support may not be viable at all in any framework. The real problem is that there are only two paths to SVDQuant integration, and both lead to the same wall: Option A: Implement SVDQuant independently, bypass Nunchaku's weight format. Option B: Add weight remapping and layout translation to bridge Nunchaku checkpoints. Both paths converge on the same conclusion: Nunchaku's kernel is unusable outside its own layout, so anyone wanting SVDQuant support in a general framework has to write their own W4A4 CUDA kernels for diffusion transformers. That's not a quantization backend integration anymore — it's building a new quantization stack. |
Thanks for your insight :) I think that is quite a lot of work if we need to write our own W4A4 CUDA kernels. Can you please help check how many parts we can resuse from vLLM upstream and what parts are unavoidable to be customized on our side? |
|
Thanks for the question. Here's a breakdown of what can be reused from vLLM upstream and what must stay on the omni side. Architecture-wise, Nunchaku's integration pattern is analogous to how vLLM upstream already integrates external quantization kernel libraries like Marlin (for GPTQ/AWQ) or DeepGEMM (for FP8). The overall structure — Here's how the components break down:
The one component that cannot be cleanly decoupled is the SwiGLU activation order reversal. The root cause: Nunchaku checkpoints are quantized from diffusers models, where gated activations use There are three options, and none is perfect:
My recommendation is to keep option 1 (the current approach) for now. The implicit assumption holds for all current diffusion models (Z-Image, Flux, HunyuanImage, etc.) and is documented in the code. If a future model breaks this assumption, we can handle it per-model at that point. |
lishunyang12
left a comment
There was a problem hiding this comment.
Left a few comments:
-
Dead code:
transform_weight(quantization/base.py,quantization/nunchaku.py) — defined but nothing in the loading pipeline calls it. Actual key remapping is hardcoded inz_image_transformer.py:load_weights. Either wire it into the loader or remove it. -
Strict weight validation disabled for ALL quantized models (
diffusers_loader.py) — a genuinely missing weight in a quantized checkpoint would be silently ignored. Consider having the quant config declare expected extra/missing params instead of blanket-disabling. -
Nunchaku-specific key remapping in model code (
z_image_transformer.py) — couples the model to a specific quantization backend. Should live in the quant config or at least be conditioned on whether Nunchaku is active.
|
Submitted an RFC to vLLM upstream to host the core quantization backend (QuantizationConfig + LinearMethodBase): vllm-project/vllm#37908 If accepted, this PR would be simplified to only the diffusion-specific glue code (key mapping, SwiGLU activation order handling, model integration). |
| # create extra parameters (e.g. wtscale, wcscales) that don't have | ||
| # corresponding entries in the checkpoint. | ||
| has_quant = self.od_config and getattr(self.od_config, "quantization_config", None) is not None | ||
| if loaded_weights is not None and not has_quant: |
There was a problem hiding this comment.
IMO it would be better to keep the checks strict where possible and not loosen it for every quant method
| # Each entry: "source_key_fragment": ("target_key_fragment", swap_swiglu) | ||
| # - swap_swiglu=True: swap the two halves of the merged gate+up weight | ||
| # to account for SwiGLU activation order difference. | ||
| _MODEL_KEY_MAPPING: dict[str, dict[str, tuple[str, bool]]] = { |
There was a problem hiding this comment.
Having this here and stored on the DiTs is confusing, it would be best to avoid duplicating this
Summary
Motivation
SVDQuant (W4A4) provides significant inference speedup and reduced memory footprint for DiT models. Nunchaku's PTX-optimized kernels are community-proven (FLUX, Qwen-Image) and lightweight enough to integrate as an optional backend.
The main blocker during integration was weight key mapping: Nunchaku checkpoints use diffusers-style naming while vLLM models use different conventions, and the naming is not standardized across models (Z-Image:
w13/w2, Flux:linear_in/linear_out, QwenImage: no remap needed). This mapping must currently be hardcoded per-model inload_weights, which is the primary effort when adding new model support.Additionally, Nunchaku's weight format is highly optimized (tiled/interleaved MMA layout via PTX assembly), so the glue code (weight packing, activation swap, shape calculations) is tightly coupled to Nunchaku's internal layout. This means weight-level manipulation (e.g. row-swapping for SwiGLU convention) is not possible — we handle this via runtime output swap instead.
Changes
svdq_nunchaku.py): vLLM quantization plugin with W4A4 GEMM + SVD low-rank correction. Quantizes QKV, MergedColumnParallel, and RowParallel layers; leaves ReplicatedLinear (adaLN, embedders) unquantized.[linear ; activation], while vLLM's SiluAndMul expects[activation ; linear]. Applied automatically at runtime inNunchakuLinearMethod.apply()for allMergedColumnParallelLinearlayers.nunchaku.py): Per-model weight key mapping table for translating diffusers-style naming to vLLM conventions.net.0.proj→w13,net.2→w2) inload_weights, fixedstacked_params_mappingsubstring collision (.w1falsely matching.w13).text_to_image_quant.py.Quantized Model
Quality Comparison (RTX 5090, seed=42, Z-Image-Turbo 1024x1024)
Follow-up Plans
--rank/--precision). Nunchaku checkpoints embedquantization_config(includingrank,group_size,method) andmodel_classin safetensors metadata — the same mechanism Nunchaku's ownfrom_pretraineduses. This would eliminate the need for users to specify these parameters manually.Test Plan
Closes #507