Add dynamic tokens-per-page TRTLLM-GEN GQA kernels#3259
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📥 CommitsReviewing files that changed from the base of the PR and between dffd4b83dfd4da812afcbe277cee082dfeb1d45b and eb503ea. 📒 Files selected for processing (6)
🚧 Files skipped from review as they are similar to previous changes (5)
📝 WalkthroughWalkthroughAdds dynamic numTokensPerPage selection for TRTLLM FMHA kernels, extends parameter structs for sparse-MLA, updates kernel selection and hash encoding (including MLA Dense-mask handling), bumps the TRTLLM_GEN_FMHA artifact path and checksum, and adds Blackwell-gated tests plus a float16 tolerance tweak. ChangesDynamic numTokensPerPage Selection for TRTLLM FMHA
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Selector
participant Hasher
participant KernelPicker
Caller->>Selector: selectNumTokensPerPage(params)
Selector-->>Caller: set mNumTokensPerPage, mDynamicNumTokensPerPage
Caller->>Hasher: build/validate kernel hash (numTokensPerPageLog2)
Hasher-->>KernelPicker: candidate hashes
KernelPicker-->>Caller: selected cubin (Dense mask for MLA gen)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Warning Gemini is experiencing higher than usual traffic and was unable to create the review. Please try again in a few hours by commenting |
3ddae55 to
8b17dc6
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@include/flashinfer/trtllm/fmha/kernelParams.h`:
- Around line 47-48: The field CUtensorMap tmaKSlidingWindowKvPool_ is declared
but never initialized or used; either remove this dead member or initialize it
in setKernelParams by calling buildNdTmaDescriptor in the same sparse-MLA
conditional where tmaK_ and tmaV_ are built (mirror their construction/guards),
ensuring the descriptor parameters match the DSv4 sparse MLA sliding-window KV
pool layout if you intend to support that path; update any related cleanup/usage
accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: efb8b457-2332-4fab-9813-3e51975b0e44
📥 Commits
Reviewing files that changed from the base of the PR and between 3ddae5562724a157a2f9e06ba150a05f72a32cca and 8b17dc664bacc1e5685345ba66e92b4411465d84.
📒 Files selected for processing (6)
flashinfer/artifacts.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_attention_sink_blackwell.pytests/attention/test_trtllm_gen_attention.py
🚧 Files skipped from review as they are similar to previous changes (4)
- include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
- flashinfer/artifacts.py
- include/flashinfer/trtllm/fmha/fmhaKernels.cuh
- tests/attention/test_trtllm_gen_attention.py
| // TMA descriptor for DSv4 sparse MLA sliding-window KV pool. Same format as tmaK_. | ||
| CUtensorMap tmaKSlidingWindowKvPool_; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find all sites that set tmaKSlidingWindowKvPool_ to verify caller-side initialization.
rg -n "tmaKSlidingWindowKvPool_" --type cpp --type h -C 4Repository: flashinfer-ai/flashinfer
Length of output: 800
🏁 Script executed:
# Find the setKernelParams function and buildNdTmaDescriptor calls
rg -n "setKernelParams|buildNdTmaDescriptor" --type cpp --type h -A 3 -B 1Repository: flashinfer-ai/flashinfer
Length of output: 9188
🏁 Script executed:
# Check the memset call around line 639 in kernelParams.h
head -660 include/flashinfer/trtllm/fmha/kernelParams.h | tail -50Repository: flashinfer-ai/flashinfer
Length of output: 2785
🏁 Script executed:
# Search for where tmaKSlidingWindowKvPool_ is actually used/read
rg -n "tmaKSlidingWindowKvPool_" --type cpp --type cuRepository: flashinfer-ai/flashinfer
Length of output: 95
🏁 Script executed:
# Read the full setKernelParams function (roughly lines 634-850)
sed -n '634,850p' include/flashinfer/trtllm/fmha/kernelParams.hRepository: flashinfer-ai/flashinfer
Length of output: 11415
🏁 Script executed:
# Search for any usage/reads of tmaKSlidingWindowKvPool_ without file type restriction
rg "tmaKSlidingWindowKvPool_" -C 4Repository: flashinfer-ai/flashinfer
Length of output: 773
🏁 Script executed:
# Look for sparse MLA related code to understand if this field is relevant
rg "sparse.*MLA|SlidingWindow" -C 2Repository: flashinfer-ai/flashinfer
Length of output: 6475
🏁 Script executed:
# Search for tmaKSlidingWindowKvPool_ in all files, focusing on kernel code
find . -type f \( -name "*.cuh" -o -name "*.cu" \) -exec grep -l "tmaKSlidingWindowKvPool_" {} \;Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Also search in any generated or header files for kernel access
rg "tmaKSlidingWindowKvPool_" . --type-listRepository: flashinfer-ai/flashinfer
Length of output: 458
🏁 Script executed:
# Look for the sparse MLA kernel to see if it uses this field
rg -n "sparse.*kernel|slidingWindow" --type-list 2>/dev/null | head -20Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Search for kernel code more broadly - look for .cuh and .cu files
find . -type f \( -name "*.cuh" -o -name "*.cu" \) | head -20Repository: flashinfer-ai/flashinfer
Length of output: 668
🏁 Script executed:
# Search all files for any reference to this field with broader patterns
rg "SlidingWindowKvPool" .Repository: flashinfer-ai/flashinfer
Length of output: 156
🏁 Script executed:
# Check git history to see if this field was recently added
git log --oneline --all -- "include/flashinfer/trtllm/fmha/kernelParams.h" | head -10Repository: flashinfer-ai/flashinfer
Length of output: 842
🏁 Script executed:
# See the actual commit that added this field
git log -p --all -S "tmaKSlidingWindowKvPool_" -- "include/flashinfer/trtllm/fmha/kernelParams.h" | head -100Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Look at recent commits to kernelParams.h to find when this field was added
git log -p --all -- "include/flashinfer/trtllm/fmha/kernelParams.h" | grep -B 10 -A 10 "tmaKSlidingWindowKvPool_" | head -50Repository: flashinfer-ai/flashinfer
Length of output: 50
🏁 Script executed:
# Also check if there are any TODOs or comments about this field being unused
rg -B 5 -A 5 "tmaKSlidingWindowKvPool_"Repository: flashinfer-ai/flashinfer
Length of output: 893
🏁 Script executed:
# Search for any test or reference that might use sparse MLA with sliding window
rg "sparse.*sliding|slidingWindow.*sparse" -C 3Repository: flashinfer-ai/flashinfer
Length of output: 633
🏁 Script executed:
# Check the fmhaKernels.cuh file that handles sparse MLA selection
cat csrc/trtllm_fmha_kernel_launcher.cu | grep -A 20 -B 5 "sparse_mla"Repository: flashinfer-ai/flashinfer
Length of output: 8495
Remove tmaKSlidingWindowKvPool_ or implement its initialization for sparse MLA sliding-window support.
The field is declared but never populated via buildNdTmaDescriptor and is unused throughout the codebase. Either remove it as dead code, or if intended for future sparse MLA sliding-window support, add the corresponding buildNdTmaDescriptor call inside setKernelParams (guarded by the sparse-MLA path, similar to how tmaK_ and tmaV_ are constructed).
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 47 - 48, The
field CUtensorMap tmaKSlidingWindowKvPool_ is declared but never initialized or
used; either remove this dead member or initialize it in setKernelParams by
calling buildNdTmaDescriptor in the same sparse-MLA conditional where tmaK_ and
tmaV_ are built (mirror their construction/guards), ensuring the descriptor
parameters match the DSv4 sparse MLA sliding-window KV pool layout if you intend
to support that path; update any related cleanup/usage accordingly.
There was a problem hiding this comment.
Not sure if I'm missing something, but what's the reason for this never actually being initialized?
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for any TODO/FIXME near this field or sparse MLA sliding window
rg -n "TODO|FIXME|tmaKSlidingWindowKvPool_|SlidingWindowKvPool|sparse.*sliding|sliding.*sparse" -C 3
# Also check if the cubin kernels might reference struct offsets that imply this field must exist
rg -n "sparseMla|sparse_mla|mSparseMla" --type cpp --type h -C 3
# Check the full sparse MLA block in setKernelParams
grep -n "sparseMla\|SlidingWindow" include/flashinfer/trtllm/fmha/kernelParams.h
Length of output: 210871
@saltyminty Great question! The root cause is binary ABI compatibility with the precompiled cubins.
KernelParams is the struct passed directly to the TRTLLM-GEN GPU kernels, so its memory layout (field offsets) must exactly match what the precompiled cubins expect. This PR updates the cubin artifacts to a new version (as noted in the PR description), and the new cubins appear to include tmaKSlidingWindowKvPool_ in their expected KernelParams layout. If the field were absent from the host-side struct, every field declared after it would be at the wrong byte offset, causing silent data corruption in the kernel.
The memset zero-initialization acts as a safe "disabled" sentinel — the kernel can check whether the descriptor is valid before using it, and a zeroed CUtensorMap signals that this feature path is inactive.
You can see the same pattern with ptrSparseMlaTopKLens (line 122): it's declared, zero-initialized by memset, and explicitly pinned to nullptr at line 867 — all without a full host-side implementation. Both fields are effectively forward-ported struct layout stubs that maintain cubin ABI compatibility while the actual initialization logic is deferred to a follow-up.
So to summarize: the field must exist to keep the struct binary-compatible with the new cubins, and the zero-initialization from memset is the correct default for the "feature not active" case.
There was a problem hiding this comment.
this is actually used for a new kind of kernels, see #3269.
There was a problem hiding this comment.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
92c6075 to
dffd4b8
Compare
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/attention/test_attention_sink_blackwell.py (1)
215-215:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winRemove debug
Line 215 (
print(q.shape, k.shape, v.shape)) is a debug artifact that produces noise in CI output. It should be removed before merging.🐛 Proposed fix
- print(q.shape, k.shape, v.shape) - o_ref = sink_attention_unified(🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/attention/test_attention_sink_blackwell.py` at line 215, Remove the debug print that prints tensor shapes — delete the statement printing q.shape, k.shape, v.shape (the `print(q.shape, k.shape, v.shape)` line) from the test so it no longer emits noisy output during CI; leave the surrounding test logic and any assertions intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@tests/attention/test_attention_sink_blackwell.py`:
- Line 215: Remove the debug print that prints tensor shapes — delete the
statement printing q.shape, k.shape, v.shape (the `print(q.shape, k.shape,
v.shape)` line) from the test so it no longer emits noisy output during CI;
leave the surrounding test logic and any assertions intact.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f785676b-3ed8-4591-90db-4b98cb5c0c67
📥 Commits
Reviewing files that changed from the base of the PR and between 8b17dc664bacc1e5685345ba66e92b4411465d84 and 92c6075b657b93a1c865835647f534560e9137c1.
📒 Files selected for processing (6)
flashinfer/artifacts.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_attention_sink_blackwell.pytests/attention/test_trtllm_gen_attention.py
✅ Files skipped from review due to trivial changes (2)
- flashinfer/artifacts.py
- include/flashinfer/trtllm/fmha/fmhaKernels.cuh
🚧 Files skipped from review as they are similar to previous changes (2)
- include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
- tests/attention/test_trtllm_gen_attention.py
saltyminty
left a comment
There was a problem hiding this comment.
Approved but see comment about tmaKSlidingWindowKvPool_
| // TMA descriptor for DSv4 sparse MLA sliding-window KV pool. Same format as tmaK_. | ||
| CUtensorMap tmaKSlidingWindowKvPool_; |
There was a problem hiding this comment.
Not sure if I'm missing something, but what's the reason for this never actually being initialized?
dffd4b8 to
eb503ea
Compare
|
/bot run |
|
@qsang-nv @saltyminty the CI failures are more related to machine-allocation issues. Feel free to merge if it looks good to you. Thanks! B300 tests are passed so it should be good. |
|
/bot run |
Summary
Tests
Summary by CodeRabbit
Chores
New Features
Tests