From ee38784c7c306ddc0ef0fa823e82b5af0f903e32 Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Tue, 5 May 2026 19:07:28 +0000 Subject: [PATCH 1/5] [ModelOpt] Add NVFP4 W4A16 (4-bit weights, fp16/bf16 acts) support Adds first-class loading for ModelOpt-exported NVFP4_W4A16 checkpoints (`quant_algo: "NVFP4_W4A16"`). Today vLLM can only consume such ckpts after rewriting them into the compressed-tensors format on disk; this change lets the ModelOpt loader feed the FP4 Marlin GEMM directly, without an on-disk conversion. Plumbing (no new config class): - `QUANT_ALGOS`: register `"NVFP4_W4A16"`. Existing `ModelOptNvFp4Config.override_quantization_method` substring check (`"NVFP4" in algo or "FP4" in algo`) already routes it to the same config class as `"NVFP4"` -- mirrors the established FP8 pattern in this file where one ModelOptFp8Config dispatches to three FP8 LinearMethods based on the algo string. - `ModelOptNvFp4Config.__init__` now takes `quant_method` and selects `self.LinearMethodCls` per algo: NVFP4 -> ModelOptNvFp4LinearMethod (existing W4A4) NVFP4_W4A16 -> ModelOptNvFp4W4A16LinearMethod (new) - `_from_config` threads `quant_method` to the constructor. New class `ModelOptNvFp4W4A16LinearMethod`: - Loads ModelOpt-style names directly (no on-disk renames): weight uint8 packed NVFP4 weight_scale fp8-e4m3 per 16-elem group along input dim weight_scale_2 fp32 per-tensor global = amax / (6.0 * 448.0) - process_weights_after_loading: rename weight_scale_2 -> weight_global_scale **without reciprocation**. ModelOpt already stores amax/2688 which is the form Marlin's nvfp4_marlin_process_global_scale consumes; the CT W4A16 path reciprocates only because CT stores 1/x on disk. Then call prepare_fp4_layer_for_marlin(layer). - apply: dispatches to apply_fp4_marlin_linear -- same call as CompressedTensorsW4A16Fp4. linear.py: add "ModelOptNvFp4W4A16LinearMethod" to WEIGHT_LOADER_V2_SUPPORTED so the linear layer uses weight_loader_v2 for our params (especially needed for PerTensorScaleParameter on fused QKV/gate-up; without v2 the legacy loader hits a shape assert). Validation (case 1, controlled equivalence): - Native W4A16 load: ModelOpt qwen3-8b W4A16 ckpt via this method -> 6.01 GiB / 2.27 s, FLASHINFER attention, fp8 KV cache, ~57 tok/s decode on enforce_eager. Outputs coherent. - CT-converted W4A16 load: same source ckpt, run through the conversion script, loaded via CompressedTensorsW4A16Fp4. Same attention backend (FLASHINFER), same KV cache dtype (fp8), same KV cache slot count (1,051,632), token-for-token identical greedy completions. Bit-identical layer state via different code routes -> same FP4 Marlin kernel call -> same output. Two-axes apples-to-apples. Validation matrix tracked in juhim/w4a16_modelopt_vllm/logs_and_results/log.md. AI-assisted: prepared with Claude (Anthropic). Human review and on-machine validation by juhim before any PR. Co-authored-by: Claude Signed-off-by: Juhi Mittal --- vllm/model_executor/layers/linear.py | 1 + .../layers/quantization/modelopt.py | 154 +++++++++++++++++- 2 files changed, 152 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 06da4261345d..9d08ab19a474 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -60,6 +60,7 @@ "ModelOptFp8PbWoLinearMethod", "QuarkLinearMethod", "ModelOptNvFp4LinearMethod", + "ModelOptNvFp4W4A16LinearMethod", "HummingLinearMethod", ] diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0862efbea294..d1e8e29abeff 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -67,6 +67,10 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, + prepare_fp4_layer_for_marlin, +) from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( MXFP8_BLOCK_SIZE, MXFP8_SCALE_DTYPE, @@ -89,6 +93,7 @@ from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ChannelQuantScaleParameter, + GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) @@ -107,8 +112,10 @@ "FP8_PER_CHANNEL_PER_TOKEN", # FP8 per-block weight-only (ModelOpt may emit this as lowercase). "FP8_PB_WO", - # FP4 + # NVFP4 W4A4 (4-bit float weights AND 4-bit float activations). "NVFP4", + # NVFP4 W4A16 (4-bit float weights, fp16/bf16 activations). + "NVFP4_W4A16", # MXFP8 "MXFP8", # MIXED_PRECISION, @@ -1003,22 +1010,39 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase): def __init__( self, + quant_method: str, is_checkpoint_nvfp4_serialized: bool, kv_cache_quant_algo: str | None, exclude_modules: list[str], group_size: int = 16, ) -> None: super().__init__(exclude_modules) + self.quant_method = quant_method self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( - "Detected ModelOpt NVFP4 checkpoint. Please note that" - " the format is experimental and could change in future." + "Detected ModelOpt NVFP4 checkpoint (quant_algo=%s). Please " + "note that the format is experimental and could change in " + "future.", + quant_method, ) self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo + # Select LinearMethod implementation based on quant_algo (FP8 pattern). + # NVFP4 -> W4A4: cutlass NVFP4 GEMM with input quantization + # NVFP4_W4A16 -> W4A16: FP4 Marlin GEMM with bf16/fp16 activations + if quant_method == "NVFP4": + self.LinearMethodCls = ModelOptNvFp4LinearMethod + elif quant_method == "NVFP4_W4A16": + self.LinearMethodCls = ModelOptNvFp4W4A16LinearMethod + else: + raise ValueError( + f"Unsupported ModelOpt NVFP4 quant_algo: {quant_method}. " + "Supported: NVFP4 / NVFP4_W4A16." + ) + def get_name(self) -> QuantizationMethods: return "modelopt_fp4" @@ -1069,6 +1093,7 @@ def _from_config( ) return cls( + quant_method, is_checkpoint_nvfp4_serialized, kv_cache_quant_method, exclude_modules, @@ -1208,6 +1233,129 @@ def apply( return self.kernel.apply_weights(layer=layer, x=x, bias=bias) +class ModelOptNvFp4W4A16LinearMethod(LinearMethodBase): + """Linear method for ModelOpt NVFP4 W4A16. + + 4-bit NVFP4 weights, fp16/bf16 activations. Loads ModelOpt-style names + directly (no on-disk conversion) and dispatches to the FP4 Marlin GEMM: + + weight uint8 packed NVFP4 (2 nibbles/byte along input dim) + weight_scale fp8-e4m3 per 16-elem group along input dim + weight_scale_2 fp32 per-tensor global scale = amax / (6.0 * 448.0) + + No activation quantization; no input_scale on disk. Marlin expects the + global scale in the same form ModelOpt stores (amax/2688), so we rename + weight_scale_2 -> weight_global_scale **without reciprocation** -- the CT + W4A16 path reciprocates only because CT stores the inverse on disk. + """ + + backend = "marlin" + + def __init__(self, quant_config: ModelOptNvFp4Config) -> None: + self.quant_config = quant_config + # Set externally by ModelOptNvFp4Config.get_quant_method when backend + # is "marlin"; left as None if the dispatch path hasn't run. + self.marlin_input_dtype = None + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError( + "NVFP4_W4A16 quantization was selected; " + "dynamic quantization is not supported." + ) + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if input_size_per_partition % 16 != 0: + raise ValueError( + "Unsupported model: input feature size is not a multiple of 16." + ) + + # Packed NVFP4 weights: uint8, 2 nibbles per byte along the input dim. + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # Per-tensor global weight scale (fp32). ModelOpt stores + # amax / (NVFP4_max * fp8_e4m3_max) = amax / 2688. PerTensorScaleParameter + # holds one entry per fused output partition (e.g. q/k/v in a fused QKV). + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + # Per-group fp8 weight scale. + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if torch.unique(layer.weight_scale_2).numel() != 1: + logger.warning_once( + "In NVFP4_W4A16 linear, the global weight scale " + "(weight_scale_2) differs across fused parallel layers " + "(e.g. q/k/v_proj). This will likely reduce accuracy. " + "Consider a checkpoint with a shared global scale." + ) + + # Rename weight_scale_2 -> weight_global_scale. NO reciprocation: + # ModelOpt already stores amax/2688, which is exactly what Marlin + # consumes via prepare_fp4_layer_for_marlin / nvfp4_marlin_process_global_scale. + layer.weight_global_scale = Parameter( + layer.weight_scale_2.max().to(torch.float32), requires_grad=False + ) + del layer.weight_scale_2 + + prepare_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return apply_fp4_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_global_scale=layer.weight_global_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. From 30c4713099e6fb3b1341297a9a155a96249ab967 Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Tue, 5 May 2026 19:21:00 +0000 Subject: [PATCH 2/5] [ModelOpt] W4A16: route through MarlinNvFp4LinearKernel adapter Refactor only -- same kernel calls, same byte-identical output. Aligns ModelOptNvFp4W4A16LinearMethod's code shape with the existing W4A4 sibling (ModelOptNvFp4LinearMethod) by going through a kernel-adapter abstraction instead of calling Marlin functions directly: __init__: self.kernel = MarlinNvFp4LinearKernel(NvFp4LinearLayerConfig()) process_weights_after_loading: self.kernel.process_weights_after_loading(layer) apply: self.kernel.apply_weights(layer=layer, x=x, bias=bias) We deliberately direct-instantiate MarlinNvFp4LinearKernel rather than go through init_nvfp4_linear_kernel(): the shared selector's first-pick on this hardware is a cutlass W4A4 kernel that quantizes activations, which would silently break our W4A16 path (no input_scale registered). For W4A16 there is exactly one valid kernel, so we pin it. Also drops the dead `backend = "marlin"` class attribute. The framework gate at ModelOptQuantConfigBase.get_quant_method only sets marlin_input_dtype when backend == "marlin"; we no longer need that because our adapter calls prepare_fp4_layer_for_marlin without an input_dtype, and that argument only affects an is_a_8bit branch that NVFP4 W4A16 never enters (bf16/fp16 acts always have itemsize > 1, and fp8/int8 acts on NVFP4 weights are explicitly rejected). The vestigial `self.marlin_input_dtype = None` slot is kept to mirror the W4A4 method's __init__ shape. Imports: drop now-unused apply_fp4_marlin_linear and prepare_fp4_layer_for_marlin from modelopt.py; add MarlinNvFp4LinearKernel and NvFp4LinearLayerConfig. Re-validated case 1 on qwen3-8b W4A16 ckpt: - FLASHINFER attention backend, candidate set [FLASHINFER, TRITON_ATTN] - fp8_e4m3 KV cache, 1,051,632 slot count - model load 6.01 GiB / 2.18 s - token-for-token identical greedy completions to the pre-refactor run Co-authored-by: Claude Signed-off-by: Juhi Mittal --- .../layers/quantization/modelopt.py | 36 +++++++++---------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index d1e8e29abeff..a8841e59acb3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -11,6 +11,8 @@ from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( + MarlinNvFp4LinearKernel, + NvFp4LinearLayerConfig, init_fp8_linear_kernel, init_mxfp8_linear_kernel, init_nvfp4_linear_kernel, @@ -67,10 +69,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - apply_fp4_marlin_linear, - prepare_fp4_layer_for_marlin, -) from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( MXFP8_BLOCK_SIZE, MXFP8_SCALE_DTYPE, @@ -1249,13 +1247,19 @@ class ModelOptNvFp4W4A16LinearMethod(LinearMethodBase): W4A16 path reciprocates only because CT stores the inverse on disk. """ - backend = "marlin" - def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config - # Set externally by ModelOptNvFp4Config.get_quant_method when backend - # is "marlin"; left as None if the dispatch path hasn't run. + # Vestigial slot mirrored from ModelOptNvFp4LinearMethod: the parent + # config's get_quant_method only fills marlin_input_dtype when + # backend == "marlin"; we don't set that since we pin the kernel + # below, but we keep the attribute for shape parity. self.marlin_input_dtype = None + # Direct-instantiate the Marlin NVFP4 adapter rather than going through + # init_nvfp4_linear_kernel(): the latter's priority list returns a + # cutlass W4A4 kernel as first-pick on this hardware, which would + # silently try to quantize activations (we have no input_scale). For + # W4A16 there is exactly one valid kernel, so we pin it. + self.kernel = MarlinNvFp4LinearKernel(NvFp4LinearLayerConfig()) def create_weights( self, @@ -1330,13 +1334,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Rename weight_scale_2 -> weight_global_scale. NO reciprocation: # ModelOpt already stores amax/2688, which is exactly what Marlin - # consumes via prepare_fp4_layer_for_marlin / nvfp4_marlin_process_global_scale. + # consumes via nvfp4_marlin_process_global_scale (called inside the + # Marlin adapter's process_weights_after_loading). layer.weight_global_scale = Parameter( layer.weight_scale_2.max().to(torch.float32), requires_grad=False ) del layer.weight_scale_2 - prepare_fp4_layer_for_marlin(layer, input_dtype=self.marlin_input_dtype) + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -1344,16 +1349,7 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return apply_fp4_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - weight_global_scale=layer.weight_global_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) + return self.kernel.apply_weights(layer=layer, x=x, bias=bias) class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): From 15d161925927a72749b944b9e5171039ae3fd730 Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Tue, 5 May 2026 21:19:13 +0000 Subject: [PATCH 3/5] [ModelOpt] W4A16: tolerate input_scale tensors from W4A4 checkpoints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The W4A16 method previously didn't register input_scale, on the assumption that vLLM's loader would silently skip on-disk *_proj.input_scale keys when the W4A4-shaped variant of a NVFP4 checkpoint was loaded under this method. That's only true in the qwen2 loader's "else" branch (post-stacked, has an explicit `if name not in params_dict: continue` guard). The "stacked" branch -- which handles q_proj/k_proj/v_proj/gate_proj/up_proj shards -- has no such guard, and unconditionally does `params_dict[name]` after renaming e.g. `q_proj.input_scale` to `qkv_proj.input_scale`. Without an `qkv_proj.input_scale` parameter registered, that lookup KeyErrors and engine init fails. This trips the moment a user tries to load a NVFP4 (W4A4) checkpoint under the W4A16 method (the eventual phase-2 use case for the --quantization=modelopt_fp4_w4a16 override, and the immediate phase-1 case-3 validation test). Fix: register a placeholder PerTensorScaleParameter named input_scale in create_weights so the loader can place per-shard input_scale tensors here without KeyError on the merged-name lookup. We discard it in process_weights_after_loading -- W4A16 mode does not quantize activations, so the value is never used. For native W4A16 checkpoints (no input_scale on disk) the placeholder stays uninitialized and is simply deleted; harmless. Validated end-to-end on qwen3-8b: - Case 1 (native W4A16): unchanged, ~102 tok/s, FLASHINFER, fp8 KV. - Case 2 (W4A4 regression): unchanged, ~42 tok/s, existing ModelOptNvFp4LinearMethod path. - Case 3 (W4A4 ckpt with quant_algo file-edited to NVFP4_W4A16, loaded via this method): now succeeds; outputs token-identical to case 1; logits **bit-identical** to case 1 (max |Δlogprob| = 0 across 47 captured positions, top-20 ranks 100% match), confirming same weight bits -> same Marlin kernel -> same computation. Side effect: makes the W4A16 method intrinsically robust to either NVFP4 checkpoint shape (W4A4 or W4A16). The eventual ModelOptNvFp4W4A16Config phase-2 override is then pure routing -- the underlying method already handles both shapes correctly. Co-authored-by: Claude Signed-off-by: Juhi Mittal --- .../layers/quantization/modelopt.py | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index a8841e59acb3..7456e12c8d39 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1241,10 +1241,16 @@ class ModelOptNvFp4W4A16LinearMethod(LinearMethodBase): weight_scale fp8-e4m3 per 16-elem group along input dim weight_scale_2 fp32 per-tensor global scale = amax / (6.0 * 448.0) - No activation quantization; no input_scale on disk. Marlin expects the - global scale in the same form ModelOpt stores (amax/2688), so we rename - weight_scale_2 -> weight_global_scale **without reciprocation** -- the CT - W4A16 path reciprocates only because CT stores the inverse on disk. + No activation quantization. Marlin expects the global scale in the same + form ModelOpt stores (amax/2688), so we rename weight_scale_2 -> + weight_global_scale **without reciprocation** -- the CT W4A16 path + reciprocates only because CT stores the inverse on disk. + + We also register a placeholder input_scale parameter so that W4A4-shaped + checkpoints (which contain *_proj.input_scale tensors) can be loaded + under this method without the per-shard loader hitting a KeyError on + the merged-name lookup. The placeholder is discarded in + process_weights_after_loading -- its value is never used. """ def __init__(self, quant_config: ModelOptNvFp4Config) -> None: @@ -1323,7 +1329,26 @@ def create_weights( ) layer.register_parameter("weight_scale", weight_scale) + # Placeholder input_scale param so W4A4-shaped checkpoints can be + # loaded under this method without KeyError on the merged-name + # lookup (qwen2-style stacked-loader path renames *_proj.input_scale + # to e.g. qkv_proj.input_scale and looks it up unconditionally). + # Discarded in process_weights_after_loading; never read by the kernel. + # For native W4A16 checkpoints (no input_scale on disk) the param + # stays uninitialized and is simply deleted. + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_scale", input_scale) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Discard the input_scale placeholder. Whether it carries values + # (W4A4 ckpt loaded as W4A16) or is uninitialized (native W4A16 + # ckpt), W4A16 mode does not quantize activations, so this is unused. + if hasattr(layer, "input_scale"): + del layer.input_scale + if torch.unique(layer.weight_scale_2).numel() != 1: logger.warning_once( "In NVFP4_W4A16 linear, the global weight scale " From 115a4f935cfc385748f36769babe0e4fabc47261 Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Tue, 5 May 2026 23:24:14 +0000 Subject: [PATCH 4/5] [ModelOpt] Rename quant_algo NVFP4_W4A16 -> W4A16_NVFP4 ModelOpt PR #1313 (commit 0fede961 on origin/hungyueh/modelopt-nvfp4-w4a16, "nvfp4_w4a16 -> w4a16_nvfp4") renamed the qformat / on-disk quant_algo string. Six string-literal edits in vllm/.../modelopt.py to match: - QUANT_ALGOS entry. - Dispatch in ModelOptNvFp4Config.__init__. - Error message in __init__. - 3 docstring / log-warning labels. No registry change: override_quantization_method's substring check ("NVFP4" in algo or "FP4" in algo) still matches "W4A16_NVFP4" because it contains "NVFP4". The LinearMethod class name ModelOptNvFp4W4A16LinearMethod is kept as-is -- it describes the concept ("NVFP4 weights, W4A16 mode"), not the on-disk algo string. Renaming the class would touch WEIGHT_LOADER_V2_SUPPORTED in linear.py and add review surface that isn't earned by the on-disk rename. Smoke-tested both qwen3-8b and Nemotron-Nano-4B after the rename; both still load + generate cleanly. The on-disk ckpt configs were patched in place (only the quant_algo JSON field, no safetensors regen) -- documented in the gitlab notes repo's log.md. Co-authored-by: Claude Signed-off-by: Juhi Mittal --- .../model_executor/layers/quantization/modelopt.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 7456e12c8d39..05d3c90da789 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -112,8 +112,8 @@ "FP8_PB_WO", # NVFP4 W4A4 (4-bit float weights AND 4-bit float activations). "NVFP4", - # NVFP4 W4A16 (4-bit float weights, fp16/bf16 activations). - "NVFP4_W4A16", + # W4A16 NVFP4 (4-bit float weights, fp16/bf16 activations). + "W4A16_NVFP4", # MXFP8 "MXFP8", # MIXED_PRECISION, @@ -1030,15 +1030,15 @@ def __init__( # Select LinearMethod implementation based on quant_algo (FP8 pattern). # NVFP4 -> W4A4: cutlass NVFP4 GEMM with input quantization - # NVFP4_W4A16 -> W4A16: FP4 Marlin GEMM with bf16/fp16 activations + # W4A16_NVFP4 -> W4A16: FP4 Marlin GEMM with bf16/fp16 activations if quant_method == "NVFP4": self.LinearMethodCls = ModelOptNvFp4LinearMethod - elif quant_method == "NVFP4_W4A16": + elif quant_method == "W4A16_NVFP4": self.LinearMethodCls = ModelOptNvFp4W4A16LinearMethod else: raise ValueError( f"Unsupported ModelOpt NVFP4 quant_algo: {quant_method}. " - "Supported: NVFP4 / NVFP4_W4A16." + "Supported: NVFP4 / W4A16_NVFP4." ) def get_name(self) -> QuantizationMethods: @@ -1280,7 +1280,7 @@ def create_weights( del input_size, output_size if not self.quant_config.is_checkpoint_nvfp4_serialized: raise ValueError( - "NVFP4_W4A16 quantization was selected; " + "W4A16_NVFP4 quantization was selected; " "dynamic quantization is not supported." ) output_size_per_partition = sum(output_partition_sizes) @@ -1351,7 +1351,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if torch.unique(layer.weight_scale_2).numel() != 1: logger.warning_once( - "In NVFP4_W4A16 linear, the global weight scale " + "In W4A16_NVFP4 linear, the global weight scale " "(weight_scale_2) differs across fused parallel layers " "(e.g. q/k/v_proj). This will likely reduce accuracy. " "Consider a checkpoint with a shared global scale." From 98e34aeba47dde0ef171756cc92a5769d7202198 Mon Sep 17 00:00:00 2001 From: Juhi Mittal Date: Wed, 6 May 2026 00:01:02 +0000 Subject: [PATCH 5/5] [ModelOpt] Default ModelOptNvFp4Config args + dispatch unit tests Two related changes: 1. Make ModelOptNvFp4Config.__init__ args defaultable so existing tests / callers that construct the config without passing quant_method (and friends) keep working unchanged. Previously adding `quant_method` as a required positional arg silently broke three test sites under tests/{compile,distributed,kernels}/... that build the config directly to exercise downstream code (eplb, MLA fusion, MoE layer). Defaults match the W4A4 path, which is what those tests were exercising: quant_method: str = "NVFP4" is_checkpoint_nvfp4_serialized: bool = False kv_cache_quant_algo: str | None = None exclude_modules: list[str] | None = None # treated as [] _from_config still passes all five explicitly when loading a real checkpoint, so the defaults only affect direct constructor users. 2. Add two unit tests under tests/quantization/test_modelopt.py that exercise the per-algo LinearMethodCls dispatch in ModelOptNvFp4Config without needing a checkpoint: - test_modelopt_nvfp4_config_dispatches_w4a4_method quant_method="NVFP4" -> ModelOptNvFp4LinearMethod - test_modelopt_nvfp4_config_dispatches_w4a16_method quant_method="W4A16_NVFP4" -> ModelOptNvFp4W4A16LinearMethod The W4A16 test asserts both `is` (positive) and `is not` against the W4A4 sibling so a regression that silently routes a W4A16 checkpoint under the W4A4 method (and then calls the cutlass W4A4 NVFP4 GEMM instead of FP4 Marlin, with no input_scale) would fail loudly. Test result locally: 2 passed, ~2.2 s. Co-authored-by: Claude Signed-off-by: Juhi Mittal --- tests/quantization/test_modelopt.py | 46 +++++++++++++++++++ .../layers/quantization/modelopt.py | 10 ++-- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py index 120b2cde0f35..593075e9d491 100644 --- a/tests/quantization/test_modelopt.py +++ b/tests/quantization/test_modelopt.py @@ -239,3 +239,49 @@ def check_model(model): output = llm.generate_greedy(["Hello my name is"], max_tokens=4) assert output print(f"ModelOpt FP8_PB_WO output: {output}") + + +def test_modelopt_nvfp4_config_dispatches_w4a4_method(): + """``quant_method="NVFP4"`` (W4A4 default) routes to the existing + ``ModelOptNvFp4LinearMethod``.""" + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config, + ModelOptNvFp4LinearMethod, + ) + + config = ModelOptNvFp4Config( + quant_method="NVFP4", + is_checkpoint_nvfp4_serialized=True, + kv_cache_quant_algo=None, + exclude_modules=[], + ) + assert config.LinearMethodCls is ModelOptNvFp4LinearMethod + assert config.quant_method == "NVFP4" + + +def test_modelopt_nvfp4_config_dispatches_w4a16_method(): + """``quant_method="W4A16_NVFP4"`` routes to the new + ``ModelOptNvFp4W4A16LinearMethod`` instead of the W4A4 sibling. + + Mirrors the FP8 dispatch precedent (``ModelOptFp8Config`` selects + one of three FP8 LinearMethods on ``quant_method``); a regression + here would mean a W4A16 NVFP4 checkpoint silently loaded under the + W4A4 method, which would try to register an ``input_scale`` runtime + parameter and (more importantly) call the cutlass W4A4 NVFP4 GEMM + instead of FP4 Marlin. + """ + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config, + ModelOptNvFp4LinearMethod, + ModelOptNvFp4W4A16LinearMethod, + ) + + config = ModelOptNvFp4Config( + quant_method="W4A16_NVFP4", + is_checkpoint_nvfp4_serialized=True, + kv_cache_quant_algo=None, + exclude_modules=[], + ) + assert config.LinearMethodCls is ModelOptNvFp4W4A16LinearMethod + assert config.LinearMethodCls is not ModelOptNvFp4LinearMethod + assert config.quant_method == "W4A16_NVFP4" diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 05d3c90da789..72178a6e6dd3 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1008,12 +1008,14 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase): def __init__( self, - quant_method: str, - is_checkpoint_nvfp4_serialized: bool, - kv_cache_quant_algo: str | None, - exclude_modules: list[str], + quant_method: str = "NVFP4", + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: str | None = None, + exclude_modules: list[str] | None = None, group_size: int = 16, ) -> None: + if exclude_modules is None: + exclude_modules = [] super().__init__(exclude_modules) self.quant_method = quant_method self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized