Skip to content

Conversation

@yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Nov 5, 2025

📌 Description

This is the first part of the performance improvement PR for sampling/mask/softmax operator, in this PR, we defer the cross thread reduction till the end of the loop (similar to how FA2 handles denominator) to reduce the number of shuffling and thread sync instructions.

For the second part of the PR, we will implement the Radix TopK algorithm to improve top-k mask logits performance when K is small.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests

    • Added comprehensive benchmarking suite for sampling and softmax operations with performance comparison and visualization tools.
  • Chores

    • Optimized internal kernel execution strategies for improved performance efficiency.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant performance improvements for sampling, masking, and softmax operations within the FlashInfer library. The core strategy involves optimizing CUDA kernels by deferring cross-thread reductions until the end of processing loops, which reduces synchronization overhead and enhances execution efficiency. To validate these optimizations, new and expanded benchmarking scripts have been added, providing detailed performance comparisons and visualizations. This is the initial phase of a two-part effort to boost the performance of these critical operations.

Highlights

  • Softmax Kernel Optimization: The OnlineSoftmaxFusedKernel and OnlineSoftmaxMapKernel have been refactored to defer cross-thread reduction of the denominator until the end of the loop. This change aims to reduce the number of shuffling and thread synchronization instructions, improving performance.
  • Sampling and Masking Kernel Optimization: Several sampling and masking kernels, including TopKSamplingFromProbKernel, TopPSamplingFromProbKernel, MinPSamplingFromProbKernel, TopKTopPSamplingFromProbKernel, TopPRenormProbKernel, TopKMaskLogitsKernel, and TopKRenormProbKernel, have been optimized. Similar to the softmax kernels, these now use thread-local aggregation before a final block-wide reduction, minimizing synchronization overhead.
  • New Softmax Benchmarking Script: A new benchmark script, bench_softmax.py, has been added to compare the performance of flashinfer.softmax against torch.softmax. This script generates heatmaps and trend plots to visualize speedups across various batch sizes and hidden dimensions.
  • Expanded Sampling Benchmarks: The existing bench_sampling.py script has been updated to include new benchmarks for top_p_renorm_probs, top_k_renorm_probs, and top_k_mask_logits, allowing for more comprehensive performance evaluation of these sampling methods.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 5, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Three distinct changes: benchmarks/bench_sampling.py receives repetitive benchmarking blocks for new sampling variants; benchmarks/bench_softmax.py is introduced as a new script for comparing PyTorch and FlashInfer softmax performance; include/flashinfer/sampling.cuh refactors kernel internals to replace block-wide reductions with per-thread accumulators across multiple sampling kernels.

Changes

Cohort / File(s) Summary
Sampling Benchmarks
benchmarks/bench_sampling.py
Adds extensive benchmarking blocks evaluating new sampling functions (top_p_renorm_probs, top_k_renorm_probs, top_k_mask_logits) across multiple vocab sizes, batch sizes, and hyperparameters. Each block seeds RNG, constructs logits/probs, measures GPU execution time, estimates bandwidth, and prints results. No API changes.
Softmax Benchmark Suite
benchmarks/bench_softmax.py
New file introducing benchmark_torch_softmax(), benchmark_flashinfer_softmax(), run_benchmark(), and plot_heatmap() functions to compare softmax performance across batch and hidden dimensions. Generates heatmap and trend-plot visualizations with speedup metrics and bandwidth calculations.
Kernel Optimization
include/flashinfer/sampling.cuh
Refactors reduction patterns across multiple sampling kernels (OnlineSoftmaxFusedKernel, TopKSamplingFromProbKernel, TopPSamplingFromProbKernel, etc.) to use per-thread local accumulators (threadlocal_*) instead of block-wide reductions. Defers final reductions to later synchronization points, reducing intermediate __syncthreads() calls. Maintains functional correctness.

Sequence Diagram

