Skip to content

Integrate flashinfer b12x MoE and FP4 GEMM kernels for SM120/121#40082

Open
meena-at-work wants to merge 3 commits intovllm-project:mainfrom
meena-at-work:integrate-flashinfer-b12x-moe
Open

Integrate flashinfer b12x MoE and FP4 GEMM kernels for SM120/121#40082
meena-at-work wants to merge 3 commits intovllm-project:mainfrom
meena-at-work:integrate-flashinfer-b12x-moe

Conversation

@meena-at-work
Copy link
Copy Markdown

@meena-at-work meena-at-work commented Apr 17, 2026

Summary

Adds two new FlashInfer b12x backends targeting SM120/SM121 GPUs (DGX Spark GB10, RTX Pro 6000 Blackwell):

1. b12x fused MoE backend ()

Uses FlashInfer's cute_dsl_fused_moe_nvfp4 kernel (flashinfer-ai/flashinfer#3066) to accelerate NVFP4 MoE on SM120/SM121. The kernel fuses token dispatch, W1 GEMM, SwiGLU, and W2 GEMM into a single call; BF16 hidden states are passed directly as activation quantization is fused internally.

  • experts/flashinfer_cutedsl_moe.py — new FlashInferCuteDSLSM12xExperts class; process_weights_after_loading normalises FP4 block scales and resets a2_gscale to 1.0 (SM12x uses dynamic per-block quantisation for FC2 input; the calibrated modelopt value causes saturation)
  • vllm/config/kernel.py — adds flashinfer_cutedsl_sm12x to MoEBackend Literal
  • vllm/utils/flashinfer.py — adds has_flashinfer_cutedsl_sm12x_moe() capability check
  • oracle/nvfp4.py — oracle auto-selects the SM12x backend on SM120-family GPUs
  • flashinfer_fp4_moe.py — weight-prep path handles the SM12x backend
  • tests/kernels/moe/test_cutedsl_sm12x_moe.py — 24 unit tests (skipped on non-SM12x hardware)

2. b12x FP4 dense GEMM backend (flashinfer-b12x)

Integrates FlashInfer PR flashinfer-ai/flashinfer#3051 b12x dense GEMM backend into the NVFP4 linear layer path. b12x uses CuTe DSL warp-level MMA with adaptive tile sizing to improve SM utilization on small-M decode shapes.

  • has_flashinfer_b12x_gemm() — availability check via Sm120BlockScaledDenseGemmKernel
  • FlashInferB12xNvFp4LinearKernel — new NvFp4LinearKernel subclass; auto-selected on SM120/SM121, falls back to FLASHINFER_CUTLASS when unavailable
  • Adds flashinfer-b12x to VLLM_NVFP4_GEMM_BACKEND valid choices
  • b12x test cases in test_flashinfer_nvfp4_scaled_mm.py

Hardware requirements

Test plan

  • pytest tests/kernels/moe/test_cutedsl_sm12x_moe.py on SM120/SM121 hardware (24/24 passing)
  • pytest tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py -k b12x on SM121 hardware (36/36 passing)
  • Full-model throughput on DGX Spark (SM121, Qwen3-30B-A3B-NVFP4):
    • b12x GEMM vs flashinfer-cutlass: +1.8% (1P), +6.0% (8P)
  • Non-SM12x hardware: oracle does not select these backends (falls through to existing path)

@mergify mergify bot added the nvidia label Apr 17, 2026
@meena-at-work meena-at-work changed the title Integrate flashinfer b12x moe Integrate flashinfer b12x MoE kernel for SM120/121 Apr 17, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the FlashInfer CuteDSL fused MoE kernel on SM12x architectures, adding the FlashInferCuteDSLSM12xExperts class and necessary backend configurations. Feedback identifies several issues: in-place modification of a2_gscale causes permanent side effects on model parameters, and performing MMA layout conversions during every forward pass adds unnecessary overhead. Additionally, suggestions were made to ensure correct data types for activation scale placeholders and routing weights when interfacing with the FlashInfer kernel.

# multiplied by values that large. Force to 1.0 so the kernel uses
# its own per-block dynamic scale — matching the unit-test convention.
if self.a2_gscale is not None:
self.a2_gscale.fill_(1.0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Modifying self.a2_gscale in-place using fill_(1.0) has a permanent side effect on the model's quantization parameters in memory. If the model is subsequently used with a different MoE backend (e.g., during profiling, testing, or if the backend is changed via configuration), the results will be incorrect because the original calibrated scales have been destroyed. It is safer to pass a tensor of ones directly to the kernel in the apply method or create a local copy if modification is necessary.

Comment on lines +510 to +524
w1_sf_mma = flashinfer_convert_sf_to_mma_layout(
self.w1_scale.reshape(num_experts_w1 * m1, k1_sf),
m=m1,
k=k1,
num_groups=num_experts_w1,
)

num_experts_w2, m2, k2_sf = self.w2_scale.shape
k2 = k2_sf * 16
w2_sf_mma = flashinfer_convert_sf_to_mma_layout(
self.w2_scale.reshape(num_experts_w2 * m2, k2_sf),
m=m2,
k=k2,
num_groups=num_experts_w2,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Converting scale factors to MMA layout on every forward pass introduces unnecessary overhead. Since the dimensions m (intermediate size) and k (hidden dimension) are fixed for the layer, this conversion should be performed once during weight initialization or in process_weights_after_loading. Even if the conversion is a zero-copy view, it is more efficient to pre-compute these views and store them as attributes of the expert class to avoid redundant metadata creation and function calls in the hot path.

x_sf_placeholder = (
a1q_scale
if a1q_scale is not None
else hidden_states.new_zeros(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The x_sf_placeholder should use the correct data type expected by the FlashInfer kernel. The existing FlashInferCuteDSLExperts class (line 157) uses float8_e4m3fn for activation scales. Using hidden_states.new_zeros(1) will create a bfloat16 tensor, which may cause a type mismatch or incorrect behavior in the kernel even if the values are ignored.

Suggested change
else hidden_states.new_zeros(1)
else hidden_states.new_zeros(1, dtype=torch.float8_e4m3fn)

x=hidden_states,
x_sf=x_sf_placeholder,
token_selected_experts=topk_ids.to(torch.int32),
token_final_scales=topk_weights,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The token_final_scales argument should be converted to float32 before being passed to the FlashInfer kernel. The existing FlashInferCuteDSLExperts class (line 166) explicitly performs this conversion, and FlashInfer fused MoE kernels typically expect routing weights in float32 for accumulation precision. Passing topk_weights directly (which is bfloat16) is likely to cause type errors or precision issues.

Suggested change
token_final_scales=topk_weights,
token_final_scales=topk_weights.float(),

meena-at-work and others added 2 commits April 17, 2026 22:59
Adds FlashInferCuteDSLSM12xExperts targeting SM120/SM121 (RTX Pro
6000 / DGX Spark) using cute_dsl_fused_moe_nvfp4 from FlashInfer
PRs vllm-project#3051 and vllm-project#3066. The kernel fuses token dispatch, W1 GEMM, SwiGLU,
and W2 GEMM into a single call; BF16 hidden states are passed directly
as activation quantization is fused internally.

- vllm/utils/flashinfer.py: lazy import wrappers for
  cute_dsl_fused_moe_nvfp4 and convert_sf_to_mma_layout; adds
  has_flashinfer_cutedsl_sm12x_moe() availability probe
- experts/flashinfer_cutedsl_moe.py: FlashInferCuteDSLSM12xExperts
  with TODO to adopt plan/run() API from PR vllm-project#3066
- oracle/nvfp4.py: FLASHINFER_CUTEDSL_SM12X backend enum and routing;
  falls back to FLASHINFER_CUTLASS on SM12x when PRs are absent
- flashinfer_fp4_moe.py: SM12X added to FI weight-prep path and
  w1/w3 → w3/w1 reorder list
- tests/kernels/moe/test_cutedsl_sm12x_moe.py: correctness tests vs
  BF16 torch reference; module-level skip when SM120 hw or FlashInfer
  PRs are absent

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Integrates FlashInfer PR vllm-project#3051 b12x dense GEMM backend into the NVFP4
linear layer path. b12x uses CuTe DSL warp-level MMA with adaptive tile
sizing to improve SM utilization on small-M decode shapes.

Changes:
- has_flashinfer_b12x_gemm(): availability check via Sm120BlockScaledDenseGemmKernel
- FlashInferB12xNvFp4LinearKernel: new NvFp4LinearKernel subclass
- Auto-selects b12x on SM120/SM121 (has_device_capability(120)), falls
  back to FLASHINFER_CUTLASS when unavailable
- Adds "flashinfer-b12x" to VLLM_NVFP4_GEMM_BACKEND valid choices
- b12x test cases in test_flashinfer_nvfp4_scaled_mm.py

Measured on DGX Spark (SM121, Qwen3-30B-A3B-NVFP4, same MoE backend):
  b12x:               71.81 out tok/s (1P), 229.24 (8P)
  flashinfer-cutlass:  70.52 out tok/s (1P), 216.28 (8P)
  (+1.8% 1P, +6.0% 8P)

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@meena-at-work meena-at-work force-pushed the integrate-flashinfer-b12x-moe branch from 23ec850 to 6473f3b Compare April 17, 2026 23:04
@meena-at-work meena-at-work changed the title Integrate flashinfer b12x MoE kernel for SM120/121 Integrate flashinfer b12x MoE and FP4 GEMM kernels for SM120/121 Apr 17, 2026
- Preserve a2_gscale; pass torch.ones_like(a2_gscale) to kernel instead
  of fill_(1.0) which destroyed the calibrated value in-place
- Precompute w1_sf_mma/w2_sf_mma in process_weights_after_loading
  instead of converting on every forward pass
- Fix x_sf_placeholder dtype: float8_e4m3fn (was bfloat16)
- Pass topk_weights.float() for float32 routing weights as kernel expects

Signed-off-by: Meenakshi Venkataraman <meenakshiv@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

1 participant