Skip to content

Conversation

@Edenzzzz
Copy link
Contributor

@Edenzzzz Edenzzzz commented Oct 30, 2025

📌 Description

Use real head sizes, seq lens and add comparison with sequential prefill + decode.
Results on H100 (without overlap, which only adds ~150GB/s for persistent):
image
cc @yzh119

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

Release Notes

  • New Features

    • Added comprehensive performance benchmarking for batch attention operations with detailed timing measurements.
    • Introduced sequential dual-kernel benchmark path with extended memory bandwidth reporting.
  • Tests

    • Updated benchmark test configurations to use deterministic, fixed values for improved reproducibility.
    • Adjusted benchmark parameters for consistency across test iterations.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 30, 2025

Walkthrough

The benchmark file adds new persistent BatchAttention and sequential two-kernel benchmark paths with dedicated timing measurements. Randomized test fixture generation is replaced with fixed deterministic configuration sets. Benchmark outputs are extended to report timing and bandwidth calculations for the new paths. Configuration parameters are updated: num_kv_heads increased from 4 to 8 and num_qo_heads from 28 to 32.

Changes

Cohort / File(s) Summary
Benchmark Infrastructure and Configuration Updates
benchmarks/bench_mixed_attention.py
Added persistent BatchAttention workflow with timing measurements; introduced sequential two-kernel benchmark path (single_prefill_with_kv_cache followed by BatchDecodeWithPagedKVCacheWrapper); extended elapsed time and bandwidth reporting for new paths; replaced randomized test fixture generation with fixed deterministic configuration sets; updated public benchmark parameters (num_kv_heads: 4→8, num_qo_heads: 28→32)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Attention required:
    • Verification that new persistent and sequential two-kernel measurement paths correctly capture timing and integrate with existing benchmark infrastructure
    • Validation of the deterministic configuration replacement logic and confirmation that removed randomization doesn't compromise test coverage
    • Review of updated parameter values (num_kv_heads and num_qo_heads changes) and their impact on benchmark representativeness
    • Correctness of bandwidth calculations for the new benchmark paths

Suggested reviewers

  • yzh119

Poem

🐰 Deterministic hops through benchmark lanes,
Fixed configs now, no random refrains,
Two kernels dancing in sequential time,
Persistent workflows and metrics sublime,
Measurements refined, the benchmarks now shine!

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title "More realistic bench for POD Attn" is directly related to the changeset's primary objective. The raw summary confirms that the changes involve using real head sizes and sequence lengths for benchmarks and adding a sequential prefill + decode comparison path. While the title is concise and doesn't enumerate all details (like the sequential comparison), it clearly captures the main intent of making the benchmark more realistic. The title is specific enough that a teammate scanning history would understand the benchmark was updated for greater realism, and per the guidelines, a title doesn't need to cover every detail of the changeset.
Description Check ✅ Passed The pull request description follows the repository's template structure with all major sections present: a Description section that explains the changes (using real head sizes, seq lens, and adding sequential prefill + decode comparison), a Related Issues section (though empty), and the Pull Request Checklist framework. The Description section, while concise, is specific and substantive, supported by an attached benchmark results image from H100. Although the Related Issues section lacks entries and the checklist items remain unchecked, the core description content is clear and complete enough to convey what the PR accomplishes and why.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Edenzzzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the realism and scope of attention mechanism benchmarks. It updates the test configurations to reflect more practical head sizes and sequence lengths, and crucially, adds new performance comparisons against Persistent BatchAttention and a Sequential two-kernel prefill-decode strategy. This provides a more thorough and relevant performance analysis for different attention implementations, aiming to better represent real-world usage on hardware like the H100.

Highlights

  • Benchmarking Realistic Scenarios: Updated attention benchmarks to use more realistic head sizes and sequence lengths, moving away from dynamically generated irregular test cases to fixed, representative configurations.
  • New Performance Comparisons: Introduced benchmarks for Persistent BatchAttention and a Sequential two-kernel approach (single prefill + batch decode) to provide a more comprehensive performance evaluation.
  • Configuration Updates: Increased num_kv_heads from 4 to 8 and num_qo_heads from 28 to 32 to simulate larger model configurations in the benchmarks.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the mixed attention benchmark to use more realistic head sizes and sequence lengths, and adds comparisons with a persistent BatchAttention kernel and a sequential prefill-decode implementation. The changes are a good step towards more representative benchmarking. I have a couple of suggestions to improve consistency and maintainability in the benchmark code.

)
o_persistent, _ = wrapper_persistent.run(q, kv_data)
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
ms_persistent = np.mean(measurements_persistent)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the other measurements in this benchmark, it's better to use np.median instead of np.mean. np.median is more robust to outliers, which can be common in performance measurements.

