Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions csrc/persistent_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -887,12 +887,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
uint32_t* shared_ordered =
reinterpret_cast<uint32_t*>(smem_raw + kFixedSmemLarge);

// RadixRowState for multi-CTA cooperative radix.
// Zero-initialization is done host-side via cudaMemsetAsync in topk.cu
// before launch — that gives a stream-ordered happens-before edge for all
// CTAs, which the previous in-kernel init (CTA-0 only + intra-CTA
// __syncthreads) did not provide and which manifested as a race against
// CTA-1+'s first red_release on arrival_counter.
RadixRowState* state = &params.row_states[group_id];

int barrier_phase = 0;
Expand Down Expand Up @@ -930,6 +924,22 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
local_histogram, suffix_sum, shared_scalars, shared_ordered, state,
cta_in_group, ctas_per_group, barrier_phase, iter, tx);
}

if (params.max_seq_len > RADIX_THRESHOLD) {
if (cta_in_group == 0) {
for (uint32_t buf = 0; buf < 3; buf++) {
for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) {
int* ptr = reinterpret_cast<int*>(&state->histogram[buf][i]);
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(ptr), "r"(0));
}
}

__syncthreads();
if (tx == 0) {
st_release(&state->arrival_counter, 0);
}
}
Comment thread
LopezCastroRoberto marked this conversation as resolved.
}
}

} // namespace persistent
Expand Down
23 changes: 0 additions & 23 deletions csrc/topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,6 @@ void launch_persistent_topk(const torch::Tensor& logits,
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
"workspace too small, need ", state_bytes, " bytes");

// Zero the per-group RadixRowState region before launch — only when the
// radix path will actually run (max_seq_len > RADIX_THRESHOLD). The
// RadixRowState fields (arrival_counter, histograms) are only touched by
// radix_topk; the decode/medium paths inside the persistent kernel
// operate purely in shared memory and never read these globals, so a
// stale workspace is harmless for them.
//
// Why we need the memset (when needs_cooperative is true):
// 1. arrival_counter accumulates within a launch and is never reset,
// so a prior call leaves it at a large positive value. Without this
// reset, the very first wait_ge in the next call sees counter >>
// target and returns instantly, breaking the barrier.
// 2. The previous in-kernel init only ran in CTA-0 with intra-CTA
// __syncthreads(), so it had no happens-before edge to CTA-1+'s
// first red_release. cudaMemsetAsync is stream-ordered: the zero
// is globally visible before any CTA runs.
if (needs_cooperative) {
cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr<uint8_t>(), 0,
state_bytes, stream);
TORCH_CHECK(mz_err == cudaSuccess,
"row_states memset failed: ", cudaGetErrorString(mz_err));
}

P::PersistentTopKParams params;
params.input = logits.data_ptr<float>();
params.output = output.data_ptr<int32_t>();
Expand Down
60 changes: 57 additions & 3 deletions tests/kernels/test_top_k_per_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def test_deepseek_persistent_topk(
offsets = torch.arange(next_n, device=logits.device, dtype=torch.int32)
lengths = (seq_lens.unsqueeze(1) - next_n + 1 + offsets).flatten()

workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
workspace = torch.zeros(1024 * 1024, dtype=torch.uint8, device="cuda")
max_seq_len = int(seq_lens.max().item())
torch.ops._C.persistent_topk(
logits, lengths, indices, workspace, top_k, max_seq_len
Expand Down Expand Up @@ -449,7 +449,7 @@ def run_large_context_topk_test(
# Create output tensor
indices = torch.empty((num_rows, top_k), dtype=torch.int32, device="cuda")

workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
workspace = torch.zeros(1024 * 1024, dtype=torch.uint8, device="cuda")
max_seq_len = max(seq_lens)
torch.ops._C.persistent_topk(
logits, lengths, indices, workspace, top_k, max_seq_len
Expand Down Expand Up @@ -818,7 +818,7 @@ def test_persistent_topk_padded_stride(top_k: int) -> None:

lengths = torch.tensor(actual_seq_lens, dtype=torch.int32, device="cuda")
indices = torch.empty((batch_size, top_k), dtype=torch.int32, device="cuda")
workspace = torch.empty(1024 * 1024, dtype=torch.uint8, device="cuda")
workspace = torch.zeros(1024 * 1024, dtype=torch.uint8, device="cuda")

torch.ops._C.persistent_topk(
logits, lengths, indices, workspace, top_k, max(actual_seq_lens)
Expand All @@ -843,3 +843,57 @@ def test_persistent_topk_padded_stride(top_k: int) -> None:
f"Row {i}: persistent_topk with padded stride doesn't match. "
f"seq_len={sl}, stride={padded_stride}"
)


@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize(
"batch_size,seq_len,top_k",
[
pytest.param(1, 65536, 512, id="bs1_64k_k512"),
pytest.param(4, 65536, 512, id="bs4_64k_k512"),
pytest.param(8, 128000, 512, id="bs8_128k_k512"),
pytest.param(1, 96000, 1024, id="bs1_96k_k1024"),
pytest.param(8, 65536, 1024, id="bs8_64k_k1024"),
pytest.param(4, 128000, 2048, id="bs4_128k_k2048"),
pytest.param(8, 163840, 2048, id="bs8_164k_k2048"),
],
)
def test_persistent_topk_stale_workspace(
batch_size: int, seq_len: int, top_k: int
) -> None:
"""Verify persistent_topk produces correct results across repeated calls
with the same workspace buffer (stale RadixRowState from prior calls).

The radix path (seq_len > 32768) uses cooperative multi-CTA barriers
via arrival_counter in global memory. Without proper cleanup between
calls, stale counter values cause barrier malfunctions and wrong results.
"""
num_iters = 100

logits = torch.randn((batch_size, seq_len), dtype=torch.float32, device="cuda")

min_len = int(seq_len * 0.8)
lengths = torch.randint(
min_len, seq_len + 1, (batch_size,), dtype=torch.int32, device="cuda"
)

# Mask invalid positions
positions = torch.arange(seq_len, device="cuda", dtype=torch.int32).unsqueeze(0)
mask = positions >= lengths.unsqueeze(1)
logits = logits.masked_fill(mask, float("-inf"))

output = torch.empty(batch_size, top_k, dtype=torch.int32, device="cuda")
workspace = torch.zeros(1024 * 1024, dtype=torch.uint8, device="cuda")

# Reference
_, ref_indices = torch.topk(logits, top_k, dim=-1)
ref_sorted = ref_indices.sort(dim=-1).values

for i in range(num_iters):
torch.ops._C.persistent_topk(logits, lengths, output, workspace, top_k, seq_len)
out_sorted = output.sort(dim=-1).values
assert torch.equal(ref_sorted, out_sorted), (
f"Stale workspace race at iter {i}: "
f"bs={batch_size} seq_len={seq_len} k={top_k} "
f"({(ref_sorted != out_sorted).sum().item()} indices differ)"
)
9 changes: 9 additions & 0 deletions vllm/model_executor/layers/sparse_attn_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,15 @@ def sparse_attn_indexer(
(topk_workspace,) = workspace_manager.get_simultaneous(
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
)
# Workspace must be zeroed on first use; the kernel resets it
# at the end of each launch for subsequent calls/graph replays.
# Re-zero if the buffer was reallocated (different data_ptr).
if (
getattr(sparse_attn_indexer, "_topk_ws_ptr", None)
!= topk_workspace.data_ptr()
):
topk_workspace.zero_()
sparse_attn_indexer._topk_ws_ptr = topk_workspace.data_ptr() # type: ignore[attr-defined]
torch.ops._C.persistent_topk(
logits,
seq_lens,
Expand Down
Loading