-
Notifications
You must be signed in to change notification settings - Fork 1k
Add dynamic tokens-per-page TRTLLM-GEN GQA kernels #3259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,8 @@ struct KernelParams { | |
| CUtensorMap tmaQ_; | ||
| // TMA descriptor for K. | ||
| CUtensorMap tmaK_; | ||
| // TMA descriptor for DSv4 sparse MLA sliding-window KV pool. Same format as tmaK_. | ||
| CUtensorMap tmaKSlidingWindowKvPool_; | ||
|
Comment on lines
+47
to
+48
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Find all sites that set tmaKSlidingWindowKvPool_ to verify caller-side initialization.
rg -n "tmaKSlidingWindowKvPool_" --type cpp --type h -C 4Repository: flashinfer-ai/flashinfer Length of output: 800 🏁 Script executed: # Find the setKernelParams function and buildNdTmaDescriptor calls
rg -n "setKernelParams|buildNdTmaDescriptor" --type cpp --type h -A 3 -B 1Repository: flashinfer-ai/flashinfer Length of output: 9188 🏁 Script executed: # Check the memset call around line 639 in kernelParams.h
head -660 include/flashinfer/trtllm/fmha/kernelParams.h | tail -50Repository: flashinfer-ai/flashinfer Length of output: 2785 🏁 Script executed: # Search for where tmaKSlidingWindowKvPool_ is actually used/read
rg -n "tmaKSlidingWindowKvPool_" --type cpp --type cuRepository: flashinfer-ai/flashinfer Length of output: 95 🏁 Script executed: # Read the full setKernelParams function (roughly lines 634-850)
sed -n '634,850p' include/flashinfer/trtllm/fmha/kernelParams.hRepository: flashinfer-ai/flashinfer Length of output: 11415 🏁 Script executed: # Search for any usage/reads of tmaKSlidingWindowKvPool_ without file type restriction
rg "tmaKSlidingWindowKvPool_" -C 4Repository: flashinfer-ai/flashinfer Length of output: 773 🏁 Script executed: # Look for sparse MLA related code to understand if this field is relevant
rg "sparse.*MLA|SlidingWindow" -C 2Repository: flashinfer-ai/flashinfer Length of output: 6475 🏁 Script executed: # Search for tmaKSlidingWindowKvPool_ in all files, focusing on kernel code
find . -type f \( -name "*.cuh" -o -name "*.cu" \) -exec grep -l "tmaKSlidingWindowKvPool_" {} \;Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Also search in any generated or header files for kernel access
rg "tmaKSlidingWindowKvPool_" . --type-listRepository: flashinfer-ai/flashinfer Length of output: 458 🏁 Script executed: # Look for the sparse MLA kernel to see if it uses this field
rg -n "sparse.*kernel|slidingWindow" --type-list 2>/dev/null | head -20Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Search for kernel code more broadly - look for .cuh and .cu files
find . -type f \( -name "*.cuh" -o -name "*.cu" \) | head -20Repository: flashinfer-ai/flashinfer Length of output: 668 🏁 Script executed: # Search all files for any reference to this field with broader patterns
rg "SlidingWindowKvPool" .Repository: flashinfer-ai/flashinfer Length of output: 156 🏁 Script executed: # Check git history to see if this field was recently added
git log --oneline --all -- "include/flashinfer/trtllm/fmha/kernelParams.h" | head -10Repository: flashinfer-ai/flashinfer Length of output: 842 🏁 Script executed: # See the actual commit that added this field
git log -p --all -S "tmaKSlidingWindowKvPool_" -- "include/flashinfer/trtllm/fmha/kernelParams.h" | head -100Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Look at recent commits to kernelParams.h to find when this field was added
git log -p --all -- "include/flashinfer/trtllm/fmha/kernelParams.h" | grep -B 10 -A 10 "tmaKSlidingWindowKvPool_" | head -50Repository: flashinfer-ai/flashinfer Length of output: 50 🏁 Script executed: # Also check if there are any TODOs or comments about this field being unused
rg -B 5 -A 5 "tmaKSlidingWindowKvPool_"Repository: flashinfer-ai/flashinfer Length of output: 893 🏁 Script executed: # Search for any test or reference that might use sparse MLA with sliding window
rg "sparse.*sliding|slidingWindow.*sparse" -C 3Repository: flashinfer-ai/flashinfer Length of output: 633 🏁 Script executed: # Check the fmhaKernels.cuh file that handles sparse MLA selection
cat csrc/trtllm_fmha_kernel_launcher.cu | grep -A 20 -B 5 "sparse_mla"Repository: flashinfer-ai/flashinfer Length of output: 8495 Remove The field is declared but never populated via 🤖 Prompt for AI Agents
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if I'm missing something, but what's the reason for this never actually being initialized?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: Length of output: 210871
The You can see the same pattern with So to summarize: the field must exist to keep the struct binary-compatible with the new cubins, and the zero-initialization from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is actually used for a new kind of kernels, see #3269.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| // TMA descriptor for V. | ||
| CUtensorMap tmaV_; | ||
| // The descriptor for O. | ||
|
|
@@ -117,6 +119,8 @@ struct KernelParams { | |
|
|
||
| // The softmax stats buffer. | ||
| float2* ptrSoftmaxStats; | ||
| // The variable sparseMla topK lengths with shape of [numTokensQ]. | ||
| int32_t const* ptrSparseMlaTopKLens; | ||
|
|
||
| // The attention window size for sliding window attention. | ||
| int32_t mAttentionWindowSize; | ||
|
|
@@ -860,6 +864,7 @@ struct KernelParams { | |
| params.mStartTokenIdxSfO = options.mSfStartTokenIdx; | ||
| params.mScaleSfKv = options.mScaleSfKv; | ||
| params.ptrSoftmaxStats = options.softmaxStatsPtr; | ||
| params.ptrSparseMlaTopKLens = nullptr; | ||
| // The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the | ||
| // indices. | ||
| FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.