Skip to content

b12x nvfp4 w4a16 use a16 fix#43929

Draft
meena-at-work wants to merge 5 commits into
vllm-project:mainfrom
meena-at-work:meenakshiv/b12x-nvfp4-w4a16-use-a16-fix
Draft

b12x nvfp4 w4a16 use a16 fix#43929
meena-at-work wants to merge 5 commits into
vllm-project:mainfrom
meena-at-work:meenakshiv/b12x-nvfp4-w4a16-use-a16-fix

Conversation

@meena-at-work
Copy link
Copy Markdown
Contributor

Summary

make_nvfp4_moe_quant_config only routes W4A16 to nvfp4_w4a16_moe_quant_config when backend == NvFp4MoeBackend.MARLIN. With FlashInferB12xExperts now accepting (kNvfp4Static, None) (per #43332 / #43341), W4A16 checkpoints reach make_nvfp4_moe_quant_config with the activation scales still loaded — real calibrated values (for modelopt) or uninitialized torch.empty() (for compressed-tensors). The function then either:

  • modelopt: silently dispatches the W4A4 branch with calibrated activation scales — running W4A4 inference on a W4A16-labeled checkpoint (silent correctness bug; no crash);
  • compressed-tensors: computes 1.0 / uninitialized_tensor, then fails FlashInfer's kernel-side source_format='compressed_tensors' requires quant_mode='w4a16' check.

Changes

  • make_nvfp4_moe_quant_config: route to nvfp4_w4a16_moe_quant_config when a13_scale is None, regardless of backend (previously MARLIN-only). Type hints on a13_scale / a2_scale widened to Tensor | None.
  • ModelOptNvFp4FusedMoE and CompressedTensorsW4A4Nvfp4MoEMethod: when self.use_a16, pass a13_scale=None / a2_scale=None at both call sites (process_weights_after_loading + get_fused_moe_quant_config). CT MoE also stores use_a16 on self (the constructor argument was previously unstored).
  • prepare_nvfp4_moe_layer_for_fi_or_cutlass: skip .max() reductions on None activation scales.
  • FlashInferB12xExperts.__init__: assert on weight_quant_dtype == "nvfp4" instead of quant_dtype (which is None for W4A16).
  • New no-GPU regression test: tests/kernels/moe/test_w4a16_propagation.py.

Dependencies

This PR is stacked on #43332 (B12x _supports_quant_scheme accepts W4A16 + apply() fc2_input_scale fallback) and #43341 (activation_precision + source_format plumbing). Both deps are currently included in this branch as commits 1–4; the actual fix is the final commit. Once #43332 and #43341 merge, this PR rebases to a clean diff against main.

Marked Draft until those dependencies land.

Duplicate-work check

No open PR addresses the make_nvfp4_moe_quant_config routing or the modelopt / CT use_a16 propagation. The neighboring PRs (#43332, #43341, #43333) extend B12x's acceptance of W4A16 but leave the routing-and-propagation gap unresolved — without this PR, even with those merged, modelopt W4A16 silently runs W4A4 and compressed-tensors W4A16 crashes.

Test plan

Unit (no GPU):

pytest tests/kernels/moe/test_w4a16_propagation.py -v

3 tests covering W4A16 B12x / W4A4 B12x / W4A16 MARLIN routing — all pass.

E2E on SM121 (B12x) with #43332+#43341 applied:

Verified end-to-end on representative W4A16 NVFP4 MoE checkpoints — both modelopt (quant_algo=W4A16_NVFP4) and compressed-tensors (nvfp4-pack-quantized). Each loads, compiles, captures CUDA graphs, and generates sensible output on B12x. With FLASHINFER_LOGLEVEL=3, every B12xMoEWrapper.__init__ call per run reports activation_precision='bf16' and the expected source_format for the checkpoint origin ('modelopt' vs 'compressed_tensors').

AI assistance disclosure

AI assistance was used for commit-message drafting and resolving conflicts when rebasing onto #43332 / #43341. Every changed line was reviewed locally by the submitter, and the unit + E2E tests were run end-to-end before opening this PR.

ECMGit and others added 5 commits May 28, 2026 21:41
… supports check

`FlashInferB12xExperts._supports_quant_scheme` currently requires the
activation key to be `kNvfp4Dynamic`, which makes the dispatcher reject
every W4A16 NVFP4 checkpoint (activation_key == None) -- e.g.
`nvidia/Qwen3.6-35B-A3B-2.06GB-per-token`. This forces such checkpoints
onto Marlin even though the b12x kernel itself is W4A16-compatible.

Per the class docstring: "Input quantization (BF16->FP4) is performed
inside the kernel so BF16 hidden states are passed directly." -- i.e.
the kernel already handles the BF16-activation case correctly. This
change only loosens the metadata gate; no kernel-side changes.

PR vllm-project#42566 ("W4A16 NVFP4 fused MoE + mixed-precision dispatch") only
touches quantization/modelopt.py and acknowledges the gap in its own
commit message: "their _supports_quant_scheme requires
(kNvfp4Static, kNvfp4Dynamic) exactly... only Marlin survives." That PR
deliberately routes W4A16 to Marlin as a workaround; this PR is the
fix on the b12x side. The two are complementary and can land
independently -- once both land, W4A16 NVFP4 prefers b12x (fast path).

Failure mode without this PR:
  ValueError: NvFp4 MoE backend 'FLASHINFER_B12X' does not support the
  deployment configuration since kernel does not support quantization
  scheme QuantKey(u8, scale(f8e4m3fn, static, GroupShape(row=1, col=16)),
  scale2(f32, static, per_tensor), symmetric) x None.

Tested on DGX Spark (GB10, sm_121a) with vllm/vllm-openai:nightly-aarch64
+ this PR + the FP8-backend-env companion PR.
Model: nvidia/Qwen3.6-35B-A3B-2.06GB-per-token (modelopt-native,
mixed NVFP4 + FP8 experts).
aiperf K=3 AL=3.12, BS=1, ISL=2048+32K prefix=34,831, OSL=1024,
60 measured + 10 warmup, 0 errors:
  Output Token Throughput        : 91.00 tok/s
  Output Token Throughput / user : 97.42 tok/s/user
  TTFT                           : 746.81 ms
  ITL                            : 10.27 ms
  Request Latency                : 11,249.37 ms
  MTP acceptance length          : 3.15 (target 3.12)

For reference on the same workload:
  Marlin (current W4A16 fallback) : OTT 92.26, TTFT 798.72
  b12x on dgx-fork (matched cubin): OTT 95.15, TTFT 758.90

Without this change b12x rejects the checkpoint at engine init; with it
b12x runs and matches/exceeds Marlin on the b12x-fast path.

Signed-off-by: Junhao Shen <junshen@nvidia.com>
…hout activation quant metadata

Addresses review feedback on the preceding commit (supports-check loosening
for W4A16).

`FlashInferB12xExperts.apply` previously asserted `self.a2_gscale is not None`
unconditionally. For W4A16 NVFP4 checkpoints lacking static
activation-quant metadata (e.g. compressed-tensors W4A16-CT layouts),
`a2_gscale` is legitimately None and the assert fires at the first
forward pass -- strictly worse than the engine-init rejection we just
removed at the dispatcher gate.

`process_weights_after_loading` already tolerates `a2_gscale is None`
(the `if self.a2_gscale is not None: ...` guard at the top of this file),
so the assert is the inconsistency. The b12x kernel performs dynamic
per-block FC2-input quantization internally, so a uniform 1.0 scale per
expert is semantically equivalent to the bake-in done for static-quant
checkpoints. Construct the default in apply() instead of asserting.

Signed-off-by: Junhao Shen <junshen@nvidia.com>
Extends FlashInferB12xExperts to handle both:
  - W4A4 NVFP4 from modelopt checkpoints (existing path)
  - W4A16 NVFP4 from compressed-tensors `nvfp4-pack-quantized`

Cherry-picked from PR vllm-project#43341 (vllm-project#43341)
onto current main + vllm-project#43332. Conflicts in modelopt.py, oracle/nvfp4.py,
compressed_tensors_moe/compressed_tensors_moe_w4a4_nvfp4.py,
fused_moe/config.py, vocab_parallel_embedding.py, and
flashinfer_b12x_moe.py were resolved to take vllm-project#43341's intent while keeping
unrelated upstream changes intact (gemm1_clamp_limit / swiglu_limit).

Co-authored-by: Claude
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Cherry-picked from PR vllm-project#43341 (vllm-project#43341).

Signed-off-by: Andrii Skliar <askliar@nvidia.com>
…6_moe_quant_config

The W4A16 plumbing added in vllm-project#43332/vllm-project#43341 makes FlashInferB12xExperts
accept (kNvfp4Static, None) checkpoints, but the upstream activation
scales loaded by ModelOptNvFp4FusedMoE and CompressedTensorsW4A4Nvfp4MoEMethod
still flow into make_nvfp4_moe_quant_config as a13_scale/a2_scale. The
function then either (modelopt) silently dispatches the W4A4 branch with
real calibrated activation scales — running W4A4 on a W4A16-labeled
checkpoint — or (compressed-tensors) divides by uninitialized
torch.empty() memory and trips FlashInfer's source_format/quant_mode
mismatch check kernel-side.

Changes:
- make_nvfp4_moe_quant_config: route to nvfp4_w4a16_moe_quant_config
  when a13_scale is None, regardless of backend (was MARLIN-only).
  Type hints widened to Tensor | None on a13_scale/a2_scale.
- ModelOptNvFp4FusedMoE and CompressedTensorsW4A4Nvfp4MoEMethod: when
  self.use_a16, pass a13_scale=None / a2_scale=None at both call sites.
- prepare_nvfp4_moe_layer_for_fi_or_cutlass: skip .max() reductions on
  None activation scales; type hints widened.
- FlashInferB12xExperts.__init__: assert on weight_quant_dtype ('nvfp4'
  for both W4A4 and W4A16) rather than quant_dtype.

Tests:
- tests/kernels/moe/test_w4a16_propagation.py — no-GPU unit test.

Empirical verification with FLASHINFER_LOGLEVEL=3 on top of vllm-project#43332+vllm-project#43341:
40/40 B12xMoEWrapper.__init__ calls report activation_precision='bf16'
on both nvidia/Qwen3.6-35B-A3B-2.06GB-per-token (modelopt W4A16_NVFP4)
and nvidia/Qwen3.6-35B-A3B-2.06GB-per-token-CT (compressed-tensors W4A16).

Depends on: vllm-project#43332, vllm-project#43341.

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
@meena-at-work meena-at-work changed the title Meenakshiv/b12x nvfp4 w4a16 use a16 fix b12x nvfp4 w4a16 use a16 fix May 28, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jun 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @meena-at-work.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants