From 9936ea8d5f0da71c6c72c6fe09a2c255220f5288 Mon Sep 17 00:00:00 2001 From: ultism Date: Thu, 21 May 2026 01:22:00 +0800 Subject: [PATCH] [Feature][Quantization] Add SVDQuant W4A4 (nunchaku backend) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add SVDQuant (https://arxiv.org/abs/2411.05007), the practical 4-bit-weight + 4-bit-activation quantization with low-rank SVD residual that drives most modern diffusion-transformer quantization. Layout: canonical row-major NVFP4 on disk. The nunchaku consumer-GPU kernel (Turing through consumer Blackwell SM_120) repacks once at load time into its PTX-MMA fragment layout in `SVDQuantLinearMethod.process_weights_after_loading`. Pack/unpack pair is bit-exact, verified against nunchaku.ops.gemm.svdq_gemm_w4a4_cuda. Apply path: svdq_quantize_w4a4_act_fuse_lora_cuda → svdq_gemm_w4a4_cuda; scalar alpha and act_unsigned are plumbed through. Hardware gate in utils/svdquant_dispatch.py::assert_svdquant_supported. Hopper SM_90 is unsupported by design — nunchaku targets older PTX-MMA shapes that the SM_90 tensor unit does not implement. Datacenter Blackwell SM_100/103 (B200/GB300) is out of scope here; that path is planned in FlashInfer so SGLang can share the primitive. vllm.utils.nunchaku provides lazy import wrappers so non-CUDA / non-consumer hosts never pull in the nunchaku package at module load. Co-authored-by: Claude (Anthropic) Signed-off-by: ultism --- tests/quantization/test_svdquant.py | 208 +++++++++ .../layers/quantization/__init__.py | 3 + .../layers/quantization/svdquant.py | 436 ++++++++++++++++++ .../quantization/utils/svdquant_dispatch.py | 72 +++ .../utils/svdquant_nvfp4_layout.py | 126 +++++ vllm/utils/nunchaku.py | 136 ++++++ 6 files changed, 981 insertions(+) create mode 100644 tests/quantization/test_svdquant.py create mode 100644 vllm/model_executor/layers/quantization/svdquant.py create mode 100644 vllm/model_executor/layers/quantization/utils/svdquant_dispatch.py create mode 100644 vllm/model_executor/layers/quantization/utils/svdquant_nvfp4_layout.py create mode 100644 vllm/utils/nunchaku.py diff --git a/tests/quantization/test_svdquant.py b/tests/quantization/test_svdquant.py new file mode 100644 index 000000000000..512a1fd7bd5c --- /dev/null +++ b/tests/quantization/test_svdquant.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Smoke tests for the SVDQuant quantization plugin. + +Real W4A4 numerics live on top of an actual quantized checkpoint and +require a CUDA capability that the kernel backend supports. These +tests cover the boundary that vLLM owns: the registry wiring, the +config / linear method shape, and the hardware-keyed backend +selection. +""" + +import pytest +import torch + +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization import ( + QUANTIZATION_METHODS, + get_quantization_config, +) +from vllm.model_executor.layers.quantization.svdquant import ( + SVDQuantConfig, + SVDQuantLinearMethod, +) +from vllm.model_executor.layers.quantization.utils.svdquant_dispatch import ( + assert_svdquant_supported, +) +from vllm.platforms import current_platform +from vllm.platforms.interface import DeviceCapability +from vllm.utils.nunchaku import has_nunchaku_w4a4 + + +def test_svdquant_is_registered() -> None: + assert "svdquant" in QUANTIZATION_METHODS + cls = get_quantization_config("svdquant") + assert cls is SVDQuantConfig + assert cls.get_name() == "svdquant" + + +def test_config_from_dict_int4() -> None: + cfg = SVDQuantConfig.from_config( + {"rank": 32, "precision": "int4", "act_unsigned": False} + ) + assert cfg.rank == 32 + assert cfg.precision == "int4" + assert cfg.group_size == 64 + assert cfg.act_unsigned is False + assert cfg.modules_to_not_convert == [] + + +def test_config_from_dict_nvfp4() -> None: + cfg = SVDQuantConfig.from_config( + { + "rank": 64, + "precision": "nvfp4", + "modules_to_not_convert": ["embedder", "final_layer"], + } + ) + assert cfg.precision == "nvfp4" + assert cfg.group_size == 16 # NVFP4 tcgen05 scale block + assert cfg.modules_to_not_convert == ["embedder", "final_layer"] + + +def test_config_rejects_unknown_precision() -> None: + with pytest.raises(ValueError, match="precision"): + SVDQuantConfig(precision="fp8") # type: ignore[arg-type] + + +@pytest.mark.skipif( + not current_platform.is_cuda(), reason="hardware gate is CUDA-specific" +) +def test_hardware_gate_accepts_consumer_gpus() -> None: + if not has_nunchaku_w4a4(): + pytest.skip("nunchaku not installed") + major, _ = current_platform.get_device_capability() + if major == 9: + pytest.skip("Hopper is intentionally unsupported") + if major == 10: + pytest.skip("Datacenter Blackwell is out of scope (FlashInfer planned)") + # Turing/Ampere/Ada (SM_75-89) and consumer Blackwell SM_120 are + # accepted by the gate for int4. + assert_svdquant_supported("int4") + + +def test_hardware_gate_rejects_hopper(monkeypatch: pytest.MonkeyPatch) -> None: + """Hopper SM_90 must raise.""" + # Patch the class (not the instance): classmethods in Platform call + # cls.get_device_capability(), bypassing instance attribute lookup. + cls = type(current_platform) + monkeypatch.setattr(cls, "is_cuda", classmethod(lambda c: True)) + monkeypatch.setattr( + cls, + "get_device_capability", + classmethod(lambda c, *a, **k: DeviceCapability(9, 0)), + ) + with pytest.raises(RuntimeError, match="Hopper"): + assert_svdquant_supported("int4") + + +def test_hardware_gate_rejects_datacenter_blackwell( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """SM_100/103 is out of scope here (FlashInfer-planned); must raise.""" + cls = type(current_platform) + monkeypatch.setattr(cls, "is_cuda", classmethod(lambda c: True)) + monkeypatch.setattr( + cls, + "get_device_capability", + classmethod(lambda c, *a, **k: DeviceCapability(10, 0)), + ) + with pytest.raises(RuntimeError, match="FlashInfer"): + assert_svdquant_supported("nvfp4") + + +def test_hardware_gate_rejects_nvfp4_on_pre_blackwell( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """NVFP4 needs SM_100+ tensor units; SM_8x must raise cleanly.""" + if not has_nunchaku_w4a4(): + pytest.skip("nunchaku not installed") + cls = type(current_platform) + monkeypatch.setattr(cls, "is_cuda", classmethod(lambda c: True)) + monkeypatch.setattr( + cls, + "get_device_capability", + classmethod(lambda c, *a, **k: DeviceCapability(8, 9)), + ) + with pytest.raises(ValueError, match="NVFP4"): + assert_svdquant_supported("nvfp4") + + +@pytest.mark.skipif( + not (current_platform.is_cuda() and has_nunchaku_w4a4()), + reason="requires CUDA + nunchaku for create_weights smoke", +) +def test_linear_method_create_weights_int4() -> None: + """Validate the parameter layout without invoking the kernel. + + Only checks that `create_weights` populates the layer with + correctly-shaped, correctly-dtyped tensors. + """ + cfg = SVDQuantConfig(rank=32, precision="int4") + method = SVDQuantLinearMethod(cfg) + + # Mimic a 4096-in / 4096-out column-parallel layer with TP=1. + layer = torch.nn.Module() + method.create_weights( + layer, + input_size_per_partition=4096, + output_partition_sizes=[4096], + input_size=4096, + output_size=4096, + params_dtype=torch.bfloat16, + ) + + assert layer.qweight.shape == (4096, 4096 // 2) + assert layer.qweight.dtype == torch.int8 + assert layer.wscales.shape == (4096 // 64, 4096) + assert layer.wscales.dtype == torch.bfloat16 + assert layer.proj_down.shape == (4096, 32) + assert layer.proj_up.shape == (4096, 32) + assert layer.smooth_factor.shape == (4096,) + assert layer.wcscales is None + assert layer.wtscale is None + + +@pytest.mark.skipif( + not (current_platform.is_cuda() and has_nunchaku_w4a4()), + reason="requires CUDA + nunchaku for create_weights smoke", +) +def test_linear_method_create_weights_nvfp4_has_per_channel_scales() -> None: + cfg = SVDQuantConfig(rank=32, precision="nvfp4") + try: + assert_svdquant_supported("nvfp4") + except (RuntimeError, ValueError, ImportError) as exc: + pytest.skip(f"nvfp4 unsupported on this box: {exc}") + method = SVDQuantLinearMethod(cfg) + layer = torch.nn.Module() + method.create_weights( + layer, + input_size_per_partition=2048, + output_partition_sizes=[2048], + input_size=2048, + output_size=2048, + params_dtype=torch.bfloat16, + ) + assert layer.wscales.dtype == torch.float8_e4m3fn + assert layer.wcscales is not None + assert layer.wcscales.shape == (2048,) + assert layer.wtscale is not None + assert layer.wtscale.shape == (1,) + + +def test_get_quant_method_skips_listed_modules() -> None: + cfg = SVDQuantConfig(modules_to_not_convert=["embedder"]) + if not has_nunchaku_w4a4(): + # SVDQuantLinearMethod ctor would call assert_svdquant_supported() + # and raise; in that case we can only check the skip path. + pytest.skip("nunchaku not installed") + fake_layer = torch.nn.Linear(8, 8) + # Subclass to satisfy isinstance(layer, LinearBase). + fake_layer.__class__ = type( + "FakeLinear", (torch.nn.Linear, LinearBase), {} + ) + + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + + method = cfg.get_quant_method(fake_layer, "model.embedder.proj") + assert isinstance(method, UnquantizedLinearMethod) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 11489feb9d74..4208487a560d 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -34,6 +34,7 @@ "mxfp4", "gpt_oss_mxfp4", "deepseek_v4_fp8", + "svdquant", "cpu_awq", "online", # Below are online quant shorthand names (see vllm.config.quantization). @@ -139,6 +140,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .moe_wna16 import MoeWNA16Config from .mxfp4 import GptOssMxfp4Config, Mxfp4Config from .online.base import OnlineQuantizationConfig + from .svdquant import SVDQuantConfig from .torchao import TorchAOConfig method_to_config: dict[str, type[QuantizationConfig]] = { @@ -166,6 +168,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "mxfp4": Mxfp4Config, "gpt_oss_mxfp4": GptOssMxfp4Config, "deepseek_v4_fp8": DeepseekV4FP8Config, + "svdquant": SVDQuantConfig, "cpu_awq": CPUAWQConfig, "humming": HummingConfig, "online": OnlineQuantizationConfig, diff --git a/vllm/model_executor/layers/quantization/svdquant.py b/vllm/model_executor/layers/quantization/svdquant.py new file mode 100644 index 000000000000..66e3ba0f19eb --- /dev/null +++ b/vllm/model_executor/layers/quantization/svdquant.py @@ -0,0 +1,436 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SVDQuant W4A4 quantization with low-rank correction. + +SVDQuant (https://arxiv.org/abs/2411.05007) is a 4-bit weight, 4-bit +activation quantization scheme paired with a low-rank residual that +absorbs the quantization error. It is the dominant practical +quantization method for diffusion transformers, delivering >2x +speedup vs BF16 with minimal quality loss. + +The in-tree GEMM path uses the external `nunchaku` package, covering +consumer NVIDIA GPUs (Turing SM_75 through consumer Blackwell +SM_120). Hopper SM_90 is unsupported; datacenter Blackwell SM_100/103 +is out of scope here (planned via FlashInfer). + +Diffusion-specific weight key remapping (e.g. diffusers naming +conventions) is not handled here; downstream pipelines remap before +loading. Checkpoints are expected to already store gated-activation +halves in `[gate; hidden]` order — produce that ordering at +quantization time, not at runtime. +""" + +from typing import TYPE_CHECKING, Any + +import torch +from torch.nn import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped, +) +from vllm.model_executor.layers.quantization.utils.svdquant_dispatch import ( + SVDQuantPrecision, + assert_svdquant_supported, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.utils.nunchaku import ( + svdq_gemm_w4a4, + svdq_quantize_w4a4_act_fuse_lora, +) + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods + +logger = init_logger(__name__) + +# Group sizes are dictated by the kernel's scaled-MMA tile: +# * NVFP4 uses tcgen05's 16-element scale block. +# * INT4 uses Nunchaku's 64-element block. +_GROUP_SIZE_BY_PRECISION: dict[str, int] = {"int4": 64, "nvfp4": 16} + + +class SVDQuantConfig(QuantizationConfig): + """Configuration for SVDQuant W4A4 quantization. + + Parameters mirror what's on disk in a Nunchaku-produced checkpoint: + + Args: + rank: SVD low-rank correction dimension. Typical values are + 16, 32, or 64; the checkpoint dictates the value. + precision: 4-bit format, either "int4" or "nvfp4". + act_unsigned: Whether activations are quantized as unsigned + (saves the sign bit at a small accuracy cost). Per + checkpoint config. + modules_to_not_convert: Layer names (or substring patterns) + that should keep their unquantized weight, e.g. embedders + and adaLN-modulation projections in diffusion models. + """ + + def __init__( + self, + rank: int = 32, + precision: SVDQuantPrecision = "int4", + act_unsigned: bool = False, + modules_to_not_convert: list[str] | None = None, + ) -> None: + super().__init__() + if precision not in _GROUP_SIZE_BY_PRECISION: + raise ValueError( + f"SVDQuant precision must be one of " + f"{set(_GROUP_SIZE_BY_PRECISION)}; got {precision!r}" + ) + self.rank = rank + self.precision = precision + self.group_size = _GROUP_SIZE_BY_PRECISION[precision] + self.act_unsigned = act_unsigned + self.modules_to_not_convert = modules_to_not_convert or [] + + def __repr__(self) -> str: + return ( + f"SVDQuantConfig(rank={self.rank}, precision={self.precision!r}, " + f"act_unsigned={self.act_unsigned})" + ) + + @classmethod + def get_name(cls) -> "QuantizationMethods": + return "svdquant" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + # SM_75 (Turing) is the floor; the dispatcher rejects SM_90 and + # routes SM_100+ separately. + return 75 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["quantization_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "SVDQuantConfig": + return cls( + rank=config.get("rank", 32), + precision=config.get("precision", "int4"), + act_unsigned=config.get("act_unsigned", False), + modules_to_not_convert=config.get("modules_to_not_convert"), + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + if not isinstance(layer, LinearBase): + return None + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): + return UnquantizedLinearMethod() + return SVDQuantLinearMethod(self) + + +class SVDQuantLinearMethod(LinearMethodBase): + """Linear method for SVDQuant W4A4. + + The same parameter layout serves both the int4 and nvfp4 paths; + only the dtypes of `wscales` and the LoRA matrices differ. The + active platform is checked at `__init__` time and an unsupported + GPU raises here, before any weights are allocated. + """ + + _hardware_logged = False + + def __init__(self, quant_config: SVDQuantConfig) -> None: + self.quant_config = quant_config + assert_svdquant_supported(quant_config.precision) + if not SVDQuantLinearMethod._hardware_logged: + logger.info( + "Using nunchaku backend for SVDQuantLinearMethod (precision=%s)", + quant_config.precision, + ) + SVDQuantLinearMethod._hardware_logged = True + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del extra_weight_attrs # weight_loader is set explicitly per-param. + output_size_per_partition = sum(output_partition_sizes) + + config = self.quant_config + rank = config.rank + group_size = config.group_size + precision = config.precision + + # The LoRA matrices and the smooth factor must be in the same + # dtype as the kernel's accumulator. Nunchaku's nvfp4 path + # locks this to bf16 regardless of the model's params_dtype; + # the int4 path inherits params_dtype. + lora_dtype = torch.bfloat16 if precision == "nvfp4" else params_dtype + + wscales_dtype = ( + torch.float8_e4m3fn if precision == "nvfp4" else params_dtype + ) + + # qweight: 4-bit weights packed two-per-byte along the input + # axis. Shape (out_per_partition, in_per_partition // 2). + qweight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + _set_attrs( + qweight, + input_dim=1, + output_dim=0, + weight_loader=default_weight_loader, + ) + + # wscales: per-(group_size) input-column scale, + # shape (in_per_partition // group_size, out_per_partition). + wscales = Parameter( + torch.empty( + input_size_per_partition // group_size, + output_size_per_partition, + dtype=wscales_dtype, + ), + requires_grad=False, + ) + _set_attrs( + wscales, + input_dim=0, + output_dim=1, + weight_loader=default_weight_loader, + ) + + # SVD low-rank correction matrices. + proj_down = Parameter( + torch.empty(input_size_per_partition, rank, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs( + proj_down, + input_dim=0, + output_dim=1, + weight_loader=default_weight_loader, + ) + + proj_up = Parameter( + torch.empty(output_size_per_partition, rank, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs( + proj_up, + input_dim=1, + output_dim=0, + weight_loader=default_weight_loader, + ) + + # Smooth-quant factors. Live on the input axis: replicated for + # column-parallel layers, sharded for row-parallel. + smooth_factor = Parameter( + torch.empty(input_size_per_partition, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs( + smooth_factor, + input_dim=0, + weight_loader=default_weight_loader, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("wscales", wscales) + layer.register_parameter("proj_down", proj_down) + layer.register_parameter("proj_up", proj_up) + layer.register_parameter("smooth_factor", smooth_factor) + + if precision == "nvfp4": + # Per-output-channel BF16 scale; sharded with the output dim. + wcscales = Parameter( + torch.ones(output_size_per_partition, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs( + wcscales, + output_dim=0, + weight_loader=default_weight_loader, + ) + # Per-tensor global scale (shape (1,) on disk). + wtscale = Parameter( + torch.ones(1, dtype=lora_dtype), + requires_grad=False, + ) + _set_attrs(wtscale, weight_loader=default_weight_loader) + layer.register_parameter("wcscales", wcscales) + layer.register_parameter("wtscale", wtscale) + else: + # Keep the attributes present so apply() can branch + # uniformly without `hasattr` checks. + layer.wcscales = None + layer.wtscale = None + + # Stash for apply(). + layer.in_features = input_size + layer.out_features = output_size + layer.out_features_per_partition = output_size_per_partition + layer.precision = precision + layer.act_unsigned = config.act_unsigned + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Post-load weight prep. + + On-disk format is canonical row-major NVFP4; the nunchaku kernel + wants a PTX-MMA fragment-permuted layout, so for NVFP4 we repack + here once. Bit-preserving pack/unpack pair (see + `utils/svdquant_nvfp4_layout.py`); round-trip verified bit-exact. + + Also caches the kernel's `alpha` from the per-tensor `wtscale`. + Do NOT collapse `wcscales` into a scalar `alpha` — the kernel + applies them as `(accumulator * alpha) * wcscales`, and + conflating the two double-counts the per-channel factors. + + All parameters are produced by our quantization pipeline and + must be loaded by the time we get here; a meta tensor at this + point is a checkpoint bug, not a missing-shard case to paper + over. + """ + if layer.precision == "nvfp4": + self._pack_nvfp4_to_nunchaku_fragment(layer) + + alpha: float = 1.0 + wtscale = getattr(layer, "wtscale", None) + if wtscale is not None: + value = float(wtscale.detach().cpu().item()) + if abs(value - 1.0) > 1e-6: + alpha = value + layer._svdquant_alpha = alpha + + @staticmethod + def _pack_nvfp4_to_nunchaku_fragment(layer: torch.nn.Module) -> None: + """Repack row-major NVFP4 params in-place to nunchaku fragment layout. + + On-disk (canonical row-major): + * qweight : [N, K/2] int8/uint8 (FP4 nibbles, low = even-k) + * wscales : [K/16, N] fp8_e4m3fn + * proj_up : [N, R] + * proj_down : [K, R] + + After repack (nunchaku PTX-MMA fragment): + * qweight : [N, K/2] int8 (permuted into MMA fragment) + * wscales : [K/16, N] fp8 (permuted into MMA fragment) + * proj_up : [N, R] (permuted into MMA fragment) + * proj_down : [K, R] (permuted into MMA fragment) + """ + # Lazy imports: nunchaku is a soft dep on non-consumer hardware, + # and the layout helpers pull in torch ops we only need here. + from nunchaku.lora.flux.nunchaku_converter import pack_lowrank_weight + + from vllm.model_executor.layers.quantization.utils.svdquant_nvfp4_layout import ( # noqa: E501 + _unpack_nibbles, + pack_nunchaku_qweight_fp4, + pack_nunchaku_wscales_fp4, + ) + + device = layer.qweight.device + + # qweight: stored as [N, K/2] packed-nibble bytes (low = even-k). + # `pack_nunchaku_qweight_fp4` expects [N, K] one-nibble-per-byte — + # unpack to that form first, then pack to nunchaku fragment. + qw_rm_packed = layer.qweight.data.view(torch.uint8) # [N, K/2] + qw_rm_nibs = _unpack_nibbles(qw_rm_packed) # [N, K] + layer.qweight.data = pack_nunchaku_qweight_fp4(qw_rm_nibs).to(device) + + # wscales: pack pair operates in fp8_e4m3fn. + layer.wscales.data = pack_nunchaku_wscales_fp4(layer.wscales.data).to(device) + + # proj_up: row-major [N, R] → nunchaku frag [N, R]. down=False. + layer.proj_up.data = pack_lowrank_weight(layer.proj_up.data, down=False).to(device) + + # proj_down: canonical row-major is [K, R]; nunchaku's pack expects + # [R, K] (transpose-quirk on the down=True path). Transpose then pack; + # output is fragment [K, R]. + pd = layer.proj_down.data + pd_rk = pd.transpose(0, 1).contiguous() + layer.proj_down.data = pack_lowrank_weight(pd_rk, down=True).to(device) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + orig_shape = x.shape + x_2d = x.reshape(-1, orig_shape[-1]) + + is_fp4 = layer.precision == "nvfp4" + out_features = layer.out_features_per_partition + + quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora( + x_2d, + lora_down=layer.proj_down, + smooth=layer.smooth_factor, + fp4=is_fp4, + pad_size=256, + ) + + # The quantize kernel may pad the batch dim up to a multiple + # of `pad_size`; the GEMM consumes the padded shape, then we + # trim back to the real batch size below. + out_2d = torch.empty( + quantized_x.shape[0], + out_features, + dtype=layer.proj_up.dtype, + device=x_2d.device, + ) + + svdq_gemm_w4a4( + act=quantized_x, + wgt=layer.qweight, + out=out_2d, + ascales=ascales, + wscales=layer.wscales, + lora_act_in=lora_act_out, + lora_up=layer.proj_up, + bias=bias, + fp4=is_fp4, + alpha=getattr(layer, "_svdquant_alpha", 1.0), + wcscales=layer.wcscales, + act_unsigned=layer.act_unsigned, + ) + + actual_batch = x_2d.shape[0] + if out_2d.shape[0] > actual_batch: + out_2d = out_2d[:actual_batch] + + return out_2d.reshape(*orig_shape[:-1], out_features) + + +def _set_attrs(param: torch.nn.Parameter, **attrs: Any) -> None: + for key, value in attrs.items(): + setattr(param, key, value) + + +__all__ = ["SVDQuantConfig", "SVDQuantLinearMethod"] diff --git a/vllm/model_executor/layers/quantization/utils/svdquant_dispatch.py b/vllm/model_executor/layers/quantization/utils/svdquant_dispatch.py new file mode 100644 index 000000000000..362d2d883844 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/svdquant_dispatch.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Hardware gate for SVDQuant W4A4. + +The only in-tree backend is the `nunchaku` pip package, covering +consumer NVIDIA GPUs (Turing SM_75 through consumer Blackwell SM_120). +Hopper SM_90 is intentionally unsupported: the kernel families +nunchaku targets are PTX-MMA on older arches, and SM_90's tensor unit +shape has no validated SVDQuant kernel. + +Datacenter Blackwell SM_100/103 (B200/GB300) is out of scope here — +the planned datacenter path is to be hosted in FlashInfer so SGLang +and vLLM can share the same primitive. +""" + +from typing import Literal + +from vllm.platforms import current_platform +from vllm.utils.nunchaku import has_nunchaku_w4a4 + +SVDQuantPrecision = Literal["int4", "nvfp4"] + + +def assert_svdquant_supported(precision: SVDQuantPrecision) -> None: + """Raise if the active platform cannot run SVDQuant at this precision.""" + if not current_platform.is_cuda(): + raise RuntimeError( + f"SVDQuant has no available backend on platform " + f"{current_platform.device_name!r}. CUDA + nunchaku required." + ) + + cap = current_platform.get_device_capability() + sm = f"SM_{cap.to_int()}" if cap is not None else "" + + if current_platform.is_device_capability_family(90): + raise RuntimeError( + "SVDQuant W4A4 is not supported on Hopper (SM_90). Use a " + "consumer GPU (SM_75–SM_89, SM_120) with nunchaku, or wait " + "for the datacenter Blackwell (SM_100/103) path planned in " + "FlashInfer." + ) + + if current_platform.is_device_capability_family(100): + raise RuntimeError( + f"SVDQuant on {sm} (B200/GB300) is not supported in-tree; " + "the datacenter path is planned in FlashInfer." + ) + + if not current_platform.has_device_capability((7, 5)): + raise RuntimeError( + f"Unsupported CUDA compute capability for SVDQuant: {sm}" + ) + + # nvfp4 needs SM_100+ tensor units; pre-Blackwell consumer cards + # (Turing/Ampere/Ada) cannot run it. + if precision == "nvfp4" and not current_platform.has_device_capability(100): + raise ValueError( + f"NVFP4 SVDQuant requires SM_100+ or SM_120; got {sm}. " + f"Use precision='int4'." + ) + + if not has_nunchaku_w4a4(): + # The PyPI `nunchaku` is an unrelated Bayesian library; the + # SVDQuant kernels ship as GitHub release wheels only. + raise ImportError( + f"SVDQuant on {sm} requires nunchaku-ai's W4A4 wheels from " + "https://github.com/nunchaku-ai/nunchaku/releases " + "(not `pip install nunchaku`, which is a different project)." + ) + + +__all__ = ["SVDQuantPrecision", "assert_svdquant_supported"] diff --git a/vllm/model_executor/layers/quantization/utils/svdquant_nvfp4_layout.py b/vllm/model_executor/layers/quantization/utils/svdquant_nvfp4_layout.py new file mode 100644 index 000000000000..732393a6bd39 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/svdquant_nvfp4_layout.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""nunchaku NVFP4 SVDQuant fragment-layout adapters. + +Bridge between the canonical row-major SVDQuant NVFP4 on-disk format +and nunchaku's PTX-MMA-tile fragment layout. Bit-preserving pure +view+permute chain — no quant/dequant. + +Used in two directions: + +* Checkpoint conversion (vllm-omni converter): a nunchaku-published + checkpoint is unpacked to canonical row-major for writing to disk. +* Load-time pack (vLLM `SVDQuantLinearMethod.process_weights_after_loading`): + for the nunchaku kernel backend, repack the row-major on-disk tensors + into fragment layout before the kernel sees them. + +Verified against `svdq_gemm_w4a4_cuda(fp4=True)`: round-trip is +bit-exact, and half-swap via unpack→swap→pack reproduces the permuted +nunchaku output bit-exactly. Workbench source: +SVDQuant kernel `baseline/kernels/_nvfp4.py`. + +Pair semantics: + * `unpack_nunchaku_wscales_fp4(s_nun)` `[K/16, N] fragment → row-major` + * `pack_nunchaku_wscales_fp4(s_row)` `[K/16, N] row-major → fragment` + * `unpack_nunchaku_qweight_fp4(q_nun)` `[N, K/2] fragment → row-major uint8 nibble bytes` + * `pack_nunchaku_qweight_fp4(nibs_row)` `[N, K] nibbles → [N, K/2] fragment int8` + +These plus `nunchaku.lora.flux.nunchaku_converter.{pack,unpack}_lowrank_weight` +cover every fragment-layout param needed for SVDQuant W4A4 NVFP4 +half-swap (qweight, wscales, proj_up). + +Constants assume `NunchakuWeightPacker(bits=4, warp_n=128)`: + wscales: s_pack_size=4, num_s_lanes=32, num_s_packs=1, insn_k/group=4 + qweight: num_n_packs=8, n_pack_size=2, num_n_lanes=8, reg_n=1, + num_k_packs=1, k_pack_size=2, num_k_lanes=4, reg_k=8 +""" +from __future__ import annotations + +import torch + +_WARP_N = 128 +_INSN_K = 64 +_GROUP = 16 + + +def _pack_nibbles(nibs: torch.Tensor) -> torch.Tensor: + """`[*, K] uint8 nibbles → [*, K/2] uint8`. Low nibble = even k.""" + assert nibs.shape[-1] % 2 == 0 + lo = nibs[..., 0::2] + hi = nibs[..., 1::2] + return (lo | (hi << 4)).to(torch.uint8) + + +def _unpack_nibbles(packed: torch.Tensor) -> torch.Tensor: + """`[*, K/2] uint8 → [*, K] uint8 nibbles`. Inverse of `_pack_nibbles`.""" + lo = packed & 0x0F + hi = (packed >> 4) & 0x0F + out = torch.stack([lo, hi], dim=-1) + return out.view(*packed.shape[:-1], packed.shape[-1] * 2) + + +def _wscale_view_shape(N: int, K: int) -> tuple[int, ...]: + assert N % _WARP_N == 0, f"N ({N}) must be multiple of {_WARP_N}" + assert K % _INSN_K == 0, f"K ({K}) must be multiple of {_INSN_K}" + return (N // _WARP_N, 1, 4, 4, 8, K // _INSN_K, 4) + + +def pack_nunchaku_wscales_fp4(scales_row: torch.Tensor) -> torch.Tensor: + """Row-major `[K/16, N]` fp8 → nunchaku fragment `[K/16, N]` fp8.""" + KG, N = scales_row.shape + K = KG * _GROUP + s = scales_row.transpose(0, 1).contiguous() + s = s.view(*_wscale_view_shape(N, K)) + s = s.permute(0, 5, 1, 4, 3, 2, 6).contiguous() + return s.view(-1, N) + + +def unpack_nunchaku_wscales_fp4(scales_nun: torch.Tensor) -> torch.Tensor: + """nunchaku fragment `[K/16, N]` fp8 → row-major `[K/16, N]` fp8.""" + KG, N = scales_nun.shape + K = KG * _GROUP + s = scales_nun.view(N // _WARP_N, K // _INSN_K, 1, 8, 4, 4, 4) + # Inverse of permute (0, 5, 1, 4, 3, 2, 6) is (0, 2, 5, 4, 3, 1, 6). + s = s.permute(0, 2, 5, 4, 3, 1, 6).contiguous() + s = s.view(N, K // _GROUP) + return s.transpose(0, 1).contiguous() + + +def pack_nunchaku_qweight_fp4(nibs_row: torch.Tensor) -> torch.Tensor: + """`[N, K] uint8 nibbles → [N, K/2] nunchaku fragment int8`.""" + N, K = nibs_row.shape + assert N % _WARP_N == 0, f"N ({N}) must be multiple of {_WARP_N}" + assert K % _INSN_K == 0, f"K ({K}) must be multiple of {_INSN_K}" + n_tiles, k_tiles = N // _WARP_N, K // _INSN_K + w = nibs_row.to(torch.int32) + w = w.reshape(n_tiles, 8, 2, 8, 1, k_tiles, 1, 2, 4, 8) + w = w.permute(0, 5, 6, 1, 3, 8, 2, 7, 4, 9).contiguous() + w = w & 0xF + shift = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device) + w = (w << shift).sum(dim=-1, dtype=torch.int32) + return w.view(dtype=torch.int8).view(N, -1).contiguous() + + +def unpack_nunchaku_qweight_fp4(q_nun: torch.Tensor) -> torch.Tensor: + """`[N, K/2] nunchaku fragment int8 → [N, K/2] uint8` (low nibble = even k).""" + N, K2 = q_nun.shape + K = K2 * 2 + assert N % _WARP_N == 0 + assert K % _INSN_K == 0 + n_tiles, k_tiles = N // _WARP_N, K // _INSN_K + q_int = q_nun.contiguous().view(dtype=torch.int32) + q_int = q_int.reshape(n_tiles, k_tiles, 1, 8, 8, 4, 2, 2, 1) + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=q_int.device) + nibs = ((q_int.unsqueeze(-1) >> shifts) & 0xF).to(torch.uint8) + # Inverse of permute (0, 5, 6, 1, 3, 8, 2, 7, 4, 9) is (0, 3, 6, 4, 8, 1, 2, 7, 5, 9). + nibs = nibs.permute(0, 3, 6, 4, 8, 1, 2, 7, 5, 9).contiguous() + nibs = nibs.view(N, K) + return _pack_nibbles(nibs) + + +__all__ = [ + "pack_nunchaku_qweight_fp4", + "unpack_nunchaku_qweight_fp4", + "pack_nunchaku_wscales_fp4", + "unpack_nunchaku_wscales_fp4", +] diff --git a/vllm/utils/nunchaku.py b/vllm/utils/nunchaku.py new file mode 100644 index 000000000000..c0fa040b4808 --- /dev/null +++ b/vllm/utils/nunchaku.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for the optional `nunchaku` dependency. + +`nunchaku` ships SVDQuant W4A4 / W4A16 CUDA kernels for diffusion +transformers on consumer NVIDIA GPUs (Turing through consumer Blackwell). +This module collects the lazy availability checks and lazy-imported call +wrappers so the rest of vLLM never imports `nunchaku` at module load +time. + +Mirrors the structure of `vllm/utils/flashinfer.py` — `has_*` for +capability detection, `_lazy_import_wrapper` for the call boundary. +""" + +import functools +import importlib +import importlib.util +from collections.abc import Callable +from typing import Any, NoReturn + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@functools.cache +def has_nunchaku() -> bool: + """Return True if the `nunchaku` package is importable.""" + if importlib.util.find_spec("nunchaku") is None: + logger.debug_once("Nunchaku unavailable: package not installed") + return False + return True + + +def _get_submodule(module_name: str) -> Any | None: + """Safely import a submodule, or return None if unavailable.""" + try: + return importlib.import_module(module_name) + except (ImportError, ModuleNotFoundError): + return None + + +@functools.cache +def has_nunchaku_w4a4() -> bool: + """Return True if Nunchaku's W4A4 GEMM + activation-quantize ops exist. + + Both ops are required for SVDQuant: the activation-side fused + quantize+LoRA preprocessing and the main W4A4 scaled GEMM. + """ + if not has_nunchaku(): + return False + required = [ + ("nunchaku.ops.gemm", "svdq_gemm_w4a4_cuda"), + ("nunchaku.ops.quantize", "svdq_quantize_w4a4_act_fuse_lora_cuda"), + ] + for module_name, attr_name in required: + mod = _get_submodule(module_name) + if mod is None or not hasattr(mod, attr_name): + logger.debug_once( + "Nunchaku W4A4 unavailable: missing %s.%s", module_name, attr_name + ) + return False + return True + + +@functools.cache +def has_nunchaku_w4a16() -> bool: + """Return True if Nunchaku's W4A16 AWQ GEMV op exists. + + Used for batch-1 / decode-style paths where activations stay in + fp16/bf16 and only the weight is 4-bit. + """ + if not has_nunchaku(): + return False + mod = _get_submodule("nunchaku.ops.gemv") + return mod is not None and hasattr(mod, "awq_gemv_w4a16_cuda") + + +def _missing(*_: Any, **__: Any) -> NoReturn: + # The PyPI `nunchaku` package is an unrelated Bayesian library; the + # SVDQuant kernels are published only on the nunchaku-ai GitHub + # releases page. + raise RuntimeError( + "Nunchaku is not installed. SVDQuant requires the nunchaku-ai " + "wheels from https://github.com/nunchaku-ai/nunchaku/releases " + "(do NOT `pip install nunchaku` — that pulls an unrelated PyPI " + "package). Source: https://github.com/nunchaku-ai/nunchaku" + ) + + +def _lazy_import_wrapper( + module_name: str, attr_name: str, fallback_fn: Callable[..., Any] = _missing +): + """Build a lazy wrapper around a single nunchaku function. + + The first call resolves the underlying op via `importlib`; subsequent + calls hit the cached resolution. The wrapper raises a clear error + if the op was never resolved. + """ + + @functools.cache + def _get_impl(): + if not has_nunchaku(): + return None + mod = _get_submodule(module_name) + return getattr(mod, attr_name, None) if mod else None + + def wrapper(*args, **kwargs): + impl = _get_impl() + if impl is None: + return fallback_fn(*args, **kwargs) + return impl(*args, **kwargs) + + wrapper.__name__ = attr_name + wrapper.__qualname__ = f"nunchaku::{attr_name}" + return wrapper + + +# Public lazy-call surface. Each wrapper has the same signature as the +# underlying nunchaku op (we don't re-document the signature here; see +# the upstream nunchaku source). +svdq_gemm_w4a4 = _lazy_import_wrapper("nunchaku.ops.gemm", "svdq_gemm_w4a4_cuda") +svdq_quantize_w4a4_act_fuse_lora = _lazy_import_wrapper( + "nunchaku.ops.quantize", "svdq_quantize_w4a4_act_fuse_lora_cuda" +) +awq_gemv_w4a16 = _lazy_import_wrapper("nunchaku.ops.gemv", "awq_gemv_w4a16_cuda") + + +__all__ = [ + "has_nunchaku", + "has_nunchaku_w4a4", + "has_nunchaku_w4a16", + "svdq_gemm_w4a4", + "svdq_quantize_w4a4_act_fuse_lora", + "awq_gemv_w4a16", +]