Skip to content

Allow BatchDecodeWithPagedKVCacheWrapper for GQA ratio 16 and 32#2895

Open
bkryu wants to merge 2 commits intoflashinfer-ai:mainfrom
bkryu:fa2_large_grp_sz
Open

Allow BatchDecodeWithPagedKVCacheWrapper for GQA ratio 16 and 32#2895
bkryu wants to merge 2 commits intoflashinfer-ai:mainfrom
bkryu:fa2_large_grp_sz

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Mar 25, 2026

📌 Description

  • Add group sizes 16, 32 to DISPATCH_GQA_GROUP_SIZE in include/flashinfer/utils.cuh, fixing BatchDecodeWithPagedKVCacheWrapper crash for models with high GQA ratios (e.g., Nemotron 32 QO / 2 KV = ratio 16)
  • Add num_kv_heads=2 to batch decode test parametrization to cover GQA ratio 16

🔍 Related Issues

#2849

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added support for additional attention group sizes (16 and 32), enabling broader runtime configuration.
  • Tests

    • Expanded batch decode kernel tests to cover more key-value head combinations (now testing 2 and 4 heads) to improve reliability across configurations.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the FlashInfer library by extending support for Grouped Query Attention (GQA) to a wider range of group sizes. This change specifically addresses and prevents crashes in batch decode operations for models with high GQA ratios, ensuring greater compatibility and stability. Additionally, the associated tests have been updated to cover these new configurations, reinforcing the robustness of the implementation.

Highlights

  • GQA Group Size Expansion: Expanded the DISPATCH_GQA_GROUP_SIZE macro in include/flashinfer/utils.cuh to support group sizes 16, 32, and 64.
  • Crash Fix: Resolved a crash in BatchDecodeWithPagedKVCacheWrapper that occurred with models having high GQA ratios, such as Nemotron 32 QO / 2 KV (ratio 16).
  • Test Coverage Improvement: Added num_kv_heads=2 to the batch decode test parametrizations to ensure comprehensive coverage for GQA ratio 16 scenarios.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 25, 2026

/bot run

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 25, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ad4e64e3-5a41-481f-9ee5-ef979522ccaf

📥 Commits

Reviewing files that changed from the base of the PR and between cb8f3d0 and afacf80.

📒 Files selected for processing (1)
  • include/flashinfer/utils.cuh

📝 Walkthrough

Walkthrough

Added compile-time dispatch branches for GQA group sizes 16 and 32 in a CUDA header; updated four decode kernel tests to parametrize num_kv_heads with [2, 4] instead of [4]. (No public API changes.)

Changes

Cohort / File(s) Summary
GQA Macro Extension
include/flashinfer/utils.cuh
Added DISPATCH_GQA_GROUP_SIZE branches for group_size == 16 and group_size == 32, defining constexpr size_t GROUP_SIZE accordingly and expanding __VA_ARGS__.
Test Parameter Coverage
tests/attention/test_batch_decode_kernels.py
Changed four @pytest.mark.parametrize("num_kv_heads", ...) entries from [4] to [2, 4], widening KV-head parameter sweep in paged KV decode tests.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested labels

v0.6.2

Suggested reviewers

  • cyx-6
  • nvmbreughe
  • jimmyzho
  • yzh119

Poem

🐰 I twitch my whiskers, twiddle my paws,

Two new group sizes hop into the cause.
Tests now probe both two and four heads true,
Compile-time branches bloom—hippity-hoo!

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: enabling support for GQA ratios 16 and 32 in BatchDecodeWithPagedKVCacheWrapper.
Description check ✅ Passed The description includes concrete details about changes to DISPATCH_GQA_GROUP_SIZE macro and test parametrization, references issue #2849, but has incomplete checklist items that may indicate pending work.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request expands the supported group sizes in flashinfer/utils.cuh to include 16, 32, and 64, in addition to the existing 8. It also enhances test coverage for batch decode kernels by adding num_kv_heads = 2 to the parameterized tests in test_batch_decode_kernels.py. There is no feedback to provide.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !463 has been created, and the CI pipeline #47010006 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@include/flashinfer/utils.cuh`:
- Around line 153-161: The GROUP_SIZE=64 branch can produce illegal thread-block
sizes for some HEAD_DIM values; add a guard that rejects invalid
GROUP_SIZE/HEAD_DIM combos by inserting a compile-time static_assert in the
kernel template (where GROUP_SIZE and HEAD_DIM are template parameters) to
validate (e.g. compute threads_per_block from HEAD_DIM and GROUP_SIZE and assert
<= 1024), and also add a runtime check in the kernel launcher to return/error if
the chosen GROUP_SIZE (from the macro that defines GROUP_SIZE) together with the
provided HEAD_DIM would create threads_per_block > 1024; reference GROUP_SIZE
and HEAD_DIM (and the macro block in include/flashinfer/utils.cuh) so the check
is colocated with the branch that sets GROUP_SIZE and the kernel launch path.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ea2e2b74-041b-45e2-8b56-7debd67de09e

📥 Commits

Reviewing files that changed from the base of the PR and between e82d33d and cb8f3d0.

📒 Files selected for processing (2)
  • include/flashinfer/utils.cuh
  • tests/attention/test_batch_decode_kernels.py

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 25, 2026

/bot stop

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 25, 2026

/bot run

@bkryu bkryu changed the title Fix CUDA-cores batch decode for GQA ratio >= 16 Allow BatchDecodeWithPagedKVCacheWrapper for GQA ratio 16 and 32 Mar 25, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #47010006 has been cancelled.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !463 has been updated with latest changes, and the CI pipeline #47010961 is currently running. I'll report back once the pipeline job completes.

@bkryu bkryu self-assigned this Mar 25, 2026
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47010961: 11/20 passed

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's recommend to use enable tensor cores (which do not rely on this macro) for large GQA shape, when group size = 16/32, cuda cores implementation is very slow, see #2684 (review).

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 27, 2026

It's recommend to use enable tensor cores (which do not rely on this macro) for large GQA shape, when group size = 16/32, cuda cores implementation is very slow, see #2684 (review).

Right. I left a comment about this in #2849 that on SM121 use_tensor_cores=True, or even cudnn_batch_decode_with_kv_cache and trtllm_batch_decode_with_kv_cache should be more performant and I'd like to know whether there are issues with using one of them.

Waiting for a response from @TrevorS who filed the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants