bugfix: fix multi-cta top-k implementation when k value is different for different row#2325
bugfix: fix multi-cta top-k implementation when k value is different for different row#2325yzh119 merged 4 commits intoflashinfer-ai:mainfrom
Conversation
Fixed a race condition in the multi-CTA radix top-k kernel where stale histogram data corrupted results when processing batches with mixed k values. Bug: When k >= vocab_size iterations skip RadixSelectFindPivot, they don't clear histogram buffers. Subsequent k < vocab_size iterations then use atomicAdd on stale histogram data, causing incorrect pivot selection. Changes: - Add __syncthreads() before red_release() in RadixSelectOneRound and RadixSelectFromSharedMemory to ensure histogram clearing completes before signaling barrier - Clear first round's histogram at start of RadixSelectFromSharedMemory to handle stale data from skipped k>=vocab iterations Added regression tests: - test_top_k_renorm_probs_mixed_k_persistent_loop - test_top_k_mask_logits_mixed_k_persistent_loop Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
📝 WalkthroughWalkthroughAdds CTA-leader guarded per-round histogram clears and barrier-aware placements across multiple RadixTopK multi-CTA kernel paths; also adds two tests exercising mixed-k behavior in persistent (multi-CTA) loops for top-k renorm/prob and mask-logits kernels. Changes
Sequence Diagram(s)sequenceDiagram
participant CTA0 as CTA 0 (leader)
participant CTAn as Other CTAs
participant Hist as Histogram Buffer
participant Barrier as CTA Barrier (sync)
CTA0->>Hist: compute/identify next_first_hist_idx
CTA0->>Hist: clear current & next-round histogram (leader-only)
CTA0->>Barrier: wait (__syncthreads / inter-CTA ordering)
CTAn->>Barrier: wait
Barrier->>CTA0: release
Barrier->>CTAn: release
CTA0->>Hist: proceed with next-round writes
CTAn->>Hist: proceed with next-round writes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a subtle but significant data corruption bug within the FlashInfer top-k sampling kernels. The fix ensures the integrity of histogram data across different Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request addresses a critical bug in the multi-CTA radix top-k kernel related to histogram data corruption in persistent loop mode with mixed k-values. The changes correctly add histogram clearing logic to the paths that skip radix selection, preventing stale data from corrupting subsequent iterations. The added __syncthreads calls also correctly fix race conditions within the CTAs. The new regression tests in Python are well-designed to catch this specific bug. My review includes suggestions to improve code maintainability in both the CUDA kernel and the Python tests by reducing code duplication.
| // Clear histogram for next iteration (in case it's k < length) | ||
| if constexpr (!SINGLE_CTA) { | ||
| constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; | ||
| uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[next_first_hist_idx][i] = 0; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
This logic to clear the histogram for the next iteration is duplicated in 5 places across RadixTopKKernel_Unified, RadixTopKMaskLogitsKernel_MultiCTA, and RadixTopKRenormProbKernel_MultiCTA. To improve maintainability and reduce code duplication, consider refactoring this into a __device__ __forceinline__ helper function.
Also, the calculation for NUM_ROUNDS can be simplified. sizeof(OrderedType) * 8 / 8 is equivalent to just sizeof(OrderedType).
Here's an example of how you could define and use the helper function:
// Placed somewhere before its first use, e.g., after RadixRowState struct
template <typename OrderedType, uint32_t BLOCK_THREADS>
__device__ __forceinline__ void ClearNextIterationHistogram(
RadixRowState* state, uint32_t iter, uint32_t cta_in_group, uint32_t tx) {
constexpr uint32_t RADIX = 256;
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType);
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
}
// ... inside the kernel ...
// Clear histogram for next iteration (in case it's k < length)
if constexpr (!SINGLE_CTA) {
ClearNextIterationHistogram<OrderedType, BLOCK_THREADS>(state, iter, cta_in_group, tx);
}This would make the code cleaner and easier to maintain.
// Clear histogram for next iteration (in case it's k < length)
if constexpr (!SINGLE_CTA) {
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType);
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
}
| torch.manual_seed(42) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
|
|
||
| # Generate random logits | ||
| logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator) | ||
|
|
||
| # Generate k values: mix of small k and k == vocab_size | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| k_values = torch.randint( | ||
| 1, 1000, (batch_size,), device="cuda:0", generator=generator | ||
| ) | ||
|
|
||
| # Randomly set some rows to k == vocab_size (about 50%) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| mask = torch.randint( | ||
| 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0" | ||
| ) | ||
| k_values.masked_fill_(mask, vocab_size) |
There was a problem hiding this comment.
The torch.Generator is created and seeded multiple times within this test function. You can create it just once at the beginning of the test and reuse it for all random tensor creations to make the code cleaner and avoid redundancy.
torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)
# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
# Generate k values: mix of small k and k == vocab_size
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)
# Randomly set some rows to k == vocab_size (about 50%)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size)| torch.manual_seed(42) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
|
|
||
| # Generate random logits | ||
| logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator) | ||
| logits = logits.to(dtype) | ||
|
|
||
| # Generate k values: mix of small k and k == vocab_size | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| k_values = torch.randint( | ||
| 1, 1000, (batch_size,), device="cuda:0", generator=generator | ||
| ) | ||
|
|
||
| # Randomly set some rows to k == vocab_size (about 50%) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| mask = torch.randint( | ||
| 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0" | ||
| ) | ||
| k_values.masked_fill_(mask, vocab_size) |
There was a problem hiding this comment.
The torch.Generator is created and seeded multiple times within this test function. You can create it just once at the beginning of the test and reuse it for all random tensor creations to make the code cleaner and avoid redundancy.
torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)
# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
logits = logits.to(dtype)
# Generate k values: mix of small k and k == vocab_size
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)
# Randomly set some rows to k == vocab_size (about 50%)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size)There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Fix all issues with AI agents
In @include/flashinfer/topk.cuh:
- Around line 918-927: Multiple threads in the leading CTA write
state->histogram for the next iteration without an intra-CTA barrier, allowing
tx==0 to proceed and signal the inter-CTA barrier before all clears complete; in
each affected block (the branches guarded by if constexpr (!SINGLE_CTA) that
compute NUM_ROUNDS, next_first_hist_idx and clear state->histogram using tx and
BLOCK_THREADS) insert a __syncthreads() after the clearing loop to ensure all
threads in the CTA finish clearing (and any preceding atomicAdds) before tx==0
or other threads can advance to the inter-CTA synchronization; apply the same
insertion to the other listed locations (blocks around lines 937-947, 955-965,
1136-1147, 1425-1437) referencing the same symbols (state->histogram, tx,
cta_in_group, BLOCK_THREADS, iter, NUM_ROUNDS).
In @tests/utils/test_sampling.py:
- Around line 516-591: The test test_top_k_renorm_probs_mixed_k_persistent_loop
allocates huge tensors and does many GPU-syncing per-row ops which will
OOM/timeout and uses torch.randint(..., dtype=torch.bool) which is unsafe;
reduce memory and syncs by shrinking batch_size and vocab_size to CI-safe
values, generate mask with torch.randint(0,2,(batch_size,), device=...,
generator=...).bool() or torch.rand(device=...) < 0.5 instead of
dtype=torch.bool, replace the Python for-loop that calls .item() and per-row
torch.topk with vectorized tensor operations to compute nonzero_counts and
thresholds (operate on entire tensors and use torch.topk once/batched or
torch.kthvalue), and add a conditional skip using flashinfer.utils (e.g.,
skip_if_gpu_too_small or similar) at the top of the test to avoid running on
unsupported GPUs.
- Around line 677-737: The test test_top_k_mask_logits_mixed_k_persistent_loop
currently checks k>=vocab_size by counting finite values which is weaker and can
miss cases; update the k>=vocab_size branch to assert a direct equality copy for
those rows by comparing masked_logits[i] == logits[i] (or .equal() /
torch.allclose after casting to the same dtype) for each i where k_values[i] >=
vocab_size, referencing flashinfer.sampling.top_k_mask_logits, masked_logits,
logits, and k_values to locate the code.
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
include/flashinfer/topk.cuhtests/utils/test_sampling.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
tests/**/*.py: Test implementations should useflashinfer.utilsfunctions (get_compute_capability,is_sm90a_supported,is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing withmpirunon multi-GPU systems, use the pattern:mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes -tests/conftest.pyprovides auto-skipping for OOM tests as a safety net but should not be relied upon
Files:
tests/utils/test_sampling.py
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/topk.cuh
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
include/flashinfer/topk.cuh
🧬 Code graph analysis (1)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (4)
top_k_renorm_probs(391-410)top_k_renorm_probs(1350-1420)top_k_mask_logits(426-446)top_k_mask_logits(1427-1495)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/topk.cuh (1)
347-375: Good fix: block-wide sync before signaling the inter-CTA barrier.This prevents
tx==0from advancingarrival_counterwhile other threads in the same CTA are still doingatomicAdd/ clearing the next histogram.
include/flashinfer/topk.cuh
Outdated
| // CTA 0 clears output counter and first histogram AFTER barrier | ||
| // Only clear on iter==0 (buffer might be uninitialized on first kernel launch) | ||
| // For iter>0, k>=vocab iterations clear the next histogram at their end | ||
| // Per-round clearing handles subsequent rounds within the same iteration | ||
| if (cta_in_group == 0) { | ||
| if (iter == 0) { | ||
| // First iteration: clear first round's histogram (buffer might be uninitialized) | ||
| // Per-round clearing will handle histograms for rounds 1-3 | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[0][i] = 0; | ||
| } | ||
| } | ||
| if (tx == 0) { | ||
| st_release(&state->output_counter, 0); | ||
| } | ||
| } | ||
| __syncthreads(); // Ensure histogram clearing completes before any CTA proceeds | ||
| } |
There was a problem hiding this comment.
Potential race on iter==0: histogram[0] is cleared after the barrier, but other CTAs can start atomicAdd into it immediately.
If state->histogram[0] can be non-zero at kernel start, CTA0 clearing it must be followed by an inter-CTA sync before round-0 histogram accumulation begins.
One way to fix: add a one-time (iter==0) inter-CTA barrier after the clear
if constexpr (!SINGLE_CTA) {
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target, tx);
barrier_phase++;
__syncthreads();
// CTA 0 clears output counter and first histogram AFTER barrier
// Only clear on iter==0 (buffer might be uninitialized on first kernel launch)
// For iter>0, k>=vocab iterations clear the next histogram at their end
// Per-round clearing handles subsequent rounds within the same iteration
if (cta_in_group == 0) {
if (iter == 0) {
// First iteration: clear first round's histogram (buffer might be uninitialized)
// Per-round clearing will handle histograms for rounds 1-3
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[0][i] = 0;
}
}
if (tx == 0) {
st_release(&state->output_counter, 0);
}
}
- __syncthreads(); // Ensure histogram clearing completes before any CTA proceeds
+ __syncthreads();
+
+ // One-time sync so no CTA starts round-0 atomicAdds before histogram[0] is cleared.
+ if (iter == 0) {
+ if (tx == 0) {
+ red_release(&state->arrival_counter, 1);
+ }
+ int target2 = (barrier_phase + 1) * ctas_per_group;
+ wait_ge(&state->arrival_counter, target2, tx);
+ barrier_phase++;
+ __syncthreads();
+ }
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // CTA 0 clears output counter and first histogram AFTER barrier | |
| // Only clear on iter==0 (buffer might be uninitialized on first kernel launch) | |
| // For iter>0, k>=vocab iterations clear the next histogram at their end | |
| // Per-round clearing handles subsequent rounds within the same iteration | |
| if (cta_in_group == 0) { | |
| if (iter == 0) { | |
| // First iteration: clear first round's histogram (buffer might be uninitialized) | |
| // Per-round clearing will handle histograms for rounds 1-3 | |
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | |
| state->histogram[0][i] = 0; | |
| } | |
| } | |
| if (tx == 0) { | |
| st_release(&state->output_counter, 0); | |
| } | |
| } | |
| __syncthreads(); // Ensure histogram clearing completes before any CTA proceeds | |
| } | |
| // CTA 0 clears output counter and first histogram AFTER barrier | |
| // Only clear on iter==0 (buffer might be uninitialized on first kernel launch) | |
| // For iter>0, k>=vocab iterations clear the next histogram at their end | |
| // Per-round clearing handles subsequent rounds within the same iteration | |
| if (cta_in_group == 0) { | |
| if (iter == 0) { | |
| // First iteration: clear first round's histogram (buffer might be uninitialized) | |
| // Per-round clearing will handle histograms for rounds 1-3 | |
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | |
| state->histogram[0][i] = 0; | |
| } | |
| } | |
| if (tx == 0) { | |
| st_release(&state->output_counter, 0); | |
| } | |
| } | |
| __syncthreads(); | |
| // One-time sync so no CTA starts round-0 atomicAdds before histogram[0] is cleared. | |
| if (iter == 0) { | |
| if (tx == 0) { | |
| red_release(&state->arrival_counter, 1); | |
| } | |
| int target2 = (barrier_phase + 1) * ctas_per_group; | |
| wait_ge(&state->arrival_counter, target2, tx); | |
| barrier_phase++; | |
| __syncthreads(); | |
| } | |
| } |
There was a problem hiding this comment.
If state->histogram[0] can be non-zero at kernel start
They are initialized to 0 at python-side when passing to the kernel at the first time, and we should guarantee they are 0 after kernel finished.
There was a problem hiding this comment.
@yzh119, thank you for the clarification! That makes sense.
If the state->histogram buffers are guaranteed to be zero at kernel entry (Python-side initialization) and zero at kernel exit (the cleanup code at the end of the kernel), then:
- The
iter==0clearing is indeed redundant – the histogram is already clean from Python initialization - No inter-CTA race concern – since the buffer is already zero, the clearing is essentially a no-op
- The real fix (as described in the PR) is clearing histograms in the
k >= vocab_sizepaths for subsequent iterations (iter > 0), which prevents stale data from affecting later rows
The code is correct with this guarantee. One minor suggestion: you might want to update the comment to reflect that the clearing is defensive/redundant rather than necessary:
// First iteration: defensively clear first round's histogram
// (already initialized to 0 by Python, but explicit clear for safety)This would make the code's intent clearer for future readers.
🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
There was a problem hiding this comment.
After a second thought, I suppose the change we made here is redundant.
There was a problem hiding this comment.
@yzh119, you're absolutely right! Given that:
- Histograms are zero-initialized from Python before the first kernel launch
- Histograms are guaranteed zero at kernel exit (via the cleanup code at the end)
- Per-round clearing handles histogram buffers during the iteration
- The k≥vocab_size paths clear the next iteration's first histogram
The iter==0 clearing in RadixSelectFromSharedMemory is indeed redundant — histogram[0] is already guaranteed to be zero from Python initialization.
Suggestion: You could remove the iter==0 block entirely:
if (cta_in_group == 0) {
if (tx == 0) {
st_release(&state->output_counter, 0);
}
}
__syncthreads();The actual fix for the bug (#2320) is the histogram clearing in the k≥vocab_size early-exit paths across the kernels (RadixTopKKernel_Unified, RadixTopKMaskLogitsKernel_MultiCTA, RadixTopKRenormProbKernel_MultiCTA), which prevents stale histogram data from affecting subsequent iterations in persistent loop mode. The RadixSelectFromSharedMemory change is unnecessary defensive code.
🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
| // Clear histogram for next iteration (in case it's k < length) | ||
| if constexpr (!SINGLE_CTA) { | ||
| constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; | ||
| uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[next_first_hist_idx][i] = 0; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Add an intra-CTA sync after “next-iteration histogram clear” to avoid partial clears being “released” early.
In these blocks, multiple threads in CTA0 write state->histogram[...], but there’s no __syncthreads() to prevent CTA0’s tx==0 from running ahead into the next iteration and signaling the next inter-CTA barrier before all clears finish (similar to the issue you already fixed in RadixSelectOneRound).
Minimal pattern to apply in each location
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
+ __syncthreads(); // ensure CTA0 finished clearing before the next iteration can signal
}
- // No sync needed - next iteration's barrier will ensure visibility
+ // Next iteration's inter-CTA barrier will handle cross-CTA visibility/order.
}Also: the comment “Ensure all threads in leading CTA finish clearing…” (Line 364) is slightly misleading—this __syncthreads() is also ensuring all threads finish the atomicAdds before tx==0 signals the barrier.
Also applies to: 937-947, 955-965, 1136-1147, 1425-1437
🤖 Prompt for AI Agents
In @include/flashinfer/topk.cuh around lines 918 - 927, Multiple threads in the
leading CTA write state->histogram for the next iteration without an intra-CTA
barrier, allowing tx==0 to proceed and signal the inter-CTA barrier before all
clears complete; in each affected block (the branches guarded by if constexpr
(!SINGLE_CTA) that compute NUM_ROUNDS, next_first_hist_idx and clear
state->histogram using tx and BLOCK_THREADS) insert a __syncthreads() after the
clearing loop to ensure all threads in the CTA finish clearing (and any
preceding atomicAdds) before tx==0 or other threads can advance to the inter-CTA
synchronization; apply the same insertion to the other listed locations (blocks
around lines 937-947, 955-965, 1136-1147, 1425-1437) referencing the same
symbols (state->histogram, tx, cta_in_group, BLOCK_THREADS, iter, NUM_ROUNDS).
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) | ||
| def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype): | ||
| """Test top_k_renorm_probs with mixed k values in persistent loop (multi-CTA mode). | ||
|
|
||
| This test catches a specific bug where: | ||
| - Large batch size triggers the persistent loop (multiple iterations per CTA group) | ||
| - Large vocab_size triggers multi-CTA mode (multiple CTAs per row) | ||
| - Mixed k values: some rows have k >= vocab_size (skip radix select), | ||
| others have k < vocab_size (use radix select) | ||
|
|
||
| The bug was that k >= vocab_size iterations would skip radix select | ||
| without clearing the histogram buffers, leaving stale data that corrupted | ||
| subsequent k < vocab_size iterations. | ||
| """ | ||
| batch_size = 1024 # Large batch to trigger persistent loop | ||
| vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode | ||
|
|
||
| torch.manual_seed(42) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
|
|
||
| # Generate random logits | ||
| logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator) | ||
|
|
||
| # Generate k values: mix of small k and k == vocab_size | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| k_values = torch.randint( | ||
| 1, 1000, (batch_size,), device="cuda:0", generator=generator | ||
| ) | ||
|
|
||
| # Randomly set some rows to k == vocab_size (about 50%) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| mask = torch.randint( | ||
| 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0" | ||
| ) | ||
| k_values.masked_fill_(mask, vocab_size) | ||
|
|
||
| # Convert to probs | ||
| probs = torch.softmax(logits, dim=-1).to(dtype) | ||
|
|
||
| # Run FlashInfer top_k_renorm_probs | ||
| renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k_values) | ||
|
|
||
| # Verify output dtype | ||
| assert renorm_probs.dtype == dtype | ||
|
|
||
| # Verify sum to 1 | ||
| sums = renorm_probs.float().sum(dim=-1) | ||
| torch.testing.assert_close(sums, torch.ones_like(sums), rtol=1e-2, atol=1e-2) | ||
|
|
||
| # Verify non-zero count matches k for each row | ||
| nonzero_counts = (renorm_probs > 0).sum(dim=-1) | ||
|
|
||
| # For rows with k >= vocab_size, all elements should be non-zero | ||
| # For rows with k < vocab_size, non-zero count should be >= k (may be more due to ties) | ||
| for i in range(batch_size): | ||
| k = k_values[i].item() | ||
| count = nonzero_counts[i].item() | ||
|
|
||
| if k >= vocab_size: | ||
| # All elements should be non-zero | ||
| assert count == vocab_size, ( | ||
| f"Row {i}: k >= vocab_size but count={count} != {vocab_size}" | ||
| ) | ||
| else: | ||
| # Count should be at least k (may be more due to ties at the threshold) | ||
| row_probs = probs[i].float() | ||
| topk_vals, _ = torch.topk(row_probs, k, sorted=True) | ||
| threshold = topk_vals[-1] | ||
| expected_ge_threshold = (row_probs >= threshold).sum().item() | ||
|
|
||
| # Allow small tolerance for floating point | ||
| assert count >= k - 1, f"Row {i}: k={k} but only {count} non-zero elements" | ||
| assert count <= expected_ge_threshold + 1, ( | ||
| f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}" | ||
| ) | ||
|
|
There was a problem hiding this comment.
These new tests are likely to OOM / timeout in CI; also avoid torch.randint(..., dtype=torch.bool) here.
- Memory:
1024 x 131072for logits + probs + output is very likely to exceed GPU memory (especially forfloat32). - Runtime: the
for i in range(batch_size)loop does repeated GPU syncs via.item()and runstorch.topkup to 1024 times. torch.randint(..., dtype=torch.bool, ...)is risky; prefer generating int/binary then.bool()(ortorch.rand()<p).- Per
tests/**/*.pyguidelines, please addflashinfer.utils-based skipping for unsupported GPUs and size defensively.
Proposed tightening (reduce OOM risk + reduce syncs)
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA required")
+ # Per tests guideline: consider using flashinfer.utils helpers to skip unsupported archs.
+
- batch_size = 1024 # Large batch to trigger persistent loop
- vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode
+ # Still triggers multi-CTA (vocab_size > 57344) but avoids common OOM.
+ vocab_size = 65536
+ batch_size = 256 if dtype == torch.float32 else 512
torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)
# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
@@
- mask = torch.randint(
- 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
- )
+ mask = torch.randint(0, 2, (batch_size,), generator=generator, device="cuda:0").bool()
k_values.masked_fill_(mask, vocab_size)
@@
- for i in range(batch_size):
+ # Only spot-check a small subset of k < vocab rows (avoid 1000x topk calls + .item() sync).
+ small_rows = torch.nonzero(k_values < vocab_size, as_tuple=False).flatten()[:16].tolist()
+ for i in small_rows:
k = k_values[i].item()
count = nonzero_counts[i].item()
@@
- if k >= vocab_size:
- # All elements should be non-zero
- assert count == vocab_size, (
- f"Row {i}: k >= vocab_size but count={count} != {vocab_size}"
- )
- else:
+ if k < vocab_size:
# Count should be at least k (may be more due to ties at the threshold)
row_probs = probs[i].float()
topk_vals, _ = torch.topk(row_probs, k, sorted=True)
threshold = topk_vals[-1]
expected_ge_threshold = (row_probs >= threshold).sum().item()🤖 Prompt for AI Agents
In @tests/utils/test_sampling.py around lines 516 - 591, The test
test_top_k_renorm_probs_mixed_k_persistent_loop allocates huge tensors and does
many GPU-syncing per-row ops which will OOM/timeout and uses torch.randint(...,
dtype=torch.bool) which is unsafe; reduce memory and syncs by shrinking
batch_size and vocab_size to CI-safe values, generate mask with
torch.randint(0,2,(batch_size,), device=..., generator=...).bool() or
torch.rand(device=...) < 0.5 instead of dtype=torch.bool, replace the Python
for-loop that calls .item() and per-row torch.topk with vectorized tensor
operations to compute nonzero_counts and thresholds (operate on entire tensors
and use torch.topk once/batched or torch.kthvalue), and add a conditional skip
using flashinfer.utils (e.g., skip_if_gpu_too_small or similar) at the top of
the test to avoid running on unsupported GPUs.
There was a problem hiding this comment.
We have enough GPU memory on CI.
There was a problem hiding this comment.
@yzh119, understood! Since you've confirmed the CI environment has sufficient GPU memory for these tests, that addresses the main concern. Feel free to keep the test as-is.
🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - `tests/conftest.py` provides auto-skipping for OOM tests as a safety net but should not be relied upon
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) | ||
| def test_top_k_mask_logits_mixed_k_persistent_loop(dtype): | ||
| """Test top_k_mask_logits with mixed k values in persistent loop (multi-CTA mode). | ||
|
|
||
| This test catches the same bug as test_top_k_renorm_probs_mixed_k_persistent_loop | ||
| but for the mask_logits variant. | ||
| """ | ||
| batch_size = 1024 # Large batch to trigger persistent loop | ||
| vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode | ||
|
|
||
| torch.manual_seed(42) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
|
|
||
| # Generate random logits | ||
| logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator) | ||
| logits = logits.to(dtype) | ||
|
|
||
| # Generate k values: mix of small k and k == vocab_size | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| k_values = torch.randint( | ||
| 1, 1000, (batch_size,), device="cuda:0", generator=generator | ||
| ) | ||
|
|
||
| # Randomly set some rows to k == vocab_size (about 50%) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| mask = torch.randint( | ||
| 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0" | ||
| ) | ||
| k_values.masked_fill_(mask, vocab_size) | ||
|
|
||
| # Run FlashInfer top_k_mask_logits | ||
| masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k_values) | ||
|
|
||
| # Verify output dtype | ||
| assert masked_logits.dtype == dtype | ||
|
|
||
| # Verify finite count matches k for each row | ||
| finite_counts = torch.isfinite(masked_logits).sum(dim=-1) | ||
|
|
||
| for i in range(batch_size): | ||
| k = k_values[i].item() | ||
| count = finite_counts[i].item() | ||
|
|
||
| if k >= vocab_size: | ||
| # All elements should be finite | ||
| assert count == vocab_size, ( | ||
| f"Row {i}: k >= vocab_size but finite count={count} != {vocab_size}" | ||
| ) | ||
| else: | ||
| # Count should be at least k (may be more due to ties at the threshold) | ||
| row_logits = logits[i].float() | ||
| topk_vals, _ = torch.topk(row_logits, k, sorted=True) | ||
| threshold = topk_vals[-1] | ||
| expected_ge_threshold = (row_logits >= threshold).sum().item() | ||
|
|
||
| # Allow small tolerance for floating point | ||
| assert count >= k - 1, f"Row {i}: k={k} but only {count} finite elements" | ||
| assert count <= expected_ge_threshold + 1, ( | ||
| f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Same OOM/timeout concerns apply; also consider validating the k>=vocab branch via equality rather than counts.
For top_k_mask_logits, when k >= vocab_size the kernel should be a pure copy; asserting masked_logits == logits for those rows is typically a stronger (and cheaper) check than counting finites, and avoids edge cases if inputs change (e.g., allow -inf later).
Minimal adjustments mirroring the other test
def test_top_k_mask_logits_mixed_k_persistent_loop(dtype):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA required")
batch_size = 1024 # Large batch to trigger persistent loop
vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode
@@
- mask = torch.randint(
- 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
- )
+ mask = torch.randint(0, 2, (batch_size,), generator=generator, device="cuda:0").bool()
k_values.masked_fill_(mask, vocab_size)
+
+ # Stronger + cheaper: rows with k>=vocab are copies
+ full_rows = (k_values >= vocab_size).nonzero(as_tuple=False).flatten()
+ if full_rows.numel() > 0:
+ torch.testing.assert_close(masked_logits[full_rows], logits[full_rows])Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In @tests/utils/test_sampling.py around lines 677 - 737, The test
test_top_k_mask_logits_mixed_k_persistent_loop currently checks k>=vocab_size by
counting finite values which is weaker and can miss cases; update the
k>=vocab_size branch to assert a direct equality copy for those rows by
comparing masked_logits[i] == logits[i] (or .equal() / torch.allclose after
casting to the same dtype) for each i where k_values[i] >= vocab_size,
referencing flashinfer.sampling.top_k_mask_logits, masked_logits, logits, and
k_values to locate the code.
|
/bot run |
|
[SUCCESS] Pipeline #41463402: 17/20 passed |
|
/bot run |
|
Warning: Failed to sync latest changes. Please try again. |
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)
919-946: PageTableTransform and RaggedTransform paths correctly apply the same fix.Both paths properly clear the histogram for the next iteration when the current row is trivially handled (length <= top_k_val).
Consider extracting the histogram clearing into a helper lambda or inline function to reduce the code duplication across all three modes:
♻️ Optional refactor to reduce duplication
// Define once before the mode-specific branches auto clear_next_iteration_histogram = [&]() { if constexpr (!SINGLE_CTA) { constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; if (cta_in_group == 0) { for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { state->histogram[next_first_hist_idx][i] = 0; } } } }; // Then in each mode's trivial case: if (k >= length) { // ... copy output ... clear_next_iteration_histogram(); continue; }
📜 Review details
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
include/flashinfer/topk.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh
📄 CodeRabbit inference engine (CLAUDE.md)
include/**/*.cuh: Torch headers MUST NOT be included in files within theinclude/directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code ininclude/flashinfer/is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly
Files:
include/flashinfer/topk.cuh
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Applied to files:
include/flashinfer/topk.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (3)
include/flashinfer/topk.cuh (3)
900-910: Histogram clearing for skip path looks correct.The calculation
((iter + 1) * NUM_ROUNDS) % 3correctly identifies the histogram buffer that would be used by the next iteration's first radix round. This mirrors what the last round of a normal iteration would clear vianext_hist = state->histogram[(global_round + 1) % 3].Minor observation: The pattern
sizeof(OrderedType) * 8 / 8simplifies tosizeof(OrderedType). While equivalent toORDERED_BITS / RADIX_BITS, using a named constant could improve clarity:constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<8>(); // Use existing traitThis is optional since the current approach works correctly.
1117-1129: Histogram clearing for MaskLogits k >= vocab_size path is correct.The comment on line 1128 ("No sync needed - next iteration's barrier will ensure visibility") accurately explains why no explicit synchronization is required after the clearing. The next iteration's initial barrier in
RadixSelectFromSharedMemorywill synchronize all CTAs before they access the histogram.
1406-1419: RenormProb histogram clearing correctly handles the k >= vocab_size path.This path is slightly different from MaskLogits because it still performs barrier-synchronized sum reduction before the histogram clearing. The placement after both the sum computation and output writing is correct.
The comment on line 1409 correctly notes that next iteration's barrier ensures visibility, which is consistent with the barrier tracking design (cumulative arrival counter).
📌 Description
This PR fixes a bug in the multi-CTA radix top-k kernel where histogram data corruption occurs when processing batches with mixed k values (some k >= vocab_size, others k < vocab_size) in persistent loop mode.
Root Cause: When a row has k >= vocab_size, the radix select phase is skipped entirely. However, this leaves stale histogram data in the RadixRowState buffer. When a subsequent row in the same persistent loop iteration has k < vocab_size, it uses atomicAdd on these stale histograms, causing incorrect pivot selection and wrong top-k results.
Conditions to trigger the bug:
Changes
🔍 Related Issues
#2320
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.