Skip to content

Conversation

@bkryu
Copy link
Collaborator

@bkryu bkryu commented Nov 7, 2025

📌 Description

Apply optimizations similar to #2044 to max/min functions.

top_p_renorm_probs, top_k_renorm_probs, and top_k_mask_logits see 14%, 7%, and 19% speedups on B200 (logs from bench_sampling.py: bench_sampling_before.txt and bench_sampling_after.txt )

🔍 Related Issues

#2044

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Performance Improvements
    • Improved sampling performance by reducing per-iteration synchronization and temporary storage, deferring aggregate reductions until after iterative work completes. This lowers runtime overhead and memory churn, yielding faster and more efficient processing for sampling operations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 7, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

This change refactors the sampling kernel's min/max reduction: threads accumulate thread-local min/max across loop iterations, then perform a single block-wide reduction after the loop to compute final min/max, removing per-iteration reductions, intermediate in-data arrays, and in-loop synchronizations.

Changes

Cohort / File(s) Summary
Min/Max Reduction Optimization
include/flashinfer/sampling.cuh
Reworked reduction to use thread-local accumulation of thread_max/thread_min across iterations; removed per-iteration in_data_ temporaries and in-loop block reductions / __syncthreads; added a single post-loop BlockReduce-based reduction to produce final min/max values and store them via temp_storage

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Host
    participant Kernel as SamplingKernel
    participant Threads as Thread[block]
    participant BlockReduce

    Host->>Kernel: launch with data pointers
    Kernel->>Threads: start loop over iterations
    alt per-iteration (old)
        Note over Threads: per-iteration reduction & sync\n(store intermediate in_data_)
        Threads->>Threads: per-iteration min/max reduction
        Threads->>Threads: __syncthreads()
    end
    Note over Threads: new flow — each thread updates\nthread_local_min / thread_local_max per iteration
    Threads->>BlockReduce: after loop, single block-wide reduce
    BlockReduce-->>Threads: block_min / block_max
    Threads->>Kernel: write final min/max to temp_storage
    Kernel-->>Host: kernel completes
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Areas needing extra attention:
    • Initialization of thread_max / thread_min for all threads (including inactive lanes)
    • Correctness of post-loop BlockReduce usage and reduction identity values
    • Any removed __syncthreads() — ensure no remaining data races for values read/written across threads
    • Edge cases: empty ranges, single-iteration loops, and boundary-thread handling
    • Ensure removal of intermediate buffers doesn't change memory visibility or lifetime assumptions

Poem

🐇 I hopped through loops with values small and tall,
Kept my max and min tucked close, no stalls at all.
One big sync at sunset, then I share what I found—
A neat little tally, stored safe and sound. 🥕✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'perf: Optimize helper max/minmax function in sampling.cuh' directly matches the core change: optimizing max/minmax helper functions with the file name specified.
Description check ✅ Passed The description follows the template structure, includes related issue #2044, marks all checklist items complete, and provides context that this applies optimizations similar to PR #2044.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e4ea76d and 69f5e5f.

📒 Files selected for processing (1)
  • include/flashinfer/sampling.cuh (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @bkryu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces performance optimizations to the max and minmax helper functions within sampling.cuh. The core change involves restructuring the reduction logic to minimize synchronization overhead, which is expected to improve the efficiency of these CUDA kernel operations. This is a work-in-progress, with verification of performance and bug fixes still pending.

Highlights

  • Performance Optimization: Refactored GetMinMaxValue and GetMaxValue functions in sampling.cuh to improve performance by deferring block-wide reductions.
  • Reduced Synchronization: Moved BlockReduce operations and associated __syncthreads() calls outside the main data processing loop, performing a single reduction after thread-local accumulation.
  • Thread-Local Accumulation: Introduced thread_max and thread_min variables to accumulate values locally within each thread before a final block-wide reduction.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 63cf562 and e4ea76d.

📒 Files selected for processing (1)
  • include/flashinfer/sampling.cuh (2 hunks)
🔇 Additional comments (1)
include/flashinfer/sampling.cuh (1)

252-275: Optimization looks correct for GetMinMaxValue.

The thread-local accumulation pattern followed by a single block reduction is a good optimization that reduces synchronization overhead. The initialization values (negative infinity for max, positive infinity for min) are correct, and the logic properly reuses temp_storage with appropriate synchronization between the max and min reductions.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to optimize the max/minmax helper functions by deferring block-wide reductions until after the main loop. This is a solid performance optimization strategy. However, the current implementation introduces/preserves critical correctness bugs related to out-of-bounds data handling and incorrect initialization. Specifically, the reduction loops process padded zeros and potentially uninitialized data, and one of the functions incorrectly initializes its accumulator. I've provided suggestions to fix these critical issues to ensure correctness while retaining the performance benefits.

Comment on lines 256 to 266
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
thread_max = max(thread_max, static_cast<float>(in_data_vec[j]));
thread_min = min(thread_min, static_cast<float>(in_data_vec[j]));
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, MaxReduceOp{}));
__syncthreads();
min_val = min(
min_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.Reduce<VEC_SIZE>(in_data_, MinReduceOp{}));
__syncthreads();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reduction loop has correctness issues with out-of-bounds data handling. It processes padded zeros for threads that are entirely out of bounds and doesn't handle tail elements for partially out-of-bounds vectors. This can lead to incorrect min/max values.

The accumulation should only happen for valid data. The following change corrects the logic:

Please note that cast_load may still read out of bounds if the input is not padded. A fully robust solution might require masked loads or scalar access for tail elements, but the suggested change fixes the logical error in reduction.

  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
    const uint32_t base_offset = (i * BLOCK_THREADS + tx) * VEC_SIZE;
    if (base_offset < d) {
      in_data_vec.cast_load(in_data + row_idx * d + base_offset);
#pragma unroll
      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
        if (base_offset + j < d) {
          thread_max = max(thread_max, in_data_vec[j]);
          thread_min = min(thread_min, in_data_vec[j]);
        }
      }
    }
  }

Comment on lines +295 to 306
// Thread-local max accumulation (deferred reduction)
float thread_max = 0.0f;
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
in_data_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
in_data_vec.cast_load(in_data + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}
float in_data_[VEC_SIZE];
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
in_data_[j] = in_data_vec[j];
thread_max = max(thread_max, static_cast<float>(in_data_vec[j]));
}
max_val = max(
max_val, BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
.template Reduce<VEC_SIZE>(in_data_, MaxReduceOp{}));
__syncthreads();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function has two correctness issues:

  1. Incorrect Initialization: thread_max is initialized to 0.0f on line 296. If all input values are negative, this will cause the function to incorrectly return 0 instead of the true maximum. It should be initialized to -infinity.

  2. Faulty Reduction Logic: The loop structure from lines 297-306 suffers from the same issues as in GetMinMaxValue. It processes padded zeros for out-of-bounds threads and doesn't handle tail elements correctly, leading to incorrect results.

Here is a suggested fix that addresses both issues:

Please note that cast_load may still read out of bounds if the input is not padded. A fully robust solution might require masked loads or scalar access for tail elements, but the suggested change fixes the logical error in reduction.

  // Thread-local max accumulation (deferred reduction)
  float thread_max = -cuda::std::numeric_limits<float>::infinity();
  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
    const uint32_t base_offset = (i * BLOCK_THREADS + tx) * VEC_SIZE;
    if (base_offset < d) {
      in_data_vec.cast_load(in_data + row_idx * d + base_offset);
#pragma unroll
      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
        if (base_offset + j < d) {
          thread_max = max(thread_max, in_data_vec[j]);
        }
      }
    }
  }

@bkryu
Copy link
Collaborator Author

bkryu commented Nov 7, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !118 has been created, and the CI pipeline #38046670 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu self-assigned this Nov 7, 2025
@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #38046670: 8/17 passed

@yzh119 yzh119 changed the title [wip] perf: Optimize helper max/minmax function in sampling.cuh perf: Optimize helper max/minmax function in sampling.cuh Nov 7, 2025
@yzh119
Copy link
Collaborator

yzh119 commented Nov 7, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !118 has been updated with latest changes, and the CI pipeline #38069158 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, should be ready to merge when gitlab CI passed.

@yzh119 yzh119 merged commit f588d96 into flashinfer-ai:main Nov 7, 2025
4 checks passed
@bkryu bkryu deleted the sampling_minmax_perf_optimization branch November 7, 2025 23:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants