[Quantization] Add ModelOpt NVFP4 W4A16 (4-bit weights, fp16/bf16 activations) support#41769
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for ModelOpt NVFP4 W4A16 quantization, enabling 4-bit weights with 16-bit activations. The changes include the implementation of the ModelOptNvFp4W4A16LinearMethod class, which utilizes the Marlin NVFP4 kernel, and updates to ModelOptNvFp4Config to handle dispatching between W4A4 and W4A16 methods. Additionally, unit tests were added to verify the configuration routing. Feedback was provided to address a potential accuracy issue during weight processing where differing global scales across fused layers were not being correctly compensated for via group scale rescaling.
Adds first-class loading for ModelOpt-exported NVFP4_W4A16 checkpoints
(`quant_algo: "NVFP4_W4A16"`). Today vLLM can only consume such ckpts
after rewriting them into the compressed-tensors format on disk; this
change lets the ModelOpt loader feed the FP4 Marlin GEMM directly,
without an on-disk conversion.
Plumbing (no new config class):
- `QUANT_ALGOS`: register `"NVFP4_W4A16"`. Existing
`ModelOptNvFp4Config.override_quantization_method` substring check
(`"NVFP4" in algo or "FP4" in algo`) already routes it to the same
config class as `"NVFP4"` -- mirrors the established FP8 pattern in
this file where one ModelOptFp8Config dispatches to three FP8
LinearMethods based on the algo string.
- `ModelOptNvFp4Config.__init__` now takes `quant_method` and selects
`self.LinearMethodCls` per algo:
NVFP4 -> ModelOptNvFp4LinearMethod (existing W4A4)
NVFP4_W4A16 -> ModelOptNvFp4W4A16LinearMethod (new)
- `_from_config` threads `quant_method` to the constructor.
New class `ModelOptNvFp4W4A16LinearMethod`:
- Loads ModelOpt-style names directly (no on-disk renames):
weight uint8 packed NVFP4
weight_scale fp8-e4m3 per 16-elem group along input dim
weight_scale_2 fp32 per-tensor global = amax / (6.0 * 448.0)
- process_weights_after_loading: rename weight_scale_2 ->
weight_global_scale **without reciprocation**. ModelOpt already stores
amax/2688 which is the form Marlin's
nvfp4_marlin_process_global_scale consumes; the CT W4A16 path
reciprocates only because CT stores 1/x on disk. Then call
prepare_fp4_layer_for_marlin(layer).
- apply: dispatches to apply_fp4_marlin_linear -- same call as
CompressedTensorsW4A16Fp4.
linear.py: add "ModelOptNvFp4W4A16LinearMethod" to
WEIGHT_LOADER_V2_SUPPORTED so the linear layer uses weight_loader_v2
for our params (especially needed for PerTensorScaleParameter on fused
QKV/gate-up; without v2 the legacy loader hits a shape assert).
Validation (case 1, controlled equivalence):
- Native W4A16 load: ModelOpt qwen3-8b W4A16 ckpt via this method ->
6.01 GiB / 2.27 s, FLASHINFER attention, fp8 KV cache, ~57 tok/s
decode on enforce_eager. Outputs coherent.
- CT-converted W4A16 load: same source ckpt, run through the
conversion script, loaded via CompressedTensorsW4A16Fp4. Same
attention backend (FLASHINFER), same KV cache dtype (fp8), same KV
cache slot count (1,051,632), token-for-token identical greedy
completions. Bit-identical layer state via different code routes ->
same FP4 Marlin kernel call -> same output. Two-axes apples-to-apples.
Validation matrix tracked in juhim/w4a16_modelopt_vllm/logs_and_results/log.md.
AI-assisted: prepared with Claude (Anthropic). Human review and
on-machine validation by juhim before any PR.
Co-authored-by: Claude
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
Refactor only -- same kernel calls, same byte-identical output. Aligns ModelOptNvFp4W4A16LinearMethod's code shape with the existing W4A4 sibling (ModelOptNvFp4LinearMethod) by going through a kernel-adapter abstraction instead of calling Marlin functions directly: __init__: self.kernel = MarlinNvFp4LinearKernel(NvFp4LinearLayerConfig()) process_weights_after_loading: self.kernel.process_weights_after_loading(layer) apply: self.kernel.apply_weights(layer=layer, x=x, bias=bias) We deliberately direct-instantiate MarlinNvFp4LinearKernel rather than go through init_nvfp4_linear_kernel(): the shared selector's first-pick on this hardware is a cutlass W4A4 kernel that quantizes activations, which would silently break our W4A16 path (no input_scale registered). For W4A16 there is exactly one valid kernel, so we pin it. Also drops the dead `backend = "marlin"` class attribute. The framework gate at ModelOptQuantConfigBase.get_quant_method only sets marlin_input_dtype when backend == "marlin"; we no longer need that because our adapter calls prepare_fp4_layer_for_marlin without an input_dtype, and that argument only affects an is_a_8bit branch that NVFP4 W4A16 never enters (bf16/fp16 acts always have itemsize > 1, and fp8/int8 acts on NVFP4 weights are explicitly rejected). The vestigial `self.marlin_input_dtype = None` slot is kept to mirror the W4A4 method's __init__ shape. Imports: drop now-unused apply_fp4_marlin_linear and prepare_fp4_layer_for_marlin from modelopt.py; add MarlinNvFp4LinearKernel and NvFp4LinearLayerConfig. Re-validated case 1 on qwen3-8b W4A16 ckpt: - FLASHINFER attention backend, candidate set [FLASHINFER, TRITON_ATTN] - fp8_e4m3 KV cache, 1,051,632 slot count - model load 6.01 GiB / 2.18 s - token-for-token identical greedy completions to the pre-refactor run Co-authored-by: Claude Signed-off-by: Juhi Mittal <juhim@nvidia.com>
The W4A16 method previously didn't register input_scale, on the assumption that vLLM's loader would silently skip on-disk *_proj.input_scale keys when the W4A4-shaped variant of a NVFP4 checkpoint was loaded under this method. That's only true in the qwen2 loader's "else" branch (post-stacked, has an explicit `if name not in params_dict: continue` guard). The "stacked" branch -- which handles q_proj/k_proj/v_proj/gate_proj/up_proj shards -- has no such guard, and unconditionally does `params_dict[name]` after renaming e.g. `q_proj.input_scale` to `qkv_proj.input_scale`. Without an `qkv_proj.input_scale` parameter registered, that lookup KeyErrors and engine init fails. This trips the moment a user tries to load a NVFP4 (W4A4) checkpoint under the W4A16 method (the eventual phase-2 use case for the --quantization=modelopt_fp4_w4a16 override, and the immediate phase-1 case-3 validation test). Fix: register a placeholder PerTensorScaleParameter named input_scale in create_weights so the loader can place per-shard input_scale tensors here without KeyError on the merged-name lookup. We discard it in process_weights_after_loading -- W4A16 mode does not quantize activations, so the value is never used. For native W4A16 checkpoints (no input_scale on disk) the placeholder stays uninitialized and is simply deleted; harmless. Validated end-to-end on qwen3-8b: - Case 1 (native W4A16): unchanged, ~102 tok/s, FLASHINFER, fp8 KV. - Case 2 (W4A4 regression): unchanged, ~42 tok/s, existing ModelOptNvFp4LinearMethod path. - Case 3 (W4A4 ckpt with quant_algo file-edited to NVFP4_W4A16, loaded via this method): now succeeds; outputs token-identical to case 1; logits **bit-identical** to case 1 (max |Δlogprob| = 0 across 47 captured positions, top-20 ranks 100% match), confirming same weight bits -> same Marlin kernel -> same computation. Side effect: makes the W4A16 method intrinsically robust to either NVFP4 checkpoint shape (W4A4 or W4A16). The eventual ModelOptNvFp4W4A16Config phase-2 override is then pure routing -- the underlying method already handles both shapes correctly. Co-authored-by: Claude Signed-off-by: Juhi Mittal <juhim@nvidia.com>
ModelOpt PR vllm-project#1313 (commit 0fede961 on origin/hungyueh/modelopt-nvfp4-w4a16, "nvfp4_w4a16 -> w4a16_nvfp4") renamed the qformat / on-disk quant_algo string. Six string-literal edits in vllm/.../modelopt.py to match: - QUANT_ALGOS entry. - Dispatch in ModelOptNvFp4Config.__init__. - Error message in __init__. - 3 docstring / log-warning labels. No registry change: override_quantization_method's substring check ("NVFP4" in algo or "FP4" in algo) still matches "W4A16_NVFP4" because it contains "NVFP4". The LinearMethod class name ModelOptNvFp4W4A16LinearMethod is kept as-is -- it describes the concept ("NVFP4 weights, W4A16 mode"), not the on-disk algo string. Renaming the class would touch WEIGHT_LOADER_V2_SUPPORTED in linear.py and add review surface that isn't earned by the on-disk rename. Smoke-tested both qwen3-8b and Nemotron-Nano-4B after the rename; both still load + generate cleanly. The on-disk ckpt configs were patched in place (only the quant_algo JSON field, no safetensors regen) -- documented in the gitlab notes repo's log.md. Co-authored-by: Claude Signed-off-by: Juhi Mittal <juhim@nvidia.com>
Two related changes:
1. Make ModelOptNvFp4Config.__init__ args defaultable so existing
tests / callers that construct the config without passing
quant_method (and friends) keep working unchanged. Previously
adding `quant_method` as a required positional arg silently broke
three test sites under tests/{compile,distributed,kernels}/...
that build the config directly to exercise downstream code (eplb,
MLA fusion, MoE layer). Defaults match the W4A4 path, which is
what those tests were exercising:
quant_method: str = "NVFP4"
is_checkpoint_nvfp4_serialized: bool = False
kv_cache_quant_algo: str | None = None
exclude_modules: list[str] | None = None # treated as []
_from_config still passes all five explicitly when loading a real
checkpoint, so the defaults only affect direct constructor users.
2. Add two unit tests under tests/quantization/test_modelopt.py
that exercise the per-algo LinearMethodCls dispatch in
ModelOptNvFp4Config without needing a checkpoint:
- test_modelopt_nvfp4_config_dispatches_w4a4_method
quant_method="NVFP4" -> ModelOptNvFp4LinearMethod
- test_modelopt_nvfp4_config_dispatches_w4a16_method
quant_method="W4A16_NVFP4" -> ModelOptNvFp4W4A16LinearMethod
The W4A16 test asserts both `is` (positive) and `is not` against
the W4A4 sibling so a regression that silently routes a W4A16
checkpoint under the W4A4 method (and then calls the cutlass W4A4
NVFP4 GEMM instead of FP4 Marlin, with no input_scale) would fail
loudly.
Test result locally: 2 passed, ~2.2 s.
Co-authored-by: Claude
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
36109e2 to
98e34ae
Compare
pavanimajety
left a comment
There was a problem hiding this comment.
LGTM, thanks @juhi10071998
|
hi @pavanimajety , thanks for approving. I saw some compilation unit tests failing but these seems to be unrelated to this PR. The failing job (pytorch-fullgraph-smoke-test) is tests/compile/fullgraph/test_multiple_graphs.py::test_multi_graph_piecewise_compile, asserting num_graphs_seen=0, expected diff=2 from vllm/compilation/counter.py. This PR's diff is confined to modelopt.py and it doesn't touch torch.compile, Inductor, or graph capture, and it has no import-time side effects. The test that's failing uses generic torch.compile fixtures (no NVFP4 ckpt), so nothing in this PR's code path executes during that test. should we bypass this? |
…ivations) support (vllm-project#41769) Signed-off-by: Juhi Mittal <juhim@nvidia.com>
…ivations) support (vllm-project#41769) Signed-off-by: Juhi Mittal <juhim@nvidia.com>
…ivations) support (vllm-project#41769) Signed-off-by: Juhi Mittal <juhim@nvidia.com>
…ivations) support (vllm-project#41769) Signed-off-by: Juhi Mittal <juhim@nvidia.com>
…ivations) support (vllm-project#41769) Signed-off-by: Juhi Mittal <juhim@nvidia.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
…ivations) support (vllm-project#41769) Signed-off-by: Juhi Mittal <juhim@nvidia.com>
Summary
Add native ModelOpt NVFP4 W4A16 (4-bit NVFP4 weights, fp16/bf16 activations) checkpoint support to vLLM, mirroring the existing FP8 dispatch precedent (
ModelOptFp8Configselecting between three FP8 LinearMethods onquant_method). Checkpoints produced by NVIDIA/Model-Optimizer's W4A16 recipe (NVIDIA/Model-Optimizer#1313) can now be loaded directly and run through themarlin_fp4kernel.quant_method="W4A16_NVFP4"inhf_quant_config.jsonroutes to a newModelOptNvFp4W4A16LinearMethodthat registers W4A16 weights/scales, toleratesinput_scaletensors carried by W4A4 sibling checkpoints (silently ingested-and-discarded), and runs the GEMM through the existing FP4 Marlin kernel adapter.Why I had to create input_scale for this
ModelOptNvfp4W4A16LinearMethod?To consume pure nvfp4 ckpt (even on unsupported architecture) once we have correct routing through CLI (currently we do through manually updating quant_algo NVFP4-> W4A16_NVFP4).
The input scale is needed to avoid keyerror issues in create_weights/ weight loading. For instance in model definition since the if branch has no fall back path if the loaded weight does not have the corresponding key- https://github.com/vllm-project/vllm/blob/22a3cbe1520bc8a3b19ace0abe497c514b3129ea/vllm/model_executor/models/qwen2.py#L485], I feel it may not be trivial to add this to all the models, hence I opted for the approach for creating the placeholder param for input_scales in create_weights so weight loading happens as expected and then remove the param in the process_weights_after_loading function.
Validation
Unit tests (in-tree, this PR)
The W4A16 test asserts both
is(positive) andis not(against W4A4 sibling), so a regression that silently routed a W4A16 checkpoint through the W4A4 method (calling cutlass W4A4 NVFP4 GEMM with noinput_scale, instead of FP4 Marlin) would fail loudly.End-to-end validation against real W4A16 ckpts (out-of-tree, manual)
We don't have a public W4A16 NVFP4 ckpt to add as an HF-Hub-downloaded integration test (mirroring the FP8
pc_pt/pb_wosiblings) — see "Test gaps" below. The W4A16 path was instead validated against ModelOpt-generated checkpoints:1. Dense (Qwen3-8B), W4A4-recipe ckpt loaded as W4A16 vs native W4A16: validate the
input_scaleingest-and-discard path (commit "tolerate input_scale tensors from W4A4 checkpoints") by loading a W4A4 sibling ckpt withquant_methodrewritten toW4A16_NVFP4.Strongest possible parity result: same on-disk weights+scales, same LinearMethod, same kernel, same KV-cache config → byte-for-byte computation.
Lint / formatting
Test gaps (acknowledged)
test_modelopt_fp8_pc_pt_checkpoint_setup/test_modelopt_fp8_pb_wo_checkpoint_setup. We'll add that test in a follow-up PR once such a checkpoint is published.Follow-up PRs
This PR establishes the load path; two intentionally-deferred follow-ups:
hf_quant_config.json:quantization.quant_algofrom"NVFP4"to"W4A16_NVFP4". The flag would let users opt a W4A4 ckpt into W4A16 inference at load time without modifying the on-disk config. This PR's "tolerateinput_scalefrom W4A4 checkpoints" change (commited46c63) was added specifically to validate that flow ahead of the flag landing — confirmed bit-identical to native W4A16 on Qwen3-8B (case 2 above). Estimated scope: ~30–50 LOC + a newModelOptNvFp4W4A16Configsubclass and registry entry.lm_head— vLLM'sParallelLMHeaddoes not currently route through a quantizationLinearMethod, so a quantizedlm_headin a W4A16 NVFP4 checkpoint is not loadable end-to-end without separate plumbing. Out of scope for this PR; will be addressed onceParallelLMHead's dispatch path is extended.AI assistance disclosure
This PR was developed with AI assistance (Claude). The submitting human reviewed every changed line, drove the validation runs, and reproduced all test results locally. AI-assistance attribution is in each commit's
Co-authored-by:trailer. DCO sign-off is in place on every commit.Test plan/ unchecked ones are non-gating and could be a follow-up
pytest tests/quantization/test_modelopt.py -k config_dispatches -v)lm_headend-to-end — gated onParallelLMHeadplumbing; deferred to follow-up PR Support tensor parallel #2🤖 Generated with Claude Code
Supplementary Changes (3 files, +224 / −6)
vllm/model_executor/layers/quantization/modelopt.pyModelOptNvFp4Config.__init__made fully defaultable (quant_method: str = "NVFP4", etc.) — preserves backward-compatibility for existing direct-instantiation callers undertests/{compile,distributed,kernels}/....quant_method="NVFP4"→ existingModelOptNvFp4LinearMethod(W4A4);quant_method="W4A16_NVFP4"→ newModelOptNvFp4W4A16LinearMethod(W4A16).ModelOptNvFp4W4A16LinearMethod: registersweight(uint8 packed),weight_scale(fp8 group),weight_scale_2(fp32 per-tensor), and a placeholderinput_scale(discarded post-load to absorb tensors found in W4A4-recipe ckpts).process_weights_after_loadingrenamesweight_scale_2 → weight_global_scale(no reciprocation) and delegates to the existingMarlinNvFp4LinearKerneladapter.vllm/model_executor/layers/linear.py"ModelOptNvFp4W4A16LinearMethod"toWEIGHT_LOADER_V2_SUPPORTEDso fused QKV / gate_up loading dispatches through the V2 loader (required forPerTensorScaleParameterhandling on stacked layers).tests/quantization/test_modelopt.py