Add GQA group_size 5, 6, 7 to DISPATCH_GQA_GROUP_SIZE#2986
Add GQA group_size 5, 6, 7 to DISPATCH_GQA_GROUP_SIZE#2986arbi-dev wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
The macro only dispatched group sizes 1, 2, 3, 4, 8 — any other value
hit a runtime error ("Unsupported group_size"). This breaks several
popular models with non-power-of-2 GQA ratios:
- group_size 6: Qwen3.5-27B (24Q/4KV), InternLM2.5-20B (48Q/8KV)
- group_size 7: Qwen2.5-7B (28Q/4KV), Yi-1.5-34B (56Q/8KV)
Add explicit constexpr cases for 5, 6, and 7 so all group sizes 1-8
are supported. Each adds one template instantiation per call site.
The error manifests as:
RuntimeError: Unsupported group_size: 6
when calling BatchDecodeWithPagedKVCache or similar kernel dispatch
paths that go through DISPATCH_GQA_GROUP_SIZE.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThe Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request expands the DISPATCH_GQA_GROUP_SIZE macro in include/flashinfer/utils.cuh to include support for group sizes 5, 6, and 7. The review feedback suggests refactoring the macro's if-else if chain into a switch statement to improve evaluation efficiency and maintain consistency with other dispatch macros in the project.
| } else if (group_size == 5) { \ | ||
| constexpr size_t GROUP_SIZE = 5; \ | ||
| __VA_ARGS__ \ | ||
| } else if (group_size == 6) { \ | ||
| constexpr size_t GROUP_SIZE = 6; \ | ||
| __VA_ARGS__ \ | ||
| } else if (group_size == 7) { \ | ||
| constexpr size_t GROUP_SIZE = 7; \ | ||
| __VA_ARGS__ \ |
There was a problem hiding this comment.
With the addition of more group sizes, the if-else if chain in DISPATCH_GQA_GROUP_SIZE is becoming increasingly long. Consider refactoring the macro to use a switch statement. This would ensure the group_size expression is evaluated only once and would improve consistency with other dispatch macros in this file (such as DISPATCH_CTA_TILE_Q and DISPATCH_HEAD_DIM) that already use switch for exact value matching.
|
Hi @arbi-dev can you see my comments in #2684 (review) |
Summary
DISPATCH_GQA_GROUP_SIZEonly handles group sizes 1, 2, 3, 4, 8. Any other value hits a runtime error:This breaks several popular models with non-power-of-2 GQA ratios:
This PR adds explicit
constexprcases for group sizes 5, 6, and 7, so all sizes 1-8 are supported. Each adds one template instantiation per call site, matching the existing dispatch pattern.Why this hasn't been reported widely
Most users access FlashInfer through vLLM or SGLang, which use
BatchDecodeWithPagedKVCacheWrapper. That wrapper handles GQA at the Python level and doesn't go throughDISPATCH_GQA_GROUP_SIZE. The error only manifests when calling the lower-level C++ kernel dispatch directly (e.g., custom attention backends or quantized KV cache implementations that bypass the Python wrapper).Test plan
AI-assisted: Claude Opus 4.6 assisted with code generation and model survey.
Summary by CodeRabbit