[Bugfix] Fix condition to clear persistent topk so that it can be captured regardless#41665
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces an unconditional cudaMemsetAsync to zero the RadixRowState workspace before the persistent_topk_kernel launch, ensuring correct state initialization for CUDA graph replays and inter-CTA synchronization. Feedback identifies a potential race condition where the existing in-kernel initialization of the same state might conflict with the new host-side memset, and it is recommended to remove the redundant in-kernel reset to ensure correctness.
| { | ||
| cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr<uint8_t>(), 0, | ||
| state_bytes, stream); | ||
| TORCH_CHECK(mz_err == cudaSuccess, | ||
| "row_states memset failed: ", cudaGetErrorString(mz_err)); | ||
| } |
There was a problem hiding this comment.
The addition of cudaMemsetAsync correctly ensures that the RadixRowState is zeroed and globally visible before the kernel launch, which is necessary for CUDA graph capture and correct inter-CTA synchronization. However, the persistent_topk_kernel in csrc/persistent_topk.cuh (lines 894-909) still contains an in-kernel initialization block that resets arrival_counter and other fields.
Because there is no global barrier between that in-kernel initialization and the subsequent work loop, a race condition exists: other CTAs (e.g., cta_in_group > 0) may start executing radix_topk, incrementing the arrival_counter, only to have cta_in_group == 0 reset it to zero. This will lead to synchronization failures or hangs.
Now that cudaMemsetAsync is used, the initialization block in csrc/persistent_topk.cuh should be removed to ensure correctness and avoid this race condition.
aee0c46 to
2e4b843
Compare
|
@claude review |
There was a problem hiding this comment.
LGTM — minimal, well-justified fix to ensure the workspace memset is captured as a graph node regardless of capture-time max_seq_len.
Extended reasoning...
Overview
Single-line change in csrc/topk.cu that removes a host-side if (needs_cooperative) guard around cudaMemsetAsync. The memset zeroes the per-group RadixRowState workspace before kernel launch. The guard caused the memset to be skipped at CUDA-graph capture time when capture-time max_seq_len <= RADIX_THRESHOLD, so the zeroing was never recorded into the captured graph, and stale workspace state could persist across replays if the runtime data later took the radix path.
Security risks
None. This is a CUDA kernel launch correctness fix in the DeepSeek V3 sparse attention indexer. No auth, crypto, network I/O, deserialization, or user-input handling is touched.
Level of scrutiny
Low. The change is mechanical (drop a conditional, keep the memset body), the rationale is fully captured in the new comment, the cost is negligible (~3 KB per group, sub-microsecond), and the author included gsm8k accuracy and serving benchmark numbers showing correct behavior on B300 MTP2.
Other factors
The Gemini bot flagged a potential in-kernel-init vs memset race, but that is stale: csrc/persistent_topk.cuh no longer has an in-kernel initialization block for arrival_counter (the comment at lines 890-895 explicitly notes the init was moved host-side for exactly this happens-before reason, and a grep for arrival_counter shows only red_release/wait_ge uses inside the kernel). Mergify also flagged conflicts, but rebasing is the author's responsibility and orthogonal to review correctness.
| // Issued UNCONDITIONALLY so the memset is captured as its own node in | ||
| // the cudagraph (a separate cudaMemsetAsync node, sequenced before the | ||
| // persistent_topk_kernel launch on the same stream). The previous | ||
| // host-side guard `if (needs_cooperative)` was evaluated at capture time; | ||
| // when capture-time max_seq_len <= RADIX_THRESHOLD (always true under | ||
| // FULL_DECODE_ONLY with max_model_len < 32 K) the memset would NOT be | ||
| // captured, leaving the workspace state to accumulate across replays. | ||
| // That's a latent correctness bug if the runtime data ever takes the | ||
| // radix path, and removes one variable while debugging hangs in the | ||
| // decode/medium paths. | ||
| // | ||
| // Cost is sub-microsecond: state_bytes = num_groups * sizeof(RadixRowState) | ||
| // is ~3 KB per group, ~100 KB for the largest grids on this hardware. |
There was a problem hiding this comment.
I don't think we need quite so much context in the comment.
| // Issued UNCONDITIONALLY so the memset is captured as its own node in | |
| // the cudagraph (a separate cudaMemsetAsync node, sequenced before the | |
| // persistent_topk_kernel launch on the same stream). The previous | |
| // host-side guard `if (needs_cooperative)` was evaluated at capture time; | |
| // when capture-time max_seq_len <= RADIX_THRESHOLD (always true under | |
| // FULL_DECODE_ONLY with max_model_len < 32 K) the memset would NOT be | |
| // captured, leaving the workspace state to accumulate across replays. | |
| // That's a latent correctness bug if the runtime data ever takes the | |
| // radix path, and removes one variable while debugging hangs in the | |
| // decode/medium paths. | |
| // | |
| // Cost is sub-microsecond: state_bytes = num_groups * sizeof(RadixRowState) | |
| // is ~3 KB per group, ~100 KB for the largest grids on this hardware. | |
| // Issued UNCONDITIONALLY so the memset is captured as its own node in | |
| // the cudagraph (a separate cudaMemsetAsync node, sequenced before the | |
| // persistent_topk_kernel launch on the same stream). |
Purpose
External memset is introduced in #41444 but only when
need_cooperativeis true.need_cooperativeonly trigger whenmax_seq_lenis greater than a certain threshold. And this means that the trigger will never fire at cuda graph capture time.This PR remove
need_cooperativestatement so that the memset kernel is always trigger at capture time.Test Plan
gsm8k v4-pro DEP4 on B300 MTP2
Test Result
Speedbench 8k/1k result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.