From 2e4b84387f734406cff5f36584ae17db9889b3f9 Mon Sep 17 00:00:00 2001 From: Yongye Zhu Date: Mon, 4 May 2026 21:12:08 +0000 Subject: [PATCH] fix Signed-off-by: Yongye Zhu --- csrc/topk.cu | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/csrc/topk.cu b/csrc/topk.cu index 68352629ef02..c5bffb32856d 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -153,14 +153,23 @@ void launch_persistent_topk(const torch::Tensor& logits, TORCH_CHECK(workspace.size(0) >= static_cast(state_bytes), "workspace too small, need ", state_bytes, " bytes"); - // Zero the per-group RadixRowState region before launch — only when the - // radix path will actually run (max_seq_len > RADIX_THRESHOLD). The - // RadixRowState fields (arrival_counter, histograms) are only touched by - // radix_topk; the decode/medium paths inside the persistent kernel - // operate purely in shared memory and never read these globals, so a - // stale workspace is harmless for them. + // Zero the per-group RadixRowState region before launch. // - // Why we need the memset (when needs_cooperative is true): + // Issued UNCONDITIONALLY so the memset is captured as its own node in + // the cudagraph (a separate cudaMemsetAsync node, sequenced before the + // persistent_topk_kernel launch on the same stream). The previous + // host-side guard `if (needs_cooperative)` was evaluated at capture time; + // when capture-time max_seq_len <= RADIX_THRESHOLD (always true under + // FULL_DECODE_ONLY with max_model_len < 32 K) the memset would NOT be + // captured, leaving the workspace state to accumulate across replays. + // That's a latent correctness bug if the runtime data ever takes the + // radix path, and removes one variable while debugging hangs in the + // decode/medium paths. + // + // Cost is sub-microsecond: state_bytes = num_groups * sizeof(RadixRowState) + // is ~3 KB per group, ~100 KB for the largest grids on this hardware. + // + // Why the memset is required (regardless of which path the kernel takes): // 1. arrival_counter accumulates within a launch and is never reset, // so a prior call leaves it at a large positive value. Without this // reset, the very first wait_ge in the next call sees counter >> @@ -169,7 +178,7 @@ void launch_persistent_topk(const torch::Tensor& logits, // __syncthreads(), so it had no happens-before edge to CTA-1+'s // first red_release. cudaMemsetAsync is stream-ordered: the zero // is globally visible before any CTA runs. - if (needs_cooperative) { + { cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr(), 0, state_bytes, stream); TORCH_CHECK(mz_err == cudaSuccess,