Integrate flashinfer b12x MoE and FP4 GEMM kernels for SM120/121#40082
Integrate flashinfer b12x MoE and FP4 GEMM kernels for SM120/121#40082meena-at-work wants to merge 3 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the FlashInfer CuteDSL fused MoE kernel on SM12x architectures, adding the FlashInferCuteDSLSM12xExperts class and necessary backend configurations. Feedback identifies several issues: in-place modification of a2_gscale causes permanent side effects on model parameters, and performing MMA layout conversions during every forward pass adds unnecessary overhead. Additionally, suggestions were made to ensure correct data types for activation scale placeholders and routing weights when interfacing with the FlashInfer kernel.
| # multiplied by values that large. Force to 1.0 so the kernel uses | ||
| # its own per-block dynamic scale — matching the unit-test convention. | ||
| if self.a2_gscale is not None: | ||
| self.a2_gscale.fill_(1.0) |
There was a problem hiding this comment.
Modifying self.a2_gscale in-place using fill_(1.0) has a permanent side effect on the model's quantization parameters in memory. If the model is subsequently used with a different MoE backend (e.g., during profiling, testing, or if the backend is changed via configuration), the results will be incorrect because the original calibrated scales have been destroyed. It is safer to pass a tensor of ones directly to the kernel in the apply method or create a local copy if modification is necessary.
| w1_sf_mma = flashinfer_convert_sf_to_mma_layout( | ||
| self.w1_scale.reshape(num_experts_w1 * m1, k1_sf), | ||
| m=m1, | ||
| k=k1, | ||
| num_groups=num_experts_w1, | ||
| ) | ||
|
|
||
| num_experts_w2, m2, k2_sf = self.w2_scale.shape | ||
| k2 = k2_sf * 16 | ||
| w2_sf_mma = flashinfer_convert_sf_to_mma_layout( | ||
| self.w2_scale.reshape(num_experts_w2 * m2, k2_sf), | ||
| m=m2, | ||
| k=k2, | ||
| num_groups=num_experts_w2, | ||
| ) |
There was a problem hiding this comment.
Converting scale factors to MMA layout on every forward pass introduces unnecessary overhead. Since the dimensions m (intermediate size) and k (hidden dimension) are fixed for the layer, this conversion should be performed once during weight initialization or in process_weights_after_loading. Even if the conversion is a zero-copy view, it is more efficient to pre-compute these views and store them as attributes of the expert class to avoid redundant metadata creation and function calls in the hot path.
| x_sf_placeholder = ( | ||
| a1q_scale | ||
| if a1q_scale is not None | ||
| else hidden_states.new_zeros(1) |
There was a problem hiding this comment.
The x_sf_placeholder should use the correct data type expected by the FlashInfer kernel. The existing FlashInferCuteDSLExperts class (line 157) uses float8_e4m3fn for activation scales. Using hidden_states.new_zeros(1) will create a bfloat16 tensor, which may cause a type mismatch or incorrect behavior in the kernel even if the values are ignored.
| else hidden_states.new_zeros(1) | |
| else hidden_states.new_zeros(1, dtype=torch.float8_e4m3fn) |
| x=hidden_states, | ||
| x_sf=x_sf_placeholder, | ||
| token_selected_experts=topk_ids.to(torch.int32), | ||
| token_final_scales=topk_weights, |
There was a problem hiding this comment.
The token_final_scales argument should be converted to float32 before being passed to the FlashInfer kernel. The existing FlashInferCuteDSLExperts class (line 166) explicitly performs this conversion, and FlashInfer fused MoE kernels typically expect routing weights in float32 for accumulation precision. Passing topk_weights directly (which is bfloat16) is likely to cause type errors or precision issues.
| token_final_scales=topk_weights, | |
| token_final_scales=topk_weights.float(), |
Adds FlashInferCuteDSLSM12xExperts targeting SM120/SM121 (RTX Pro 6000 / DGX Spark) using cute_dsl_fused_moe_nvfp4 from FlashInfer PRs vllm-project#3051 and vllm-project#3066. The kernel fuses token dispatch, W1 GEMM, SwiGLU, and W2 GEMM into a single call; BF16 hidden states are passed directly as activation quantization is fused internally. - vllm/utils/flashinfer.py: lazy import wrappers for cute_dsl_fused_moe_nvfp4 and convert_sf_to_mma_layout; adds has_flashinfer_cutedsl_sm12x_moe() availability probe - experts/flashinfer_cutedsl_moe.py: FlashInferCuteDSLSM12xExperts with TODO to adopt plan/run() API from PR vllm-project#3066 - oracle/nvfp4.py: FLASHINFER_CUTEDSL_SM12X backend enum and routing; falls back to FLASHINFER_CUTLASS on SM12x when PRs are absent - flashinfer_fp4_moe.py: SM12X added to FI weight-prep path and w1/w3 → w3/w1 reorder list - tests/kernels/moe/test_cutedsl_sm12x_moe.py: correctness tests vs BF16 torch reference; module-level skip when SM120 hw or FlashInfer PRs are absent Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Integrates FlashInfer PR vllm-project#3051 b12x dense GEMM backend into the NVFP4 linear layer path. b12x uses CuTe DSL warp-level MMA with adaptive tile sizing to improve SM utilization on small-M decode shapes. Changes: - has_flashinfer_b12x_gemm(): availability check via Sm120BlockScaledDenseGemmKernel - FlashInferB12xNvFp4LinearKernel: new NvFp4LinearKernel subclass - Auto-selects b12x on SM120/SM121 (has_device_capability(120)), falls back to FLASHINFER_CUTLASS when unavailable - Adds "flashinfer-b12x" to VLLM_NVFP4_GEMM_BACKEND valid choices - b12x test cases in test_flashinfer_nvfp4_scaled_mm.py Measured on DGX Spark (SM121, Qwen3-30B-A3B-NVFP4, same MoE backend): b12x: 71.81 out tok/s (1P), 229.24 (8P) flashinfer-cutlass: 70.52 out tok/s (1P), 216.28 (8P) (+1.8% 1P, +6.0% 8P) Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
23ec850 to
6473f3b
Compare
- Preserve a2_gscale; pass torch.ones_like(a2_gscale) to kernel instead of fill_(1.0) which destroyed the calibrated value in-place - Precompute w1_sf_mma/w2_sf_mma in process_weights_after_loading instead of converting on every forward pass - Fix x_sf_placeholder dtype: float8_e4m3fn (was bfloat16) - Pass topk_weights.float() for float32 routing weights as kernel expects Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com> Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
Adds two new FlashInfer b12x backends targeting SM120/SM121 GPUs (DGX Spark GB10, RTX Pro 6000 Blackwell):
1. b12x fused MoE backend ()
Uses FlashInfer's
cute_dsl_fused_moe_nvfp4kernel (flashinfer-ai/flashinfer#3066) to accelerate NVFP4 MoE on SM120/SM121. The kernel fuses token dispatch, W1 GEMM, SwiGLU, and W2 GEMM into a single call; BF16 hidden states are passed directly as activation quantization is fused internally.experts/flashinfer_cutedsl_moe.py— newFlashInferCuteDSLSM12xExpertsclass;process_weights_after_loadingnormalises FP4 block scales and resetsa2_gscaleto 1.0 (SM12x uses dynamic per-block quantisation for FC2 input; the calibrated modelopt value causes saturation)vllm/config/kernel.py— addsflashinfer_cutedsl_sm12xtoMoEBackendLiteralvllm/utils/flashinfer.py— addshas_flashinfer_cutedsl_sm12x_moe()capability checkoracle/nvfp4.py— oracle auto-selects the SM12x backend on SM120-family GPUsflashinfer_fp4_moe.py— weight-prep path handles the SM12x backendtests/kernels/moe/test_cutedsl_sm12x_moe.py— 24 unit tests (skipped on non-SM12x hardware)2. b12x FP4 dense GEMM backend (
flashinfer-b12x)Integrates FlashInfer PR flashinfer-ai/flashinfer#3051 b12x dense GEMM backend into the NVFP4 linear layer path. b12x uses CuTe DSL warp-level MMA with adaptive tile sizing to improve SM utilization on small-M decode shapes.
has_flashinfer_b12x_gemm()— availability check viaSm120BlockScaledDenseGemmKernelFlashInferB12xNvFp4LinearKernel— newNvFp4LinearKernelsubclass; auto-selected on SM120/SM121, falls back toFLASHINFER_CUTLASSwhen unavailableflashinfer-b12xtoVLLM_NVFP4_GEMM_BACKENDvalid choicestest_flashinfer_nvfp4_scaled_mm.pyHardware requirements
Test plan
pytest tests/kernels/moe/test_cutedsl_sm12x_moe.pyon SM120/SM121 hardware (24/24 passing)pytest tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py -k b12xon SM121 hardware (36/36 passing)