-
Notifications
You must be signed in to change notification settings - Fork 585
perf: improve sampling/mask/softmax performance (part 1/2) #2044
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant performance improvements for sampling, masking, and softmax operations within the FlashInfer library. The core strategy involves optimizing CUDA kernels by deferring cross-thread reductions until the end of processing loops, which reduces synchronization overhead and enhances execution efficiency. To validate these optimizations, new and expanded benchmarking scripts have been added, providing detailed performance comparisons and visualizations. This is the initial phase of a two-part effort to boost the performance of these critical operations. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThree distinct changes: benchmarks/bench_sampling.py receives repetitive benchmarking blocks for new sampling variants; benchmarks/bench_softmax.py is introduced as a new script for comparing PyTorch and FlashInfer softmax performance; include/flashinfer/sampling.cuh refactors kernel internals to replace block-wide reductions with per-thread accumulators across multiple sampling kernels. Changes
Sequence DiagramsequenceDiagram
participant main as main()
participant run as run_benchmark()
participant torch_bench as benchmark_torch_softmax()
participant fi_bench as benchmark_flashinfer_softmax()
participant plot as plot_heatmap()
main->>run: For each (batch_size, hidden_size) pair
run->>torch_bench: Generate logits, measure torch.softmax GPU time
torch_bench-->>run: torch_time (median)
run->>fi_bench: Measure flashinfer.sampling.softmax GPU time
fi_bench-->>run: flashinfer_time (median)
run->>run: Compute speedup & bandwidth
run-->>main: Return speedups, batch_sizes, hidden_sizes arrays
main->>plot: Create heatmap with speedup annotations
plot->>plot: Generate trend plots (speedup vs. batch, vs. hidden)
plot-->>main: Save PNG visualizations
main->>main: Print summary statistics (mean, median, min, max speedups)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Poem
Pre-merge checks and finishing touches✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @yzh119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request, the first part of a two-part series, focuses on enhancing the performance of sampling, masking, and softmax operations within the FlashInfer library. The primary method involves optimizing CUDA kernels by deferring cross-thread reductions to minimize synchronization overhead. Additionally, new and expanded benchmarking tools have been introduced to thoroughly evaluate these performance improvements. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request improves the performance of several sampling, masking, and softmax CUDA kernels by deferring cross-thread reductions until the end of the vocabulary-chunk loop. This reduces synchronization overhead. The changes are applied consistently across multiple kernels and appear correct. Additionally, new benchmarks have been added for top-p/top-k renorm and top-k mask operations, along with a new comprehensive benchmark script for softmax performance. My review focuses on the new benchmark code, suggesting a refactoring to reduce code duplication.
| print("---") | ||
| print("top-p renorm probs") | ||
| for vocab_size in [128512]: | ||
| for batch_size in [1, 16, 32, 64, 128, 256, 512]: | ||
| torch.manual_seed(42) | ||
| for distrib in [ | ||
| normal_distribution(1), | ||
| normal_distribution(5), | ||
| gumbel_distribution(0.1), | ||
| gumbel_distribution(1), | ||
| ]: | ||
| for p in [0.1, 0.5, 0.9]: | ||
| logits = distrib((batch_size, vocab_size), device="cuda") | ||
| probs = torch.softmax(logits, dim=-1) | ||
| measurements = bench_gpu_time( | ||
| lambda: flashinfer.sampling.top_p_renorm_probs(probs, p), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| io = probs.numel() * probs.element_size() * 2 | ||
| bandwidth = io * 1e-6 / ms | ||
| print( | ||
| f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" | ||
| ) | ||
|
|
||
| print("---") | ||
| print("top-k renorm probs") | ||
| for vocab_size in [128512]: | ||
| for batch_size in [1, 16, 32, 64, 128, 256, 512]: | ||
| torch.manual_seed(42) | ||
| for distrib in [ | ||
| normal_distribution(1), | ||
| normal_distribution(5), | ||
| gumbel_distribution(0.1), | ||
| gumbel_distribution(1), | ||
| ]: | ||
| for k in [10, 100, 1000, 5000]: | ||
| logits = distrib((batch_size, vocab_size), device="cuda") | ||
| probs = torch.softmax(logits, dim=-1) | ||
| measurements = bench_gpu_time( | ||
| lambda: flashinfer.sampling.top_k_renorm_probs(probs, k), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| io = probs.numel() * probs.element_size() * 2 | ||
| bandwidth = io * 1e-6 / ms | ||
| print( | ||
| f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" | ||
| ) | ||
|
|
||
| print("---") | ||
| print("top-k mask logits") | ||
| for vocab_size in [128512]: | ||
| for batch_size in [1, 16, 32, 64, 128, 256, 512]: | ||
| torch.manual_seed(42) | ||
| for distrib in [ | ||
| normal_distribution(1), | ||
| normal_distribution(5), | ||
| gumbel_distribution(0.1), | ||
| gumbel_distribution(1), | ||
| ]: | ||
| for k in [10, 100, 1000, 5000]: | ||
| logits = distrib((batch_size, vocab_size), device="cuda") | ||
| measurements = bench_gpu_time( | ||
| lambda: flashinfer.sampling.top_k_mask_logits(logits, k), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| io = logits.numel() * logits.element_size() * 2 | ||
| bandwidth = io * 1e-6 / ms | ||
| print( | ||
| f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication in the newly added benchmark tests. The three benchmark loops for top-p renorm probs, top-k renorm probs, and top-k mask logits are very similar. To improve maintainability and reduce redundancy, you could refactor this into a helper function.
Here's an example of how you could structure the helper function (to be placed before main):
def _run_renorm_benchmark(
title: str,
func,
param_name: str,
param_values: list,
use_probs: bool,
):
print("---")
print(title)
for vocab_size in [128512]:
for batch_size in [1, 16, 32, 64, 128, 256, 512]:
torch.manual_seed(42)
for distrib in [
normal_distribution(1),
normal_distribution(5),
gumbel_distribution(0.1),
gumbel_distribution(1),
]:
for param_val in param_values:
logits = distrib((batch_size, vocab_size), device="cuda")
if use_probs:
tensor_to_bench = torch.softmax(logits, dim=-1)
else:
tensor_to_bench = logits
measurements = bench_gpu_time(
lambda: func(tensor_to_bench, param_val),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
io = tensor_to_bench.numel() * tensor_to_bench.element_size() * 2
bandwidth = io * 1e-6 / ms
print(
f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, {param_name}: {param_val}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s"
)You would then call this helper from main() as suggested below.
_run_renorm_benchmark(
"top-p renorm probs",
flashinfer.sampling.top_p_renorm_probs,
"p",
[0.1, 0.5, 0.9],
use_probs=True,
)
_run_renorm_benchmark(
"top-k renorm probs",
flashinfer.sampling.top_k_renorm_probs,
"k",
[10, 100, 1000, 5000],
use_probs=True,
)
_run_renorm_benchmark(
"top-k mask logits",
flashinfer.sampling.top_k_mask_logits,
"k",
[10, 100, 1000, 5000],
use_probs=False,
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces performance improvements for sampling, mask, and softmax operators by deferring cross-thread reductions in the CUDA kernels. This is a sound optimization strategy that should reduce synchronization overhead. The changes across the various CUDA kernels are consistent and appear to be correctly implemented. The addition of new benchmark scripts is also a valuable contribution for measuring the impact of these changes. My review includes a couple of suggestions for the new benchmark scripts to improve maintainability and plotting consistency. Overall, this is a solid pull request that focuses on performance.
| print("---") | ||
| print("top-p renorm probs") | ||
| for vocab_size in [128512]: | ||
| for batch_size in [1, 16, 32, 64, 128, 256, 512]: | ||
| torch.manual_seed(42) | ||
| for distrib in [ | ||
| normal_distribution(1), | ||
| normal_distribution(5), | ||
| gumbel_distribution(0.1), | ||
| gumbel_distribution(1), | ||
| ]: | ||
| for p in [0.1, 0.5, 0.9]: | ||
| logits = distrib((batch_size, vocab_size), device="cuda") | ||
| probs = torch.softmax(logits, dim=-1) | ||
| measurements = bench_gpu_time( | ||
| lambda: flashinfer.sampling.top_p_renorm_probs(probs, p), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| io = probs.numel() * probs.element_size() * 2 | ||
| bandwidth = io * 1e-6 / ms | ||
| print( | ||
| f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, p: {p}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" | ||
| ) | ||
|
|
||
| print("---") | ||
| print("top-k renorm probs") | ||
| for vocab_size in [128512]: | ||
| for batch_size in [1, 16, 32, 64, 128, 256, 512]: | ||
| torch.manual_seed(42) | ||
| for distrib in [ | ||
| normal_distribution(1), | ||
| normal_distribution(5), | ||
| gumbel_distribution(0.1), | ||
| gumbel_distribution(1), | ||
| ]: | ||
| for k in [10, 100, 1000, 5000]: | ||
| logits = distrib((batch_size, vocab_size), device="cuda") | ||
| probs = torch.softmax(logits, dim=-1) | ||
| measurements = bench_gpu_time( | ||
| lambda: flashinfer.sampling.top_k_renorm_probs(probs, k), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| io = probs.numel() * probs.element_size() * 2 | ||
| bandwidth = io * 1e-6 / ms | ||
| print( | ||
| f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" | ||
| ) | ||
|
|
||
| print("---") | ||
| print("top-k mask logits") | ||
| for vocab_size in [128512]: | ||
| for batch_size in [1, 16, 32, 64, 128, 256, 512]: | ||
| torch.manual_seed(42) | ||
| for distrib in [ | ||
| normal_distribution(1), | ||
| normal_distribution(5), | ||
| gumbel_distribution(0.1), | ||
| gumbel_distribution(1), | ||
| ]: | ||
| for k in [10, 100, 1000, 5000]: | ||
| logits = distrib((batch_size, vocab_size), device="cuda") | ||
| measurements = bench_gpu_time( | ||
| lambda: flashinfer.sampling.top_k_mask_logits(logits, k), | ||
| dry_run_time_ms=100, | ||
| repeat_time_ms=1000, | ||
| ) | ||
| ms = np.median(measurements) | ||
|
|
||
| io = logits.numel() * logits.element_size() * 2 | ||
| bandwidth = io * 1e-6 / ms | ||
| print( | ||
| f"vocab_size: {vocab_size}, batch_size: {batch_size}, distrib: {distrib.__name__}, k: {k}, duration: {ms * 1e3:.2f} us, effective bandwidth: {bandwidth:.2f} GB/s" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The three new benchmark sections for top-p renorm probs, top-k renorm probs, and top-k mask logits contain significant code duplication. To improve maintainability and readability, consider refactoring this logic into a single helper function. This function could be parameterized by the operation name, the function to benchmark, the parameter name ('p' or 'k'), the parameter values, and whether the function takes logits or probabilities as input.
| ax1.set_title("Speedup vs Hidden Size", fontsize=13, fontweight="bold") | ||
| ax1.grid(True, alpha=0.3) | ||
| ax1.legend(fontsize=9) | ||
| ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with the second subplot (ax2), consider adding a label to this axhline call. This will ensure the 'No speedup' line is labeled in the legend for both plots.
| ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5) | |
| ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
benchmarks/bench_sampling.py (1)
238-299: Bind loop variables when building the benchmark lambdasRuff’s B023 warning here is legitimate: the lambdas capture
probs,p, andkby reference, so a future change that defers execution (or reuses the callable after the loop advances) would end up timing the wrong tensors/thresholds. Binding the current values via default arguments silences the warning and makes the benchmarks future-proof. Please apply the following update in all three loops.- lambda: flashinfer.sampling.top_p_renorm_probs(probs, p), + lambda probs=probs, p=p: flashinfer.sampling.top_p_renorm_probs(probs, p), ... - lambda: flashinfer.sampling.top_k_renorm_probs(probs, k), + lambda probs=probs, k=k: flashinfer.sampling.top_k_renorm_probs(probs, k), ... - lambda: flashinfer.sampling.top_k_mask_logits(logits, k), + lambda logits=logits, k=k: flashinfer.sampling.top_k_mask_logits(logits, k),benchmarks/bench_softmax.py (1)
105-121: Silence the unusedfigbinding
figisn’t used afterplt.subplots, which raises RUF059. Prefix it with an underscore (or drop it) to keep linters quiet.- fig, ax = plt.subplots(figsize=(12, 8)) + _fig, ax = plt.subplots(figsize=(12, 8))
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
benchmarks/bench_sampling.py(1 hunks)benchmarks/bench_softmax.py(1 hunks)include/flashinfer/sampling.cuh(18 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
benchmarks/bench_sampling.py (3)
benchmarks/bench_renorm.py (2)
normal_distribution(8-13)gumbel_distribution(16-23)flashinfer/sampling.py (8)
softmax(52-72)softmax(505-559)top_p_renorm_probs(325-341)top_p_renorm_probs(1166-1226)top_k_renorm_probs(354-368)top_k_renorm_probs(1232-1291)top_k_mask_logits(381-395)top_k_mask_logits(1297-1351)flashinfer/testing/utils.py (1)
bench_gpu_time(972-1033)
benchmarks/bench_softmax.py (1)
flashinfer/testing/utils.py (1)
bench_gpu_time(972-1033)
🪛 Ruff (0.14.3)
benchmarks/bench_sampling.py
238-238: Function definition does not bind loop variable probs
(B023)
238-238: Function definition does not bind loop variable p
(B023)
265-265: Function definition does not bind loop variable probs
(B023)
265-265: Function definition does not bind loop variable k
(B023)
291-291: Function definition does not bind loop variable logits
(B023)
291-291: Function definition does not bind loop variable k
(B023)
benchmarks/bench_softmax.py
105-105: Unpacked variable fig is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
| for j, hidden_size in enumerate(hidden_sizes): | ||
| ax2.plot( | ||
| batch_sizes, | ||
| speedups[:, j], | ||
| marker="o", | ||
| label=f"Hidden={hidden_size // 1000}K", | ||
| linewidth=2, | ||
| ) | ||
|
|
||
| ax2.set_xlabel("Batch Size", fontsize=12, fontweight="bold") | ||
| ax2.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold") | ||
| ax2.set_title("Speedup vs Batch Size", fontsize=13, fontweight="bold") | ||
| ax2.set_xscale("log", base=2) | ||
| ax2.grid(True, alpha=0.3) | ||
| ax2.legend(fontsize=9) | ||
| ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup") | ||
|
|
||
| # Plot 1: Speedup trends across hidden sizes | ||
| for i, batch_size in enumerate(batch_sizes[::2]): # Sample every other batch size | ||
| idx = i * 2 | ||
| ax1.plot( | ||
| [h // 1000 for h in hidden_sizes], | ||
| speedups[idx, :], | ||
| marker="s", | ||
| label=f"Batch={batch_size}", | ||
| linewidth=2, | ||
| ) | ||
|
|
||
| ax1.set_xlabel("Hidden Size (K)", fontsize=12, fontweight="bold") | ||
| ax1.set_ylabel("Speedup (x)", fontsize=12, fontweight="bold") | ||
| ax1.set_title("Speedup vs Hidden Size", fontsize=13, fontweight="bold") | ||
| ax1.grid(True, alpha=0.3) | ||
| ax1.legend(fontsize=9) | ||
| ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5) | ||
|
|
||
| plt.tight_layout() | ||
| comparison_path = save_path.replace(".png", "_trends.png") | ||
| plt.savefig(comparison_path, dpi=300, bbox_inches="tight") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure the “No speedup” guide lines appear in the legends
ax2.legend(...) (and ax1.legend(...)) is called before each axhline, so the horizontal “No speedup” lines never reach the legend despite the label argument. Move the legend calls after the guide lines (or call legend again afterwards) so the reference line is actually shown.
- ax2.legend(fontsize=9)
- ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")
+ ax2.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")
+ ax2.legend(fontsize=9)
...
- ax1.legend(fontsize=9)
- ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5)
+ ax1.axhline(y=1.0, color="red", linestyle="--", alpha=0.5, label="No speedup")
+ ax1.legend(fontsize=9)🤖 Prompt for AI Agents
In benchmarks/bench_softmax.py around lines 139 to 176 the horizontal "No
speedup" axhline calls are placed after ax2.legend(...) and ax1.legend(...), so
their labeled guide lines don't appear in the legends; fix by moving the
legend() calls to after the axhline(...) calls (or call legend() again after
adding the axhline), i.e., ensure each axis adds its axhline before creating the
legend so the guide line with label="No speedup" is included.
|
/bot run |
|
[SUCCESS] Pipeline #37995303: 13/17 passed |
bkryu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and unit tests are passing and seems like the reduced number of synchronizations are applied throughout correctly.
Nothing to do for this PR, but I was checking if there are any further optimizations to do and came up with #2058 as a followup for completeness.
<!-- .github/pull_request_template.md --> ## 📌 Description Apply optimizations similar to #2044 to max/min functions. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> #2044 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Performance Improvements** * Improved sampling performance by reducing per-iteration synchronization and temporary storage, deferring aggregate reductions until after iterative work completes. This lowers runtime overhead and memory churn, yielding faster and more efficient processing for sampling operations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: yzh119 <[email protected]>
📌 Description
This is the first part of the performance improvement PR for sampling/mask/softmax operator, in this PR, we defer the cross thread reduction till the end of the loop (similar to how FA2 handles denominator) to reduce the number of shuffling and thread sync instructions.
For the second part of the PR, we will implement the Radix TopK algorithm to improve top-k mask logits performance when K is small.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Tests
Chores