[MoE] Fix warp-shfl UB in grouped_topk renormalize#23774
Closed
Kangyan-Zhou wants to merge 1 commit intomainfrom
Closed
[MoE] Fix warp-shfl UB in grouped_topk renormalize#23774Kangyan-Zhou wants to merge 1 commit intomainfrom
Kangyan-Zhou wants to merge 1 commit intomainfrom
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request fixes potential undefined behavior in the grouped_topk_single_group_kernel CUDA kernel by ensuring all warp lanes participate in the warp_sum_f32 reduction, avoiding divergence during synchronization. It also re-enables the nvidia_nemotron_3_nano test in the CI suite. I have no feedback to provide as there were no review comments to evaluate.
Collaborator
Author
|
/tag-and-rerun-ci |
…n test
The Phase-3 renormalize block in `grouped_topk_single_group_kernel` called
`warp_sum_f32` (which uses `__shfl_xor_sync(0xffffffff, ...)`) from inside
`if (lane_id < topk)`. With `topk` < 32 (e.g. nemotron-3-nano: topk=6), only
lanes 0..topk-1 reached the intrinsic, but the mask 0xffffffff named all 32
lanes. CUDA spec: every lane named in the mask must execute the intrinsic
at the same site, otherwise the result is undefined.
Empirically the UB returned values from the absent lanes' registers,
producing wrong renormalized weights — 2 of 6 weights per token were
unnormalized (~1.5x too large). The wrong values were tolerated in eager
inference, but under piecewise CUDA graph replay they cascaded into a
downstream OOB that surfaced as IMA at `piecewise_cuda_graph_runner.py:794`
on `TestNvidiaNemotron3Nano30BFP8.test_lm_eval`.
Fix: move the warp_sum out of the divergent `if`, have all 32 lanes
participate, with inactive lanes contributing the additive identity (0).
Output writes remain gated by `if (lane_id < topk)`.
Validated:
- Unit sweep across E in {16..512}, K in {1..8}, N in {1..128}: matches
reference biased_grouped_topk_impl with max diff < 1e-7.
- 2x H200 e2e: TestNvidiaNemotron3Nano30BFP8.test_lm_eval passes
(gsm8k strict=0.839, flexible=0.542, both within rtol=0.08).
- Buggy kernel + eager (no graphs) also passes — confirming the kernel
itself doesn't fault, only the cascade-under-graph-replay does.
This is the surgical alternative to #23758, which reverts the entire
#23533 (~4000 lines). The model code, tool/reasoning parsers, and tuned
MoE configs from #23533 are not part of the bug.
Also re-enables `test_nvidia_nemotron_3_nano` (the stop-gap disable was
added in #23720 when this IMA started showing up).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
8ad064e to
95bf063
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix the IMA in
TestNvidiaNemotron3Nano30BFP8.test_lm_evalcaused by a CUDA warp-shuffle UB ingrouped_topk_single_group_kernel's renormalize path, and re-enable the test that was disabled in #23720 as a stop-gap.python/sglang/jit_kernel/csrc/moe/grouped_topk.cuh.disabled=...kwarg fromregister_cuda_ciintest/registered/models/test_nvidia_nemotron_3_nano.py.Root cause
Phase 3 of
grouped_topk_single_group_kernelcalledwarp_sum_f32(which does__shfl_xor_sync(0xffffffff, ...)) from insideif (lane_id < topk):For nemotron-3-nano (topk=6), only lanes 0..5 reach the shfl_xor, but the mask
0xffffffffdeclares all 32 lanes participate. Per the CUDA programming guide, every lane named in the mask must execute the intrinsic at the same site, otherwise the result is undefined.Empirically the UB returns values from the absent lanes' registers — producing wrong renormalized weights, with 2 of 6 weights per token coming out ~1.5× too large because they weren't divided by the full sum. The author's
(lane_id < topk) ? weight : 0.0fternary captured the right intent, but it was inside anif (lane_id < topk)so it's a tautology — the lanes that needed to contribute 0 never reached it.Why the symptom was an IMA at PCG replay
The kernel UB itself does not raise a hardware fault — it produces wrong values. Verified: with the buggy kernel,
--disable-cuda-graph --disable-piecewise-cuda-graph, andCUDA_LAUNCH_BLOCKING=1, the full GSM8K eval completes cleanly with accuracy within rtol.The IMA only manifests under PCG replay because:
replay()call.Fix
Move the warp_sum out of the divergent context. All 32 lanes of warp 0 reach
warp_sum_f32together; inactive lanes (lane_id >= topk) contribute the additive identity (0). Output writes remain gated byif (lane_id < topk).Phase 2's
warp_max_u64also uses__shfl_xor_sync(0xffffffff, ...)but is safe — it runs afterif (warp_id != 0) return;, so all 32 lanes of warp 0 are converged at the call site.Test plan
biased_grouped_topk_implwith max diff < 1e-7. Buggy kernel fails by ~0.08 absolute on 2 weights per row.TestNvidiaNemotron3Nano30BFP8.test_lm_evalon 2× H200, both graph modes default-on:piecewise_cuda_graph_runner.py:794(matches the original failure).test_lm_eval ... ok, gsm8k strict=0.839 (target 0.847, rtol=0.08), flexible=0.542 (target 0.556, rtol=0.08).stage-b-test-2-gpu-largeruns the re-enabled test and reports gsm8k≈0.85.🤖 Generated with Claude Code