[Bug]: Fix FlashInfer CUTLASS BF16 + CUDA graphs IMA#39593
Conversation
There was a problem hiding this comment.
Code Review
This pull request enhances the robustness of MoE routing kernels by clamping input values to -FLT_MAX in the CUDA implementation, preventing issues when encountering NaNs or extreme values. Additionally, it introduces new test cases and helper functions to verify that routing logic correctly handles all-NaN rows by selecting distinct expert IDs for both fused and grouped top-k routing. I have no feedback to provide as there were no review comments to evaluate.
Signed-off-by: Yifan <yzong@redhat.com>
yewentao256
left a comment
There was a problem hiding this comment.
Thanks for the work! Could you also benchmark e2e performance using vllm bench ...?
|
Added vLLM serve latency test on B200 |
We usually benchmark with |
|
Updated benchmark results to use |
|
Other failures due to infra issues. |
|
#39391 fixes the same issue |
Purpose
Fix illegal memory access crash in FlashInfer CUTLASS BF16 MoE with DP/DEP + CUDA graphs.
Closes #37758
Root cause: When CUDA graphs are enabled under data parallelism, all ranks pad
to the same batch size. Attention often produces NaN for padding tokens. In BF16,
NaN propagates through RMSNorm, projections, and residual adds until it reaches the
MoE router. The topK kernel picks the same expert index K times when encountering
NaN values. FlashInfer Cutlass MoE kernels do not expect duplicate expert indices
per token — the resulting out-of-bounds access causes a hard CUDA error.
NaNs in padding tokens post-attention
Instrumentation of attention output (Qwen3-30B-A3B BF16, piecewise CUDA graphs,
FlashInfer CUTLASS MoE, GSM8K eval) confirms NaN is present in padding rows on virtually every
padded step:
DP/DEP produce ~13× more padding events with much larger pad regions than single-GPU
bucket rounding, which is why the bug manifests under DP but not single GPU.
Flashinfer CUTLASS kernel IMA with duplicate expert indices
cutlass_fp8_bf16_simple.pyshows that calling FlashInfer'scutlass_fused_moein eagermode with all tokens routed to the same expert (simulating the duplicate-ID condition)
crashes with an illegal memory access on both BF16 and FP8 paths.
cutlass_fp8_bf16_simple.py
FP8 was immune because the FP8 quantization kernel (
scaled_fp8_conversion)uses
fminf/fmaxf, which silently clamp NaN to ±448.0, acting as an implicit NaNfirewall at every linear layer.
Fix: Sanitize NaN in the fused softmax/sigmoid+topk CUDA kernel
(
csrc/moe/topk_softmax_kernels.cu). This is applied in three places:moeSoftmax(fallback path for non-power-of-2 expert counts): NaN isreplaced with
-FLT_MAXin the max-reduce, exp-sum, and final softmax loops.moeSigmoid(fallback sigmoid path): NaN is replaced before the sigmoidcomputation.
topkGating(fused warp-level path for power-of-2 / multiple-of-64 expertcounts): NaN is replaced in-register immediately after loading, before
softmax/sigmoid and argmax.
topK now returns K distinct indices, never duplicates.
Test Plan
tests/kernels/moe/test_routing.pytest_topk_nan_row_distinct_expertsandtest_grouped_topk_nan_row_distinct_experts, which inject all-NaN rows intotop-k. They verify that NaN rows return distinct values and that non NaN rows
match the baseline.
End-to-end GSM8K correctness (Qwen3-30B-A3B BF16 and Qwen3.5-35B-A3B BF16,
FlashInfer CUTLASS MoE, piecewise CUDA graphs):
Micro-benchmark (
benchmark_router_select_experts.py, 128 tokens × 128 experts,top_k=8, bfloat16, 10% NaNs, B200, mirrors Qwen3-30B):
The NaN check is fused into the existing kernel rather than added as a
separate
torch.nan_to_num(-float('inf'))call before routing.A separate
nan_to_numlaunch adds ~30% overhead:torch.nan_to_numtorch.nan_to_numbenchmark_router_select_experts.py
Results within run to run variance
Before fix:
After fix:
torch.nan_to_num
Test Result
See above.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)