Skip to content

[MoE] Fix warp-shfl UB in grouped_topk renormalize#23774

Closed
Kangyan-Zhou wants to merge 1 commit intomainfrom
fix_grouped_topk_warp_shfl
Closed

[MoE] Fix warp-shfl UB in grouped_topk renormalize#23774
Kangyan-Zhou wants to merge 1 commit intomainfrom
fix_grouped_topk_warp_shfl

Conversation

@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator

@Kangyan-Zhou Kangyan-Zhou commented Apr 26, 2026

Summary

Fix the IMA in TestNvidiaNemotron3Nano30BFP8.test_lm_eval caused by a CUDA warp-shuffle UB in grouped_topk_single_group_kernel's renormalize path, and re-enable the test that was disabled in #23720 as a stop-gap.

  • 4 functional lines + a comment in python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh.
  • Removes the disabled=... kwarg from register_cuda_ci in test/registered/models/test_nvidia_nemotron_3_nano.py.

Root cause

Phase 3 of grouped_topk_single_group_kernel called warp_sum_f32 (which does __shfl_xor_sync(0xffffffff, ...)) from inside if (lane_id < topk):

if (lane_id < topk) {                                  // only lanes 0..K-1 enter
    ...
    if (renormalize) {
      float partial = (lane_id < topk) ? weight : 0.0f;  // tautology — already inside the if
      float total   = warp_sum_f32(partial);             // <-- UB
      ...
    }
}

For nemotron-3-nano (topk=6), only lanes 0..5 reach the shfl_xor, but the mask 0xffffffff declares all 32 lanes participate. Per the CUDA programming guide, every lane named in the mask must execute the intrinsic at the same site, otherwise the result is undefined.

Empirically the UB returns values from the absent lanes' registers — producing wrong renormalized weights, with 2 of 6 weights per token coming out ~1.5× too large because they weren't divided by the full sum. The author's (lane_id < topk) ? weight : 0.0f ternary captured the right intent, but it was inside an if (lane_id < topk) so it's a tautology — the lanes that needed to contribute 0 never reached it.

Why the symptom was an IMA at PCG replay

The kernel UB itself does not raise a hardware fault — it produces wrong values. Verified: with the buggy kernel, --disable-cuda-graph --disable-piecewise-cuda-graph, and CUDA_LAUNCH_BLOCKING=1, the full GSM8K eval completes cleanly with accuracy within rtol.

The IMA only manifests under PCG replay because:

  • The captured graph hot-loops the buggy kernel hundreds of times per replay with no host-side sync between launches.
  • UB outcome is non-deterministic (depends on absent lanes' register state, which varies replay-to-replay).
  • Eventually a replay produces values extreme enough that downstream FP8 quant or MoE permutation arithmetic produces a NaN-derived offset → real OOB → IMA blamed on the captured replay() call.

Fix

Move the warp_sum out of the divergent context. All 32 lanes of warp 0 reach warp_sum_f32 together; inactive lanes (lane_id >= topk) contribute the additive identity (0). Output writes remain gated by if (lane_id < topk).

-  // Phase 3: renormalize and write output
+  // Phase 3: renormalize and write output. `__shfl_xor_sync` requires every
+  // lane named in the mask (0xffffffff) to execute it, so we pad inactive
+  // lanes with 0 and run warp_sum_f32 on all 32 lanes uniformly.
+  float weight = (lane_id < topk) ? selected_weights[lane_id] : 0.0f;
+  float divisor = renormalize ? warp_sum_f32(weight) + 1e-20f : 1.0f;
   if (lane_id < topk) {
-    float weight = selected_weights[lane_id];
-    float final_weight = weight * scaling_factor;
-    if (renormalize) {
-      float partial = (lane_id < topk) ? weight : 0.0f;
-      float total = warp_sum_f32(partial);
-      final_weight = weight * scaling_factor / (total + 1e-20f);
-    }
     out_ids[lane_id] = selected_ids[lane_id];
-    out_vals[lane_id] = final_weight;
+    out_vals[lane_id] = weight * scaling_factor / divisor;
   }

Phase 2's warp_max_u64 also uses __shfl_xor_sync(0xffffffff, ...) but is safe — it runs after if (warp_id != 0) return;, so all 32 lanes of warp 0 are converged at the call site.

Test plan

  • Unit-level numerical sweep on H200: E ∈ {16, 32, 64, 128, 192, 256, 384, 512}, K ∈ {1, 2, 4, 6, 7, 8}, N ∈ {1..128}. Patched kernel matches reference biased_grouped_topk_impl with max diff < 1e-7. Buggy kernel fails by ~0.08 absolute on 2 weights per row.
  • TestNvidiaNemotron3Nano30BFP8.test_lm_eval on 2× H200, both graph modes default-on:
    • Buggy kernel → IMA at piecewise_cuda_graph_runner.py:794 (matches the original failure).
    • Patched kernel → test_lm_eval ... ok, gsm8k strict=0.839 (target 0.847, rtol=0.08), flexible=0.542 (target 0.556, rtol=0.08).
  • CI: stage-b-test-2-gpu-large runs the re-enabled test and reports gsm8k≈0.85.

🤖 Generated with Claude Code

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes potential undefined behavior in the grouped_topk_single_group_kernel CUDA kernel by ensuring all warp lanes participate in the warp_sum_f32 reduction, avoiding divergence during synchronization. It also re-enables the nvidia_nemotron_3_nano test in the CI suite. I have no feedback to provide as there were no review comments to evaluate.

@Kangyan-Zhou Kangyan-Zhou changed the title [MoE] Fix warp-shfl UB in grouped_topk renormalize (alternative to #23758) [MoE] Fix warp-shfl UB in grouped_topk renormalize, re-enable nemotron test Apr 26, 2026
@Kangyan-Zhou Kangyan-Zhou changed the title [MoE] Fix warp-shfl UB in grouped_topk renormalize, re-enable nemotron test [MoE] Fix warp-shfl UB in grouped_topk renormalize Apr 26, 2026
@Kangyan-Zhou
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

…n test

The Phase-3 renormalize block in `grouped_topk_single_group_kernel` called
`warp_sum_f32` (which uses `__shfl_xor_sync(0xffffffff, ...)`) from inside
`if (lane_id < topk)`. With `topk` < 32 (e.g. nemotron-3-nano: topk=6), only
lanes 0..topk-1 reached the intrinsic, but the mask 0xffffffff named all 32
lanes. CUDA spec: every lane named in the mask must execute the intrinsic
at the same site, otherwise the result is undefined.

Empirically the UB returned values from the absent lanes' registers,
producing wrong renormalized weights — 2 of 6 weights per token were
unnormalized (~1.5x too large). The wrong values were tolerated in eager
inference, but under piecewise CUDA graph replay they cascaded into a
downstream OOB that surfaced as IMA at `piecewise_cuda_graph_runner.py:794`
on `TestNvidiaNemotron3Nano30BFP8.test_lm_eval`.

Fix: move the warp_sum out of the divergent `if`, have all 32 lanes
participate, with inactive lanes contributing the additive identity (0).
Output writes remain gated by `if (lane_id < topk)`.

Validated:
- Unit sweep across E in {16..512}, K in {1..8}, N in {1..128}: matches
  reference biased_grouped_topk_impl with max diff < 1e-7.
- 2x H200 e2e: TestNvidiaNemotron3Nano30BFP8.test_lm_eval passes
  (gsm8k strict=0.839, flexible=0.542, both within rtol=0.08).
- Buggy kernel + eager (no graphs) also passes — confirming the kernel
  itself doesn't fault, only the cascade-under-graph-replay does.

This is the surgical alternative to #23758, which reverts the entire
#23533 (~4000 lines). The model code, tool/reasoning parsers, and tuned
MoE configs from #23533 are not part of the bug.

Also re-enables `test_nvidia_nemotron_3_nano` (the stop-gap disable was
added in #23720 when this IMA started showing up).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Kangyan-Zhou Kangyan-Zhou force-pushed the fix_grouped_topk_warp_shfl branch from 8ad064e to 95bf063 Compare April 26, 2026 21:05
@Kangyan-Zhou Kangyan-Zhou reopened this Apr 27, 2026
@Kangyan-Zhou Kangyan-Zhou marked this pull request as draft April 27, 2026 03:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant