Skip to content

[Bug]: Fix FlashInfer CUTLASS BF16 + CUDA graphs IMA#39593

Closed
yzong-rh wants to merge 9 commits into
vllm-project:mainfrom
yzong-rh:yzong-rh/routing_nan_fix
Closed

[Bug]: Fix FlashInfer CUTLASS BF16 + CUDA graphs IMA#39593
yzong-rh wants to merge 9 commits into
vllm-project:mainfrom
yzong-rh:yzong-rh/routing_nan_fix

Conversation

@yzong-rh
Copy link
Copy Markdown
Contributor

@yzong-rh yzong-rh commented Apr 12, 2026

Purpose

Fix illegal memory access crash in FlashInfer CUTLASS BF16 MoE with DP/DEP + CUDA graphs.

Closes #37758

Root cause: When CUDA graphs are enabled under data parallelism, all ranks pad
to the same batch size. Attention often produces NaN for padding tokens. In BF16,
NaN propagates through RMSNorm, projections, and residual adds until it reaches the
MoE router. The topK kernel picks the same expert index K times when encountering
NaN values. FlashInfer Cutlass MoE kernels do not expect duplicate expert indices
per token — the resulting out-of-bounds access causes a hard CUDA error.

NaNs in padding tokens post-attention

Instrumentation of attention output (Qwen3-30B-A3B BF16, piecewise CUDA graphs,
FlashInfer CUTLASS MoE, GSM8K eval) confirms NaN is present in padding rows on virtually every
padded step:

Config Steps with padding NaN in padding Inf in padding Mean pad tokens
1 GPU 62 62 (100%) 5 4.0
DP=2 832 829 (99.6%) 22 16.7
DEP=2 729 729 (100%) 11 14.2

DP/DEP produce ~13× more padding events with much larger pad regions than single-GPU
bucket rounding, which is why the bug manifests under DP but not single GPU.

Flashinfer CUTLASS kernel IMA with duplicate expert indices

cutlass_fp8_bf16_simple.py shows that calling FlashInfer's cutlass_fused_moe in eager
mode with all tokens routed to the same expert (simulating the duplicate-ID condition)
crashes with an illegal memory access on both BF16 and FP8 paths.

cutlass_fp8_bf16_simple.py

FP8 was immune because the FP8 quantization kernel (scaled_fp8_conversion)
uses fminf/fmaxf, which silently clamp NaN to ±448.0, acting as an implicit NaN
firewall at every linear layer.

Fix: Sanitize NaN in the fused softmax/sigmoid+topk CUDA kernel
(csrc/moe/topk_softmax_kernels.cu). This is applied in three places:

  1. moeSoftmax (fallback path for non-power-of-2 expert counts): NaN is
    replaced with -FLT_MAX in the max-reduce, exp-sum, and final softmax loops.
  2. moeSigmoid (fallback sigmoid path): NaN is replaced before the sigmoid
    computation.
  3. topkGating (fused warp-level path for power-of-2 / multiple-of-64 expert
    counts): NaN is replaced in-register immediately after loading, before
    softmax/sigmoid and argmax.

topK now returns K distinct indices, never duplicates.

Test Plan

  1. tests/kernels/moe/test_routing.py

    • Includes additional tests test_topk_nan_row_distinct_experts and
      test_grouped_topk_nan_row_distinct_experts, which inject all-NaN rows into
      top-k. They verify that NaN rows return distinct values and that non NaN rows
      match the baseline.
  2. End-to-end GSM8K correctness (Qwen3-30B-A3B BF16 and Qwen3.5-35B-A3B BF16,
    FlashInfer CUTLASS MoE, piecewise CUDA graphs):

    Model Measured Expected
    Qwen3-30B-A3B BF16 (DP=2) 0.8870 0.8800
    Qwen3-30B-A3B BF16 (DEP=2) 0.8901 0.8800
    Qwen3.5-35B-A3B BF16 (DP=2) 0.8408 0.8400
    Qwen3.5-35B-A3B BF16 (DEP=2) 0.8560 0.8400
  3. Micro-benchmark (benchmark_router_select_experts.py, 128 tokens × 128 experts,
    top_k=8, bfloat16, 10% NaNs, B200, mirrors Qwen3-30B):

    The NaN check is fused into the existing kernel rather than added as a
    separate torch.nan_to_num(-float('inf')) call before routing.
    A separate nan_to_num launch adds ~30% overhead:

    Variant Kernel p20 p50 p80
    Before fix topk_softmax 10.21 µs 10.27 µs 10.37 µs
    Before fix topk_sigmoid 10.24 µs 12.13 µs 12.32 µs
    After fix (fused) topk_softmax 10.11 µs 10.24 µs 10.30 µs
    After fix (fused) topk_sigmoid 10.24 µs 12.16 µs 10.32 µs
    torch.nan_to_num topk_softmax 11.42 µs 13.28 µs 13.44 µs
    torch.nan_to_num topk_sigmoid 13.15 µs 13.31 µs 13.44 µs

benchmark_router_select_experts.py

  1. E2E Latency:
vllm serve Qwen/Qwen3-30B-A3B \
  --data-parallel-size 2 \
  --max-model-len 8192

vllm bench serve \
  --backend vllm \
  --model Qwen/Qwen3-30B-A3B \
  --endpoint /v1/completions \
  --num-prompts 128 \
  --random-input-len 2 \
  --random-output-len 512 \
  --num-warmups 128 \
  --request-rate inf \
  --temperature 0

Results within run to run variance

Before fix:
============ Serving Benchmark Result ============
Successful requests:                     128       
Failed requests:                         0         
Benchmark duration (s):                  4.74      
Total input tokens:                      256       
Total generated tokens:                  65536     
Request throughput (req/s):              27.01     
Output token throughput (tok/s):         13826.70  
Peak output token throughput (tok/s):    14336.00  
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          13880.71  
---------------Time to First Token----------------
Mean TTFT (ms):                          97.14     
Median TTFT (ms):                        99.54     
P99 TTFT (ms):                           116.43    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.06      
Median TPOT (ms):                        9.06      
P99 TPOT (ms):                           9.06      
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.06      
Median ITL (ms):                         9.06      
P99 ITL (ms):                            10.59     
==================================================
After fix:
============ Serving Benchmark Result ============
Successful requests:                     128       
Failed requests:                         0         
Benchmark duration (s):                  4.72      
Total input tokens:                      256       
Total generated tokens:                  65536     
Request throughput (req/s):              27.11     
Output token throughput (tok/s):         13881.52  
Peak output token throughput (tok/s):    14336.00  
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          13935.74  
---------------Time to First Token----------------
Mean TTFT (ms):                          95.84     
Median TTFT (ms):                        97.08     
P99 TTFT (ms):                           109.43    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.03      
Median TPOT (ms):                        9.03      
P99 TPOT (ms):                           9.03      
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.03      
Median ITL (ms):                         9.00      
P99 ITL (ms):                            11.15     
==================================================
torch.nan_to_num
============ Serving Benchmark Result ============
Successful requests:                     128       
Failed requests:                         0         
Benchmark duration (s):                  4.83      
Total input tokens:                      256       
Total generated tokens:                  65536     
Request throughput (req/s):              26.48     
Output token throughput (tok/s):         13556.04  
Peak output token throughput (tok/s):    14031.00  
Peak concurrent requests:                128.00    
Total token throughput (tok/s):          13608.99  
---------------Time to First Token----------------
Mean TTFT (ms):                          97.80     
Median TTFT (ms):                        104.58    
P99 TTFT (ms):                           115.65    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          9.23      
Median TPOT (ms):                        9.23      
P99 TPOT (ms):                           9.24      
---------------Inter-token Latency----------------
Mean ITL (ms):                           9.23      
Median ITL (ms):                         9.22      
P99 ITL (ms):                            11.01     
==================================================

Test Result

See above.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)

Signed-off-by: Yifan Zong <yzong@redhat.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the robustness of MoE routing kernels by clamping input values to -FLT_MAX in the CUDA implementation, preventing issues when encountering NaNs or extreme values. Additionally, it introduces new test cases and helper functions to verify that routing logic correctly handles all-NaN rows by selecting distinct expert IDs for both fused and grouped top-k routing. I have no feedback to provide as there were no review comments to evaluate.

@yzong-rh yzong-rh marked this pull request as ready for review April 12, 2026 01:07
@yzong-rh
Copy link
Copy Markdown
Contributor Author

@yzong-rh yzong-rh changed the title Fix MoE routing NaN crash (FlashInfer CUTLASS BF16, DP/DEP + CUDA graphs) [BUG] Fix MoE routing NaN crash (FlashInfer CUTLASS BF16, DP/DEP + CUDA graphs) Apr 12, 2026
@mergify mergify Bot added the bug Something isn't working label Apr 12, 2026
@yzong-rh yzong-rh changed the title [BUG] Fix MoE routing NaN crash (FlashInfer CUTLASS BF16, DP/DEP + CUDA graphs) [Bug]: Fix FlashInfer CUTLASS BF16 + CUDA graphs IMA Apr 12, 2026
Signed-off-by: Yifan <yzong@redhat.com>
@yzong-rh yzong-rh requested a review from pavanimajety as a code owner April 12, 2026 01:46
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Could you also benchmark e2e performance using vllm bench ...?

@yzong-rh
Copy link
Copy Markdown
Contributor Author

Added vLLM serve latency test on B200

@yewentao256
Copy link
Copy Markdown
Member

Added vLLM serve latency test on B200

We usually benchmark with vllm bench serve --model $MODEL so we have better metrics

@yzong-rh
Copy link
Copy Markdown
Contributor Author

Updated benchmark results to use vllm bench serve

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 15, 2026
@yzong-rh
Copy link
Copy Markdown
Contributor Author

tests/distributed/test_elastic_ep.py::test_elastic_ep_scaling also failing on main.
kernels/moe/test_moe_layer.py::test_moe_layer[False-deepep_high_throughput-2-1-True] passes for me on 2xH200.

Other failures due to infra issues.

@yzong-rh
Copy link
Copy Markdown
Contributor Author

#39391 fixes the same issue

@yzong-rh yzong-rh closed this Apr 22, 2026
@github-project-automation github-project-automation Bot moved this to Done in NVIDIA Apr 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: FLASHINFER_CUTLASS and FLASHINFER_TRTLLM do not work for Qwen3.5 Bf16 DP/EP

2 participants