diff --git a/csrc/persistent_topk.cuh b/csrc/persistent_topk.cuh index 8b9d10ff83dd..11e639797ffd 100644 --- a/csrc/persistent_topk.cuh +++ b/csrc/persistent_topk.cuh @@ -887,12 +887,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2) uint32_t* shared_ordered = reinterpret_cast(smem_raw + kFixedSmemLarge); - // RadixRowState for multi-CTA cooperative radix. - // Zero-initialization is done host-side via cudaMemsetAsync in topk.cu - // before launch — that gives a stream-ordered happens-before edge for all - // CTAs, which the previous in-kernel init (CTA-0 only + intra-CTA - // __syncthreads) did not provide and which manifested as a race against - // CTA-1+'s first red_release on arrival_counter. RadixRowState* state = ¶ms.row_states[group_id]; int barrier_phase = 0; @@ -930,6 +924,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2) local_histogram, suffix_sum, shared_scalars, shared_ordered, state, cta_in_group, ctas_per_group, barrier_phase, iter, tx); } + + if (params.max_seq_len > RADIX_THRESHOLD) { + if (cta_in_group == 0) { + for (uint32_t buf = 0; buf < 3; buf++) { + for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) { + int* ptr = reinterpret_cast(&state->histogram[buf][i]); + asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(ptr), "r"(0)); + } + } + + __syncthreads(); + if (tx == 0) { + st_release(&state->arrival_counter, 0); + } + } + } } } // namespace persistent diff --git a/csrc/topk.cu b/csrc/topk.cu index 68352629ef02..b0f612ba6e4b 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -153,29 +153,6 @@ void launch_persistent_topk(const torch::Tensor& logits, TORCH_CHECK(workspace.size(0) >= static_cast(state_bytes), "workspace too small, need ", state_bytes, " bytes"); - // Zero the per-group RadixRowState region before launch — only when the - // radix path will actually run (max_seq_len > RADIX_THRESHOLD). The - // RadixRowState fields (arrival_counter, histograms) are only touched by - // radix_topk; the decode/medium paths inside the persistent kernel - // operate purely in shared memory and never read these globals, so a - // stale workspace is harmless for them. - // - // Why we need the memset (when needs_cooperative is true): - // 1. arrival_counter accumulates within a launch and is never reset, - // so a prior call leaves it at a large positive value. Without this - // reset, the very first wait_ge in the next call sees counter >> - // target and returns instantly, breaking the barrier. - // 2. The previous in-kernel init only ran in CTA-0 with intra-CTA - // __syncthreads(), so it had no happens-before edge to CTA-1+'s - // first red_release. cudaMemsetAsync is stream-ordered: the zero - // is globally visible before any CTA runs. - if (needs_cooperative) { - cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr(), 0, - state_bytes, stream); - TORCH_CHECK(mz_err == cudaSuccess, - "row_states memset failed: ", cudaGetErrorString(mz_err)); - } - P::PersistentTopKParams params; params.input = logits.data_ptr(); params.output = output.data_ptr(); diff --git a/tests/kernels/test_top_k_per_row.py b/tests/kernels/test_top_k_per_row.py index 7b9c11495e8b..016bc47aa053 100644 --- a/tests/kernels/test_top_k_per_row.py +++ b/tests/kernels/test_top_k_per_row.py @@ -366,7 +366,7 @@ def test_deepseek_persistent_topk( offsets = torch.arange(next_n, device=logits.device, dtype=torch.int32) lengths = (seq_lens.unsqueeze(1) - next_n + 1 + offsets).flatten() - workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda") + workspace = torch.zeros(1024 * 1024, dtype=torch.uint8, device="cuda") max_seq_len = int(seq_lens.max().item()) torch.ops._C.persistent_topk( logits, lengths, indices, workspace, top_k, max_seq_len @@ -449,7 +449,7 @@ def run_large_context_topk_test( # Create output tensor indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda") - workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda") + workspace = torch.zeros(1024 * 1024, dtype=torch.uint8, device="cuda") max_seq_len = max(seq_lens) torch.ops._C.persistent_topk( logits, lengths, indices, workspace, top_k, max_seq_len @@ -818,7 +818,7 @@ def test_persistent_topk_padded_stride(top_k: int) -> None: lengths = torch.tensor(actual_seq_lens, dtype=torch.int32, device="cuda") indices = torch.empty((batch_size, top_k), dtype=torch.int32, device="cuda") - workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda") + workspace = torch.zeros(1024 * 1024, dtype=torch.uint8, device="cuda") torch.ops._C.persistent_topk( logits, lengths, indices, workspace, top_k, max(actual_seq_lens) @@ -843,3 +843,57 @@ def test_persistent_topk_padded_stride(top_k: int) -> None: f"Row {i}: persistent_topk with padded stride doesn't match. " f"seq_len={sl}, stride={padded_stride}" ) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA") +@pytest.mark.parametrize( + "batch_size,seq_len,top_k", + [ + pytest.param(1, 65536, 512, id="bs1_64k_k512"), + pytest.param(4, 65536, 512, id="bs4_64k_k512"), + pytest.param(8, 128000, 512, id="bs8_128k_k512"), + pytest.param(1, 96000, 1024, id="bs1_96k_k1024"), + pytest.param(8, 65536, 1024, id="bs8_64k_k1024"), + pytest.param(4, 128000, 2048, id="bs4_128k_k2048"), + pytest.param(8, 163840, 2048, id="bs8_164k_k2048"), + ], +) +def test_persistent_topk_stale_workspace( + batch_size: int, seq_len: int, top_k: int +) -> None: + """Verify persistent_topk produces correct results across repeated calls + with the same workspace buffer (stale RadixRowState from prior calls). + + The radix path (seq_len > 32768) uses cooperative multi-CTA barriers + via arrival_counter in global memory. Without proper cleanup between + calls, stale counter values cause barrier malfunctions and wrong results. + """ + num_iters = 100 + + logits = torch.randn((batch_size, seq_len), dtype=torch.float32, device="cuda") + + min_len = int(seq_len * 0.8) + lengths = torch.randint( + min_len, seq_len + 1, (batch_size,), dtype=torch.int32, device="cuda" + ) + + # Mask invalid positions + positions = torch.arange(seq_len, device="cuda", dtype=torch.int32).unsqueeze(0) + mask = positions >= lengths.unsqueeze(1) + logits = logits.masked_fill(mask, float("-inf")) + + output = torch.empty(batch_size, top_k, dtype=torch.int32, device="cuda") + workspace = torch.zeros(1024 * 1024, dtype=torch.uint8, device="cuda") + + # Reference + _, ref_indices = torch.topk(logits, top_k, dim=-1) + ref_sorted = ref_indices.sort(dim=-1).values + + for i in range(num_iters): + torch.ops._C.persistent_topk(logits, lengths, output, workspace, top_k, seq_len) + out_sorted = output.sort(dim=-1).values + assert torch.equal(ref_sorted, out_sorted), ( + f"Stale workspace race at iter {i}: " + f"bs={batch_size} seq_len={seq_len} k={top_k} " + f"({(ref_sorted != out_sorted).sum().item()} indices differ)" + ) diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 4bf52a49c43f..a134d84ade76 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -325,6 +325,15 @@ def sparse_attn_indexer( (topk_workspace,) = workspace_manager.get_simultaneous( ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), ) + # Workspace must be zeroed on first use; the kernel resets it + # at the end of each launch for subsequent calls/graph replays. + # Re-zero if the buffer was reallocated (different data_ptr). + if ( + getattr(sparse_attn_indexer, "_topk_ws_ptr", None) + != topk_workspace.data_ptr() + ): + topk_workspace.zero_() + sparse_attn_indexer._topk_ws_ptr = topk_workspace.data_ptr() # type: ignore[attr-defined] torch.ops._C.persistent_topk( logits, seq_lens,