Skip to content

bugfix: fix multi-cta top-k implementation when k value is different for different row#2325

Merged
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
yzh119:fix-topk
Jan 13, 2026
Merged

bugfix: fix multi-cta top-k implementation when k value is different for different row#2325
yzh119 merged 4 commits intoflashinfer-ai:mainfrom
yzh119:fix-topk

Conversation

@yzh119
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 commented Jan 10, 2026

📌 Description

This PR fixes a bug in the multi-CTA radix top-k kernel where histogram data corruption occurs when processing batches with mixed k values (some k >= vocab_size, others k < vocab_size) in persistent loop mode.

Root Cause: When a row has k >= vocab_size, the radix select phase is skipped entirely. However, this leaves stale histogram data in the RadixRowState buffer. When a subsequent row in the same persistent loop iteration has k < vocab_size, it uses atomicAdd on these stale histograms, causing incorrect pivot selection and wrong top-k results.

Conditions to trigger the bug:

  1. Multi-CTA mode (vocab_size > 57344)
  2. Persistent loop mode (batch_size > num_groups)
  3. Mixed k values across rows (some k >= vocab_size, others k < vocab_size)

Changes

  • include/flashinfer/topk.cuh:
    • RadixSelectFromSharedMemory: Only clear histogram[0] when iter == 0
    • RadixTopKKernel_Unified: Add histogram clearing in all k >= length paths (Basic, PageTableTransform, RaggedTransform modes)
    • RadixTopKMaskLogitsKernel_MultiCTA: Add histogram clearing in k >= vocab_size path
    • RadixTopKRenormProbKernel_MultiCTA: Add histogram clearing in k >= vocab_size path
  • tests/utils/test_sampling.py:
    • Added regression tests for mixed k values with persistent loop

🔍 Related Issues

#2320

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Improved inter-worker synchronization and explicit histogram buffer clearing across top-k kernels to prevent cross-iteration/barrier data hazards, improving reliability for large batches and persistent (multi-worker) loops.
  • Tests

    • Added mixed-k tests for probability renormalization and logit masking across dtypes and large batch/vocabulary sizes to validate persistent/multi-worker behavior.

✏️ Tip: You can customize this high-level summary in your review settings.

yzh119 and others added 2 commits January 9, 2026 19:15
Fixed a race condition in the multi-CTA radix top-k kernel where stale
histogram data corrupted results when processing batches with mixed k values.

Bug: When k >= vocab_size iterations skip RadixSelectFindPivot, they don't
clear histogram buffers. Subsequent k < vocab_size iterations then use
atomicAdd on stale histogram data, causing incorrect pivot selection.

Changes:
- Add __syncthreads() before red_release() in RadixSelectOneRound and
  RadixSelectFromSharedMemory to ensure histogram clearing completes
  before signaling barrier
- Clear first round's histogram at start of RadixSelectFromSharedMemory
  to handle stale data from skipped k>=vocab iterations

Added regression tests:
- test_top_k_renorm_probs_mixed_k_persistent_loop
- test_top_k_mask_logits_mixed_k_persistent_loop

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Jan 10, 2026

📝 Walkthrough

Walkthrough

Adds CTA-leader guarded per-round histogram clears and barrier-aware placements across multiple RadixTopK multi-CTA kernel paths; also adds two tests exercising mixed-k behavior in persistent (multi-CTA) loops for top-k renorm/prob and mask-logits kernels.

Changes

Cohort / File(s) Summary
Kernel synchronization & histogram management
include/flashinfer/topk.cuh
Adds leader-only clears of next-round histogram buffers (compute next_first_hist_idx using per-round NUM_ROUNDS) and places clears after key result materialization across multi-CTA RadixTopK variants (Unified, MaskLogitsMultiCTA, RenormProbMultiCTA, PageTable/Ragged epilogues) to avoid cross-barrier data hazards.
Multi-k persistent loop tests
tests/utils/test_sampling.py
Adds test_top_k_renorm_probs_mixed_k_persistent_loop() and test_top_k_mask_logits_mixed_k_persistent_loop() to exercise mixed k (including k >= vocab_size) across dtypes and large batch/vocab sizes to trigger persistent/multi-CTA behavior and validate outputs.

Sequence Diagram(s)

