Skip to content

[Quantization] Add ModelOpt NVFP4 W4A16 (4-bit weights, fp16/bf16 activations) support#41769

Merged
pavanimajety merged 19 commits into
vllm-project:mainfrom
juhi10071998:w4a16_modelopt_support
May 9, 2026
Merged

[Quantization] Add ModelOpt NVFP4 W4A16 (4-bit weights, fp16/bf16 activations) support#41769
pavanimajety merged 19 commits into
vllm-project:mainfrom
juhi10071998:w4a16_modelopt_support

Conversation

@juhi10071998

@juhi10071998 juhi10071998 commented May 6, 2026

Copy link
Copy Markdown
Contributor

Summary

Add native ModelOpt NVFP4 W4A16 (4-bit NVFP4 weights, fp16/bf16 activations) checkpoint support to vLLM, mirroring the existing FP8 dispatch precedent (ModelOptFp8Config selecting between three FP8 LinearMethods on quant_method). Checkpoints produced by NVIDIA/Model-Optimizer's W4A16 recipe (NVIDIA/Model-Optimizer#1313) can now be loaded directly and run through the marlin_fp4 kernel.

quant_method="W4A16_NVFP4" in hf_quant_config.json routes to a new ModelOptNvFp4W4A16LinearMethod that registers W4A16 weights/scales, tolerates input_scale tensors carried by W4A4 sibling checkpoints (silently ingested-and-discarded), and runs the GEMM through the existing FP4 Marlin kernel adapter.

Why I had to create input_scale for this ModelOptNvfp4W4A16LinearMethod?
To consume pure nvfp4 ckpt (even on unsupported architecture) once we have correct routing through CLI (currently we do through manually updating quant_algo NVFP4-> W4A16_NVFP4).
The input scale is needed to avoid keyerror issues in create_weights/ weight loading. For instance in model definition since the if branch has no fall back path if the loaded weight does not have the corresponding key- https://github.com/vllm-project/vllm/blob/22a3cbe1520bc8a3b19ace0abe497c514b3129ea/vllm/model_executor/models/qwen2.py#L485], I feel it may not be trivial to add this to all the models, hence I opted for the approach for creating the placeholder param for input_scales in create_weights so weight loading happens as expected and then remove the param in the process_weights_after_loading function.

Validation

Unit tests (in-tree, this PR)

$ pytest tests/quantization/test_modelopt.py -k "config_dispatches" -v
tests/quantization/test_modelopt.py::test_modelopt_nvfp4_config_dispatches_w4a4_method PASSED [ 50%]
tests/quantization/test_modelopt.py::test_modelopt_nvfp4_config_dispatches_w4a16_method PASSED [100%]
================= 2 passed, 3 deselected, 16 warnings in 2.18s =================

The W4A16 test asserts both is (positive) and is not (against W4A4 sibling), so a regression that silently routed a W4A16 checkpoint through the W4A4 method (calling cutlass W4A4 NVFP4 GEMM with no input_scale, instead of FP4 Marlin) would fail loudly.

End-to-end validation against real W4A16 ckpts (out-of-tree, manual)

We don't have a public W4A16 NVFP4 ckpt to add as an HF-Hub-downloaded integration test (mirroring the FP8 pc_pt / pb_wo siblings) — see "Test gaps" below. The W4A16 path was instead validated against ModelOpt-generated checkpoints:

1. Dense (Qwen3-8B), W4A4-recipe ckpt loaded as W4A16 vs native W4A16: validate the input_scale ingest-and-discard path (commit "tolerate input_scale tensors from W4A4 checkpoints") by loading a W4A4 sibling ckpt with quant_method rewritten to W4A16_NVFP4.

=== overall ===
  global max |Δlogprob|       : 0
  argmax-mismatch positions   : 0   (47/47 positions agree, 100% top-K rank match)
  ✅ Logprobs are BIT-IDENTICAL at every position across the top-K.

Strongest possible parity result: same on-disk weights+scales, same LinearMethod, same kernel, same KV-cache config → byte-for-byte computation.

Lint / formatting

$ pre-commit run --files vllm/model_executor/layers/quantization/modelopt.py \
                          vllm/model_executor/layers/linear.py \
                          tests/quantization/test_modelopt.py
ruff check..........Passed
ruff format.........Passed
typos...............Passed
mypy-3.10...........Passed
SPDX headers........Passed
[all hooks passed]

Test gaps (acknowledged)

  • No public W4A16 NVFP4 checkpoint exists yet to add an HF-Hub-downloaded integration test mirroring test_modelopt_fp8_pc_pt_checkpoint_setup / test_modelopt_fp8_pb_wo_checkpoint_setup. We'll add that test in a follow-up PR once such a checkpoint is published.

Follow-up PRs

This PR establishes the load path; two intentionally-deferred follow-ups:

  1. Expose a CLI flag to load W4A4 ckpts through the W4A16 method — avoids the current manual workaround of rewriting hf_quant_config.json:quantization.quant_algo from "NVFP4" to "W4A16_NVFP4". The flag would let users opt a W4A4 ckpt into W4A16 inference at load time without modifying the on-disk config. This PR's "tolerate input_scale from W4A4 checkpoints" change (commit ed46c63) was added specifically to validate that flow ahead of the flag landing — confirmed bit-identical to native W4A16 on Qwen3-8B (case 2 above). Estimated scope: ~30–50 LOC + a new ModelOptNvFp4W4A16Config subclass and registry entry.
  2. Extend W4A16 to cover lm_head — vLLM's ParallelLMHead does not currently route through a quantization LinearMethod, so a quantized lm_head in a W4A16 NVFP4 checkpoint is not loadable end-to-end without separate plumbing. Out of scope for this PR; will be addressed once ParallelLMHead's dispatch path is extended.

AI assistance disclosure

This PR was developed with AI assistance (Claude). The submitting human reviewed every changed line, drove the validation runs, and reproduced all test results locally. AI-assistance attribution is in each commit's Co-authored-by: trailer. DCO sign-off is in place on every commit.

Test plan/ unchecked ones are non-gating and could be a follow-up

  • Unit tests pass (pytest tests/quantization/test_modelopt.py -k config_dispatches -v)
  • Pre-commit clean on touched files
  • End-to-end load + generation validated on Qwen3-8B (modelopt-direct vs CT-bridged, argmax-stable)
  • End-to-end load + generation validated on Qwen3-8B (W4A4-as-W4A16 vs native W4A16, bit-identical)
  • Public-checkpoint integration test — deferred until a W4A16 NVFP4 ckpt is published to HF Hub
  • CLI flag to route W4A4 ckpts through the W4A16 method — deferred to follow-up PR Fix a bug in tying OPT embeddings #1 (see "Follow-up PRs")
  • Quantized lm_head end-to-end — gated on ParallelLMHead plumbing; deferred to follow-up PR Support tensor parallel #2

🤖 Generated with Claude Code

Supplementary Changes (3 files, +224 / −6)

  • vllm/model_executor/layers/quantization/modelopt.py
    • ModelOptNvFp4Config.__init__ made fully defaultable (quant_method: str = "NVFP4", etc.) — preserves backward-compatibility for existing direct-instantiation callers under tests/{compile,distributed,kernels}/....
    • Per-algo dispatch: quant_method="NVFP4" → existing ModelOptNvFp4LinearMethod (W4A4); quant_method="W4A16_NVFP4" → new ModelOptNvFp4W4A16LinearMethod (W4A16).
    • New ModelOptNvFp4W4A16LinearMethod: registers weight (uint8 packed), weight_scale (fp8 group), weight_scale_2 (fp32 per-tensor), and a placeholder input_scale (discarded post-load to absorb tensors found in W4A4-recipe ckpts). process_weights_after_loading renames weight_scale_2 → weight_global_scale (no reciprocation) and delegates to the existing MarlinNvFp4LinearKernel adapter.
  • vllm/model_executor/layers/linear.py
    • Adds "ModelOptNvFp4W4A16LinearMethod" to WEIGHT_LOADER_V2_SUPPORTED so fused QKV / gate_up loading dispatches through the V2 loader (required for PerTensorScaleParameter handling on stacked layers).
  • tests/quantization/test_modelopt.py
    • Two dispatch unit tests guarding against silent-W4A4-on-W4A16 regression — see "Validation".

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for ModelOpt NVFP4 W4A16 quantization, enabling 4-bit weights with 16-bit activations. The changes include the implementation of the ModelOptNvFp4W4A16LinearMethod class, which utilizes the Marlin NVFP4 kernel, and updates to ModelOptNvFp4Config to handle dispatching between W4A4 and W4A16 methods. Additionally, unit tests were added to verify the configuration routing. Feedback was provided to address a potential accuracy issue during weight processing where differing global scales across fused layers were not being correctly compensated for via group scale rescaling.

Comment thread vllm/model_executor/layers/quantization/modelopt.py
@juhi10071998 juhi10071998 marked this pull request as draft May 6, 2026 01:07
@juhi10071998 juhi10071998 marked this pull request as ready for review May 6, 2026 02:05
Adds first-class loading for ModelOpt-exported NVFP4_W4A16 checkpoints
(`quant_algo: "NVFP4_W4A16"`). Today vLLM can only consume such ckpts
after rewriting them into the compressed-tensors format on disk; this
change lets the ModelOpt loader feed the FP4 Marlin GEMM directly,
without an on-disk conversion.

Plumbing (no new config class):
- `QUANT_ALGOS`: register `"NVFP4_W4A16"`. Existing
  `ModelOptNvFp4Config.override_quantization_method` substring check
  (`"NVFP4" in algo or "FP4" in algo`) already routes it to the same
  config class as `"NVFP4"` -- mirrors the established FP8 pattern in
  this file where one ModelOptFp8Config dispatches to three FP8
  LinearMethods based on the algo string.
- `ModelOptNvFp4Config.__init__` now takes `quant_method` and selects
  `self.LinearMethodCls` per algo:
    NVFP4         -> ModelOptNvFp4LinearMethod        (existing W4A4)
    NVFP4_W4A16   -> ModelOptNvFp4W4A16LinearMethod   (new)
- `_from_config` threads `quant_method` to the constructor.

New class `ModelOptNvFp4W4A16LinearMethod`:
- Loads ModelOpt-style names directly (no on-disk renames):
    weight          uint8     packed NVFP4
    weight_scale    fp8-e4m3  per 16-elem group along input dim
    weight_scale_2  fp32      per-tensor global = amax / (6.0 * 448.0)
- process_weights_after_loading: rename weight_scale_2 ->
  weight_global_scale **without reciprocation**. ModelOpt already stores
  amax/2688 which is the form Marlin's
  nvfp4_marlin_process_global_scale consumes; the CT W4A16 path
  reciprocates only because CT stores 1/x on disk. Then call
  prepare_fp4_layer_for_marlin(layer).
- apply: dispatches to apply_fp4_marlin_linear -- same call as
  CompressedTensorsW4A16Fp4.

linear.py: add "ModelOptNvFp4W4A16LinearMethod" to
WEIGHT_LOADER_V2_SUPPORTED so the linear layer uses weight_loader_v2
for our params (especially needed for PerTensorScaleParameter on fused
QKV/gate-up; without v2 the legacy loader hits a shape assert).

Validation (case 1, controlled equivalence):
- Native W4A16 load: ModelOpt qwen3-8b W4A16 ckpt via this method ->
  6.01 GiB / 2.27 s, FLASHINFER attention, fp8 KV cache, ~57 tok/s
  decode on enforce_eager. Outputs coherent.
- CT-converted W4A16 load: same source ckpt, run through the
  conversion script, loaded via CompressedTensorsW4A16Fp4. Same
  attention backend (FLASHINFER), same KV cache dtype (fp8), same KV
  cache slot count (1,051,632), token-for-token identical greedy
  completions. Bit-identical layer state via different code routes ->
  same FP4 Marlin kernel call -> same output. Two-axes apples-to-apples.

Validation matrix tracked in juhim/w4a16_modelopt_vllm/logs_and_results/log.md.

AI-assisted: prepared with Claude (Anthropic). Human review and
on-machine validation by juhim before any PR.

Co-authored-by: Claude
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
Refactor only -- same kernel calls, same byte-identical output. Aligns
ModelOptNvFp4W4A16LinearMethod's code shape with the existing W4A4
sibling (ModelOptNvFp4LinearMethod) by going through a kernel-adapter
abstraction instead of calling Marlin functions directly:

  __init__:                self.kernel = MarlinNvFp4LinearKernel(NvFp4LinearLayerConfig())
  process_weights_after_loading:  self.kernel.process_weights_after_loading(layer)
  apply:                   self.kernel.apply_weights(layer=layer, x=x, bias=bias)

We deliberately direct-instantiate MarlinNvFp4LinearKernel rather than
go through init_nvfp4_linear_kernel(): the shared selector's first-pick
on this hardware is a cutlass W4A4 kernel that quantizes activations,
which would silently break our W4A16 path (no input_scale registered).
For W4A16 there is exactly one valid kernel, so we pin it.

Also drops the dead `backend = "marlin"` class attribute. The framework
gate at ModelOptQuantConfigBase.get_quant_method only sets
marlin_input_dtype when backend == "marlin"; we no longer need that
because our adapter calls prepare_fp4_layer_for_marlin without an
input_dtype, and that argument only affects an is_a_8bit branch that
NVFP4 W4A16 never enters (bf16/fp16 acts always have itemsize > 1, and
fp8/int8 acts on NVFP4 weights are explicitly rejected). The vestigial
`self.marlin_input_dtype = None` slot is kept to mirror the W4A4
method's __init__ shape.

Imports: drop now-unused apply_fp4_marlin_linear and
prepare_fp4_layer_for_marlin from modelopt.py; add
MarlinNvFp4LinearKernel and NvFp4LinearLayerConfig.

Re-validated case 1 on qwen3-8b W4A16 ckpt:
- FLASHINFER attention backend, candidate set [FLASHINFER, TRITON_ATTN]
- fp8_e4m3 KV cache, 1,051,632 slot count
- model load 6.01 GiB / 2.18 s
- token-for-token identical greedy completions to the pre-refactor run

Co-authored-by: Claude
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
The W4A16 method previously didn't register input_scale, on the
assumption that vLLM's loader would silently skip on-disk
*_proj.input_scale keys when the W4A4-shaped variant of a NVFP4
checkpoint was loaded under this method. That's only true in the
qwen2 loader's "else" branch (post-stacked, has an explicit
`if name not in params_dict: continue` guard).

The "stacked" branch -- which handles q_proj/k_proj/v_proj/gate_proj/up_proj
shards -- has no such guard, and unconditionally does
`params_dict[name]` after renaming e.g. `q_proj.input_scale` to
`qkv_proj.input_scale`. Without an `qkv_proj.input_scale` parameter
registered, that lookup KeyErrors and engine init fails.

This trips the moment a user tries to load a NVFP4 (W4A4) checkpoint
under the W4A16 method (the eventual phase-2 use case for the
--quantization=modelopt_fp4_w4a16 override, and the immediate
phase-1 case-3 validation test).

Fix: register a placeholder PerTensorScaleParameter named input_scale
in create_weights so the loader can place per-shard input_scale
tensors here without KeyError on the merged-name lookup. We discard
it in process_weights_after_loading -- W4A16 mode does not quantize
activations, so the value is never used. For native W4A16
checkpoints (no input_scale on disk) the placeholder stays
uninitialized and is simply deleted; harmless.

Validated end-to-end on qwen3-8b:
- Case 1 (native W4A16):  unchanged, ~102 tok/s, FLASHINFER, fp8 KV.
- Case 2 (W4A4 regression): unchanged, ~42 tok/s, existing
  ModelOptNvFp4LinearMethod path.
- Case 3 (W4A4 ckpt with quant_algo file-edited to NVFP4_W4A16,
  loaded via this method): now succeeds; outputs token-identical to
  case 1; logits **bit-identical** to case 1 (max |Δlogprob| = 0
  across 47 captured positions, top-20 ranks 100% match), confirming
  same weight bits -> same Marlin kernel -> same computation.

Side effect: makes the W4A16 method intrinsically robust to either
NVFP4 checkpoint shape (W4A4 or W4A16). The eventual
ModelOptNvFp4W4A16Config phase-2 override is then pure routing -- the
underlying method already handles both shapes correctly.

Co-authored-by: Claude
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
ModelOpt PR vllm-project#1313 (commit 0fede961 on
origin/hungyueh/modelopt-nvfp4-w4a16, "nvfp4_w4a16 -> w4a16_nvfp4")
renamed the qformat / on-disk quant_algo string. Six string-literal
edits in vllm/.../modelopt.py to match:

- QUANT_ALGOS entry.
- Dispatch in ModelOptNvFp4Config.__init__.
- Error message in __init__.
- 3 docstring / log-warning labels.

No registry change: override_quantization_method's substring check
("NVFP4" in algo or "FP4" in algo) still matches "W4A16_NVFP4"
because it contains "NVFP4".

The LinearMethod class name ModelOptNvFp4W4A16LinearMethod is kept
as-is -- it describes the concept ("NVFP4 weights, W4A16 mode"),
not the on-disk algo string. Renaming the class would touch
WEIGHT_LOADER_V2_SUPPORTED in linear.py and add review surface that
isn't earned by the on-disk rename.

Smoke-tested both qwen3-8b and Nemotron-Nano-4B after the rename;
both still load + generate cleanly. The on-disk ckpt configs were
patched in place (only the quant_algo JSON field, no safetensors
regen) -- documented in the gitlab notes repo's log.md.

Co-authored-by: Claude
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
Two related changes:

1. Make ModelOptNvFp4Config.__init__ args defaultable so existing
   tests / callers that construct the config without passing
   quant_method (and friends) keep working unchanged. Previously
   adding `quant_method` as a required positional arg silently broke
   three test sites under tests/{compile,distributed,kernels}/...
   that build the config directly to exercise downstream code (eplb,
   MLA fusion, MoE layer). Defaults match the W4A4 path, which is
   what those tests were exercising:

       quant_method: str = "NVFP4"
       is_checkpoint_nvfp4_serialized: bool = False
       kv_cache_quant_algo: str | None = None
       exclude_modules: list[str] | None = None  # treated as []

   _from_config still passes all five explicitly when loading a real
   checkpoint, so the defaults only affect direct constructor users.

2. Add two unit tests under tests/quantization/test_modelopt.py
   that exercise the per-algo LinearMethodCls dispatch in
   ModelOptNvFp4Config without needing a checkpoint:

   - test_modelopt_nvfp4_config_dispatches_w4a4_method
       quant_method="NVFP4" -> ModelOptNvFp4LinearMethod
   - test_modelopt_nvfp4_config_dispatches_w4a16_method
       quant_method="W4A16_NVFP4" -> ModelOptNvFp4W4A16LinearMethod

   The W4A16 test asserts both `is` (positive) and `is not` against
   the W4A4 sibling so a regression that silently routes a W4A16
   checkpoint under the W4A4 method (and then calls the cutlass W4A4
   NVFP4 GEMM instead of FP4 Marlin, with no input_scale) would fail
   loudly.

Test result locally: 2 passed, ~2.2 s.

Co-authored-by: Claude
Signed-off-by: Juhi Mittal <juhim@nvidia.com>
@juhi10071998 juhi10071998 force-pushed the w4a16_modelopt_support branch from 36109e2 to 98e34ae Compare May 6, 2026 16:28
@pavanimajety pavanimajety added the verified Run pre-commit for new contributors without triggering other tests label May 6, 2026

@pavanimajety pavanimajety left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @juhi10071998

@pavanimajety pavanimajety added the ready ONLY add when PR is ready to merge/full CI is needed label May 6, 2026
@juhi10071998

Copy link
Copy Markdown
Contributor Author

hi @pavanimajety , thanks for approving. I saw some compilation unit tests failing but these seems to be unrelated to this PR.

The failing job (pytorch-fullgraph-smoke-test) is tests/compile/fullgraph/test_multiple_graphs.py::test_multi_graph_piecewise_compile, asserting num_graphs_seen=0, expected diff=2 from vllm/compilation/counter.py. This PR's diff is confined to modelopt.py and it doesn't touch torch.compile, Inductor, or graph capture, and it has no import-time side effects. The test that's failing uses generic torch.compile fixtures (no NVFP4 ckpt), so nothing in this PR's code path executes during that test. should we bypass this?

@pavanimajety pavanimajety enabled auto-merge (squash) May 8, 2026 22:06
@pavanimajety pavanimajety merged commit 7a2b596 into vllm-project:main May 9, 2026
70 checks passed
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request May 11, 2026
…ivations) support (vllm-project#41769)

Signed-off-by: Juhi Mittal <juhim@nvidia.com>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
…ivations) support (vllm-project#41769)

Signed-off-by: Juhi Mittal <juhim@nvidia.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
…ivations) support (vllm-project#41769)

Signed-off-by: Juhi Mittal <juhim@nvidia.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
…ivations) support (vllm-project#41769)

Signed-off-by: Juhi Mittal <juhim@nvidia.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
…ivations) support (vllm-project#41769)

Signed-off-by: Juhi Mittal <juhim@nvidia.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
…ivations) support (vllm-project#41769)

Signed-off-by: Juhi Mittal <juhim@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed verified Run pre-commit for new contributors without triggering other tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants