diff --git a/csrc/persistent_topk.cuh b/csrc/persistent_topk.cuh index d6162d52998b..8b9d10ff83dd 100644 --- a/csrc/persistent_topk.cuh +++ b/csrc/persistent_topk.cuh @@ -887,27 +887,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2) uint32_t* shared_ordered = reinterpret_cast(smem_raw + kFixedSmemLarge); - // RadixRowState for multi-CTA cooperative radix + // RadixRowState for multi-CTA cooperative radix. + // Zero-initialization is done host-side via cudaMemsetAsync in topk.cu + // before launch — that gives a stream-ordered happens-before edge for all + // CTAs, which the previous in-kernel init (CTA-0 only + intra-CTA + // __syncthreads) did not provide and which manifested as a race against + // CTA-1+'s first red_release on arrival_counter. RadixRowState* state = ¶ms.row_states[group_id]; - // -- Initialize RadixRowState (only needed if large rows exist) -- - if (params.max_seq_len > RADIX_THRESHOLD) { - if (cta_in_group == 0) { - for (uint32_t buf = 0; buf < 3; buf++) { - for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) { - state->histogram[buf][i] = 0; - } - } - if (tx == 0) { - state->remaining_k = 0; - state->prefix = 0; - state->arrival_counter = 0; - state->output_counter = 0; - } - } - __syncthreads(); - } - int barrier_phase = 0; const uint32_t total_iters = (params.num_rows + num_groups - 1) / num_groups; diff --git a/csrc/topk.cu b/csrc/topk.cu index b0f612ba6e4b..68352629ef02 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -153,6 +153,29 @@ void launch_persistent_topk(const torch::Tensor& logits, TORCH_CHECK(workspace.size(0) >= static_cast(state_bytes), "workspace too small, need ", state_bytes, " bytes"); + // Zero the per-group RadixRowState region before launch — only when the + // radix path will actually run (max_seq_len > RADIX_THRESHOLD). The + // RadixRowState fields (arrival_counter, histograms) are only touched by + // radix_topk; the decode/medium paths inside the persistent kernel + // operate purely in shared memory and never read these globals, so a + // stale workspace is harmless for them. + // + // Why we need the memset (when needs_cooperative is true): + // 1. arrival_counter accumulates within a launch and is never reset, + // so a prior call leaves it at a large positive value. Without this + // reset, the very first wait_ge in the next call sees counter >> + // target and returns instantly, breaking the barrier. + // 2. The previous in-kernel init only ran in CTA-0 with intra-CTA + // __syncthreads(), so it had no happens-before edge to CTA-1+'s + // first red_release. cudaMemsetAsync is stream-ordered: the zero + // is globally visible before any CTA runs. + if (needs_cooperative) { + cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr(), 0, + state_bytes, stream); + TORCH_CHECK(mz_err == cudaSuccess, + "row_states memset failed: ", cudaGetErrorString(mz_err)); + } + P::PersistentTopKParams params; params.input = logits.data_ptr(); params.output = output.data_ptr();