Suggested change
ms_persistent = np.mean(measurements_persistent)
ms_persistent = np.median(measurements_persistent)

num_kv_heads,
head_dim,
page_block_size,
data_type=torch.bfloat16,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The data_type parameter in BatchDecodeWithPagedKVCacheWrapper.plan is deprecated. Please use kv_data_type instead for clarity and to avoid using deprecated APIs.

Suggested change
data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b9287c9 and c5ef2ca.

📒 Files selected for processing (1)
  • benchmarks/bench_mixed_attention.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
benchmarks/bench_mixed_attention.py (3)
flashinfer/attention.py (1)
  • BatchAttention (42-198)
flashinfer/decode.py (9)
  • plan (810-1102)
  • plan (1603-1726)
  • run (1132-1145)
  • run (1148-1161)
  • run (1163-1374)
  • run (1728-1852)
  • BatchDecodeWithPagedKVCacheWrapper (581-1410)
  • use_tensor_cores (779-780)
  • use_tensor_cores (1576-1577)
flashinfer/prefill.py (11)
  • plan (1523-1919)
  • plan (2489-2777)
  • run (1950-1962)
  • run (1965-1977)
  • run (1979-2206)
  • run (2807-2817)
  • run (2820-2830)
  • run (2832-2978)
  • single_prefill_with_kv_cache (911-932)
  • single_prefill_with_kv_cache (936-957)
  • single_prefill_with_kv_cache (960-1195)
🪛 Ruff (0.14.2)
benchmarks/bench_mixed_attention.py

90-90: Unpacked variable o_persistent is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

Comment on lines +90 to +92
o_persistent, _ = wrapper_persistent.run(q, kv_data)
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
ms_persistent = np.mean(measurements_persistent)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Drop unused persistent output.

Line 90 binds o_persistent, but the value is never read and Ruff emits RUF059. Please discard the binding (for example, call wrapper_persistent.run(q, kv_data) without assignment or bind to _) so the warm-up still happens without leaving an unused variable.

-    o_persistent, _ = wrapper_persistent.run(q, kv_data)
+    wrapper_persistent.run(q, kv_data)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
o_persistent, _ = wrapper_persistent.run(q, kv_data)
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
ms_persistent = np.mean(measurements_persistent)
wrapper_persistent.run(q, kv_data)
measurements_persistent = bench_gpu_time(lambda: wrapper_persistent.run(q, kv_data))
ms_persistent = np.mean(measurements_persistent)
🧰 Tools
🪛 Ruff (0.14.2)

90-90: Unpacked variable o_persistent is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In benchmarks/bench_mixed_attention.py around lines 90 to 92, the first call
assigns o_persistent which is never used (RUF059); remove the unused variable by
calling wrapper_persistent.run(q, kv_data) without assignment or assign the
result to _ so the warm-up call still executes but no unused binding remains.

Comment on lines +145 to +178
# Sequential two kernels: single prefill + batch decode (tensor cores)
# Prefill using single_prefill_with_kv_cache
def _run_single_prefill():
return flashinfer.prefill.single_prefill_with_kv_cache(
q_p,
k_p,
v_p,
causal=causal,
pos_encoding_mode="NONE",
backend="fa2",
)

measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
ms_prefill = np.median(measurements_prefill)

# Batch decode using tensor cores
wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
)
wrapper_decode.plan(
d_kv_indptr.to(device),
kv_indices_d.to(device),
last_page_len_d,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
data_type=torch.bfloat16,
q_data_type=torch.bfloat16,
)
measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
ms_decode = np.median(measurements_decode)
ms_seq_two_kernels = ms_prefill + ms_decode

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Measure sequential path in one benchmarked call.

Lines 158-177 derive ms_seq_two_kernels by summing medians from two completely separate benchmark runs. Because bench_gpu_time synchronizes around each callable, that sum omits the synchronization gap between kernels and hides any stream/data dependency penalties when prefill hands off to decode. As a result, the reported “Sequential two kernels” latency is optimistic and not directly comparable to the single-call POD/persistent timings. Benchmark the sequential path inside a single callable and use that median instead so the printed number reflects the real pipeline cost.

         measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
         ms_prefill = np.median(measurements_prefill)
 
         # Batch decode using tensor cores
         wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
             workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
         )
@@
         )
-        measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
-        ms_decode = np.median(measurements_decode)
-        ms_seq_two_kernels = ms_prefill + ms_decode
+        measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
+        ms_decode = np.median(measurements_decode)
+
+        def _run_prefill_and_decode():
+            _run_single_prefill()
+            return wrapper_decode.run(q_d, kv_d)
+
+        measurements_seq = bench_gpu_time(_run_prefill_and_decode)
+        ms_seq_two_kernels = np.median(measurements_seq)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Sequential two kernels: single prefill + batch decode (tensor cores)
# Prefill using single_prefill_with_kv_cache
def _run_single_prefill():
return flashinfer.prefill.single_prefill_with_kv_cache(
q_p,
k_p,
v_p,
causal=causal,
pos_encoding_mode="NONE",
backend="fa2",
)
measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
ms_prefill = np.median(measurements_prefill)
# Batch decode using tensor cores
wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
)
wrapper_decode.plan(
d_kv_indptr.to(device),
kv_indices_d.to(device),
last_page_len_d,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
data_type=torch.bfloat16,
q_data_type=torch.bfloat16,
)
measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
ms_decode = np.median(measurements_decode)
ms_seq_two_kernels = ms_prefill + ms_decode
# Sequential two kernels: single prefill + batch decode (tensor cores)
# Prefill using single_prefill_with_kv_cache
def _run_single_prefill():
return flashinfer.prefill.single_prefill_with_kv_cache(
q_p,
k_p,
v_p,
causal=causal,
pos_encoding_mode="NONE",
backend="fa2",
)
measurements_prefill = bench_gpu_time(lambda: _run_single_prefill())
ms_prefill = np.median(measurements_prefill)
# Batch decode using tensor cores
wrapper_decode = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, kv_layout=kv_layout, use_tensor_cores=True
)
wrapper_decode.plan(
d_kv_indptr.to(device),
kv_indices_d.to(device),
last_page_len_d,
num_qo_heads,
num_kv_heads,
head_dim,
page_block_size,
data_type=torch.bfloat16,
q_data_type=torch.bfloat16,
)
measurements_decode = bench_gpu_time(lambda: wrapper_decode.run(q_d, kv_d))
ms_decode = np.median(measurements_decode)
def _run_prefill_and_decode():
_run_single_prefill()
return wrapper_decode.run(q_d, kv_d)
measurements_seq = bench_gpu_time(_run_prefill_and_decode)
ms_seq_two_kernels = np.median(measurements_seq)
🤖 Prompt for AI Agents
In benchmarks/bench_mixed_attention.py around lines 145 to 178, the sequential
two-kernel latency is computed by summing medians from two separate
bench_gpu_time runs (prefill and decode), which omits inter-kernel
synchronization and handoff cost; instead, wrap the whole sequential sequence
(call single_prefill_with_kv_cache followed immediately by wrapper_decode.run)
in a single callable passed to bench_gpu_time so the synchronization overhead
between kernels is measured, take the median of that single measurement as
ms_seq_two_kernels, and use that value wherever the combined sequential latency
is reported.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the benefit of POD mainly coming from overlapping?

LGTM overall, we will revamp the OSS attention code in the coming release and let's check the performance later.

@yzh119 yzh119 enabled auto-merge (squash) October 30, 2025 21:23
@Edenzzzz
Copy link
Contributor Author

Edenzzzz commented Oct 30, 2025

I suppose the benefit of POD mainly coming from overlapping?

LGTM overall, we will revamp the OSS attention code in the coming release and let's check the performance later.

Yes, though when I tested overlap persistent kernel, the benefit is not as significant as in POD, which is confusing because it should achieve statically planned overlap instead of opportunistic, and the block launch/warmup overhead should be lower
Will make a bench/issue against FA3 later
d09a73481279ca3dc8d71e7aa970a5ab

912d8ba42e62c44405df133d7012cdad image

@yzh119 yzh119 disabled auto-merge October 31, 2025 06:48
@yzh119 yzh119 merged commit 1181c5d into flashinfer-ai:main Oct 31, 2025
4 checks passed
@Edenzzzz Edenzzzz deleted the upd_pod_bench branch October 31, 2025 15:23
@Edenzzzz
Copy link
Contributor Author

@yzh119 ncu profiling results show POD has less branching and higher memory throughput
image
image
image
persistent_vs_pod.ncu-rep.zip

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants