Skip to content

[B12x][NVFP4] Add W4A16 support to FlashInfer SM12x MoE wrapper#43341

Open
askliar wants to merge 2 commits into
vllm-project:mainfrom
askliar:feat/b12x_w4a16
Open

[B12x][NVFP4] Add W4A16 support to FlashInfer SM12x MoE wrapper#43341
askliar wants to merge 2 commits into
vllm-project:mainfrom
askliar:feat/b12x_w4a16

Conversation

@askliar
Copy link
Copy Markdown
Contributor

@askliar askliar commented May 21, 2026

Purpose

Adds W4A16 NVFP4 support to the FlashInfer SM12x B12x MoE wrapper (FlashInferB12xExperts) for compressed-tensors nvfp4-pack-quantized checkpoints, plus the supporting plumbing the kernel needs.

This is part 1/3 of splitting #43333 (the combined B12x + Nemotron-3.5 + Qwen3.5 PR). Parts 2/3 and 3/3 are filed separately so each can be reviewed independently.

Key changes

  • FlashInferB12xExperts (vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py)

    • activation_precision plumbing ("fp4" / "bf16") auto-detected from quant_config.a1_gscale; no env-var flag.
    • _supports_quant_scheme accepts the compressed-tensors W4A16 NVFP4 key shape (uint8 weights + kNvfp4StaticGroupScale + kStaticTensorScale, no activation key).
    • source_format ("modelopt" / "compressed_tensors") read from FusedMoEQuantConfig and forwarded to B12xMoEWrapper so the kernel can interpret the FP4 byte payload correctly.
    • MMA-layout views (w1_sf_mma / w2_sf_mma) precomputed in process_weights_after_loading.
  • FusedMoEQuantConfig (vllm/model_executor/layers/fused_moe/config.py, oracle/nvfp4.py)

    • New source_format field, plumbed through nvfp4_moe_quant_config / nvfp4_w4a16_moe_quant_config / make_nvfp4_moe_quant_config. ModelOptNvFp4FusedMoE sets "modelopt"; CompressedTensorsW4A4Nvfp4MoEMethod sets "compressed_tensors".
  • prepare_nvfp4_moe_layer_for_fi_or_cutlass: bump intermediate_size_per_partition when pad is applied so downstream shape consumers see the padded dimension.

  • prepare_fp4_layer_for_marlin: fall back to torch.get_default_dtype() when layer.params_dtype is absent (ParallelLMHead / VocabParallelEmbedding).

  • VocabParallelEmbedding.weight_loader: reshape (not assert) when AutoQuantize emits scalar FP4 scales (shape () on disk vs (1,) from PerTensorScaleParameter).

  • ModelOptMixedPrecisionConfig.get_quant_method: route ParallelLMHead through the FP8/NVFP4 linear methods (previously fell through to unquantized).

  • FlashInfer bumped to 0.6.11.post3 (docker/versions.json, requirements/cuda.txt).

Relation to existing open PRs (overlap disclosure)

This PR overlaps with several open PRs. Posting anyway as part of the #43333 split — reviewers may want to coordinate landing order:

Unique to this PR (not covered by any of the above):

  • All B12x kernel-side changes in flashinfer_b12x_moe.py (activation_precision, source_format plumbing, MMA-view caching, W4A16 compressed-tensors scheme acceptance).
  • FusedMoEQuantConfig.source_format field and its propagation through nvfp4_moe_quant_config / nvfp4_w4a16_moe_quant_config / make_nvfp4_moe_quant_config.
  • MoE intermediate_size_per_partition padding bump.
  • prepare_fp4_layer_for_marlin params_dtype fallback.

Test Plan

  • .venv/bin/python -m pytest tests/kernels/moe/test_flashinfer_b12x_moe.py -v on SM120 / SM121.
  • End-to-end serve Nemotron-H 3.5 compressed-tensors nvfp4-pack-quantized W4A16 on RTX Pro 6000 / DGX Spark with --moe-backend=flashinfer_b12x.

Fill in actual test results before review per AGENTS.md.

AI assistance

This change was developed with AI assistance (Claude). Every changed line was reviewed locally by the submitter.

🤖 Generated with Claude Code

Extends FlashInferB12xExperts to handle both:
  - W4A4 NVFP4 from modelopt checkpoints (existing path)
  - W4A16 NVFP4 from compressed-tensors `nvfp4-pack-quantized`

Key changes
- FlashInferB12xExperts:
  * activation_precision plumbing ("fp4" / "bf16") auto-detected from
    quant_config.a1_gscale; no env-var flag.
  * _supports_quant_scheme accepts the compressed-tensors W4A16 key
    shape (uint8 weights + kNvfp4StaticGroupScale + kStaticTensorScale,
    no activation key).
  * source_format ("modelopt" / "compressed_tensors") read from
    FusedMoEQuantConfig and forwarded to B12xMoEWrapper.
  * Precomputed MMA-layout views (w1_sf_mma / w2_sf_mma) cached in
    process_weights_after_loading.

