[B12x][NVFP4] Add W4A16 support to FlashInfer SM12x MoE wrapper#43341
[B12x][NVFP4] Add W4A16 support to FlashInfer SM12x MoE wrapper#43341askliar wants to merge 2 commits into
Conversation
Extends FlashInferB12xExperts to handle both:
- W4A4 NVFP4 from modelopt checkpoints (existing path)
- W4A16 NVFP4 from compressed-tensors `nvfp4-pack-quantized`
Key changes
- FlashInferB12xExperts:
* activation_precision plumbing ("fp4" / "bf16") auto-detected from
quant_config.a1_gscale; no env-var flag.
* _supports_quant_scheme accepts the compressed-tensors W4A16 key
shape (uint8 weights + kNvfp4StaticGroupScale + kStaticTensorScale,
no activation key).
* source_format ("modelopt" / "compressed_tensors") read from
FusedMoEQuantConfig and forwarded to B12xMoEWrapper.
* Precomputed MMA-layout views (w1_sf_mma / w2_sf_mma) cached in
process_weights_after_loading.
- FusedMoEQuantConfig: new source_format field and plumbing through
nvfp4_moe_quant_config / nvfp4_w4a16_moe_quant_config /
make_nvfp4_moe_quant_config; ModelOptNvFp4FusedMoE and
CompressedTensorsW4A4Nvfp4MoEMethod set it explicitly.
- prepare_nvfp4_moe_layer_for_fi_or_cutlass: bump
intermediate_size_per_partition when pad is applied.
- prepare_fp4_layer_for_marlin: fall back to torch.get_default_dtype()
when layer.params_dtype is absent (ParallelLMHead / VocabParallelEmbedding).
- VocabParallelEmbedding.weight_loader: reshape (not assert) when
AutoQuantize emits scalar FP4 scales (shape () vs (1,)).
- ModelOptMixedPrecisionConfig.get_quant_method: route ParallelLMHead
through the FP8/NVFP4 linear methods (previously fell through to
unquantized).
- FlashInfer bumped to 0.6.11.post3.
Test plan
- pytest tests/kernels/moe/test_flashinfer_b12x_moe.py -v on SM120/SM121.
AI assistance was used for refactoring and commit-message drafting;
every changed line was reviewed locally by the submitter.
Co-authored-by: Claude
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request updates FlashInfer to version 0.6.11.post3 and transitions the FlashInfer MoE implementation to use the new B12xMoEWrapper class. It introduces support for the RELU2_NO_MUL activation function used in Nemotron-H models and adds a source_format field to FusedMoEQuantConfig to distinguish between different quantization providers like ModelOpt and compressed-tensors. Additionally, it enables ModelOpt quantization for the ParallelLMHead. Feedback focuses on performance and correctness: the apply method in the FlashInfer expert should pass the output tensor directly to avoid unnecessary copies, the in-place modification of moe_config padding is not idempotent and could lead to shape accumulation across layers, and ParallelLMHead support should be extended to the base ModelOpt quantization class to ensure consistency.
| result = self._wrapper.run( | ||
| x=hidden_states, | ||
| token_selected_experts=topk_ids.to(torch.int32), | ||
| token_final_scales=topk_weights, | ||
| w1_weight=w1, | ||
| w1_weight_sf=self.w1_sf_mma, | ||
| w1_alpha=self.g1_alphas, | ||
| fc2_input_scale=self.a2_gscale, | ||
| w2_weight=w2, | ||
| w2_weight_sf=self.w2_sf_mma, | ||
| w2_alpha=self.g2_alphas, | ||
| num_experts=global_num_experts, | ||
| top_k=top_k, | ||
| num_local_experts=self.num_local_experts, | ||
| output_dtype=self.out_dtype, | ||
| output=output, | ||
| token_selected_experts=topk_ids.to(torch.int32), | ||
| token_final_scales=topk_weights, | ||
| ) | ||
| output.copy_(result) |
There was a problem hiding this comment.
The current implementation of apply performs an unnecessary tensor allocation and copy by not passing the output tensor directly to the B12xMoEWrapper.run method. FlashInfer's run method supports an out parameter which allows writing the results directly into the pre-allocated output tensor, significantly improving efficiency in the MoE hot path by avoiding extra memory operations.
| result = self._wrapper.run( | |
| x=hidden_states, | |
| token_selected_experts=topk_ids.to(torch.int32), | |
| token_final_scales=topk_weights, | |
| w1_weight=w1, | |
| w1_weight_sf=self.w1_sf_mma, | |
| w1_alpha=self.g1_alphas, | |
| fc2_input_scale=self.a2_gscale, | |
| w2_weight=w2, | |
| w2_weight_sf=self.w2_sf_mma, | |
| w2_alpha=self.g2_alphas, | |
| num_experts=global_num_experts, | |
| top_k=top_k, | |
| num_local_experts=self.num_local_experts, | |
| output_dtype=self.out_dtype, | |
| output=output, | |
| token_selected_experts=topk_ids.to(torch.int32), | |
| token_final_scales=topk_weights, | |
| ) | |
| output.copy_(result) | |
| self._wrapper.run( | |
| x=hidden_states, | |
| w1_weight=w1, | |
| w1_weight_sf=self.w1_sf_mma, | |
| w1_alpha=self.g1_alphas, | |
| fc2_input_scale=self.a2_gscale, | |
| w2_weight=w2, | |
| w2_weight_sf=self.w2_sf_mma, | |
| w2_alpha=self.g2_alphas, | |
| token_selected_experts=topk_ids.to(torch.int32), | |
| token_final_scales=topk_weights, | |
| out=output, | |
| ) |
| if isinstance(layer, ParallelLMHead): | ||
| if quant_algo == "FP8": | ||
| return ModelOptFp8LinearMethod(self.fp8_config) | ||
| if quant_algo == "NVFP4": | ||
| return ModelOptNvFp4LinearMethod(self.nvfp4_config) | ||
| return UnquantizedLinearMethod() |
There was a problem hiding this comment.
While ParallelLMHead support was added here for mixed precision, it appears to be missing in the base ModelOptQuantConfigBase.get_quant_method. This means ParallelLMHead will remain unquantized when using pure ModelOptFp8Config or ModelOptNvFp4Config (which inherit from the base class), even if the checkpoint contains quantized weights for the LM head. Consider adding this check to the base class as well to ensure consistent behavior across all ModelOpt configurations.
| layer.moe_config.intermediate_size_per_partition = ( | ||
| layer.moe_config.intermediate_size_per_partition + pad_size | ||
| ) |
There was a problem hiding this comment.
Modifying layer.moe_config.intermediate_size_per_partition in-place is risky because the moe_config object is often shared across all MoE layers in a model. If multiple layers require padding, this addition will accumulate on every call to prepare_nvfp4_moe_layer_for_fi_or_cutlass, leading to incorrect shapes and potential crashes in subsequent layers. This logic should be made idempotent (e.g., by checking if padding was already applied) or the padded size should be handled locally within the expert implementation.
Adds activation_precision in {"fp4", "bf16"} parametrization to both
existing FlashInferB12xExperts tests:
- test_flashinfer_b12x_moe (SiLU + Mul)
- test_flashinfer_b12x_moe_relu2 (ReLU2 + non-gated, Nemotron-H 3.5)
For the W4A16 branch, a1_gscale and a2_gscale are passed as None and
source_format is set to "compressed_tensors", matching the production
path used by compressed-tensors `nvfp4-pack-quantized` checkpoints.
Each test now also asserts that FlashInferB12xExperts auto-detects
activation_precision from a1_gscale and reads source_format off the
quant config, locking in the contract added in the parent PR.
Co-authored-by: Claude
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
|
Hi @askliar, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Adds W4A16 NVFP4 support to the FlashInfer SM12x B12x MoE wrapper (
FlashInferB12xExperts) for compressed-tensorsnvfp4-pack-quantizedcheckpoints, plus the supporting plumbing the kernel needs.This is part 1/3 of splitting #43333 (the combined B12x + Nemotron-3.5 + Qwen3.5 PR). Parts 2/3 and 3/3 are filed separately so each can be reviewed independently.
Key changes
FlashInferB12xExperts(vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py)activation_precisionplumbing ("fp4"/"bf16") auto-detected fromquant_config.a1_gscale; no env-var flag._supports_quant_schemeaccepts the compressed-tensors W4A16 NVFP4 key shape (uint8weights +kNvfp4StaticGroupScale+kStaticTensorScale, no activation key).source_format("modelopt"/"compressed_tensors") read fromFusedMoEQuantConfigand forwarded toB12xMoEWrapperso the kernel can interpret the FP4 byte payload correctly.w1_sf_mma/w2_sf_mma) precomputed inprocess_weights_after_loading.FusedMoEQuantConfig(vllm/model_executor/layers/fused_moe/config.py,oracle/nvfp4.py)source_formatfield, plumbed throughnvfp4_moe_quant_config/nvfp4_w4a16_moe_quant_config/make_nvfp4_moe_quant_config.ModelOptNvFp4FusedMoEsets"modelopt";CompressedTensorsW4A4Nvfp4MoEMethodsets"compressed_tensors".prepare_nvfp4_moe_layer_for_fi_or_cutlass: bumpintermediate_size_per_partitionwhen pad is applied so downstream shape consumers see the padded dimension.prepare_fp4_layer_for_marlin: fall back totorch.get_default_dtype()whenlayer.params_dtypeis absent (ParallelLMHead/VocabParallelEmbedding).VocabParallelEmbedding.weight_loader: reshape (not assert) when AutoQuantize emits scalar FP4 scales (shape()on disk vs(1,)fromPerTensorScaleParameter).ModelOptMixedPrecisionConfig.get_quant_method: routeParallelLMHeadthrough the FP8/NVFP4 linear methods (previously fell through to unquantized).FlashInfer bumped to
0.6.11.post3(docker/versions.json,requirements/cuda.txt).Relation to existing open PRs (overlap disclosure)
This PR overlaps with several open PRs. Posting anyway as part of the #43333 split — reviewers may want to coordinate landing order:
Bump flashinfer to v0.6.11.post3) — Same version bump indocker/versions.jsonandrequirements/cuda.txt; [CI/Build] Bump flashinfer to v0.6.11.post3 #43251 additionally bumps the two Dockerfiles. If [CI/Build] Bump flashinfer to v0.6.11.post3 #43251 lands first, I'll drop the version-bump hunks here.Accept W4A16 (kNvfp4Static, None)for B12x) — Different W4A16 scheme. [MoE/b12x] Accept W4A16 (kNvfp4Static, None) in FlashInferB12xExperts supports check #43332 accepts(kNvfp4Static, None)for modelopt-native W4A16; this PR accepts the compressed-tensorsnvfp4-pack-quantizedkey shape (uint8weights +kNvfp4StaticGroupScale+kStaticTensorScale). Both could land in any order; together they cover both W4A16 sources.Add LM head quantization support for ModelOpt) — Overlaps on theisinstance(layer, ParallelLMHead)routing inmodelopt.pyand thevocab_parallel_embedding.pyscalar-scale reshape. If Add LM head quantization support for ModelOpt #42124 lands first, those two hunks become redundant and I'll rebase.Support native W4A16 NVFP4 checkpoints) — Samemodelopt.py+vocab_parallel_embedding.pyoverlap as Add LM head quantization support for ModelOpt #42124, plus W4A16 backend selection logic that this PR does not duplicate.Unique to this PR (not covered by any of the above):
flashinfer_b12x_moe.py(activation_precision, source_format plumbing, MMA-view caching, W4A16 compressed-tensors scheme acceptance).FusedMoEQuantConfig.source_formatfield and its propagation throughnvfp4_moe_quant_config/nvfp4_w4a16_moe_quant_config/make_nvfp4_moe_quant_config.intermediate_size_per_partitionpadding bump.prepare_fp4_layer_for_marlinparams_dtypefallback.Test Plan
.venv/bin/python -m pytest tests/kernels/moe/test_flashinfer_b12x_moe.py -von SM120 / SM121.nvfp4-pack-quantizedW4A16 on RTX Pro 6000 / DGX Spark with--moe-backend=flashinfer_b12x.AI assistance
This change was developed with AI assistance (Claude). Every changed line was reviewed locally by the submitter.
🤖 Generated with Claude Code