Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions csrc/topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,23 @@ void launch_persistent_topk(const torch::Tensor& logits,
TORCH_CHECK(workspace.size(0) >= static_cast<int64_t>(state_bytes),
"workspace too small, need ", state_bytes, " bytes");

// Zero the per-group RadixRowState region before launch — only when the
// radix path will actually run (max_seq_len > RADIX_THRESHOLD). The
// RadixRowState fields (arrival_counter, histograms) are only touched by
// radix_topk; the decode/medium paths inside the persistent kernel
// operate purely in shared memory and never read these globals, so a
// stale workspace is harmless for them.
// Zero the per-group RadixRowState region before launch.
//
// Why we need the memset (when needs_cooperative is true):
// Issued UNCONDITIONALLY so the memset is captured as its own node in
// the cudagraph (a separate cudaMemsetAsync node, sequenced before the
// persistent_topk_kernel launch on the same stream). The previous
// host-side guard `if (needs_cooperative)` was evaluated at capture time;
// when capture-time max_seq_len <= RADIX_THRESHOLD (always true under
// FULL_DECODE_ONLY with max_model_len < 32 K) the memset would NOT be
// captured, leaving the workspace state to accumulate across replays.
// That's a latent correctness bug if the runtime data ever takes the
// radix path, and removes one variable while debugging hangs in the
// decode/medium paths.
//
// Cost is sub-microsecond: state_bytes = num_groups * sizeof(RadixRowState)
// is ~3 KB per group, ~100 KB for the largest grids on this hardware.
Comment on lines +158 to +170
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need quite so much context in the comment.

Suggested change
// Issued UNCONDITIONALLY so the memset is captured as its own node in
// the cudagraph (a separate cudaMemsetAsync node, sequenced before the
// persistent_topk_kernel launch on the same stream). The previous
// host-side guard `if (needs_cooperative)` was evaluated at capture time;
// when capture-time max_seq_len <= RADIX_THRESHOLD (always true under
// FULL_DECODE_ONLY with max_model_len < 32 K) the memset would NOT be
// captured, leaving the workspace state to accumulate across replays.
// That's a latent correctness bug if the runtime data ever takes the
// radix path, and removes one variable while debugging hangs in the
// decode/medium paths.
//
// Cost is sub-microsecond: state_bytes = num_groups * sizeof(RadixRowState)
// is ~3 KB per group, ~100 KB for the largest grids on this hardware.
// Issued UNCONDITIONALLY so the memset is captured as its own node in
// the cudagraph (a separate cudaMemsetAsync node, sequenced before the
// persistent_topk_kernel launch on the same stream).

//
// Why the memset is required (regardless of which path the kernel takes):
// 1. arrival_counter accumulates within a launch and is never reset,
// so a prior call leaves it at a large positive value. Without this
// reset, the very first wait_ge in the next call sees counter >>
Expand All @@ -169,7 +178,7 @@ void launch_persistent_topk(const torch::Tensor& logits,
// __syncthreads(), so it had no happens-before edge to CTA-1+'s
// first red_release. cudaMemsetAsync is stream-ordered: the zero
// is globally visible before any CTA runs.
if (needs_cooperative) {
{
cudaError_t mz_err = cudaMemsetAsync(workspace.data_ptr<uint8_t>(), 0,
state_bytes, stream);
TORCH_CHECK(mz_err == cudaSuccess,
Expand Down
Loading