sequenceDiagram
    participant main as main()
    participant run as run_benchmark()
    participant torch_bench as benchmark_torch_softmax()
    participant fi_bench as benchmark_flashinfer_softmax()
    participant plot as plot_heatmap()
    
    main->>run: For each (batch_size, hidden_size) pair
    run->>torch_bench: Generate logits, measure torch.softmax GPU time
    torch_bench-->>run: torch_time (median)
    run->>fi_bench: Measure flashinfer.sampling.softmax GPU time
    fi_bench-->>run: flashinfer_time (median)
    run->>run: Compute speedup & bandwidth
    run-->>main: Return speedups, batch_sizes, hidden_sizes arrays
    main->>plot: Create heatmap with speedup annotations
    plot->>plot: Generate trend plots (speedup vs. batch, vs. hidden)
    plot-->>main: Save PNG visualizations
    main->>main: Print summary statistics (mean, median, min, max speedups)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Specific areas requiring attention:
    • include/flashinfer/sampling.cuh: Complex kernel refactoring with per-thread accumulator patterns applied across ~10+ kernels; verify threadlocal variables are correctly scoped, initialized, and reduced; ensure __syncthreads() removal doesn't introduce data races or correctness regressions.
    • benchmarks/bench_softmax.py: New public functions with multiple dependencies (matplotlib, numpy, torch); verify plotting logic correctness, ensure heatmap/trend-plot generation handles edge cases and saves files correctly.
    • benchmarks/bench_sampling.py: While mostly repetitive, confirm seeding, timing measurements, and bandwidth calculations are consistent across all benchmark blocks.

Poem

🐰 Per-thread accumulators hop with glee,
No more syncs block the GPU spree,
Softmax speeds dance in heatmaps bright,
Benchmarks measure efficiency's flight,
Swift reductions, threads unbound—
FlashInfer's leap to faster ground!

Pre-merge checks and finishing touches

✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly indicates a performance improvement for sampling/mask/softmax, and accurately reflects the kernel-level optimization changes across multiple files.
Description check ✅ Passed The PR description clearly explains the performance improvements, defers cross-thread reduction logic, and mentions a planned second part, with all required sections completed and checklist items checked.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request, the first part of a two-part series, focuses on enhancing the performance of sampling, masking, and softmax operations within the FlashInfer library. The primary method involves optimizing CUDA kernels by deferring cross-thread reductions to minimize synchronization overhead. Additionally, new and expanded benchmarking tools have been introduced to thoroughly evaluate these performance improvements.

Highlights

  • Performance Optimization Strategy: Implemented a strategy to defer cross-thread reductions until the end of loops in sampling, masking, and softmax operations. This aims to reduce the number of shuffling and thread synchronization instructions, similar to how FlashAttention 2 handles denominators.
  • Softmax Kernel Refinement: Modified OnlineSoftmaxFusedKernel and OnlineSoftmaxMapKernel to use thread-local accumulation for denominator calculations, reducing global synchronization points.
  • Sampling and Masking Kernel Improvements: Applied the deferred reduction technique across various sampling and masking kernels, including TopKSamplingFromProbKernel, TopPSamplingFromProbKernel, MinPSamplingFromProbKernel, TopKTopPSamplingFromProbKernel, TopPRenormProbKernel, TopKMaskLogitsKernel, and TopKRenormProbKernel, by introducing thread-local aggregate variables.
  • New Softmax Benchmarking Script: Added a dedicated benchmarking script (bench_softmax.py) to compare the performance of flashinfer.softmax against torch.softmax across different batch and hidden sizes, including heatmap visualization of speedups.
  • Expanded Sampling Benchmarks: Extended the existing bench_sampling.py script to include new benchmarks for top_p_renorm_probs, top_k_renorm_probs, and top_k_mask_logits to evaluate their performance.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves the performance of several sampling, masking, and softmax CUDA kernels by deferring cross-thread reductions until the end of the vocabulary-chunk loop. This reduces synchronization overhead. The changes are applied consistently across multiple kernels and appear correct. Additionally, new benchmarks have been added for top-p/top-k renorm and top-k mask operations, along with a new comprehensive benchmark script for softmax performance. My review focuses on the new benchmark code, suggesting a refactoring to reduce code duplication.

Comment on lines +223 to +302
print("---")
print("top-p renorm probs")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
torch.manual_seed(42)
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for p in [0.1, 0.5, 0.9]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
measurements = bench_gpu_time(
lambda: flashinfer.sampling.top_p_renorm_probs(probs, p),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

io = probs.numel() * probs.element_size() * 2
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

print("---")
print("top-k renorm probs")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
torch.manual_seed(42)
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for k in [10, 100, 1000, 5000]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
measurements = bench_gpu_time(
lambda: flashinfer.sampling.top_k_renorm_probs(probs, k),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

io = probs.numel() * probs.element_size() * 2
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

print("---")
print("top-k mask logits")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
torch.manual_seed(42)
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for k in [10, 100, 1000, 5000]:
logits = distrib((batch_size, vocab_size), device="cuda")
measurements = bench_gpu_time(
lambda: flashinfer.sampling.top_k_mask_logits(logits, k),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

io = logits.numel() * logits.element_size() * 2
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is significant code duplication in the newly added benchmark tests. The three benchmark loops for top-p renorm probs, top-k renorm probs, and top-k mask logits are very similar. To improve maintainability and reduce redundancy, you could refactor this into a helper function.

Here's an example of how you could structure the helper function (to be placed before main):

def _run_renorm_benchmark(
    title: str,
    func,
    param_name: str,
    param_values: list,
    use_probs: bool,
):
    print("---")
    print(title)
    for vocab_size in [128512]:
        for batch_size in [1, 16, 32, 64, 128, 256, 512]:
            torch.manual_seed(42)
            for distrib in [
                normal_distribution(1),
                normal_distribution(5),
                gumbel_distribution(0.1),
                gumbel_distribution(1),
            ]:
                for param_val in param_values:
                    logits = distrib((batch_size, vocab_size), device="cuda")
                    if use_probs:
                        tensor_to_bench = torch.softmax(logits, dim=-1)
                    else:
                        tensor_to_bench = logits

                    measurements = bench_gpu_time(
                        lambda: func(tensor_to_bench, param_val),
                        dry_run_time_ms=100,
                        repeat_time_ms=1000,
                    )
                    ms = np.median(measurements)

                    io = tensor_to_bench.numel() * tensor_to_bench.element_size() * 2
                    bandwidth = io * 1e-6 / ms
                    print(
                        f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, {param_name}: {param_val}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
                    )

You would then call this helper from main() as suggested below.

    _run_renorm_benchmark(
        "top-p renorm probs",
        flashinfer.sampling.top_p_renorm_probs,
        "p",
        [0.1, 0.5, 0.9],
        use_probs=True,
    )

    _run_renorm_benchmark(
        "top-k renorm probs",
        flashinfer.sampling.top_k_renorm_probs,
        "k",
        [10, 100, 1000, 5000],
        use_probs=True,
    )

    _run_renorm_benchmark(
        "top-k mask logits",
        flashinfer.sampling.top_k_mask_logits,
        "k",
        [10, 100, 1000, 5000],
        use_probs=False,
    )

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance improvements for sampling, mask, and softmax operators by deferring cross-thread reductions in the CUDA kernels. This is a sound optimization strategy that should reduce synchronization overhead. The changes across the various CUDA kernels are consistent and appear to be correctly implemented. The addition of new benchmark scripts is also a valuable contribution for measuring the impact of these changes. My review includes a couple of suggestions for the new benchmark scripts to improve maintainability and plotting consistency. Overall, this is a solid pull request that focuses on performance.

Comment on lines +223 to +302
print("---")
print("top-p renorm probs")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
torch.manual_seed(42)
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for p in [0.1, 0.5, 0.9]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
measurements = bench_gpu_time(
lambda: flashinfer.sampling.top_p_renorm_probs(probs, p),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

io = probs.numel() * probs.element_size() * 2
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

print("---")
print("top-k renorm probs")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
torch.manual_seed(42)
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for k in [10, 100, 1000, 5000]:
logits = distrib((batch_size, vocab_size), device="cuda")
probs = torch.softmax(logits, dim=-1)
measurements = bench_gpu_time(
lambda: flashinfer.sampling.top_k_renorm_probs(probs, k),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

io = probs.numel() * probs.element_size() * 2
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

print("---")
print("top-k mask logits")
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
torch.manual_seed(42)
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for k in [10, 100, 1000, 5000]:
logits = distrib((batch_size, vocab_size), device="cuda")
measurements = bench_gpu_time(
lambda: flashinfer.sampling.top_k_mask_logits(logits, k),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)

io = logits.numel() * logits.element_size() * 2
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The three new benchmark sections for top-p renorm probs, top-k renorm probs, and top-k mask logits contain significant code duplication. To improve maintainability and readability, consider refactoring this logic into a single helper function. This function could be parameterized by the operation name, the function to benchmark, the parameter name ('p' or 'k'), the parameter values, and whether the function takes logits or probabilities as input.

ax1.set_title("Speedup vs Hidden Size", fontsize=13, fontweight="bold")
ax1.grid(True, alpha=0.3)
ax1.legend(fontsize=9)
ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the second subplot (ax2), consider adding a label to this axhline call. This will ensure the 'No speedup' line is labeled in the legend for both plots.

Suggested change
ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5)
ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
benchmarks/bench_sampling.py (1)

238-299: Bind loop variables when building the benchmark lambdas

Ruff’s B023 warning here is legitimate: the lambdas capture probs, p, and k by reference, so a future change that defers execution (or reuses the callable after the loop advances) would end up timing the wrong tensors/thresholds. Binding the current values via default arguments silences the warning and makes the benchmarks future-proof. Please apply the following update in all three loops.

-                        lambda: flashinfer.sampling.top_p_renorm_probs(probs, p),
+                        lambda probs=probs, p=p: flashinfer.sampling.top_p_renorm_probs(probs, p),
...
-                        lambda: flashinfer.sampling.top_k_renorm_probs(probs, k),
+                        lambda probs=probs, k=k: flashinfer.sampling.top_k_renorm_probs(probs, k),
...
-                        lambda: flashinfer.sampling.top_k_mask_logits(logits, k),
+                        lambda logits=logits, k=k: flashinfer.sampling.top_k_mask_logits(logits, k),
benchmarks/bench_softmax.py (1)

105-121: Silence the unused fig binding

fig isn’t used after plt.subplots, which raises RUF059. Prefix it with an underscore (or drop it) to keep linters quiet.

-    fig, ax = plt.subplots(figsize=(12, 8))
+    _fig, ax = plt.subplots(figsize=(12, 8))
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9bc5bd5 and 491dd3c.

📒 Files selected for processing (3)
  • benchmarks/bench_sampling.py (1 hunks)
  • benchmarks/bench_softmax.py (1 hunks)
  • include/flashinfer/sampling.cuh (18 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/bench_sampling.py (3)
benchmarks/bench_renorm.py (2)
  • normal_distribution (8-13)
  • gumbel_distribution (16-23)
flashinfer/sampling.py (8)
  • softmax (52-72)
  • softmax (505-559)
  • top_p_renorm_probs (325-341)
  • top_p_renorm_probs (1166-1226)
  • top_k_renorm_probs (354-368)
  • top_k_renorm_probs (1232-1291)
  • top_k_mask_logits (381-395)
  • top_k_mask_logits (1297-1351)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
benchmarks/bench_softmax.py (1)
flashinfer/testing/utils.py (1)
  • bench_gpu_time (972-1033)
🪛 Ruff (0.14.3)
benchmarks/bench_sampling.py

238-238: Function definition does not bind loop variable probs

(B023)


238-238: Function definition does not bind loop variable p

(B023)


265-265: Function definition does not bind loop variable probs

(B023)


265-265: Function definition does not bind loop variable k

(B023)


291-291: Function definition does not bind loop variable logits

(B023)


291-291: Function definition does not bind loop variable k

(B023)

benchmarks/bench_softmax.py

105-105: Unpacked variable fig is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs

Comment on lines +139 to +176
for j, hidden_size in enumerate(hidden_sizes):
ax2.plot(
batch_sizes,
speedups[:, j],
marker="o",
label=f"Hidden={hidden_size // 1000}K",
linewidth=2,
)

ax2.set_xlabel("Batch Size", fontsize=12, fontweight="bold")
ax2.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold")
ax2.set_title("Speedup vs Batch Size", fontsize=13, fontweight="bold")
ax2.set_xscale("log", base=2)
ax2.grid(True, alpha=0.3)
ax2.legend(fontsize=9)
ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")

# Plot 1: Speedup trends across hidden sizes
for i, batch_size in enumerate(batch_sizes[::2]): # Sample every other batch size
idx = i * 2
ax1.plot(
[h // 1000 for h in hidden_sizes],
speedups[idx, :],
marker="s",
label=f"Batch={batch_size}",
linewidth=2,
)

ax1.set_xlabel("Hidden Size (K)", fontsize=12, fontweight="bold")
ax1.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold")
ax1.set_title("Speedup vs Hidden Size", fontsize=13, fontweight="bold")
ax1.grid(True, alpha=0.3)
ax1.legend(fontsize=9)
ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5)

plt.tight_layout()
comparison_path = save_path.replace(".png", "_trends.png")
plt.savefig(comparison_path, dpi=300, bbox_inches="tight")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Ensure the “No speedup” guide lines appear in the legends

ax2.legend(...) (and ax1.legend(...)) is called before each axhline, so the horizontal “No speedup” lines never reach the legend despite the label argument. Move the legend calls after the guide lines (or call legend again afterwards) so the reference line is actually shown.

-    ax2.legend(fontsize=9)
-    ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")
+    ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")
+    ax2.legend(fontsize=9)
...
-    ax1.legend(fontsize=9)
-    ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5)
+    ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")
+    ax1.legend(fontsize=9)
🤖 Prompt for AI Agents
In benchmarks/bench_softmax.py around lines 139 to 176 the horizontal "No
speedup" axhline calls are placed after ax2.legend(...) and ax1.legend(...), so
their labeled guide lines don't appear in the legends; fix by moving the
legend() calls to after the axhline(...) calls (or call legend() again after
adding the axhline), i.e., ensure each axis adds its axhline before creating the
legend so the guide line with label="No speedup" is included.

@yzh119
Copy link
Collaborator Author

yzh119 commented Nov 6, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !117 has been created, and the CI pipeline #37995303 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[SUCCESS] Pipeline #37995303: 13/17 passed

Copy link
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and unit tests are passing and seems like the reduced number of synchronizations are applied throughout correctly.

Nothing to do for this PR, but I was checking if there are any further optimizations to do and came up with #2058 as a followup for completeness.

@yzh119 yzh119 merged commit adcc5dd into flashinfer-ai:main Nov 7, 2025
4 checks passed
yzh119 added a commit that referenced this pull request Nov 7, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

Apply optimizations similar to #2044 to max/min functions.

<!-- What does this PR do? Briefly describe the changes and why they’re
needed. -->

## 🔍 Related Issues

<!-- Link any related issues here -->
#2044 

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Performance Improvements**
* Improved sampling performance by reducing per-iteration
synchronization and temporary storage, deferring aggregate reductions
until after iterative work completes. This lowers runtime overhead and
memory churn, yielding faster and more efficient processing for sampling
operations.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Co-authored-by: yzh119 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants