Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
ee38784
[ModelOpt] Add NVFP4 W4A16 (4-bit weights, fp16/bf16 acts) support
juhi10071998 May 5, 2026
30c4713
[ModelOpt] W4A16: route through MarlinNvFp4LinearKernel adapter
juhi10071998 May 5, 2026
15d1619
[ModelOpt] W4A16: tolerate input_scale tensors from W4A4 checkpoints
juhi10071998 May 5, 2026
115a4f9
[ModelOpt] Rename quant_algo NVFP4_W4A16 -> W4A16_NVFP4
juhi10071998 May 5, 2026
98e34ae
[ModelOpt] Default ModelOptNvFp4Config args + dispatch unit tests
juhi10071998 May 6, 2026
6c8124d
Merge branch 'main' into w4a16_modelopt_support
pavanimajety May 6, 2026
ef68525
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 6, 2026
2f39dcb
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 7, 2026
19154e0
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 7, 2026
45410cb
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 7, 2026
2183889
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 7, 2026
9b13362
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 8, 2026
d5b9f88
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 8, 2026
c0ece4b
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 8, 2026
b272da6
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 8, 2026
3d75d6b
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 8, 2026
4ee5cfe
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 9, 2026
3abe3a6
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 9, 2026
aa530b2
Merge branch 'main' into w4a16_modelopt_support
juhi10071998 May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions tests/quantization/test_modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,49 @@ def check_model(model):
output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
assert output
print(f"ModelOpt FP8_PB_WO output: {output}")


def test_modelopt_nvfp4_config_dispatches_w4a4_method():
"""``quant_method="NVFP4"`` (W4A4 default) routes to the existing
``ModelOptNvFp4LinearMethod``."""
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config,
ModelOptNvFp4LinearMethod,
)

config = ModelOptNvFp4Config(
quant_method="NVFP4",
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
)
assert config.LinearMethodCls is ModelOptNvFp4LinearMethod
assert config.quant_method == "NVFP4"


def test_modelopt_nvfp4_config_dispatches_w4a16_method():
"""``quant_method="W4A16_NVFP4"`` routes to the new
``ModelOptNvFp4W4A16LinearMethod`` instead of the W4A4 sibling.

Mirrors the FP8 dispatch precedent (``ModelOptFp8Config`` selects
one of three FP8 LinearMethods on ``quant_method``); a regression
here would mean a W4A16 NVFP4 checkpoint silently loaded under the
W4A4 method, which would try to register an ``input_scale`` runtime
parameter and (more importantly) call the cutlass W4A4 NVFP4 GEMM
instead of FP4 Marlin.
"""
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config,
ModelOptNvFp4LinearMethod,
ModelOptNvFp4W4A16LinearMethod,
)

config = ModelOptNvFp4Config(
quant_method="W4A16_NVFP4",
is_checkpoint_nvfp4_serialized=True,
kv_cache_quant_algo=None,
exclude_modules=[],
)
assert config.LinearMethodCls is ModelOptNvFp4W4A16LinearMethod
assert config.LinearMethodCls is not ModelOptNvFp4LinearMethod
assert config.quant_method == "W4A16_NVFP4"
1 change: 1 addition & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
"ModelOptFp8PbWoLinearMethod",
"QuarkLinearMethod",
"ModelOptNvFp4LinearMethod",
"ModelOptNvFp4W4A16LinearMethod",
"HummingLinearMethod",
]

Expand Down
183 changes: 177 additions & 6 deletions vllm/model_executor/layers/quantization/modelopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
MarlinNvFp4LinearKernel,
NvFp4LinearLayerConfig,
init_fp8_linear_kernel,
init_mxfp8_linear_kernel,
init_nvfp4_linear_kernel,
Expand Down Expand Up @@ -89,6 +91,7 @@
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
Expand All @@ -107,8 +110,10 @@
"FP8_PER_CHANNEL_PER_TOKEN",
# FP8 per-block weight-only (ModelOpt may emit this as lowercase).
"FP8_PB_WO",
# FP4
# NVFP4 W4A4 (4-bit float weights AND 4-bit float activations).
"NVFP4",
# W4A16 NVFP4 (4-bit float weights, fp16/bf16 activations).
"W4A16_NVFP4",
# MXFP8
"MXFP8",
# MIXED_PRECISION,
Expand Down Expand Up @@ -1003,22 +1008,41 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):

def __init__(
self,
is_checkpoint_nvfp4_serialized: bool,
kv_cache_quant_algo: str | None,
exclude_modules: list[str],
quant_method: str = "NVFP4",
is_checkpoint_nvfp4_serialized: bool = False,
kv_cache_quant_algo: str | None = None,
exclude_modules: list[str] | None = None,
group_size: int = 16,
) -> None:
if exclude_modules is None:
exclude_modules = []
super().__init__(exclude_modules)
self.quant_method = quant_method
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
"Detected ModelOpt NVFP4 checkpoint. Please note that"
" the format is experimental and could change in future."
"Detected ModelOpt NVFP4 checkpoint (quant_algo=%s). Please "
"note that the format is experimental and could change in "
"future.",
quant_method,
)

self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo

# Select LinearMethod implementation based on quant_algo (FP8 pattern).
# NVFP4 -> W4A4: cutlass NVFP4 GEMM with input quantization
# W4A16_NVFP4 -> W4A16: FP4 Marlin GEMM with bf16/fp16 activations
if quant_method == "NVFP4":
self.LinearMethodCls = ModelOptNvFp4LinearMethod
elif quant_method == "W4A16_NVFP4":
self.LinearMethodCls = ModelOptNvFp4W4A16LinearMethod
else:
raise ValueError(
f"Unsupported ModelOpt NVFP4 quant_algo: {quant_method}. "
"Supported: NVFP4 / W4A16_NVFP4."
)

def get_name(self) -> QuantizationMethods:
return "modelopt_fp4"

Expand Down Expand Up @@ -1069,6 +1093,7 @@ def _from_config(
)

return cls(
quant_method,
is_checkpoint_nvfp4_serialized,
kv_cache_quant_method,
exclude_modules,
Expand Down Expand Up @@ -1208,6 +1233,152 @@ def apply(
return self.kernel.apply_weights(layer=layer, x=x, bias=bias)


class ModelOptNvFp4W4A16LinearMethod(LinearMethodBase):
"""Linear method for ModelOpt NVFP4 W4A16.

4-bit NVFP4 weights, fp16/bf16 activations. Loads ModelOpt-style names
directly (no on-disk conversion) and dispatches to the FP4 Marlin GEMM:

weight uint8 packed NVFP4 (2 nibbles/byte along input dim)
weight_scale fp8-e4m3 per 16-elem group along input dim
weight_scale_2 fp32 per-tensor global scale = amax / (6.0 * 448.0)

No activation quantization. Marlin expects the global scale in the same
form ModelOpt stores (amax/2688), so we rename weight_scale_2 ->
weight_global_scale **without reciprocation** -- the CT W4A16 path
reciprocates only because CT stores the inverse on disk.

We also register a placeholder input_scale parameter so that W4A4-shaped
checkpoints (which contain *_proj.input_scale tensors) can be loaded
under this method without the per-shard loader hitting a KeyError on
the merged-name lookup. The placeholder is discarded in
process_weights_after_loading -- its value is never used.
"""

def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
self.quant_config = quant_config
# Vestigial slot mirrored from ModelOptNvFp4LinearMethod: the parent
# config's get_quant_method only fills marlin_input_dtype when
# backend == "marlin"; we don't set that since we pin the kernel
# below, but we keep the attribute for shape parity.
self.marlin_input_dtype = None
# Direct-instantiate the Marlin NVFP4 adapter rather than going through
# init_nvfp4_linear_kernel(): the latter's priority list returns a
# cutlass W4A4 kernel as first-pick on this hardware, which would
# silently try to quantize activations (we have no input_scale). For
# W4A16 there is exactly one valid kernel, so we pin it.
self.kernel = MarlinNvFp4LinearKernel(NvFp4LinearLayerConfig())

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError(
"W4A16_NVFP4 quantization was selected; "
"dynamic quantization is not supported."
)
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

if input_size_per_partition % 16 != 0:
raise ValueError(
"Unsupported model: input feature size is not a multiple of 16."
)

# Packed NVFP4 weights: uint8, 2 nibbles per byte along the input dim.
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)

# Per-tensor global weight scale (fp32). ModelOpt stores
# amax / (NVFP4_max * fp8_e4m3_max) = amax / 2688. PerTensorScaleParameter
# holds one entry per fused output partition (e.g. q/k/v in a fused QKV).
weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale_2", weight_scale_2)

# Per-group fp8 weight scale.
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

# Placeholder input_scale param so W4A4-shaped checkpoints can be
# loaded under this method without KeyError on the merged-name
# lookup (qwen2-style stacked-loader path renames *_proj.input_scale
# to e.g. qkv_proj.input_scale and looks it up unconditionally).
# Discarded in process_weights_after_loading; never read by the kernel.
# For native W4A16 checkpoints (no input_scale on disk) the param
# stays uninitialized and is simply deleted.
input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Discard the input_scale placeholder. Whether it carries values
# (W4A4 ckpt loaded as W4A16) or is uninitialized (native W4A16
# ckpt), W4A16 mode does not quantize activations, so this is unused.
if hasattr(layer, "input_scale"):
del layer.input_scale

if torch.unique(layer.weight_scale_2).numel() != 1:
logger.warning_once(
"In W4A16_NVFP4 linear, the global weight scale "
"(weight_scale_2) differs across fused parallel layers "
"(e.g. q/k/v_proj). This will likely reduce accuracy. "
"Consider a checkpoint with a shared global scale."
)

# Rename weight_scale_2 -> weight_global_scale. NO reciprocation:
# ModelOpt already stores amax/2688, which is exactly what Marlin
# consumes via nvfp4_marlin_process_global_scale (called inside the
# Marlin adapter's process_weights_after_loading).
layer.weight_global_scale = Parameter(
layer.weight_scale_2.max().to(torch.float32), requires_grad=False
)
del layer.weight_scale_2

self.kernel.process_weights_after_loading(layer)
Comment thread
juhi10071998 marked this conversation as resolved.

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.kernel.apply_weights(layer=layer, x=x, bias=bias)


class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
"""
MoE Method for FP4 Quantization.
Expand Down
Loading