- FusedMoEQuantConfig: new source_format field and plumbing through
  nvfp4_moe_quant_config / nvfp4_w4a16_moe_quant_config /
  make_nvfp4_moe_quant_config; ModelOptNvFp4FusedMoE and
  CompressedTensorsW4A4Nvfp4MoEMethod set it explicitly.

- prepare_nvfp4_moe_layer_for_fi_or_cutlass: bump
  intermediate_size_per_partition when pad is applied.

- prepare_fp4_layer_for_marlin: fall back to torch.get_default_dtype()
  when layer.params_dtype is absent (ParallelLMHead / VocabParallelEmbedding).

- VocabParallelEmbedding.weight_loader: reshape (not assert) when
  AutoQuantize emits scalar FP4 scales (shape () vs (1,)).

- ModelOptMixedPrecisionConfig.get_quant_method: route ParallelLMHead
  through the FP8/NVFP4 linear methods (previously fell through to
  unquantized).

- FlashInfer bumped to 0.6.11.post3.

Test plan
- pytest tests/kernels/moe/test_flashinfer_b12x_moe.py -v on SM120/SM121.

AI assistance was used for refactoring and commit-message drafting;
every changed line was reviewed locally by the submitter.

Co-authored-by: Claude
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates FlashInfer to version 0.6.11.post3 and transitions the FlashInfer MoE implementation to use the new B12xMoEWrapper class. It introduces support for the RELU2_NO_MUL activation function used in Nemotron-H models and adds a source_format field to FusedMoEQuantConfig to distinguish between different quantization providers like ModelOpt and compressed-tensors. Additionally, it enables ModelOpt quantization for the ParallelLMHead. Feedback focuses on performance and correctness: the apply method in the FlashInfer expert should pass the output tensor directly to avoid unnecessary copies, the in-place modification of moe_config padding is not idempotent and could lead to shape accumulation across layers, and ParallelLMHead support should be extended to the base ModelOpt quantization class to ensure consistency.

Comment on lines +282 to +294
result = self._wrapper.run(
x=hidden_states,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
num_experts=global_num_experts,
top_k=top_k,
num_local_experts=self.num_local_experts,
output_dtype=self.out_dtype,
output=output,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
)
output.copy_(result)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of apply performs an unnecessary tensor allocation and copy by not passing the output tensor directly to the B12xMoEWrapper.run method. FlashInfer's run method supports an out parameter which allows writing the results directly into the pre-allocated output tensor, significantly improving efficiency in the MoE hot path by avoiding extra memory operations.

Suggested change
result = self._wrapper.run(
x=hidden_states,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
num_experts=global_num_experts,
top_k=top_k,
num_local_experts=self.num_local_experts,
output_dtype=self.out_dtype,
output=output,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
)
output.copy_(result)
self._wrapper.run(
x=hidden_states,
w1_weight=w1,
w1_weight_sf=self.w1_sf_mma,
w1_alpha=self.g1_alphas,
fc2_input_scale=self.a2_gscale,
w2_weight=w2,
w2_weight_sf=self.w2_sf_mma,
w2_alpha=self.g2_alphas,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
out=output,
)

Comment on lines +2362 to +2367
if isinstance(layer, ParallelLMHead):
if quant_algo == "FP8":
return ModelOptFp8LinearMethod(self.fp8_config)
if quant_algo == "NVFP4":
return ModelOptNvFp4LinearMethod(self.nvfp4_config)
return UnquantizedLinearMethod()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While ParallelLMHead support was added here for mixed precision, it appears to be missing in the base ModelOptQuantConfigBase.get_quant_method. This means ParallelLMHead will remain unquantized when using pure ModelOptFp8Config or ModelOptNvFp4Config (which inherit from the base class), even if the checkpoint contains quantized weights for the LM head. Consider adding this check to the base class as well to ensure consistent behavior across all ModelOpt configurations.

Comment on lines +391 to +393
layer.moe_config.intermediate_size_per_partition = (
layer.moe_config.intermediate_size_per_partition + pad_size
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying layer.moe_config.intermediate_size_per_partition in-place is risky because the moe_config object is often shared across all MoE layers in a model. If multiple layers require padding, this addition will accumulate on every call to prepare_nvfp4_moe_layer_for_fi_or_cutlass, leading to incorrect shapes and potential crashes in subsequent layers. This logic should be made idempotent (e.g., by checking if padding was already applied) or the padded size should be handled locally within the expert implementation.

Adds activation_precision in {"fp4", "bf16"} parametrization to both
existing FlashInferB12xExperts tests:
  - test_flashinfer_b12x_moe (SiLU + Mul)
  - test_flashinfer_b12x_moe_relu2 (ReLU2 + non-gated, Nemotron-H 3.5)

For the W4A16 branch, a1_gscale and a2_gscale are passed as None and
source_format is set to "compressed_tensors", matching the production
path used by compressed-tensors `nvfp4-pack-quantized` checkpoints.

Each test now also asserts that FlashInferB12xExperts auto-detects
activation_precision from a1_gscale and reads source_format off the
quant config, locking in the contract added in the parent PR.

Co-authored-by: Claude
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 21, 2026

Hi @askliar, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 23, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @askliar.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant