diff --git a/benchmarks/diffusion/quantization_quality.py b/benchmarks/diffusion/quantization_quality.py index b47fae41f2a..f02c9444bba 100644 --- a/benchmarks/diffusion/quantization_quality.py +++ b/benchmarks/diffusion/quantization_quality.py @@ -53,6 +53,20 @@ --height 1024 --width 1024 \ --num-inference-steps 50 --seed 42 +Offline-quantized checkpoint (SVDQuant / MXFP8 offline / ModelOpt FP4 etc.): + Use `--baseline-model` for the BF16 reference and `--quantization auto` + so the quantized run honors the on-disk `transformer/config.json + ["quantization_config"]` instead of overriding it. + + python benchmarks/diffusion/quantization_quality.py \ + --baseline-model Tongyi-MAI/Z-Image-Turbo \ + --model ultranationalism/Z-Image-Turbo-SVDQuant-NVFP4 \ + --task t2i \ + --quantization auto \ + --prompts "a cup of coffee on a wooden table, morning light" \ + --height 1024 --width 1024 \ + --num-inference-steps 20 --seed 42 + Output directory structure (--output-dir, default: ./quant_bench_output): quant_bench_output/ baseline/ # BF16 outputs @@ -136,8 +150,55 @@ def compute_lpips_video( return float(np.mean(scores)) -def _build_omni_kwargs(args, quantization=None): - """Build kwargs dict for Omni() constructor.""" +def _gpu_memory_gib(device_index: int = 0) -> float: + """Total GPU memory currently used on `device_index`, in GiB. + + Uses `nvidia-smi` because vllm-omni spawns model workers in separate + processes; `torch.cuda.memory_allocated()` from the bench driver + process reports 0 since the workers hold the model in their own CUDA + contexts. `nvidia-smi`'s `memory.used` aggregates all processes on the + device, which on a single-GPU benchmark gives us the right number. + + Returns 0.0 if nvidia-smi is unavailable or fails — callers should + treat 0.0 as "unknown" (the markdown summary guards against divide- + by-zero in the reduction column). + """ + import shutil + import subprocess + + if not shutil.which("nvidia-smi"): + return 0.0 + try: + out = subprocess.run( + [ + "nvidia-smi", + f"--id={device_index}", + "--query-gpu=memory.used", + "--format=csv,noheader,nounits", + ], + capture_output=True, + text=True, + timeout=5, + check=True, + ) + return int(out.stdout.strip()) / 1024 # MiB -> GiB + except (subprocess.CalledProcessError, ValueError, subprocess.TimeoutExpired): + return 0.0 + + +def _build_omni_kwargs(args, quantization=None, model_override=None): + """Build kwargs dict for Omni() constructor. + + `model_override` lets the baseline and quantized runs target different + on-disk paths (offline-quantized checkpoints like SVDQuant ship their + own pipeline tree separate from the BF16 reference model). + + `quantization="auto"` is a sentinel meaning "do not pass + `quantization_config` to Omni; let the on-disk + `transformer/config.json["quantization_config"]` drive the choice". + Useful for offline-quantized checkpoints where the method + per-layer + skip list are baked into the config file. + """ from vllm_omni.diffusion.data import DiffusionParallelConfig parallel_config = DiffusionParallelConfig( @@ -146,23 +207,35 @@ def _build_omni_kwargs(args, quantization=None): tensor_parallel_size=args.tensor_parallel_size, ) kwargs = { - "model": args.model, + "model": model_override if model_override is not None else args.model, "parallel_config": parallel_config, "enforce_eager": args.enforce_eager, } - if quantization: + if quantization and quantization != "auto": kwargs["quantization_config"] = quantization return kwargs def _generate_image(omni, args, prompt, seed): - """Generate a single image and return (PIL.Image, time_seconds, memory_gib).""" + """Generate a single image; returns (PIL.Image, elapsed_s, peak_gib, weights_gib, activations_gib). + + `weights_gib` is GPU memory snapshotted right before `generate()` + (the engine holds model params + persistent buffers; few transient + activations should be live). `peak_gib` is sampled right after + `generate()` returns — close to but not necessarily exactly the + high-water mark, since the diffusion loop may have freed temporary + buffers before returning. `activations_gib` = peak - weights. + + Uses `nvidia-smi` instead of `torch.cuda.max_memory_allocated()` so + we capture the worker process's memory (vllm-omni runs the model in + a separate spawn-mode worker). See `_gpu_memory_gib`. + """ from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.platforms import current_omni_platform generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed) - torch.accelerator.reset_peak_memory_stats() + weights_mem = _gpu_memory_gib() start = time.perf_counter() outputs = omni.generate( {"prompt": prompt}, @@ -174,23 +247,27 @@ def _generate_image(omni, args, prompt, seed): ), ) elapsed = time.perf_counter() - start - peak_mem = torch.accelerator.max_memory_allocated() / (1024**3) + peak_mem = _gpu_memory_gib() + activations_mem = max(peak_mem - weights_mem, 0.0) req_out = OmniRequestOutput.unwrap_result(outputs) if not req_out.images: raise ValueError("Could not extract image output from result.") img = req_out.images[0] - return img, elapsed, peak_mem + return img, elapsed, peak_mem, weights_mem, activations_mem def _generate_video(omni, args, prompt, seed): - """Generate a video and return (np.ndarray [F,H,W,C], time_seconds, memory_gib).""" + """Generate a video; returns (np.ndarray [F,H,W,C], elapsed_s, peak_gib, weights_gib, activations_gib). + + See `_generate_image` for what each memory number means. + """ from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.outputs import OmniRequestOutput from vllm_omni.platforms import current_omni_platform generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed) - torch.accelerator.reset_peak_memory_stats() + weights_mem = _gpu_memory_gib() start = time.perf_counter() outputs = omni.generate( {"prompt": prompt, "negative_prompt": ""}, @@ -204,7 +281,8 @@ def _generate_video(omni, args, prompt, seed): ), ) elapsed = time.perf_counter() - start - peak_mem = torch.accelerator.max_memory_allocated() / (1024**3) + peak_mem = _gpu_memory_gib() + activations_mem = max(peak_mem - weights_mem, 0.0) first = outputs[0] if hasattr(first, "request_output") and isinstance(first.request_output, list): @@ -241,7 +319,7 @@ def _generate_video(omni, args, prompt, seed): if frames_array.ndim == 5: frames_array = frames_array[0] - return frames_array, elapsed, peak_mem + return frames_array, elapsed, peak_mem, weights_mem, activations_mem def _free_gpu_memory(): @@ -274,20 +352,28 @@ def run_benchmark(args): print("\n" + "=" * 60) print("Running BF16 baseline...") print("=" * 60) - bl_kwargs = _build_omni_kwargs(args, quantization=None) + # `--baseline-model` lets offline-quantized methods (SVDQuant etc.) + # point the baseline at a separate BF16 pipeline tree, while `--model` + # points at the quantized checkpoint. + baseline_model = getattr(args, "baseline_model", None) or args.model + bl_kwargs = _build_omni_kwargs(args, quantization=None, model_override=baseline_model) omni_bl = Omni(**bl_kwargs) - baseline_outputs = {} # prompt -> (output, time, mem) + baseline_outputs = {} # prompt -> (output, time, peak_mem, weights_mem, activations_mem) for prompt in prompts: - print(f" Generating: {prompt[:60]}...") + print(f" Generating: {prompt[:60]}...", flush=True) if is_video: - out, t, mem = _generate_video(omni_bl, args, prompt, seed) + out, t, mem, weights, acts = _generate_video(omni_bl, args, prompt, seed) else: - out, t, mem = _generate_image(omni_bl, args, prompt, seed) - baseline_outputs[prompt] = (out, t, mem) + out, t, mem, weights, acts = _generate_image(omni_bl, args, prompt, seed) + baseline_outputs[prompt] = (out, t, mem, weights, acts) + print(f" -> {t:.2f}s weights={weights:.2f}GiB peak={mem:.2f}GiB", flush=True) bl_avg_time = np.mean([v[1] for v in baseline_outputs.values()]) - bl_mem = baseline_outputs[prompts[0]][2] # use first prompt's memory + # First prompt's memory snapshot is canonical (matches PR #1470 convention). + bl_mem = baseline_outputs[prompts[0]][2] + bl_weights = baseline_outputs[prompts[0]][3] + bl_acts = baseline_outputs[prompts[0]][4] omni_bl.shutdown() del omni_bl _free_gpu_memory() @@ -321,15 +407,18 @@ def run_benchmark(args): qt_outputs = {} for prompt in prompts: - print(f" Generating: {prompt[:60]}...") + print(f" Generating: {prompt[:60]}...", flush=True) if is_video: - out, t, mem = _generate_video(omni_qt, args, prompt, seed) + out, t, mem, weights, acts = _generate_video(omni_qt, args, prompt, seed) else: - out, t, mem = _generate_image(omni_qt, args, prompt, seed) - qt_outputs[prompt] = (out, t, mem) + out, t, mem, weights, acts = _generate_image(omni_qt, args, prompt, seed) + qt_outputs[prompt] = (out, t, mem, weights, acts) + print(f" -> {t:.2f}s weights={weights:.2f}GiB peak={mem:.2f}GiB", flush=True) qt_avg_time = np.mean([v[1] for v in qt_outputs.values()]) qt_mem = qt_outputs[prompts[0]][2] + qt_weights = qt_outputs[prompts[0]][3] + qt_acts = qt_outputs[prompts[0]][4] omni_qt.shutdown() del omni_qt _free_gpu_memory() @@ -359,7 +448,7 @@ def run_benchmark(args): mean_lpips = np.mean([p["lpips"] for p in per_prompt]) speedup = bl_avg_time / qt_avg_time if qt_avg_time > 0 else float("inf") - mem_reduction = (bl_mem - qt_mem) / bl_mem * 100 + mem_reduction = (bl_mem - qt_mem) / bl_mem * 100 if bl_mem > 0 else 0.0 all_results.append( { @@ -367,6 +456,8 @@ def run_benchmark(args): "avg_time": qt_avg_time, "speedup": speedup, "memory_gib": qt_mem, + "weights_gib": qt_weights, + "activations_gib": qt_acts, "mem_reduction_pct": mem_reduction, "mean_lpips": mean_lpips, "per_prompt": per_prompt, @@ -404,6 +495,30 @@ def run_benchmark(args): lines.append("> LPIPS < 0.01 = imperceptible, > 0.1 = clearly noticeable.") lines.append("") + # Memory profiling table (mirrors PR #1470 layout) + tp = args.tensor_parallel_size + lines.append("### Memory Profiling") + lines.append("") + lines.append( + f"First-prompt snapshot at {args.height}x{args.width}, " + f"{args.num_inference_steps} steps. " + "Weights = `memory_allocated()` before `generate()`; " + "Peak = `max_memory_allocated()` during `generate()`; " + "Activations = Peak − Weights." + ) + lines.append("") + lines.append("| Config | Weights | Activations | Peak | Total Reduction |") + lines.append("|--------|---------|-------------|------|-----------------|") + lines.append(f"| BF16, TP={tp} | {bl_weights:.2f} GiB | {bl_acts:.2f} GiB | {bl_mem:.2f} GiB | — |") + for r in all_results: + reduction_pct = (bl_mem - r["memory_gib"]) / bl_mem * 100 if bl_mem > 0 else 0.0 + lines.append( + f"| {r['config']}, TP={tp} | {r['weights_gib']:.2f} GiB " + f"| {r['activations_gib']:.2f} GiB | {r['memory_gib']:.2f} GiB " + f"| **{reduction_pct:.0f}%** |" + ) + lines.append("") + # Per-prompt table if len(prompts) > 1: lines.append("### Per-Prompt LPIPS") @@ -442,6 +557,16 @@ def parse_args(): formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("--model", required=True, help="Model name or local path.") + parser.add_argument( + "--baseline-model", + default=None, + help=( + "Optional BF16 model path/name for the baseline run. Defaults to " + "`--model`. Use this when the quantized checkpoint is a separate " + "on-disk pipeline tree (e.g. SVDQuant, MXFP8 offline) rather than " + "an online-quantized variant of the same model." + ), + ) parser.add_argument( "--task", default="t2i", @@ -452,7 +577,13 @@ def parse_args(): "--quantization", nargs="+", required=True, - help="One or more quantization methods to benchmark (e.g. fp8 int8 bitsandbytes).", + help=( + "One or more quantization methods to benchmark (e.g. fp8 int8 " + "bitsandbytes). Use the sentinel `auto` for offline-quantized " + "checkpoints (SVDQuant etc.) where the method + per-layer skip " + 'list are baked into `transformer/config.json["quantization_config"]` ' + "— the bench will not override it and the on-disk config drives the run." + ), ) parser.add_argument( "--prompts", diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index c0fd337bd93..85d7b1d1b9e 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -17,6 +17,42 @@ from vllm_omni.platforms import current_omni_platform +def _resolve_quantization_label(cli_quantization: str | None, model: str) -> str: + """Label that reflects what the engine will actually run, not just CLI args. + + `--quantization` overrides whatever is on disk. Otherwise vllm-omni + auto-detects from `transformer/config.json["quantization_config"]` + (see `OmniDiffusionConfig._propagate_quantization_from_tf_config`). + Mirror that lookup so the printed banner doesn't say "None (BF16)" + for a checkpoint that's actually going to load quantized. + """ + if cli_quantization: + return cli_quantization + try: + from vllm.transformers_utils.config import get_hf_file_to_dict + + cfg = get_hf_file_to_dict("transformer/config.json", model) + except Exception: + cfg = None + if isinstance(cfg, dict): + qc = cfg.get("quantization_config") + if isinstance(qc, dict): + method = qc.get("quant_method") or qc.get("method") + if method == "component" or ( + method is None + and any(isinstance(v, dict) and (v.get("quant_method") or v.get("method")) for v in qc.values()) + ): + default = qc.get("default") + if isinstance(default, dict): + inner = default.get("quant_method") or default.get("method") + if inner: + return f"{inner} (per-component, from checkpoint)" + return "component (from checkpoint)" + if method: + return f"{method} (from checkpoint)" + return "None (BF16)" + + def is_nextstep_model(model_name: str) -> bool: """Check if the model is a NextStep model by reading its config.""" from vllm.transformers_utils.config import get_hf_file_to_dict @@ -412,7 +448,7 @@ def main(): print(f" Model: {args.model}") print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {cache_backend if cache_backend else 'None (no acceleration)'}") - print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}") + print(f" Quantization: {_resolve_quantization_label(args.quantization, args.model)}") if ignored_layers: print(f" Ignored layers: {ignored_layers}") print( diff --git a/tests/diffusion/quantization/test_svdquant_config.py b/tests/diffusion/quantization/test_svdquant_config.py new file mode 100644 index 00000000000..8d7e47f8324 --- /dev/null +++ b/tests/diffusion/quantization/test_svdquant_config.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-Omni project +"""Smoke tests for the SVDQuant quantization plugin. + +Real W4A4 numerics live on top of an actual quantized checkpoint and +require a CUDA capability that the kernel backend supports. These +tests cover the boundary that vllm-omni owns: factory wiring, the +config shape, and the hardware-keyed backend selection. +""" + +import pytest +import torch +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability + +from vllm_omni.quantization import build_quant_config +from vllm_omni.quantization.factory import SUPPORTED_QUANTIZATION_METHODS +from vllm_omni.quantization.svdquant_config import ( + DiffusionSVDQuantConfig, + DiffusionSVDQuantLinearMethod, +) +from vllm_omni.quantization.svdquant_dispatch import assert_svdquant_supported +from vllm_omni.quantization.svdquant_nunchaku import has_nunchaku_w4a4 + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + + +def test_svdquant_is_registered() -> None: + assert "svdquant" in SUPPORTED_QUANTIZATION_METHODS + + +def test_config_from_dict_int4() -> None: + cfg = DiffusionSVDQuantConfig.from_config({"rank": 32, "precision": "int4", "act_unsigned": False}) + assert cfg.rank == 32 + assert cfg.precision == "int4" + assert cfg.group_size == 64 + assert cfg.act_unsigned is False + assert cfg.modules_to_not_convert == [] + assert cfg.get_name() == "svdquant" + + +def test_config_from_dict_nvfp4() -> None: + cfg = DiffusionSVDQuantConfig.from_config( + { + "rank": 64, + "precision": "nvfp4", + "modules_to_not_convert": ["embedder", "final_layer"], + } + ) + assert cfg.precision == "nvfp4" + assert cfg.group_size == 16 # NVFP4 tcgen05 scale block + assert cfg.modules_to_not_convert == ["embedder", "final_layer"] + + +def test_config_rejects_unknown_precision() -> None: + with pytest.raises(ValueError, match="precision"): + DiffusionSVDQuantConfig(precision="fp8") # type: ignore[arg-type] + + +def test_build_quant_config_routes_to_svdquant() -> None: + """Factory should pick DiffusionSVDQuantConfig for "svdquant".""" + cfg = build_quant_config("svdquant", rank=16, precision="int4") + assert isinstance(cfg, DiffusionSVDQuantConfig) + assert cfg.rank == 16 + assert cfg.precision == "int4" + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="hardware gate is CUDA-specific") +def test_hardware_gate_accepts_consumer_gpus() -> None: + if not has_nunchaku_w4a4(): + pytest.skip("nunchaku not installed") + major, _ = current_platform.get_device_capability() + if major == 9: + pytest.skip("Hopper is intentionally unsupported") + if major == 10: + pytest.skip("Datacenter Blackwell is out of scope (FlashInfer planned)") + # Turing/Ampere/Ada (SM_75-89) and consumer Blackwell SM_120 are + # accepted by the gate for int4. + assert_svdquant_supported("int4") + + +def test_hardware_gate_rejects_hopper(monkeypatch: pytest.MonkeyPatch) -> None: + """Hopper SM_90 must raise.""" + cls = type(current_platform) + monkeypatch.setattr(cls, "is_cuda", classmethod(lambda c: True)) + monkeypatch.setattr( + cls, + "get_device_capability", + classmethod(lambda c, *a, **k: DeviceCapability(9, 0)), + ) + with pytest.raises(RuntimeError, match="Hopper"): + assert_svdquant_supported("int4") + + +def test_hardware_gate_rejects_datacenter_blackwell( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """SM_100/103 is out of scope here (FlashInfer-planned); must raise.""" + cls = type(current_platform) + monkeypatch.setattr(cls, "is_cuda", classmethod(lambda c: True)) + monkeypatch.setattr( + cls, + "get_device_capability", + classmethod(lambda c, *a, **k: DeviceCapability(10, 0)), + ) + with pytest.raises(RuntimeError, match="FlashInfer"): + assert_svdquant_supported("nvfp4") + + +def test_hardware_gate_rejects_nvfp4_on_pre_blackwell( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """NVFP4 needs SM_100+ tensor units; SM_8x must raise cleanly.""" + if not has_nunchaku_w4a4(): + pytest.skip("nunchaku not installed") + cls = type(current_platform) + monkeypatch.setattr(cls, "is_cuda", classmethod(lambda c: True)) + monkeypatch.setattr( + cls, + "get_device_capability", + classmethod(lambda c, *a, **k: DeviceCapability(8, 9)), + ) + with pytest.raises(ValueError, match="NVFP4"): + assert_svdquant_supported("nvfp4") + + +@pytest.mark.skipif( + not (current_platform.is_cuda() and has_nunchaku_w4a4()), + reason="requires CUDA + nunchaku for create_weights smoke", +) +def test_linear_method_create_weights_int4() -> None: + """Validate the parameter layout without invoking the kernel. + + Only checks that `create_weights` populates the layer with + correctly-shaped, correctly-dtyped tensors. + """ + cfg = DiffusionSVDQuantConfig(rank=32, precision="int4") + method = DiffusionSVDQuantLinearMethod(cfg) + + # Mimic a 4096-in / 4096-out column-parallel layer with TP=1. + layer = torch.nn.Module() + method.create_weights( + layer, + input_size_per_partition=4096, + output_partition_sizes=[4096], + input_size=4096, + output_size=4096, + params_dtype=torch.bfloat16, + ) + + assert layer.qweight.shape == (4096, 4096 // 2) + assert layer.qweight.dtype == torch.int8 + assert layer.wscales.shape == (4096 // 64, 4096) + assert layer.wscales.dtype == torch.bfloat16 + assert layer.proj_down.shape == (4096, 32) + assert layer.proj_up.shape == (4096, 32) + assert layer.smooth_factor.shape == (4096,) + assert layer.wcscales is None + assert layer.wtscale is None + + +@pytest.mark.skipif( + not (current_platform.is_cuda() and has_nunchaku_w4a4()), + reason="requires CUDA + nunchaku for create_weights smoke", +) +def test_linear_method_create_weights_nvfp4_has_per_channel_scales() -> None: + cfg = DiffusionSVDQuantConfig(rank=32, precision="nvfp4") + try: + assert_svdquant_supported("nvfp4") + except (RuntimeError, ValueError, ImportError) as exc: + pytest.skip(f"nvfp4 unsupported on this box: {exc}") + method = DiffusionSVDQuantLinearMethod(cfg) + layer = torch.nn.Module() + method.create_weights( + layer, + input_size_per_partition=2048, + output_partition_sizes=[2048], + input_size=2048, + output_size=2048, + params_dtype=torch.bfloat16, + ) + assert layer.wscales.dtype == torch.float8_e4m3fn + assert layer.wcscales is not None + assert layer.wcscales.shape == (2048,) + assert layer.wtscale is not None + assert layer.wtscale.shape == (1,) + + +def test_get_quant_method_skips_listed_modules() -> None: + cfg = DiffusionSVDQuantConfig(modules_to_not_convert=["embedder"]) + if not has_nunchaku_w4a4(): + # DiffusionSVDQuantLinearMethod ctor calls assert_svdquant_supported() + # and raises; in that case we can only check the skip path. + pytest.skip("nunchaku not installed") + fake_layer = torch.nn.Linear(8, 8) + fake_layer.__class__ = type("FakeLinear", (torch.nn.Linear, LinearBase), {}) + + method = cfg.get_quant_method(fake_layer, "model.embedder.proj") + assert isinstance(method, UnquantizedLinearMethod) diff --git a/vllm_omni/diffusion/models/z_image/z_image_transformer.py b/vllm_omni/diffusion/models/z_image/z_image_transformer.py index 1399eb0f0b3..0d755f71806 100644 --- a/vllm_omni/diffusion/models/z_image/z_image_transformer.py +++ b/vllm_omni/diffusion/models/z_image/z_image_transformer.py @@ -1036,9 +1036,13 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: (".to_qkv.", ".to_q.", "q"), (".to_qkv.", ".to_k.", "k"), (".to_qkv.", ".to_v.", "v"), - # ffn - (".w13", ".w1", 0), - (".w13", ".w3", 1), + # ffn — trailing dots required: `.w1` without a boundary is a + # substring of `.w13.`, so a pre-fused-on-disk checkpoint with + # `.w13.qweight` keys (e.g. the SVDQuant NVFP4 converter that + # half-swaps gate/up at quant time) would otherwise match + # `.w1` and rewrite `.w13.proj_down` → `.w133.proj_down`. + (".w13.", ".w1.", 0), + (".w13.", ".w3.", 1), ] # Expose packed shard mappings for LoRA handling of fused projections. self.stacked_params_mapping = stacked_params_mapping diff --git a/vllm_omni/quantization/component_config.py b/vllm_omni/quantization/component_config.py index f9286079be1..a7bebbc3e56 100644 --- a/vllm_omni/quantization/component_config.py +++ b/vllm_omni/quantization/component_config.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any import torch +from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, ) @@ -78,7 +79,12 @@ def get_name(self) -> str: def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> QuantizeMethodBase | None: config = self.resolve(prefix) if config is None: - return None + # `None` in the component dict means "this component is + # unquantized" (per the class docstring's `"vae": None` + # example). vLLM's `LinearBase.__init__` treats a `None` + # return as "linear doesn't support this quant method" and + # raises; instead surface an explicit unquantized method. + return UnquantizedLinearMethod() return config.get_quant_method(layer, prefix) @classmethod diff --git a/vllm_omni/quantization/factory.py b/vllm_omni/quantization/factory.py index 3766e4596cd..412018dc30e 100644 --- a/vllm_omni/quantization/factory.py +++ b/vllm_omni/quantization/factory.py @@ -71,6 +71,13 @@ def _build_mxfp4_dualscale(**kw: Any) -> QuantizationConfig: return DiffusionMXFP4DualScaleMixedConfig(**kw) +def _build_svdquant(**kw: Any) -> QuantizationConfig: + """Lazy import for SVDQuant W4A4 diffusion config (CUDA, nunchaku backend).""" + from .svdquant_config import DiffusionSVDQuantConfig + + return DiffusionSVDQuantConfig(**kw) + + def _build_inc(**kw: Any) -> QuantizationConfig: """Lazy import for INC/AutoRound config with checkpoint kwarg normalization.""" from .inc_config import OmniINCConfig @@ -91,6 +98,7 @@ def _build_inc(**kw: Any) -> QuantizationConfig: "mxfp8": _build_mxfp8, "mxfp4": _build_mxfp4, "mxfp4_dualscale": _build_mxfp4_dualscale, + "svdquant": _build_svdquant, "inc": _build_inc, "auto-round": _build_inc, "auto_round": _build_inc, diff --git a/vllm_omni/quantization/svdquant_config.py b/vllm_omni/quantization/svdquant_config.py new file mode 100644 index 00000000000..0e99bce1500 --- /dev/null +++ b/vllm_omni/quantization/svdquant_config.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-Omni project +"""SVDQuant W4A4 config + LinearMethod for diffusion transformers. + +SVDQuant (https://arxiv.org/abs/2411.05007) is a 4-bit weight, 4-bit +activation quantization scheme paired with a low-rank residual that +absorbs the quantization error. It is the dominant practical +quantization method for diffusion transformers, delivering >2x +speedup vs BF16 with minimal quality loss. + +This module owns the on-disk parameter layout (canonical row-major +NVFP4 / INT4-nibble) and the vLLM `LinearMethodBase` plumbing. +Backend-specific kernel calls and weight prep live in sibling modules +(`svdquant_nunchaku.py`, future `svdquant_flashinfer.py`); the active +backend is selected at `__init__` via `svdquant_dispatch.select_backend`. + +Diffusion-specific weight key remapping (e.g. diffusers naming +conventions) is not handled here; downstream pipelines remap before +loading. Checkpoints are expected to already store gated-activation +halves in `[gate; hidden]` order — produced at quantization time, not +at runtime. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import torch +from torch.nn import Parameter +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from .svdquant_dispatch import ( + SVDQuantPrecision, + assert_svdquant_supported, + select_backend, +) + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods + +logger = init_logger(__name__) + +# Group sizes are dictated by the kernel's scaled-MMA tile: +# * NVFP4 uses tcgen05's 16-element scale block. +# * INT4 uses Nunchaku's 64-element block. +_GROUP_SIZE_BY_PRECISION: dict[str, int] = {"int4": 64, "nvfp4": 16} + + +class DiffusionSVDQuantConfig(QuantizationConfig): + """Configuration for SVDQuant W4A4 quantization. + + Parameters mirror what's on disk in a SVDQuant-produced checkpoint: + + Args: + rank: SVD low-rank correction dimension. Typical values are + 16, 32, or 64; the checkpoint dictates the value. + precision: 4-bit format, either "int4" or "nvfp4". + act_unsigned: Whether activations are quantized as unsigned + (saves the sign bit at a small accuracy cost). Per + checkpoint config. + modules_to_not_convert: Layer names (or substring patterns) + that should keep their unquantized weight, e.g. embedders + and adaLN-modulation projections in diffusion models. + """ + + def __init__( + self, + rank: int = 32, + precision: SVDQuantPrecision = "int4", + act_unsigned: bool = False, + modules_to_not_convert: list[str] | None = None, + ) -> None: + super().__init__() + if precision not in _GROUP_SIZE_BY_PRECISION: + raise ValueError(f"SVDQuant precision must be one of {set(_GROUP_SIZE_BY_PRECISION)}; got {precision!r}") + self.rank = rank + self.precision = precision + self.group_size = _GROUP_SIZE_BY_PRECISION[precision] + self.act_unsigned = act_unsigned + self.modules_to_not_convert = modules_to_not_convert or [] + + def __repr__(self) -> str: + return ( + f"DiffusionSVDQuantConfig(rank={self.rank}, precision={self.precision!r}, act_unsigned={self.act_unsigned})" + ) + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "svdquant" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + # SM_75 (Turing) is the floor; the dispatcher rejects SM_90 and + # routes SM_100+ separately. + return 75 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantization_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> DiffusionSVDQuantConfig: + return cls( + rank=config.get("rank", 32), + precision=config.get("precision", "int4"), + act_unsigned=config.get("act_unsigned", False), + modules_to_not_convert=config.get("modules_to_not_convert"), + ) + + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> QuantizeMethodBase | None: + if not isinstance(layer, LinearBase): + return None + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): + return UnquantizedLinearMethod() + return DiffusionSVDQuantLinearMethod(self) + + +class DiffusionSVDQuantLinearMethod(LinearMethodBase): + """Backend-agnostic LinearMethod for SVDQuant W4A4. + + The same parameter layout serves both the int4 and nvfp4 paths; + only the dtypes of `wscales` and the LoRA matrices differ. The + active platform is checked at `__init__` time and an unsupported + GPU raises here, before any weights are allocated. + + All backend-specific behavior (weight prep, GEMM call) is + delegated to the module returned by + `svdquant_dispatch.select_backend`. The on-disk layout is fixed + and shared across backends. + """ + + _hardware_logged = False + + def __init__(self, quant_config: DiffusionSVDQuantConfig) -> None: + self.quant_config = quant_config + assert_svdquant_supported(quant_config.precision) + self._backend = select_backend(quant_config.precision) + if not DiffusionSVDQuantLinearMethod._hardware_logged: + logger.info( + "SVDQuant backend selected: %s (precision=%s)", + self._backend.__name__, + quant_config.precision, + ) + DiffusionSVDQuantLinearMethod._hardware_logged = True + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del extra_weight_attrs # weight_loader is set explicitly per-param. + output_size_per_partition = sum(output_partition_sizes) + + config = self.quant_config + rank = config.rank + group_size = config.group_size + precision = config.precision + + # The LoRA matrices and the smooth factor must be in the same + # dtype as the kernel's accumulator. Nunchaku's nvfp4 path + # locks this to bf16 regardless of the model's params_dtype; + # the int4 path inherits params_dtype. + lora_dtype = torch.bfloat16 if precision == "nvfp4" else params_dtype + + wscales_dtype = torch.float8_e4m3fn if precision == "nvfp4" else params_dtype + + # qweight: 4-bit weights packed two-per-byte along the input + # axis. Shape (out_per_partition, in_per_partition // 2). + qweight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + _set_attrs( + qweight, + input_dim=1, + output_dim=0, + weight_loader=default_weight_loader, + ) + + # wscales: per-(group_size) input-column scale, + # shape (in_per_partition // group_size, out_per_partition). + wscales = Parameter( + torch.empty( + input_size_per_partition // group_size, + output_size_per_partition, + dtype=wscales_dtype, + ), + requires_grad=False, + ) + _set_attrs( + wscales, + input_dim=0, + output_dim=1, + weight_loader=default_weight_loader, + ) + + # SVD low-rank correction matrices. + proj_down = Parameter( + torch.empty(input_size_per_partition, rank, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs( + proj_down, + input_dim=0, + output_dim=1, + weight_loader=default_weight_loader, + ) + + proj_up = Parameter( + torch.empty(output_size_per_partition, rank, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs( + proj_up, + input_dim=1, + output_dim=0, + weight_loader=default_weight_loader, + ) + + # Smooth-quant factors. Live on the input axis: replicated for + # column-parallel layers, sharded for row-parallel. + smooth_factor = Parameter( + torch.empty(input_size_per_partition, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs( + smooth_factor, + input_dim=0, + weight_loader=default_weight_loader, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("wscales", wscales) + layer.register_parameter("proj_down", proj_down) + layer.register_parameter("proj_up", proj_up) + layer.register_parameter("smooth_factor", smooth_factor) + + if precision == "nvfp4": + # Per-output-channel BF16 scale; sharded with the output dim. + wcscales = Parameter( + torch.ones(output_size_per_partition, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs( + wcscales, + output_dim=0, + weight_loader=default_weight_loader, + ) + # Per-tensor global scale (shape (1,) on disk). + wtscale = Parameter( + torch.ones(1, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs(wtscale, weight_loader=default_weight_loader) + layer.register_parameter("wcscales", wcscales) + layer.register_parameter("wtscale", wtscale) + else: + # Keep the attributes present so backend.apply() can branch + # uniformly without `hasattr` checks. + layer.wcscales = None + layer.wtscale = None + + # Stashed for backend.apply() to consume. + layer.in_features = input_size + layer.out_features = output_size + layer.out_features_per_partition = output_size_per_partition + layer.precision = precision + layer.act_unsigned = config.act_unsigned + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Delegate post-load weight prep to the active backend. + + All parameters are produced by our quantization pipeline and + must be loaded by the time we get here; a meta tensor at this + point is a checkpoint bug, not a missing-shard case to paper + over. + """ + self._backend.prepare_weights(layer, self.quant_config.precision) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self._backend.apply(layer, x, bias) + + +def _set_attrs(param: torch.nn.Parameter, **attrs: Any) -> None: + for key, value in attrs.items(): + setattr(param, key, value) + + +__all__ = ["DiffusionSVDQuantConfig", "DiffusionSVDQuantLinearMethod"] diff --git a/vllm_omni/quantization/svdquant_dispatch.py b/vllm_omni/quantization/svdquant_dispatch.py new file mode 100644 index 00000000000..8024739fca4 --- /dev/null +++ b/vllm_omni/quantization/svdquant_dispatch.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-Omni project +"""SVDQuant backend dispatch + hardware gate. + +The SVDQuant on-disk format is canonical row-major NVFP4 (or +INT4-nibble), backend-agnostic. The runtime kernel backend is picked +at LinearMethod construction based on CUDA compute capability: + + SM_75–89, SM_120 → nunchaku (in this PR) + SM_100/103 → FlashInfer (planned; not yet integrated) + SM_90 (Hopper) → unsupported (no validated kernel family) + +Add a new backend by writing a `vllm_omni/quantization/svdquant_.py` +module exposing `supports(cap, precision) -> bool`, `prepare_weights( +layer, precision) -> None`, and `apply(layer, x, bias) -> Tensor`, then +appending it to `_candidate_backends()` below. +""" + +from __future__ import annotations + +from types import ModuleType +from typing import Literal + +from vllm.platforms import current_platform + +SVDQuantPrecision = Literal["int4", "nvfp4"] + + +def _candidate_backends() -> list[ModuleType]: + """Backends to try, in priority order. + + When FlashInfer lands, prepend it here so it takes precedence on + its supported caps before falling back to nunchaku. + """ + from . import svdquant_nunchaku + + return [svdquant_nunchaku] + + +def select_backend(precision: SVDQuantPrecision) -> ModuleType: + """Return the first backend that supports (current platform, precision). + + Defense in depth — callers normally call `assert_svdquant_supported` + first, which raises a more actionable error for unsupported + platforms. This raises a generic error if you somehow skipped the + gate. + """ + cap = current_platform.get_device_capability() if current_platform.is_cuda() else None + for backend in _candidate_backends(): + if backend.supports(cap, precision): + return backend + raise RuntimeError( + f"No SVDQuant backend supports precision={precision!r} on " + f"{current_platform.device_name!r}. Call " + "assert_svdquant_supported() for a detailed diagnostic." + ) + + +def assert_svdquant_supported(precision: SVDQuantPrecision) -> None: + """Raise a precise error if the active platform cannot run SVDQuant.""" + if not current_platform.is_cuda(): + raise RuntimeError( + f"SVDQuant has no available backend on platform " + f"{current_platform.device_name!r}. CUDA + a SVDQuant backend " + "(nunchaku for consumer GPUs, FlashInfer for SM_100/103 — " + "planned) is required." + ) + + cap = current_platform.get_device_capability() + sm = f"SM_{cap.to_int()}" if cap is not None else "" + + if current_platform.is_device_capability_family(90): + raise RuntimeError( + "SVDQuant W4A4 is not supported on Hopper (SM_90). Use a " + "consumer GPU (SM_75–SM_89, SM_120) with nunchaku, or wait " + "for the datacenter Blackwell (SM_100/103) path planned in " + "FlashInfer." + ) + + if current_platform.is_device_capability_family(100): + raise RuntimeError( + f"SVDQuant on {sm} (B200/GB300) is not yet integrated; the datacenter path is planned in FlashInfer." + ) + + if not current_platform.has_device_capability((7, 5)): + raise RuntimeError(f"Unsupported CUDA compute capability for SVDQuant: {sm}") + + # nvfp4 needs SM_100+ tensor units; pre-Blackwell consumer cards + # (Turing/Ampere/Ada) cannot run it. + if precision == "nvfp4" and not current_platform.has_device_capability(100): + raise ValueError(f"NVFP4 SVDQuant requires SM_100+ or SM_120; got {sm}. Use precision='int4'.") + + # Backend-level missing-package check (current single backend = nunchaku). + from . import svdquant_nunchaku + + if not svdquant_nunchaku.has_nunchaku_w4a4(): + # The PyPI `nunchaku` package is an unrelated Bayesian library; + # SVDQuant kernels ship as GitHub release wheels only. + raise ImportError( + f"SVDQuant on {sm} requires nunchaku-ai's W4A4 wheels from " + "https://github.com/nunchaku-ai/nunchaku/releases " + "(not `pip install nunchaku`, which is a different project)." + ) + + +__all__ = ["SVDQuantPrecision", "assert_svdquant_supported", "select_backend"] diff --git a/vllm_omni/quantization/svdquant_nunchaku.py b/vllm_omni/quantization/svdquant_nunchaku.py new file mode 100644 index 00000000000..2d3c1221e6d --- /dev/null +++ b/vllm_omni/quantization/svdquant_nunchaku.py @@ -0,0 +1,281 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-Omni project +"""Nunchaku backend for SVDQuant W4A4. + +Covers consumer NVIDIA GPUs (SM_75 Turing through consumer Blackwell +SM_120) via the external `nunchaku` package. Hopper SM_90 and +datacenter Blackwell SM_100/103 are intentionally excluded — the +former has no validated kernel family; the latter is planned in +FlashInfer. + +Exposes the backend interface consumed by `svdquant_dispatch` and +`DiffusionSVDQuantLinearMethod`: + + supports(cap, precision) -> bool + prepare_weights(layer, precision) -> None + apply(layer, x, bias) -> Tensor + +Plus `has_nunchaku()` / `has_nunchaku_w4a4()` / `has_nunchaku_w4a16()` +for callers that need capability detection (notably the hardware gate). + +Install note: the PyPI `nunchaku` package is an unrelated Bayesian +library; SVDQuant kernels ship as GitHub release wheels from +https://github.com/nunchaku-ai/nunchaku/releases only. +""" + +from __future__ import annotations + +import functools +import importlib +import importlib.util +from collections.abc import Callable +from typing import Any, NoReturn + +import torch +from vllm.logger import init_logger +from vllm.platforms.interface import DeviceCapability + +logger = init_logger(__name__) + + +# ── Capability detection ──────────────────────────────────────────── + + +@functools.cache +def has_nunchaku() -> bool: + """Return True if the `nunchaku` package is importable.""" + if importlib.util.find_spec("nunchaku") is None: + logger.debug_once("Nunchaku unavailable: package not installed") + return False + return True + + +def _get_submodule(module_name: str) -> Any | None: + try: + return importlib.import_module(module_name) + except (ImportError, ModuleNotFoundError): + return None + + +@functools.cache +def has_nunchaku_w4a4() -> bool: + """True iff both the W4A4 GEMM and the fused act-quantize+LoRA op exist.""" + if not has_nunchaku(): + return False + required = [ + ("nunchaku.ops.gemm", "svdq_gemm_w4a4_cuda"), + ("nunchaku.ops.quantize", "svdq_quantize_w4a4_act_fuse_lora_cuda"), + ] + for module_name, attr_name in required: + mod = _get_submodule(module_name) + if mod is None or not hasattr(mod, attr_name): + logger.debug_once("Nunchaku W4A4 unavailable: missing %s.%s", module_name, attr_name) + return False + return True + + +@functools.cache +def has_nunchaku_w4a16() -> bool: + """True iff Nunchaku's W4A16 AWQ GEMV op exists (decode-style paths).""" + if not has_nunchaku(): + return False + mod = _get_submodule("nunchaku.ops.gemv") + return mod is not None and hasattr(mod, "awq_gemv_w4a16_cuda") + + +# ── Lazy call wrappers ────────────────────────────────────────────── + + +def _missing(*_: Any, **__: Any) -> NoReturn: + raise RuntimeError( + "Nunchaku is not installed. SVDQuant requires the nunchaku-ai " + "wheels from https://github.com/nunchaku-ai/nunchaku/releases " + "(do NOT `pip install nunchaku` — that pulls an unrelated PyPI " + "package). Source: https://github.com/nunchaku-ai/nunchaku" + ) + + +def _lazy_import_wrapper(module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing): + @functools.cache + def _get_impl(): + if not has_nunchaku(): + return None + mod = _get_submodule(module_name) + return getattr(mod, attr_name, None) if mod else None + + def wrapper(*args, **kwargs): + impl = _get_impl() + if impl is None: + return fallback_fn(*args, **kwargs) + return impl(*args, **kwargs) + + wrapper.__name__ = attr_name + wrapper.__qualname__ = f"nunchaku::{attr_name}" + return wrapper + + +_svdq_gemm_w4a4 = _lazy_import_wrapper("nunchaku.ops.gemm", "svdq_gemm_w4a4_cuda") +_svdq_quantize_w4a4_act_fuse_lora = _lazy_import_wrapper( + "nunchaku.ops.quantize", "svdq_quantize_w4a4_act_fuse_lora_cuda" +) + + +# ── Backend interface ─────────────────────────────────────────────── + +# Compute capabilities the nunchaku PTX-MMA family targets. Hopper SM_90 +# and datacenter Blackwell SM_100/103 are deliberately absent. +_SUPPORTED_CAPS: set[tuple[int, int]] = { + (7, 5), # Turing + (8, 0), # Ampere A100 + (8, 6), # Ampere consumer (RTX 30xx) + (8, 9), # Ada (RTX 40xx) + (12, 0), # Consumer Blackwell (RTX 5090) +} + + +def supports(cap: DeviceCapability | None, precision: str) -> bool: + """Return True iff this backend can serve (cap, precision).""" + if cap is None: + return False + if not has_nunchaku_w4a4(): + return False + # nvfp4 needs tcgen05's SM_100+ tensor units; in this backend that + # means consumer Blackwell only. + if precision == "nvfp4" and (cap.major, cap.minor) != (12, 0): + return False + return (cap.major, cap.minor) in _SUPPORTED_CAPS + + +def prepare_weights(layer: torch.nn.Module, precision: str) -> None: + """Post-load weight prep for the nunchaku kernel. + + On-disk format is canonical row-major NVFP4 (or INT4-nibble); the + nunchaku kernel wants a PTX-MMA fragment-permuted layout. For + NVFP4 we repack in-place via the bit-preserving pack chain in + `tools/svdquant_nvfp4_layout`; for INT4 the on-disk layout is + already what the kernel expects. + + Also caches the kernel's per-tensor `alpha` from `wtscale`. Do NOT + fold `wcscales` into `alpha`: the kernel applies them as + `(accumulator * alpha) * wcscales` and conflating them + double-counts the per-channel factors. + """ + if precision == "nvfp4": + _pack_nvfp4_to_nunchaku_fragment(layer) + + alpha: float = 1.0 + wtscale = getattr(layer, "wtscale", None) + if wtscale is not None: + value = float(wtscale.detach().cpu().item()) + if abs(value - 1.0) > 1e-6: + alpha = value + layer._svdquant_alpha = alpha + + +def _pack_nvfp4_to_nunchaku_fragment(layer: torch.nn.Module) -> None: + """Repack row-major NVFP4 params in-place to nunchaku fragment layout. + + On-disk (canonical row-major): + * qweight : [N, K/2] int8/uint8 (FP4 nibbles, low = even-k) + * wscales : [K/16, N] fp8_e4m3fn + * proj_up : [N, R] + * proj_down : [K, R] + + After repack (nunchaku PTX-MMA fragment): + * qweight : [N, K/2] int8 (permuted into MMA fragment) + * wscales : [K/16, N] fp8 (permuted into MMA fragment) + * proj_up : [N, R] (permuted into MMA fragment) + * proj_down : [K, R] (permuted into MMA fragment) + """ + # Lazy imports: nunchaku is a soft dep on non-consumer hardware, + # and the layout helpers pull in torch ops we only need here. + from nunchaku.lora.flux.nunchaku_converter import pack_lowrank_weight + + from vllm_omni.quantization.tools.svdquant_nvfp4_layout import ( + _unpack_nibbles, + pack_nunchaku_qweight_fp4, + pack_nunchaku_wscales_fp4, + ) + + device = layer.qweight.device + + # qweight: stored as [N, K/2] packed-nibble bytes (low = even-k). + # `pack_nunchaku_qweight_fp4` expects [N, K] one-nibble-per-byte — + # unpack to that form first, then pack to nunchaku fragment. + qw_rm_packed = layer.qweight.data.view(torch.uint8) # [N, K/2] + qw_rm_nibs = _unpack_nibbles(qw_rm_packed) # [N, K] + layer.qweight.data = pack_nunchaku_qweight_fp4(qw_rm_nibs).to(device) + + # wscales: pack pair operates in fp8_e4m3fn. + layer.wscales.data = pack_nunchaku_wscales_fp4(layer.wscales.data).to(device) + + # proj_up: row-major [N, R] → nunchaku frag [N, R]. down=False. + layer.proj_up.data = pack_lowrank_weight(layer.proj_up.data, down=False).to(device) + + # proj_down: canonical row-major is [K, R]; nunchaku's pack expects + # [R, K] (transpose-quirk on the down=True path). Transpose then pack; + # output is fragment [K, R]. + pd_rk = layer.proj_down.data.transpose(0, 1).contiguous() + layer.proj_down.data = pack_lowrank_weight(pd_rk, down=True).to(device) + + +def apply( + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, +) -> torch.Tensor: + """Run the nunchaku W4A4 GEMM.""" + orig_shape = x.shape + x_2d = x.reshape(-1, orig_shape[-1]) + + is_fp4 = layer.precision == "nvfp4" + out_features = layer.out_features_per_partition + + quantized_x, ascales, lora_act_out = _svdq_quantize_w4a4_act_fuse_lora( + x_2d, + lora_down=layer.proj_down, + smooth=layer.smooth_factor, + fp4=is_fp4, + pad_size=256, + ) + + # The quantize kernel may pad the batch dim up to a multiple of + # `pad_size`; the GEMM consumes the padded shape, then we trim back + # below. + out_2d = torch.empty( + quantized_x.shape[0], + out_features, + dtype=layer.proj_up.dtype, + device=x_2d.device, + ) + + _svdq_gemm_w4a4( + act=quantized_x, + wgt=layer.qweight, + out=out_2d, + ascales=ascales, + wscales=layer.wscales, + lora_act_in=lora_act_out, + lora_up=layer.proj_up, + bias=bias, + fp4=is_fp4, + alpha=getattr(layer, "_svdquant_alpha", 1.0), + wcscales=layer.wcscales, + act_unsigned=layer.act_unsigned, + ) + + actual_batch = x_2d.shape[0] + if out_2d.shape[0] > actual_batch: + out_2d = out_2d[:actual_batch] + + return out_2d.reshape(*orig_shape[:-1], out_features) + + +__all__ = [ + "has_nunchaku", + "has_nunchaku_w4a4", + "has_nunchaku_w4a16", + "supports", + "prepare_weights", + "apply", +] diff --git a/vllm_omni/quantization/tools/convert_nunchaku_to_svdquant.py b/vllm_omni/quantization/tools/convert_nunchaku_to_svdquant.py new file mode 100644 index 00000000000..e3da24f44d4 --- /dev/null +++ b/vllm_omni/quantization/tools/convert_nunchaku_to_svdquant.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +"""Convert a nunchaku-published merged-safetensors NVFP4 SVDQuant checkpoint +into a vLLM-loadable diffusers pipeline folder. + +On-disk format: canonical row-major + FP4 nibble pack. This is the +layout the SM_100 native (CuTe) kernel consumes directly. For the +nunchaku kernel backend (consumer GPUs), vLLM repacks to fragment at +load time via `SVDQuantLinearMethod.process_weights_after_loading`. +The on-disk format is the same regardless of target backend; users do +not need to know about nunchaku-vs-native layout. + +What this does: + 1. Resolve inputs (local paths or HuggingFace repo ids → snapshot_download). + 2. Stream tensors from the nunchaku merged safetensors, grouping by linear + layer prefix (those with a `.qweight` sibling). + 3. Unpack fragment layout → row-major for every layer: + qweight via `unpack_nunchaku_qweight_fp4` → [N, K/2] uint8 (FP4 nibbles) + wscales via `unpack_nunchaku_wscales_fp4` → [K/16, N] fp8_e4m3fn + proj_up via `unpack_lowrank_weight(down=False)` → [N, R] + proj_down via `unpack_lowrank_weight(down=True)` → unpack returns + [R, K] (transpose-quirk in nunchaku), transpose back to [K, R] + 4. For each fused gate-up linear (suffix `.feed_forward.net.0.proj` in + Z-Image), do a bit-preserving N-axis half-swap so the on-disk layout + matches vLLM's standard `[gate; hidden]` SiluAndMul convention. + 5. Ensure each linear's state-dict block carries `wtscale` as a `(1,)` + bf16 tensor (default 1.0 if missing). + 6. Write a complete diffusers folder at `--output-dir`: + /model_index.json (linked from base) + /scheduler/, text_encoder/, ... (linked from base) + /transformer/config.json (base config + injected + "quantization_config" field + so vllm-omni auto-picks SVDQuant) + /transformer/diffusion_pytorch_model.safetensors (converted weights) + +All transforms are pure permute+view (bit-preserving). Round-trip +verified: `pack(unpack(x)) == x` bit-exactly for proj_down and proj_up +across shape stress; half-swap pipeline verified end-to-end against +`svdq_gemm_w4a4_cuda`. + +Usage: + python -m vllm_omni.quantization.tools.convert_nunchaku_to_svdquant \\ + --nunchaku-checkpoint nunchaku-tech/nunchaku-z-image-turbo/svdq-fp4_r128-z-image-turbo.safetensors \\ + --base-pipeline Tongyi-MAI/Z-Image-Turbo \\ + --output-dir ~/.cache/huggingface/hub/models--ultranationalism--nunchaku-z-image-turbo-svdq/ + +Both `--nunchaku-checkpoint` and `--base-pipeline` accept either a local +path or an HF repo id; HF auto-download happens only on cache miss. + +Non-transformer subfolders + top-level files are hard-linked from the +base pipeline by default (saves 35+ GB). `huggingface_hub.upload_folder` +reads file content, so hard links upload fine. Use `--copy` to disable. +""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +from pathlib import Path + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from vllm_omni.quantization.tools.svdquant_nvfp4_layout import ( + _unpack_nibbles, + unpack_nunchaku_qweight_fp4, + unpack_nunchaku_wscales_fp4, +) + +# --------------------------------------------------------------------------- +# Z-Image-specific knowledge +# --------------------------------------------------------------------------- + +# Fused gate-up linear suffix in nunchaku-published Z-Image NVFP4 checkpoints. +# diffusers `FeedForward(activation_fn="swiglu")` stores the fused up +# projection at `.net.0.proj`; the `feed_forward` parent name is from +# Z-Image's transformer block. Identified empirically from +# `svdq-fp4_r128-z-image-turbo.safetensors` (34 such layers; shape +# [2*hidden, K]). +ZIMAGE_FUSED_GATE_UP_TAIL = "feed_forward.net.0.proj" + +# Key-prefix renames from diffusers FF naming (in nunchaku checkpoint) to +# vllm-omni Z-Image's `MergedColumnParallelLinear` + `RowParallelLinear` naming. +# See `vllm_omni/diffusion/models/z_image/z_image_transformer.py:368-394` — +# `FeedForward.__init__` registers `self.w13` (fused gate-up) and `self.w2` +# (down). The diffusers FF block uses `net = ModuleList([gate_up, act, down])` +# instead, giving `net.0.proj` and `net.2`. Both quantized. +ZIMAGE_LAYER_RENAMES: dict[str, str] = { + "feed_forward.net.0.proj": "feed_forward.w13", + "feed_forward.net.2": "feed_forward.w2", +} + + +def is_fused_gate_up_zimage(layer_prefix: str) -> bool: + return layer_prefix.endswith(ZIMAGE_FUSED_GATE_UP_TAIL) + + +def rename_key_zimage(key: str) -> str: + """Apply the Z-Image FF layer-name renames (nunchaku diffusers → vllm-omni). + + Only matches `..` substrings (with both bounding dots) to avoid + false matches on substrings. + """ + for src, dst in ZIMAGE_LAYER_RENAMES.items(): + marker = f".{src}." + if marker in key: + return key.replace(marker, f".{dst}.", 1) + return key + + +# --------------------------------------------------------------------------- +# Per-linear nunchaku-fragment → row-major (with optional half-swap) +# --------------------------------------------------------------------------- + + +# nunchaku.lora.flux is the canonical home of unpack_lowrank_weight; import +# lazily so the script can at least argparse-help without nunchaku installed. +def _lowrank_unpack(): + from nunchaku.lora.flux.nunchaku_converter import unpack_lowrank_weight + + return unpack_lowrank_weight + + +def _pack_qweight_row_major(nibs: torch.Tensor) -> torch.Tensor: + """`[N, K] uint8 nibbles → [N, K/2] uint8`, low nibble = even k. + + Inverse of `_unpack_nibbles`. The on-disk canonical `qweight` is the + pair-packed nibble byte exactly as the SM_100 CuTe kernel expects. + """ + assert nibs.shape[-1] % 2 == 0 + lo = nibs[..., 0::2] + hi = nibs[..., 1::2] + return (lo | (hi << 4)).to(torch.uint8) + + +def unpack_nvfp4_layer( + params: dict[str, torch.Tensor], + *, + half_swap_n: bool, +) -> dict[str, torch.Tensor]: + """nunchaku fragment → canonical row-major for one NVFP4 SVDQuant linear. + + Pure permute+view (bit-preserving) for `qweight`, `wscales`, `proj_up`, + `proj_down`. `wcscales`, `bias`, `smooth_factor`, etc. are already + layout-agnostic and copy through. + + When `half_swap_n=True`, additionally swap the two N-axis halves on + `qweight`, `wscales`, `proj_up`, `wcscales`, `bias` — the SiluAndMul + `[gate; hidden]` reorder. Swap happens on row-major intermediates, + which is free (it's a slice + cat). + """ + unpack_lowrank_weight = _lowrank_unpack() + + qweight = params["qweight"] # [N, K/2] int8 (nunchaku fragment) + wscales = params["wscales"] # [K/16, N] fp8 (nunchaku fragment) + proj_up = params["proj_up"] # [N, R] bf16 (nunchaku fragment) + proj_down = params["proj_down"] # [K, R] bf16 (nunchaku fragment) + wcscales = params.get("wcscales") # [N] bf16 (optional) + bias = params.get("bias") # [N] bf16 (optional) + + N = qweight.shape[0] + if half_swap_n: + assert N % 2 == 0, f"fused gate-up N must be even; got {N}" + half = N // 2 + + # qweight: unpack fragment → [N, K/2] uint8 nibble bytes (low = even-k); + # then `_unpack_nibbles` → [N, K] full-nibble form so we can slice on N + # then repack to [N, K/2] for storage. + qw_rm = unpack_nunchaku_qweight_fp4(qweight.view(torch.int8)) # [N, K/2] uint8 + if half_swap_n: + nibs = _unpack_nibbles(qw_rm) # [N, K] uint8 + nibs = torch.cat([nibs[half:], nibs[:half]], dim=0).contiguous() + qw_rm = _pack_qweight_row_major(nibs) + qweight_out = qw_rm.contiguous() + + # wscales: unpack to [K/16, N] row-major fp8. + ws_rm = unpack_nunchaku_wscales_fp4(wscales) + if half_swap_n: + ws_rm = torch.cat([ws_rm[:, half:], ws_rm[:, :half]], dim=1).contiguous() + wscales_out = ws_rm.contiguous() + + # proj_up: down=False → unpack returns [N, R] directly. + pu_rm = unpack_lowrank_weight(proj_up, down=False) + if half_swap_n: + pu_rm = torch.cat([pu_rm[half:], pu_rm[:half]], dim=0).contiguous() + proj_up_out = pu_rm.contiguous() + + # proj_down: down=True. nunchaku's unpack returns [R, K]; canonical + # row-major is [K, R] (matches SM_100 CuTe kernel's expected layout). + # Transpose to [K, R]. + pd_rm = unpack_lowrank_weight(proj_down, down=True) + K = proj_down.shape[0] + R = proj_down.shape[1] + if pd_rm.shape == (R, K): + pd_rm = pd_rm.transpose(0, 1).contiguous() + assert pd_rm.shape == (K, R), f"proj_down expected ({K}, {R}); got {tuple(pd_rm.shape)}" + proj_down_out = pd_rm + + out = dict(params) + out["qweight"] = qweight_out + out["wscales"] = wscales_out + out["proj_up"] = proj_up_out + out["proj_down"] = proj_down_out + if half_swap_n: + if wcscales is not None: + out["wcscales"] = torch.cat([wcscales[half:], wcscales[:half]]).contiguous() + if bias is not None: + out["bias"] = torch.cat([bias[half:], bias[:half]]).contiguous() + return out + + +# --------------------------------------------------------------------------- +# Input materialization +# --------------------------------------------------------------------------- + + +def _resolve_nunchaku_checkpoint(arg: str) -> Path: + """Accept a local file path OR an HF spec `/`. + + Local path is returned as-is. Otherwise the trailing component is treated + as the filename within the repo, and the rest is the repo id. Downloads + only on cache miss. + """ + p = Path(arg) + if p.exists() and p.is_file(): + return p + # Treat as HF spec: split into (repo_id, filename). + parts = arg.split("/") + if len(parts) < 3: + raise ValueError( + f"--nunchaku-checkpoint {arg!r} is not a local file and not a " + "/ spec (need owner/name/file.safetensors)" + ) + repo_id = "/".join(parts[:2]) + filename = "/".join(parts[2:]) + from huggingface_hub import hf_hub_download + + print(f"resolving nunchaku checkpoint: repo={repo_id} file={filename}") + path = hf_hub_download(repo_id=repo_id, filename=filename) + return Path(path) + + +def _resolve_base_pipeline(arg: str) -> Path: + """Accept a local diffusers folder OR an HF repo id.""" + p = Path(arg) + if p.exists() and p.is_dir(): + # Local diffusers folder. + return p + # HF repo id → snapshot_download (uses cache if present). + from huggingface_hub import snapshot_download + + print(f"resolving base pipeline: repo={arg}") + path = snapshot_download(repo_id=arg) + return Path(path) + + +# --------------------------------------------------------------------------- +# Filesystem mirror (hard-link with copy fallback) +# --------------------------------------------------------------------------- + + +def _link_or_copy_file(src: Path, dst: Path, prefer_copy: bool) -> None: + """Hard-link src → dst, falling back to copy. Resolves source symlinks + (the HF cache uses symlink-from-snapshot-to-blob; we want the blob). + """ + real = src.resolve() + if dst.exists() or dst.is_symlink(): + dst.unlink() + if prefer_copy: + shutil.copy2(real, dst) + return + try: + os.link(real, dst) + except OSError: + # Cross-fs or permissions: fall back to copy. + shutil.copy2(real, dst) + + +def _link_or_copy_tree(src: Path, dst: Path, prefer_copy: bool) -> None: + dst.mkdir(parents=True, exist_ok=True) + for item in src.iterdir(): + d = dst / item.name + if item.is_dir(): + _link_or_copy_tree(item, d, prefer_copy) + else: + _link_or_copy_file(item, d, prefer_copy) + + +# --------------------------------------------------------------------------- +# Conversion driver +# --------------------------------------------------------------------------- + + +# Suffixes nunchaku publishes alongside every quantized linear that the +# vLLM SVDQuant LinearMethod does not consume — they bloat the output +# checkpoint without serving any backend. Filter them at group time so +# downstream conversion / save never touches them. +# +# `smooth_factor_orig`: declared by nunchaku as "(Unused)" (see +# `nunchaku/models/linear.py:54`) and never read by any quantize/forward +# path in either int4 or nvfp4. ~0.001 GB across a Z-Image checkpoint — +# trivially small, but keeping it triggers a KeyError at load time since +# vLLM does not register a `smooth_factor_orig` parameter. +_DROPPED_NUNCHAKU_SUFFIXES: frozenset[str] = frozenset({"smooth_factor_orig"}) + + +def _group_keys_by_layer( + keys: list[str], +) -> tuple[dict[str, list[str]], list[str]]: + """Return (layer_prefix → list-of-suffixes, leftover-keys). + + A "linear" is any key prefix that has a `.qweight` sibling. Suffixes + in `_DROPPED_NUNCHAKU_SUFFIXES` are filtered out entirely. + """ + qweight_prefixes = {k.rsplit(".", 1)[0] for k in keys if k.endswith(".qweight")} + layer_to_suffixes: dict[str, list[str]] = {p: [] for p in qweight_prefixes} + leftover: list[str] = [] + for k in keys: + prefix, _, suffix = k.rpartition(".") + if prefix in layer_to_suffixes: + if suffix in _DROPPED_NUNCHAKU_SUFFIXES: + continue + layer_to_suffixes[prefix].append(suffix) + else: + leftover.append(k) + return layer_to_suffixes, leftover + + +def _detect_rank_precision(f, sample_prefix: str) -> tuple[int, str]: + proj_down = f.get_tensor(f"{sample_prefix}.proj_down") + wscales = f.get_tensor(f"{sample_prefix}.wscales") + rank = int(proj_down.shape[1]) + if wscales.dtype == torch.float8_e4m3fn: + precision = "nvfp4" + elif wscales.dtype in (torch.float16, torch.bfloat16): + precision = "int4" + else: + raise ValueError(f"unexpected wscales dtype {wscales.dtype}") + return rank, precision + + +def convert( + nunchaku_checkpoint: Path, + base_pipeline: Path, + output_dir: Path, + *, + prefer_copy: bool, + is_fused_gate_up=is_fused_gate_up_zimage, + rename_key=rename_key_zimage, + progress: bool = True, +) -> None: + """Drive the full conversion. See module docstring for behavior.""" + output_dir.mkdir(parents=True, exist_ok=True) + + # ----- Mirror base pipeline (everything except transformer/) ----- + base_top_level = sorted(base_pipeline.iterdir(), key=lambda p: p.name) + for item in base_top_level: + if item.name == "transformer": + continue + d = output_dir / item.name + if item.is_dir(): + _link_or_copy_tree(item, d, prefer_copy) + else: + _link_or_copy_file(item, d, prefer_copy) + print(f"mirrored {len(base_top_level) - 1} top-level entries from base ({'copy' if prefer_copy else 'hard-link'})") + + # ----- transformer/ ----- + transformer_dir = output_dir / "transformer" + transformer_dir.mkdir(exist_ok=True) + base_transformer = base_pipeline / "transformer" + + # ----- Scan nunchaku checkpoint ----- + with safe_open(str(nunchaku_checkpoint), framework="pt", device="cpu") as f: + keys = list(f.keys()) + metadata = f.metadata() or {} + layer_to_suffixes, leftover = _group_keys_by_layer(keys) + if not layer_to_suffixes: + raise RuntimeError("no quantized linears found (no .qweight keys)") + sample_prefix = next(iter(layer_to_suffixes)) + rank, precision = _detect_rank_precision(f, sample_prefix) + + n_linears = len(layer_to_suffixes) + n_fused = sum(1 for p in layer_to_suffixes if is_fused_gate_up(p)) + print( + f"nunchaku checkpoint: {n_linears} quantized linears, " + f"{n_fused} fused gate-up (to swap); {len(leftover)} other keys" + ) + print(f"detected rank={rank} precision={precision}") + if "model_class" in metadata: + print(f"nunchaku metadata model_class={metadata['model_class']!r}") + + # ----- Build output state_dict via streaming reads ----- + out_sd: dict[str, torch.Tensor] = {} + + for i, (prefix, suffixes) in enumerate(sorted(layer_to_suffixes.items())): + params: dict[str, torch.Tensor] = {} + for suf in suffixes: + params[suf] = f.get_tensor(f"{prefix}.{suf}") + + # ---- normalize: make the per-layer state-dict self-contained ---- + # vllm-omni's diffusers_loader doesn't whitelist SVDQuant-specific + # suffixes in `_QUANTIZED_WEIGHT_SUFFIXES`, so missing wcscales / + # wtscale would be treated as unexpected_missing → ValueError. + # Fill with the kernel-identity defaults vLLM uses in create_weights. + qweight = params["qweight"] + N = qweight.shape[0] + lora_dtype = params["proj_up"].dtype # bf16 for NVFP4 per vLLM convention + + # wcscales (NVFP4 only): default ones = identity per-channel scale + if precision == "nvfp4" and "wcscales" not in params: + params["wcscales"] = torch.ones(N, dtype=lora_dtype) + # wtscale (NVFP4 only): default 1.0 = identity per-tensor scale. + # Also normalize 0-D → 1-D for the entries that are present. + if precision == "nvfp4": + if "wtscale" not in params: + params["wtscale"] = torch.tensor([1.0], dtype=lora_dtype) + elif params["wtscale"].dim() == 0: + params["wtscale"] = params["wtscale"].view(1).contiguous() + + # ---- transform: nunchaku fragment → row-major (+ SwiGLU + # half-swap for fused gate-up layers). On-disk format is + # canonical row-major regardless of target backend; nunchaku + # backend repacks at load time in vLLM. + params = unpack_nvfp4_layer(params, half_swap_n=is_fused_gate_up(prefix)) + + # ---- emit: rename source prefix to vllm-omni's param naming ---- + out_prefix = rename_key(f"{prefix}.dummy")[: -len(".dummy")] + for suf, t in params.items(): + out_sd[f"{out_prefix}.{suf}"] = t + if progress and (i % 20 == 0 or i == n_linears - 1): + print(f" [{i + 1}/{n_linears}] {prefix}" + (f" -> {out_prefix}" if out_prefix != prefix else "")) + + # Leftover (unquantized) keys: rename too (most are no-ops; safer to apply uniformly). + for k in leftover: + out_sd[rename_key(k)] = f.get_tensor(k) + + # ----- transformer/config.json: inject quantization_config ----- + # vllm-omni reads `transformer/config.json["quantization_config"]` to + # auto-detect the quant method (see `OmniDiffusionConfig` / + # `TransformerConfig.from_dict`); a sidecar `quantization_config.json` + # is *not* consulted. Mirror what `merge_mxfp8_checkpoint.py` does: + # load base config.json, inject the dict, write back. + # + # Per-component routing (`ComponentQuantizationConfig`): SVDQuant is an + # *offline* method — checkpoints have `.qweight`/`.wscales`/... keys, + # not `.weight`. If we apply it to ZImagePipeline globally, the Qwen3 + # text encoder (BF16, ships with `.weight` keys) gets its linears + # wrapped in SVDQuant slots too and refuses to load. The PR #1034 + # (Z-Image FP8) path got away with this because FP8 *online* mode + # accepts plain `.weight` and converts at load-time; SVDQuant can't. + # + # Scope: prefix `"model"` matches the Qwen3 text encoder layers + # (`model.layers.X.{self_attn,mlp}.*`) — masked to None. Everything + # else (Z-Image DiT prefixes `layers.X.*` / `noise_refiner.X.*` / + # `context_refiner.X.*`) falls through to the `default` SVDQuant rule. + with open(base_transformer / "config.json") as fp: + tf_config = json.load(fp) + tf_config["quantization_config"] = { + # Text encoder (Qwen3): unquantized. Its only nn.Linear instances + # live under `model.layers.*` and get prefixed accordingly by + # `recursive_replace_linear` (utils.py:96 starts prefix=""). + "model": None, + "default": { + # Use `quant_method` (HF convention; matches MXFP8 path). + # `vllm-omni`'s factory accepts either `quant_method` or `method`. + "quant_method": "svdquant", + "rank": rank, + "precision": precision, + "act_unsigned": False, + # `lm_head` covers Qwen3's text-encoder language-modeling + # head, which lives at the **top level** of `Qwen3ForCausalLM` + # — *not* under `model.*` — so it is not caught by the + # `"model": None` prefix rule above. Without this substring + # skip it falls through to SVDQuant and hits a tied-weight + # `data_ptr` error on the first forward (see vllm-omni + # diffusers_loader handling of Qwen3 text encoder). + # + # Other precision-sensitive Z-Image DiT linears + # (cap_embedder, x_embedder, adaLN_modulation, t_embedder, + # FinalLayer.linear) already pass `quant_config=None` in + # the model class itself, so they need no entry here. + "modules_to_not_convert": ["lm_head"], + }, + } + out_config_path = transformer_dir / "config.json" + # Defensive: a previous run may have hard-linked config.json from the base + # snapshot; open(..., "w") would truncate the shared inode and corrupt + # the base's cached blob. Unlink first to detach. + if out_config_path.exists() or out_config_path.is_symlink(): + out_config_path.unlink() + with open(out_config_path, "w") as fp: + json.dump(tf_config, fp, indent=2) + print(f"wrote {out_config_path} (with embedded quantization_config)") + + # ----- Write the converted single safetensors ----- + out_path = transformer_dir / "diffusion_pytorch_model.safetensors" + # Preserve nunchaku metadata so downstream can still inspect provenance. + out_metadata = {k: v for k, v in metadata.items() if isinstance(v, str)} + out_metadata["conversion"] = json.dumps( + { + "tool": "vllm_omni.quantization.tools.convert_nunchaku_to_svdquant", + "layout": "row_major", # canonical; vLLM repacks for nunchaku backend at load + "half_swapped_layers": [p for p in layer_to_suffixes if is_fused_gate_up(p)], + } + ) + save_file(out_sd, str(out_path), metadata=out_metadata) + print(f"wrote {out_path} ({out_path.stat().st_size / 2**30:.2f} GiB)") + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main() -> None: + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description=__doc__, + ) + parser.add_argument( + "--nunchaku-checkpoint", + required=True, + help="Local path to nunchaku merged .safetensors OR HF spec " + "/ (e.g. nunchaku-tech/nunchaku-z-image-turbo" + "/svdq-fp4_r128-z-image-turbo.safetensors).", + ) + parser.add_argument( + "--base-pipeline", + default="Tongyi-MAI/Z-Image-Turbo", + help="Local diffusers folder OR HF repo id of the unquantized base " + "pipeline. Default: Tongyi-MAI/Z-Image-Turbo.", + ) + parser.add_argument( + "--output-dir", + required=True, + help="Output diffusers folder path.", + ) + parser.add_argument( + "--copy", + action="store_true", + help="Copy non-transformer files instead of hard-linking (slower, " + "uses ~35 GiB extra). Default: hard-link (HF upload-safe).", + ) + args = parser.parse_args() + + nunchaku_path = _resolve_nunchaku_checkpoint(args.nunchaku_checkpoint) + base_path = _resolve_base_pipeline(args.base_pipeline) + output_dir = Path(args.output_dir).expanduser() + + print(f"nunchaku checkpoint: {nunchaku_path}") + print(f"base pipeline: {base_path}") + print(f"output: {output_dir}") + print() + + convert( + nunchaku_checkpoint=nunchaku_path, + base_pipeline=base_path, + output_dir=output_dir, + prefer_copy=args.copy, + ) + print("\ndone.") + + +if __name__ == "__main__": + main() diff --git a/vllm_omni/quantization/tools/svdquant_nvfp4_layout.py b/vllm_omni/quantization/tools/svdquant_nvfp4_layout.py new file mode 100644 index 00000000000..096dae163b0 --- /dev/null +++ b/vllm_omni/quantization/tools/svdquant_nvfp4_layout.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM-Omni project +"""nunchaku NVFP4 SVDQuant fragment-layout adapters. + +Bridge between the canonical row-major SVDQuant NVFP4 on-disk format +and nunchaku's PTX-MMA-tile fragment layout. Bit-preserving pure +view+permute chain — no quant/dequant. + +Used in two directions: + +* Checkpoint conversion (vllm-omni converter): a nunchaku-published + checkpoint is unpacked to canonical row-major for writing to disk. +* Load-time pack (`svdquant_nunchaku.prepare_weights`): for the nunchaku + kernel backend, repack the row-major on-disk tensors into fragment + layout before the kernel sees them. + +Verified against `svdq_gemm_w4a4_cuda(fp4=True)`: round-trip is +bit-exact, and half-swap via unpack→swap→pack reproduces the permuted +nunchaku output bit-exactly. Workbench source: +SVDQuant kernel `baseline/kernels/_nvfp4.py`. + +Pair semantics: + * `unpack_nunchaku_wscales_fp4(s_nun)` `[K/16, N] fragment → row-major` + * `pack_nunchaku_wscales_fp4(s_row)` `[K/16, N] row-major → fragment` + * `unpack_nunchaku_qweight_fp4(q_nun)` `[N, K/2] fragment → row-major uint8 nibble bytes` + * `pack_nunchaku_qweight_fp4(nibs_row)` `[N, K] nibbles → [N, K/2] fragment int8` + +These plus `nunchaku.lora.flux.nunchaku_converter.{pack,unpack}_lowrank_weight` +cover every fragment-layout param needed for SVDQuant W4A4 NVFP4 +half-swap (qweight, wscales, proj_up). + +Constants assume `NunchakuWeightPacker(bits=4, warp_n=128)`: + wscales: s_pack_size=4, num_s_lanes=32, num_s_packs=1, insn_k/group=4 + qweight: num_n_packs=8, n_pack_size=2, num_n_lanes=8, reg_n=1, + num_k_packs=1, k_pack_size=2, num_k_lanes=4, reg_k=8 +""" + +from __future__ import annotations + +import torch + +_WARP_N = 128 +_INSN_K = 64 +_GROUP = 16 + + +def _pack_nibbles(nibs: torch.Tensor) -> torch.Tensor: + """`[*, K] uint8 nibbles → [*, K/2] uint8`. Low nibble = even k.""" + assert nibs.shape[-1] % 2 == 0 + lo = nibs[..., 0::2] + hi = nibs[..., 1::2] + return (lo | (hi << 4)).to(torch.uint8) + + +def _unpack_nibbles(packed: torch.Tensor) -> torch.Tensor: + """`[*, K/2] uint8 → [*, K] uint8 nibbles`. Inverse of `_pack_nibbles`.""" + lo = packed & 0x0F + hi = (packed >> 4) & 0x0F + out = torch.stack([lo, hi], dim=-1) + return out.view(*packed.shape[:-1], packed.shape[-1] * 2) + + +def _wscale_view_shape(N: int, K: int) -> tuple[int, ...]: # noqa: N803 + assert N % _WARP_N == 0, f"N ({N}) must be multiple of {_WARP_N}" + assert K % _INSN_K == 0, f"K ({K}) must be multiple of {_INSN_K}" + return (N // _WARP_N, 1, 4, 4, 8, K // _INSN_K, 4) + + +def pack_nunchaku_wscales_fp4(scales_row: torch.Tensor) -> torch.Tensor: + """Row-major `[K/16, N]` fp8 → nunchaku fragment `[K/16, N]` fp8.""" + KG, N = scales_row.shape + K = KG * _GROUP + s = scales_row.transpose(0, 1).contiguous() + s = s.view(*_wscale_view_shape(N, K)) + s = s.permute(0, 5, 1, 4, 3, 2, 6).contiguous() + return s.view(-1, N) + + +def unpack_nunchaku_wscales_fp4(scales_nun: torch.Tensor) -> torch.Tensor: + """nunchaku fragment `[K/16, N]` fp8 → row-major `[K/16, N]` fp8.""" + KG, N = scales_nun.shape + K = KG * _GROUP + s = scales_nun.view(N // _WARP_N, K // _INSN_K, 1, 8, 4, 4, 4) + # Inverse of permute (0, 5, 1, 4, 3, 2, 6) is (0, 2, 5, 4, 3, 1, 6). + s = s.permute(0, 2, 5, 4, 3, 1, 6).contiguous() + s = s.view(N, K // _GROUP) + return s.transpose(0, 1).contiguous() + + +def pack_nunchaku_qweight_fp4(nibs_row: torch.Tensor) -> torch.Tensor: + """`[N, K] uint8 nibbles → [N, K/2] nunchaku fragment int8`.""" + N, K = nibs_row.shape + assert N % _WARP_N == 0, f"N ({N}) must be multiple of {_WARP_N}" + assert K % _INSN_K == 0, f"K ({K}) must be multiple of {_INSN_K}" + n_tiles, k_tiles = N // _WARP_N, K // _INSN_K + w = nibs_row.to(torch.int32) + w = w.reshape(n_tiles, 8, 2, 8, 1, k_tiles, 1, 2, 4, 8) + w = w.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous() + w = w & 0xF + shift = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device) + w = (w << shift).sum(dim=-1, dtype=torch.int32) + return w.view(dtype=torch.int8).view(N, -1).contiguous() + + +def unpack_nunchaku_qweight_fp4(q_nun: torch.Tensor) -> torch.Tensor: + """`[N, K/2] nunchaku fragment int8 → [N, K/2] uint8` (low nibble = even k).""" + N, K2 = q_nun.shape + K = K2 * 2 + assert N % _WARP_N == 0 + assert K % _INSN_K == 0 + n_tiles, k_tiles = N // _WARP_N, K // _INSN_K + q_int = q_nun.contiguous().view(dtype=torch.int32) + q_int = q_int.reshape(n_tiles, k_tiles, 1, 8, 8, 4, 2, 2, 1) + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=q_int.device) + nibs = ((q_int.unsqueeze(-1) >> shifts) & 0xF).to(torch.uint8) + # Inverse of permute (0, 5, 6, 1, 3, 8, 2, 7, 4, 9) is (0, 3, 6, 4, 8, 1, 2, 7, 5, 9). + nibs = nibs.permute(0, 3, 6, 4, 8, 1, 2, 7, 5, 9).contiguous() + nibs = nibs.view(N, K) + return _pack_nibbles(nibs) + + +__all__ = [ + "pack_nunchaku_qweight_fp4", + "unpack_nunchaku_qweight_fp4", + "pack_nunchaku_wscales_fp4", + "unpack_nunchaku_wscales_fp4", +]