Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions csrc/persistent_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -887,27 +887,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock, 2)
uint32_t* shared_ordered =
reinterpret_cast<uint32_t*>(smem_raw + kFixedSmemLarge);

// RadixRowState for multi-CTA cooperative radix
// RadixRowState for multi-CTA cooperative radix.
// Zero-initialization is done host-side via cudaMemsetAsync in topk.cu
// before launch — that gives a stream-ordered happens-before edge for all
// CTAs, which the previous in-kernel init (CTA-0 only + intra-CTA
// __syncthreads) did not provide and which manifested as a race against
// CTA-1+'s first red_release on arrival_counter.
RadixRowState* state = &params.row_states[group_id];

// -- Initialize RadixRowState (only needed if large rows exist) --
if (params.max_seq_len > RADIX_THRESHOLD) {
if (cta_in_group == 0) {
for (uint32_t buf = 0; buf < 3; buf++) {
for (uint32_t i = tx; i < RADIX; i += kThreadsPerBlock) {
state->histogram[buf][i] = 0;
}
}
if (tx == 0) {
state->remaining_k = 0;
state->prefix = 0;
state->arrival_counter = 0;
state->output_counter = 0;
}
}
__syncthreads();
}

int barrier_phase = 0;
const uint32_t total_iters = (params.num_rows + num_groups - 1) / num_groups;

Expand Down
23 changes: 23 additions & 0 deletions csrc/topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,29 @@ void launch_persistent_topk(const torch::Tensor& logits,
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
"workspace too small, need ", state_bytes, " bytes");

// Zero the per-group RadixRowState region before launch — only when the
// radix path will actually run (max_seq_len > RADIX_THRESHOLD). The
// RadixRowState fields (arrival_counter, histograms) are only touched by
// radix_topk; the decode/medium paths inside the persistent kernel
// operate purely in shared memory and never read these globals, so a
// stale workspace is harmless for them.
//
// Why we need the memset (when needs_cooperative is true):
// 1. arrival_counter accumulates within a launch and is never reset,
// so a prior call leaves it at a large positive value. Without this
// reset, the very first wait_ge in the next call sees counter >>
// target and returns instantly, breaking the barrier.
// 2. The previous in-kernel init only ran in CTA-0 with intra-CTA
// __syncthreads(), so it had no happens-before edge to CTA-1+'s
// first red_release. cudaMemsetAsync is stream-ordered: the zero
// is globally visible before any CTA runs.
if (needs_cooperative) {
cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr<uint8_t>(), 0,
state_bytes, stream);
TORCH_CHECK(mz_err == cudaSuccess,
"row_states memset failed: ", cudaGetErrorString(mz_err));
}

P::PersistentTopKParams params;
params.input = logits.data_ptr<float>();
params.output = output.data_ptr<int32_t>();
Expand Down
Loading