fix: ensure each CTA processes full numHeadsQPerKv for trtllm decode kernel#2380
Conversation
📝 WalkthroughWalkthroughA filtering condition in the GQA tile-size heuristic is tightened: candidate tileSizeQ values must satisfy both tileSizeQ ≤ defaultTileSizeQ and tileSizeQ ≥ mNumHeadsQPerKv, changing which candidates are evaluated during tile-size selection. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
e4e6234 to
728252e
Compare
|
Hi @dongjiyingdjy would you mind explaning the context of this pull requests in the PR description? |
|
Caution Docstrings generation - FAILED No docstrings were generated. |
|
/bot run |
This PR ensures that all Q heads within the same group are in the same CTA. The previous tile select strategy did not account for this, which could cause Q heads from a single group to be split across multiple CTAs, leading to incorrect results. |
📌 Description
Skip candidates where kernelMeta.mStepQ < params.mNumHeadsQPerKv in GQA tile selection to avoid numTokensPerCtaQ=0, resulting divide-by-zero crash.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.