sequenceDiagram
    participant CTA0 as CTA 0 (leader)
    participant CTAn as Other CTAs
    participant Hist as Histogram Buffer
    participant Barrier as CTA Barrier (sync)

    CTA0->>Hist: compute/identify next_first_hist_idx
    CTA0->>Hist: clear current & next-round histogram (leader-only)
    CTA0->>Barrier: wait (__syncthreads / inter-CTA ordering)
    CTAn->>Barrier: wait
    Barrier->>CTA0: release
    Barrier->>CTAn: release
    CTA0->>Hist: proceed with next-round writes
    CTAn->>Hist: proceed with next-round writes
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • IwakuraRein
  • jiahanc
  • kahyunnam
  • cyx-6
  • nvmbreughe

Poem

🐰 Hopping through rounds with a cheerful cheer,
I clear the bins so data's clear.
CTA leaders tap the drum,
Barriers clap — the rounds go hum.
Fresh histos bloom, the loops persevere!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main bugfix focus: addressing multi-CTA top-k implementation issues when k values differ across rows, which is the core change.
Description check ✅ Passed The description covers the bug root cause, trigger conditions, affected code locations, and includes regression tests, following the provided template structure.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a subtle but significant data corruption bug within the FlashInfer top-k sampling kernels. The fix ensures the integrity of histogram data across different k values in batched, multi-CTA, persistent loop scenarios, thereby guaranteeing the correctness of top-k selections in high-performance GPU computations. The changes prevent erroneous results that could arise from stale state information, enhancing the reliability of the sampling operations.

Highlights

  • Bug Fix: Radix Top-K Histogram Corruption: This PR resolves a critical bug in the multi-CTA radix top-k kernel where histogram data corruption occurred. This happened when processing batches with mixed k values (some k >= vocab_size, others k < vocab_size) in persistent loop mode.
  • Root Cause Identification: The bug stemmed from k >= vocab_size rows skipping the radix select phase without clearing stale histogram data in the RadixRowState buffer. Subsequent k < vocab_size rows would then use this corrupted data, leading to incorrect pivot selection and top-k results.
  • Targeted Fixes in Kernel Code: Histogram clearing logic has been added or refined in RadixTopKKernel_Unified, RadixTopKMaskLogitsKernel_MultiCTA, and RadixTopKRenormProbKernel_MultiCTA to ensure that histogram buffers are properly reset when k >= vocab_size rows are processed. Synchronization primitives (__syncthreads()) were also added to ensure proper memory visibility during clearing operations.
  • New Regression Tests: Comprehensive regression tests (test_top_k_renorm_probs_mixed_k_persistent_loop and test_top_k_mask_logits_mixed_k_persistent_loop) have been introduced to specifically validate the fix under conditions that previously triggered the bug (large batch size, large vocab size, mixed k values).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yzh119 yzh119 changed the title Fix topk bugfix: fix multi-cta top-k implementation when k value is different for different row Jan 10, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in the multi-CTA radix top-k kernel related to histogram data corruption in persistent loop mode with mixed k-values. The changes correctly add histogram clearing logic to the paths that skip radix selection, preventing stale data from corrupting subsequent iterations. The added __syncthreads calls also correctly fix race conditions within the CTAs. The new regression tests in Python are well-designed to catch this specific bug. My review includes suggestions to improve code maintainability in both the CUDA kernel and the Python tests by reducing code duplication.

Comment on lines +918 to +927
// Clear histogram for next iteration (in case it's k < length)
if constexpr (!SINGLE_CTA) {
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8;
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to clear the histogram for the next iteration is duplicated in 5 places across RadixTopKKernel_Unified, RadixTopKMaskLogitsKernel_MultiCTA, and RadixTopKRenormProbKernel_MultiCTA. To improve maintainability and reduce code duplication, consider refactoring this into a __device__ __forceinline__ helper function.

Also, the calculation for NUM_ROUNDS can be simplified. sizeof(OrderedType) * 8 / 8 is equivalent to just sizeof(OrderedType).

Here's an example of how you could define and use the helper function:

// Placed somewhere before its first use, e.g., after RadixRowState struct
template <typename OrderedType, uint32_t BLOCK_THREADS>
__device__ __forceinline__ void ClearNextIterationHistogram(
    RadixRowState* state, uint32_t iter, uint32_t cta_in_group, uint32_t tx) {
  constexpr uint32_t RADIX = 256;
  constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType);
  uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
  if (cta_in_group == 0) {
    for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
      state->histogram[next_first_hist_idx][i] = 0;
    }
  }
}

// ... inside the kernel ...
// Clear histogram for next iteration (in case it's k < length)
if constexpr (!SINGLE_CTA) {
  ClearNextIterationHistogram<OrderedType, BLOCK_THREADS>(state, iter, cta_in_group, tx);
}

This would make the code cleaner and easier to maintain.

        // Clear histogram for next iteration (in case it's k < length)
        if constexpr (!SINGLE_CTA) {
          constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType);
          uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
          if (cta_in_group == 0) {
            for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
              state->histogram[next_first_hist_idx][i] = 0;
            }
          }
        }

Comment on lines +533 to +550
torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)

# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)

# Generate k values: mix of small k and k == vocab_size
generator = torch.Generator(device="cuda:0").manual_seed(42)
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)

# Randomly set some rows to k == vocab_size (about 50%)
generator = torch.Generator(device="cuda:0").manual_seed(42)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch.Generator is created and seeded multiple times within this test function. You can create it just once at the beginning of the test and reuse it for all random tensor creations to make the code cleaner and avoid redundancy.

    torch.manual_seed(42)
    generator = torch.Generator(device="cuda:0").manual_seed(42)

    # Generate random logits
    logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)

    # Generate k values: mix of small k and k == vocab_size
    k_values = torch.randint(
        1, 1000, (batch_size,), device="cuda:0", generator=generator
    )

    # Randomly set some rows to k == vocab_size (about 50%)
    mask = torch.randint(
        0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
    )
    k_values.masked_fill_(mask, vocab_size)

Comment on lines +687 to +705
torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)

# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
logits = logits.to(dtype)

# Generate k values: mix of small k and k == vocab_size
generator = torch.Generator(device="cuda:0").manual_seed(42)
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)

# Randomly set some rows to k == vocab_size (about 50%)
generator = torch.Generator(device="cuda:0").manual_seed(42)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch.Generator is created and seeded multiple times within this test function. You can create it just once at the beginning of the test and reuse it for all random tensor creations to make the code cleaner and avoid redundancy.

    torch.manual_seed(42)
    generator = torch.Generator(device="cuda:0").manual_seed(42)

    # Generate random logits
    logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
    logits = logits.to(dtype)

    # Generate k values: mix of small k and k == vocab_size
    k_values = torch.randint(
        1, 1000, (batch_size,), device="cuda:0", generator=generator
    )

    # Randomly set some rows to k == vocab_size (about 50%)
    mask = torch.randint(
        0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
    )
    k_values.masked_fill_(mask, vocab_size)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In @include/flashinfer/topk.cuh:
- Around line 918-927: Multiple threads in the leading CTA write
state->histogram for the next iteration without an intra-CTA barrier, allowing
tx==0 to proceed and signal the inter-CTA barrier before all clears complete; in
each affected block (the branches guarded by if constexpr (!SINGLE_CTA) that
compute NUM_ROUNDS, next_first_hist_idx and clear state->histogram using tx and
BLOCK_THREADS) insert a __syncthreads() after the clearing loop to ensure all
threads in the CTA finish clearing (and any preceding atomicAdds) before tx==0
or other threads can advance to the inter-CTA synchronization; apply the same
insertion to the other listed locations (blocks around lines 937-947, 955-965,
1136-1147, 1425-1437) referencing the same symbols (state->histogram, tx,
cta_in_group, BLOCK_THREADS, iter, NUM_ROUNDS).

In @tests/utils/test_sampling.py:
- Around line 516-591: The test test_top_k_renorm_probs_mixed_k_persistent_loop
allocates huge tensors and does many GPU-syncing per-row ops which will
OOM/timeout and uses torch.randint(..., dtype=torch.bool) which is unsafe;
reduce memory and syncs by shrinking batch_size and vocab_size to CI-safe
values, generate mask with torch.randint(0,2,(batch_size,), device=...,
generator=...).bool() or torch.rand(device=...) < 0.5 instead of
dtype=torch.bool, replace the Python for-loop that calls .item() and per-row
torch.topk with vectorized tensor operations to compute nonzero_counts and
thresholds (operate on entire tensors and use torch.topk once/batched or
torch.kthvalue), and add a conditional skip using flashinfer.utils (e.g.,
skip_if_gpu_too_small or similar) at the top of the test to avoid running on
unsupported GPUs.
- Around line 677-737: The test test_top_k_mask_logits_mixed_k_persistent_loop
currently checks k>=vocab_size by counting finite values which is weaker and can
miss cases; update the k>=vocab_size branch to assert a direct equality copy for
those rows by comparing masked_logits[i] == logits[i] (or .equal() /
torch.allclose after casting to the same dtype) for each i where k_values[i] >=
vocab_size, referencing flashinfer.sampling.top_k_mask_logits, masked_logits,
logits, and k_values to locate the code.
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2062dec and 7e8ecb2.

📒 Files selected for processing (2)
  • include/flashinfer/topk.cuh
  • tests/utils/test_sampling.py
🧰 Additional context used
📓 Path-based instructions (2)
tests/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

tests/**/*.py: Test implementations should use flashinfer.utils functions (get_compute_capability, is_sm90a_supported, is_sm100a_supported, etc.) to skip tests on unsupported GPU architectures
For testing with mpirun on multi-GPU systems, use the pattern: mpirun -np <num_gpus> pytest tests/path/to/test.py::test_function
Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - tests/conftest.py provides auto-skipping for OOM tests as a safety net but should not be relied upon

Files:

  • tests/utils/test_sampling.py
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/topk.cuh
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/topk.cuh
🧬 Code graph analysis (1)
tests/utils/test_sampling.py (1)
flashinfer/sampling.py (4)
  • top_k_renorm_probs (391-410)
  • top_k_renorm_probs (1350-1420)
  • top_k_mask_logits (426-446)
  • top_k_mask_logits (1427-1495)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (1)
include/flashinfer/topk.cuh (1)

347-375: Good fix: block-wide sync before signaling the inter-CTA barrier.

This prevents tx==0 from advancing arrival_counter while other threads in the same CTA are still doing atomicAdd / clearing the next histogram.

Comment on lines 505 to 522
// CTA 0 clears output counter and first histogram AFTER barrier
// Only clear on iter==0 (buffer might be uninitialized on first kernel launch)
// For iter>0, k>=vocab iterations clear the next histogram at their end
// Per-round clearing handles subsequent rounds within the same iteration
if (cta_in_group == 0) {
if (iter == 0) {
// First iteration: clear first round's histogram (buffer might be uninitialized)
// Per-round clearing will handle histograms for rounds 1-3
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[0][i] = 0;
}
}
if (tx == 0) {
st_release(&state->output_counter, 0);
}
}
__syncthreads(); // Ensure histogram clearing completes before any CTA proceeds
}
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Potential race on iter==0: histogram[0] is cleared after the barrier, but other CTAs can start atomicAdd into it immediately.

If state->histogram[0] can be non-zero at kernel start, CTA0 clearing it must be followed by an inter-CTA sync before round-0 histogram accumulation begins.

One way to fix: add a one-time (iter==0) inter-CTA barrier after the clear
   if constexpr (!SINGLE_CTA) {
     if (tx == 0) {
       red_release(&state->arrival_counter, 1);
     }
     int target = (barrier_phase + 1) * ctas_per_group;
     wait_ge(&state->arrival_counter, target, tx);
     barrier_phase++;
     __syncthreads();

     // CTA 0 clears output counter and first histogram AFTER barrier
     // Only clear on iter==0 (buffer might be uninitialized on first kernel launch)
     // For iter>0, k>=vocab iterations clear the next histogram at their end
     // Per-round clearing handles subsequent rounds within the same iteration
     if (cta_in_group == 0) {
       if (iter == 0) {
         // First iteration: clear first round's histogram (buffer might be uninitialized)
         // Per-round clearing will handle histograms for rounds 1-3
         for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
           state->histogram[0][i] = 0;
         }
       }
       if (tx == 0) {
         st_release(&state->output_counter, 0);
       }
     }
-    __syncthreads();  // Ensure histogram clearing completes before any CTA proceeds
+    __syncthreads();
+
+    // One-time sync so no CTA starts round-0 atomicAdds before histogram[0] is cleared.
+    if (iter == 0) {
+      if (tx == 0) {
+        red_release(&state->arrival_counter, 1);
+      }
+      int target2 = (barrier_phase + 1) * ctas_per_group;
+      wait_ge(&state->arrival_counter, target2, tx);
+      barrier_phase++;
+      __syncthreads();
+    }
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// CTA 0 clears output counter and first histogram AFTER barrier
// Only clear on iter==0 (buffer might be uninitialized on first kernel launch)
// For iter>0, k>=vocab iterations clear the next histogram at their end
// Per-round clearing handles subsequent rounds within the same iteration
if (cta_in_group == 0) {
if (iter == 0) {
// First iteration: clear first round's histogram (buffer might be uninitialized)
// Per-round clearing will handle histograms for rounds 1-3
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[0][i] = 0;
}
}
if (tx == 0) {
st_release(&state->output_counter, 0);
}
}
__syncthreads(); // Ensure histogram clearing completes before any CTA proceeds
}
// CTA 0 clears output counter and first histogram AFTER barrier
// Only clear on iter==0 (buffer might be uninitialized on first kernel launch)
// For iter>0, k>=vocab iterations clear the next histogram at their end
// Per-round clearing handles subsequent rounds within the same iteration
if (cta_in_group == 0) {
if (iter == 0) {
// First iteration: clear first round's histogram (buffer might be uninitialized)
// Per-round clearing will handle histograms for rounds 1-3
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[0][i] = 0;
}
}
if (tx == 0) {
st_release(&state->output_counter, 0);
}
}
__syncthreads();
// One-time sync so no CTA starts round-0 atomicAdds before histogram[0] is cleared.
if (iter == 0) {
if (tx == 0) {
red_release(&state->arrival_counter, 1);
}
int target2 = (barrier_phase + 1) * ctas_per_group;
wait_ge(&state->arrival_counter, target2, tx);
barrier_phase++;
__syncthreads();
}
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If state->histogram[0] can be non-zero at kernel start

They are initialized to 0 at python-side when passing to the kernel at the first time, and we should guarantee they are 0 after kernel finished.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119, thank you for the clarification! That makes sense.

If the state->histogram buffers are guaranteed to be zero at kernel entry (Python-side initialization) and zero at kernel exit (the cleanup code at the end of the kernel), then:

  1. The iter==0 clearing is indeed redundant – the histogram is already clean from Python initialization
  2. No inter-CTA race concern – since the buffer is already zero, the clearing is essentially a no-op
  3. The real fix (as described in the PR) is clearing histograms in the k >= vocab_size paths for subsequent iterations (iter > 0), which prevents stale data from affecting later rows

The code is correct with this guarantee. One minor suggestion: you might want to update the comment to reflect that the clearing is defensive/redundant rather than necessary:

// First iteration: defensively clear first round's histogram 
// (already initialized to 0 by Python, but explicit clear for safety)

This would make the code's intent clearer for future readers.


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a second thought, I suppose the change we made here is redundant.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119, you're absolutely right! Given that:

  1. Histograms are zero-initialized from Python before the first kernel launch
  2. Histograms are guaranteed zero at kernel exit (via the cleanup code at the end)
  3. Per-round clearing handles histogram buffers during the iteration
  4. The k≥vocab_size paths clear the next iteration's first histogram

The iter==0 clearing in RadixSelectFromSharedMemory is indeed redundanthistogram[0] is already guaranteed to be zero from Python initialization.

Suggestion: You could remove the iter==0 block entirely:

if (cta_in_group == 0) {
  if (tx == 0) {
    st_release(&state->output_counter, 0);
  }
}
__syncthreads();

The actual fix for the bug (#2320) is the histogram clearing in the k≥vocab_size early-exit paths across the kernels (RadixTopKKernel_Unified, RadixTopKMaskLogitsKernel_MultiCTA, RadixTopKRenormProbKernel_MultiCTA), which prevents stale histogram data from affecting subsequent iterations in persistent loop mode. The RadixSelectFromSharedMemory change is unnecessary defensive code.


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Comment on lines +918 to +927
// Clear histogram for next iteration (in case it's k < length)
if constexpr (!SINGLE_CTA) {
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8;
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Add an intra-CTA sync after “next-iteration histogram clear” to avoid partial clears being “released” early.

In these blocks, multiple threads in CTA0 write state->histogram[...], but there’s no __syncthreads() to prevent CTA0’s tx==0 from running ahead into the next iteration and signaling the next inter-CTA barrier before all clears finish (similar to the issue you already fixed in RadixSelectOneRound).

Minimal pattern to apply in each location
         if (cta_in_group == 0) {
           for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
             state->histogram[next_first_hist_idx][i] = 0;
           }
+          __syncthreads();  // ensure CTA0 finished clearing before the next iteration can signal
         }
-        // No sync needed - next iteration's barrier will ensure visibility
+        // Next iteration's inter-CTA barrier will handle cross-CTA visibility/order.
       }

Also: the comment “Ensure all threads in leading CTA finish clearing…” (Line 364) is slightly misleading—this __syncthreads() is also ensuring all threads finish the atomicAdds before tx==0 signals the barrier.

Also applies to: 937-947, 955-965, 1136-1147, 1425-1437

🤖 Prompt for AI Agents
In @include/flashinfer/topk.cuh around lines 918 - 927, Multiple threads in the
leading CTA write state->histogram for the next iteration without an intra-CTA
barrier, allowing tx==0 to proceed and signal the inter-CTA barrier before all
clears complete; in each affected block (the branches guarded by if constexpr
(!SINGLE_CTA) that compute NUM_ROUNDS, next_first_hist_idx and clear
state->histogram using tx and BLOCK_THREADS) insert a __syncthreads() after the
clearing loop to ensure all threads in the CTA finish clearing (and any
preceding atomicAdds) before tx==0 or other threads can advance to the inter-CTA
synchronization; apply the same insertion to the other listed locations (blocks
around lines 937-947, 955-965, 1136-1147, 1425-1437) referencing the same
symbols (state->histogram, tx, cta_in_group, BLOCK_THREADS, iter, NUM_ROUNDS).

Comment on lines +516 to +591
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype):
"""Test top_k_renorm_probs with mixed k values in persistent loop (multi-CTA mode).

This test catches a specific bug where:
- Large batch size triggers the persistent loop (multiple iterations per CTA group)
- Large vocab_size triggers multi-CTA mode (multiple CTAs per row)
- Mixed k values: some rows have k >= vocab_size (skip radix select),
others have k < vocab_size (use radix select)

The bug was that k >= vocab_size iterations would skip radix select
without clearing the histogram buffers, leaving stale data that corrupted
subsequent k < vocab_size iterations.
"""
batch_size = 1024 # Large batch to trigger persistent loop
vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode

torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)

# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)

# Generate k values: mix of small k and k == vocab_size
generator = torch.Generator(device="cuda:0").manual_seed(42)
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)

# Randomly set some rows to k == vocab_size (about 50%)
generator = torch.Generator(device="cuda:0").manual_seed(42)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size)

# Convert to probs
probs = torch.softmax(logits, dim=-1).to(dtype)

# Run FlashInfer top_k_renorm_probs
renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k_values)

# Verify output dtype
assert renorm_probs.dtype == dtype

# Verify sum to 1
sums = renorm_probs.float().sum(dim=-1)
torch.testing.assert_close(sums, torch.ones_like(sums), rtol=1e-2, atol=1e-2)

# Verify non-zero count matches k for each row
nonzero_counts = (renorm_probs > 0).sum(dim=-1)

# For rows with k >= vocab_size, all elements should be non-zero
# For rows with k < vocab_size, non-zero count should be >= k (may be more due to ties)
for i in range(batch_size):
k = k_values[i].item()
count = nonzero_counts[i].item()

if k >= vocab_size:
# All elements should be non-zero
assert count == vocab_size, (
f"Row {i}: k >= vocab_size but count={count} != {vocab_size}"
)
else:
# Count should be at least k (may be more due to ties at the threshold)
row_probs = probs[i].float()
topk_vals, _ = torch.topk(row_probs, k, sorted=True)
threshold = topk_vals[-1]
expected_ge_threshold = (row_probs >= threshold).sum().item()

# Allow small tolerance for floating point
assert count >= k - 1, f"Row {i}: k={k} but only {count} non-zero elements"
assert count <= expected_ge_threshold + 1, (
f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}"
)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These new tests are likely to OOM / timeout in CI; also avoid torch.randint(..., dtype=torch.bool) here.

  • Memory: 1024 x 131072 for logits + probs + output is very likely to exceed GPU memory (especially for float32).
  • Runtime: the for i in range(batch_size) loop does repeated GPU syncs via .item() and runs torch.topk up to 1024 times.
  • torch.randint(..., dtype=torch.bool, ...) is risky; prefer generating int/binary then .bool() (or torch.rand()<p).
  • Per tests/**/*.py guidelines, please add flashinfer.utils-based skipping for unsupported GPUs and size defensively.
Proposed tightening (reduce OOM risk + reduce syncs)
 @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
 def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA required")
+    # Per tests guideline: consider using flashinfer.utils helpers to skip unsupported archs.
+
-    batch_size = 1024  # Large batch to trigger persistent loop
-    vocab_size = 128 * 1024  # Large vocab to trigger multi-CTA mode
+    # Still triggers multi-CTA (vocab_size > 57344) but avoids common OOM.
+    vocab_size = 65536
+    batch_size = 256 if dtype == torch.float32 else 512

     torch.manual_seed(42)
     generator = torch.Generator(device="cuda:0").manual_seed(42)

     # Generate random logits
     logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)

@@
-    mask = torch.randint(
-        0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
-    )
+    mask = torch.randint(0, 2, (batch_size,), generator=generator, device="cuda:0").bool()
     k_values.masked_fill_(mask, vocab_size)

@@
-    for i in range(batch_size):
+    # Only spot-check a small subset of k < vocab rows (avoid 1000x topk calls + .item() sync).
+    small_rows = torch.nonzero(k_values < vocab_size, as_tuple=False).flatten()[:16].tolist()
+    for i in small_rows:
         k = k_values[i].item()
         count = nonzero_counts[i].item()
@@
-        if k >= vocab_size:
-            # All elements should be non-zero
-            assert count == vocab_size, (
-                f"Row {i}: k >= vocab_size but count={count} != {vocab_size}"
-            )
-        else:
+        if k < vocab_size:
             # Count should be at least k (may be more due to ties at the threshold)
             row_probs = probs[i].float()
             topk_vals, _ = torch.topk(row_probs, k, sorted=True)
             threshold = topk_vals[-1]
             expected_ge_threshold = (row_probs >= threshold).sum().item()
🤖 Prompt for AI Agents
In @tests/utils/test_sampling.py around lines 516 - 591, The test
test_top_k_renorm_probs_mixed_k_persistent_loop allocates huge tensors and does
many GPU-syncing per-row ops which will OOM/timeout and uses torch.randint(...,
dtype=torch.bool) which is unsafe; reduce memory and syncs by shrinking
batch_size and vocab_size to CI-safe values, generate mask with
torch.randint(0,2,(batch_size,), device=..., generator=...).bool() or
torch.rand(device=...) < 0.5 instead of dtype=torch.bool, replace the Python
for-loop that calls .item() and per-row torch.topk with vectorized tensor
operations to compute nonzero_counts and thresholds (operate on entire tensors
and use torch.topk once/batched or torch.kthvalue), and add a conditional skip
using flashinfer.utils (e.g., skip_if_gpu_too_small or similar) at the top of
the test to avoid running on unsupported GPUs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have enough GPU memory on CI.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119, understood! Since you've confirmed the CI environment has sufficient GPU memory for these tests, that addresses the main concern. Feel free to keep the test as-is.


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - `tests/conftest.py` provides auto-skipping for OOM tests as a safety net but should not be relied upon

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures

Comment on lines +677 to +737
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_top_k_mask_logits_mixed_k_persistent_loop(dtype):
"""Test top_k_mask_logits with mixed k values in persistent loop (multi-CTA mode).

This test catches the same bug as test_top_k_renorm_probs_mixed_k_persistent_loop
but for the mask_logits variant.
"""
batch_size = 1024 # Large batch to trigger persistent loop
vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode

torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)

# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
logits = logits.to(dtype)

# Generate k values: mix of small k and k == vocab_size
generator = torch.Generator(device="cuda:0").manual_seed(42)
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)

# Randomly set some rows to k == vocab_size (about 50%)
generator = torch.Generator(device="cuda:0").manual_seed(42)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size)

# Run FlashInfer top_k_mask_logits
masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k_values)

# Verify output dtype
assert masked_logits.dtype == dtype

# Verify finite count matches k for each row
finite_counts = torch.isfinite(masked_logits).sum(dim=-1)

for i in range(batch_size):
k = k_values[i].item()
count = finite_counts[i].item()

if k >= vocab_size:
# All elements should be finite
assert count == vocab_size, (
f"Row {i}: k >= vocab_size but finite count={count} != {vocab_size}"
)
else:
# Count should be at least k (may be more due to ties at the threshold)
row_logits = logits[i].float()
topk_vals, _ = torch.topk(row_logits, k, sorted=True)
threshold = topk_vals[-1]
expected_ge_threshold = (row_logits >= threshold).sum().item()

# Allow small tolerance for floating point
assert count >= k - 1, f"Row {i}: k={k} but only {count} finite elements"
assert count <= expected_ge_threshold + 1, (
f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}"
)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same OOM/timeout concerns apply; also consider validating the k>=vocab branch via equality rather than counts.

For top_k_mask_logits, when k >= vocab_size the kernel should be a pure copy; asserting masked_logits == logits for those rows is typically a stronger (and cheaper) check than counting finites, and avoids edge cases if inputs change (e.g., allow -inf later).

Minimal adjustments mirroring the other test
 def test_top_k_mask_logits_mixed_k_persistent_loop(dtype):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA required")
     batch_size = 1024  # Large batch to trigger persistent loop
     vocab_size = 128 * 1024  # Large vocab to trigger multi-CTA mode
@@
-    mask = torch.randint(
-        0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
-    )
+    mask = torch.randint(0, 2, (batch_size,), generator=generator, device="cuda:0").bool()
     k_values.masked_fill_(mask, vocab_size)
+
+    # Stronger + cheaper: rows with k>=vocab are copies
+    full_rows = (k_values >= vocab_size).nonzero(as_tuple=False).flatten()
+    if full_rows.numel() > 0:
+        torch.testing.assert_close(masked_logits[full_rows], logits[full_rows])

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @tests/utils/test_sampling.py around lines 677 - 737, The test
test_top_k_mask_logits_mixed_k_persistent_loop currently checks k>=vocab_size by
counting finite values which is weaker and can miss cases; update the
k>=vocab_size branch to assert a direct equality copy for those rows by
comparing masked_logits[i] == logits[i] (or .equal() / torch.allclose after
casting to the same dtype) for each i where k_values[i] >= vocab_size,
referencing flashinfer.sampling.top_k_mask_logits, masked_logits, logits, and
k_values to locate the code.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119, understood! I've noted your response.

@yzh119 yzh119 added the v0.6.1 label Jan 10, 2026
@yzh119
Copy link
Copy Markdown
Collaborator Author

yzh119 commented Jan 10, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !233 has been created, and the CI pipeline #41463402 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[SUCCESS] Pipeline #41463402: 17/20 passed

@yzh119
Copy link
Copy Markdown
Collaborator Author

yzh119 commented Jan 12, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

Warning: Failed to sync latest changes. Please try again.

@yzh119
Copy link
Copy Markdown
Collaborator Author

yzh119 commented Jan 13, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !233 has been updated with latest changes, and the CI pipeline #41612664 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
include/flashinfer/topk.cuh (1)

919-946: PageTableTransform and RaggedTransform paths correctly apply the same fix.

Both paths properly clear the histogram for the next iteration when the current row is trivially handled (length <= top_k_val).

Consider extracting the histogram clearing into a helper lambda or inline function to reduce the code duplication across all three modes:

♻️ Optional refactor to reduce duplication
// Define once before the mode-specific branches
auto clear_next_iteration_histogram = [&]() {
  if constexpr (!SINGLE_CTA) {
    constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8;
    uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
    if (cta_in_group == 0) {
      for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
        state->histogram[next_first_hist_idx][i] = 0;
      }
    }
  }
};

// Then in each mode's trivial case:
if (k >= length) {
  // ... copy output ...
  clear_next_iteration_histogram();
  continue;
}
📜 Review details

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e9968a6 and 308aadc.

📒 Files selected for processing (1)
  • include/flashinfer/topk.cuh
🧰 Additional context used
📓 Path-based instructions (1)
include/**/*.cuh

📄 CodeRabbit inference engine (CLAUDE.md)

include/**/*.cuh: Torch headers MUST NOT be included in files within the include/ directory - keep framework-agnostic CUDA kernels that accept raw pointers
Kernel code in include/flashinfer/ is automatically picked up by JIT compilation on changes - no pip reinstall needed
For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers
Consult the PTX ISA documentation (https://docs.nvidia.com/cuda/parallel-thread-execution/) for low-level instruction details and new GPU architecture features when writing inline PTX assembly

Files:

  • include/flashinfer/topk.cuh
🧠 Learnings (1)
📚 Learning: 2025-12-30T09:34:39.900Z
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to include/**/*.cuh : For performance-critical hot paths, leave comments explaining special algorithmic choices and potential alternatives for future reviewers

Applied to files:

  • include/flashinfer/topk.cuh
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (3)
include/flashinfer/topk.cuh (3)

900-910: Histogram clearing for skip path looks correct.

The calculation ((iter + 1) * NUM_ROUNDS) % 3 correctly identifies the histogram buffer that would be used by the next iteration's first radix round. This mirrors what the last round of a normal iteration would clear via next_hist = state->histogram[(global_round + 1) % 3].

Minor observation: The pattern sizeof(OrderedType) * 8 / 8 simplifies to sizeof(OrderedType). While equivalent to ORDERED_BITS / RADIX_BITS, using a named constant could improve clarity:

constexpr uint32_t NUM_ROUNDS = Traits::template num_rounds<8>();  // Use existing trait

This is optional since the current approach works correctly.


1117-1129: Histogram clearing for MaskLogits k >= vocab_size path is correct.

The comment on line 1128 ("No sync needed - next iteration's barrier will ensure visibility") accurately explains why no explicit synchronization is required after the clearing. The next iteration's initial barrier in RadixSelectFromSharedMemory will synchronize all CTAs before they access the histogram.


1406-1419: RenormProb histogram clearing correctly handles the k >= vocab_size path.

This path is slightly different from MaskLogits because it still performs barrier-synchronized sum reduction before the histogram clearing. The placement after both the sum computation and output writing is correct.

The comment on line 1409 correctly notes that next iteration's barrier ensures visibility, which is consistent with the barrier tracking design (cumulative arrival counter).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants