Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class ArtifactPath:
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "1d876ee612888821b168c25ffa75a9dcbb963aaa/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "5d4df6c2647e860992d1cc57ced05204b55f3787/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"c21ddd11585c1eea5764927465d0be15dd957e45/batched_gemm-91e0ba0-da44fdf/"
)
Expand All @@ -157,7 +157,7 @@ class CheckSumHash:
"""

TRTLLM_GEN_FMHA: str = (
"1abeea012a8779c6df5b84332fad43c6cfc3b257fe5ab883c8ea501464010d16"
"681e69c9c0215b4780eaf92f897c2dc94285a9143b90f765085eba83af5afa5b"
)
TRTLLM_GEN_BMM: str = (
"4a3ed9c3dc6547ea3eed01ebda75b0e4322f6c01fc40cd2a4978e4deaba2732a"
Expand Down
67 changes: 52 additions & 15 deletions include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ constexpr bool isSMCompatible(int gpuSM, int kernelSM) {
////////////////////////////////////////////////////////////////////////////////////////////////////
class TllmGenFmhaKernel {
public:
static constexpr int kDynamicNumTokensPerPageThreshold = 128;
static constexpr int kDynamicNumTokensPerPageKernelKey = 128;

// The parameters for launching the kernel.
// maxNumCtasQ, maxNumCtasKv, numCtasX, numCtasY, numCtasZ, clusterDimX
struct CtaLaunchParams {
Expand Down Expand Up @@ -160,9 +163,9 @@ class TllmGenFmhaKernel {
"Expect (32 <= headDim <= 1024), got headDimPerCtaV=%d, headDimQk=%d, "
"headDimV=%d",
headDimPerCtaV, headDimQk, headDimV);
// The numTokensPerPage must be power of 2.
FLASHINFER_CHECK((numTokensPerPage & (numTokensPerPage - 1)) == 0,
"The numTokensPerPage must be power of 2.");
// The numTokensPerPage must be 0 (unused for non-paged kernels) or power of 2.
FLASHINFER_CHECK(numTokensPerPage == 0 || ((numTokensPerPage & (numTokensPerPage - 1)) == 0),
"The numTokensPerPage must be 0 or power of 2.");
FLASHINFER_CHECK(tileSizeQ <= 128 && tileSizeKv <= 128,
"The tileSizeQ and tileSizeKv must be <= 128.");
FLASHINFER_CHECK((tileSizeQ & (tileSizeQ - 1)) == 0 && (tileSizeKv & (tileSizeKv - 1)) == 0,
Expand All @@ -184,14 +187,15 @@ class TllmGenFmhaKernel {
// Bit 54 - 54: uses2CtaMma.
// Bit 55 - 55: sparseMla.
// Bit 56 - 56: skipsSoftmax.
uint64_t const numTokensPerPageLog2 =
numTokensPerPage == 0 ? 0 : static_cast<uint64_t>(log2(numTokensPerPage));
return (static_cast<uint64_t>(qkvLayout) << 0) | (static_cast<uint64_t>(maskType) << 4) |
(static_cast<uint64_t>(kernelType) << 8) | (static_cast<uint64_t>(scheduler) << 12) |
(static_cast<uint64_t>(multiCtasKvMode) << 16) |
(static_cast<uint64_t>(headDimPerCtaV >> 3) << 18) |
(static_cast<uint64_t>(headDimQk >> 3) << 26) |
(static_cast<uint64_t>(headDimV >> 3) << 34) |
(static_cast<uint64_t>(tileSizeKv >> 6) << 42) |
(static_cast<uint64_t>(log2(numTokensPerPage)) << 44) |
(static_cast<uint64_t>(tileSizeKv >> 6) << 42) | (numTokensPerPageLog2 << 44) |
(static_cast<uint64_t>(log2(tileSizeQ)) << 49) |
(static_cast<uint64_t>(reuseSmemKForV) << 53) |
(static_cast<uint64_t>(uses2CtaMma) << 54) | (static_cast<uint64_t>(sparseMla) << 55) |
Expand Down Expand Up @@ -362,7 +366,34 @@ class TllmGenFmhaKernel {

// Is it MLA generation kernel ?
inline bool isMlaGenKernel(RunnerParams const& params) const {
return params.mHeadDimQk == 576 && params.mHeadDimV == 512;
return (params.mHeadDimQk == 576 && params.mHeadDimV == 512) ||
(params.mHeadDimQk == 320 && params.mHeadDimV == 256);
}

inline bool useDynamicNumTokensPerPage(RunnerParams const& params) const {
return isPagedKv(params.mQkvLayout) && !params.mSparseMla && params.mNumHeadsQPerKv > 1 &&
params.mHeadDimQk == params.mHeadDimV &&
params.mNumTokensPerPage >= kDynamicNumTokensPerPageThreshold;
}

void selectNumTokensPerPage(RunnerParams const& params,
SelectKernelParams& selectKernelParams) const {
selectKernelParams.mDynamicNumTokensPerPage = false;
if (params.mSparseMla) {
// SparseMla kernels use a fixed numTokensPerPage = 1.
selectKernelParams.mNumTokensPerPage = 1;
} else if (!isPagedKv(params.mQkvLayout)) {
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
selectKernelParams.mNumTokensPerPage = 0;
} else if (useDynamicNumTokensPerPage(params)) {
FLASHINFER_CHECK((params.mNumTokensPerPage & (params.mNumTokensPerPage - 1)) == 0,
"Dynamic numTokensPerPage requires a power-of-2 page size, got %d.",
params.mNumTokensPerPage);
selectKernelParams.mDynamicNumTokensPerPage = true;
selectKernelParams.mNumTokensPerPage = kDynamicNumTokensPerPageKernelKey;
} else {
selectKernelParams.mNumTokensPerPage = params.mNumTokensPerPage;
}
}

// Compute the number of CTAs in X, Y and Z dimension and the cluster size in the X dimension.
Expand Down Expand Up @@ -813,9 +844,22 @@ class TllmGenFmhaKernel {

// Select a kernel based on the heuristic.
void selectKernel(RunnerParams const& params, SelectKernelParams& selectKernelParams) const {
// Normalize this before heuristic probing; some GQA-generation heuristics load candidate
// kernels while selecting tileSizeQ.
selectNumTokensPerPage(params, selectKernelParams);
bool const isMlaGeneration = isGenerationKernel(params.mKernelType) && isMlaGenKernel(params);

// Select the kernel based on the kernel type.
if (isGenerationKernel(params.mKernelType) && isMlaGenKernel(params)) {
if (isMlaGeneration) {
selectMlaGenerationKernel(params, selectKernelParams);
// TRTLLM-GEN MLA generation kernels are exported with dense mask metadata. Each generation
// CTA processes a bounded query tile, so the runtime sequence lengths/indices provide the
// effective masking.
selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::Dense;
Comment thread
PerkzZheng marked this conversation as resolved.
FLASHINFER_CHECK(
params.mMaxSeqLenKv <= params.mAttentionWindowSize &&
params.mChunkedAttentionSize == INT_MAX,
"TRTLLM-GEN MLA generation does not support sliding-window or chunked attention.");
} else if (isGenerationKernel(params.mKernelType)) {
selectGqGenerationKernel(params, selectKernelParams);
}
Expand All @@ -841,14 +885,6 @@ class TllmGenFmhaKernel {
"Sliding window attention and chunked attention should not be used together");
selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
}

// SparseMla kernels use a fixed numTokensPerPage = 1.
if (params.mSparseMla) {
selectKernelParams.mNumTokensPerPage = 1;
} else if (!isPagedKv(params.mQkvLayout)) {
// NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels.
selectKernelParams.mNumTokensPerPage = 0;
}
}

std::pair<uint64_t, std::string> hashFromRunnerParams(
Expand All @@ -867,6 +903,7 @@ class TllmGenFmhaKernel {
", tileSizeQ=" + std::to_string(selectKernelParams.mTileSizeQ) +
", tileSizeKv=" + std::to_string(selectKernelParams.mTileSizeKv) +
", numTokensPerPage=" + std::to_string(selectKernelParams.mNumTokensPerPage) +
", dynamicNumTokensPerPage=" + std::to_string(selectKernelParams.mDynamicNumTokensPerPage) +
", reuseSmemKForV=" + std::to_string(selectKernelParams.mReuseSmemKForV) +
", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma) +
", sparseMla=" + std::to_string(params.mSparseMla) +
Expand Down
3 changes: 3 additions & 0 deletions include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ struct TllmGenSelectKernelParams {
TrtllmGenAttentionMaskType mMaskType;
// The number of tokens per page.
int mNumTokensPerPage;
// Whether a dynamic tokens-per-page cubin is selected.
bool mDynamicNumTokensPerPage;
// Reuse smemK for V or not (only work with MLA generation kernels).
bool mReuseSmemKForV;
// Do we need to select a new kernel as the parameters have been updated.
Expand All @@ -398,6 +400,7 @@ struct TllmGenSelectKernelParams {
mForceGmemReduction(false),
mMaskType(params.mMaskType),
mNumTokensPerPage(params.mNumTokensPerPage),
mDynamicNumTokensPerPage(false),
mReuseSmemKForV(false),
mSelectNewKernel(false),
mSkipsSoftmaxWhenPossible(params.mSkipsSoftmaxWhenPossible),
Expand Down
5 changes: 5 additions & 0 deletions include/flashinfer/trtllm/fmha/kernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct KernelParams {
CUtensorMap tmaQ_;
// TMA descriptor for K.
CUtensorMap tmaK_;
// TMA descriptor for DSv4 sparse MLA sliding-window KV pool. Same format as tmaK_.
CUtensorMap tmaKSlidingWindowKvPool_;
Comment on lines +47 to +48
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find all sites that set tmaKSlidingWindowKvPool_ to verify caller-side initialization.
rg -n "tmaKSlidingWindowKvPool_" --type cpp --type h -C 4

Repository: flashinfer-ai/flashinfer

Length of output: 800


🏁 Script executed:

# Find the setKernelParams function and buildNdTmaDescriptor calls
rg -n "setKernelParams|buildNdTmaDescriptor" --type cpp --type h -A 3 -B 1

Repository: flashinfer-ai/flashinfer

Length of output: 9188


🏁 Script executed:

# Check the memset call around line 639 in kernelParams.h
head -660 include/flashinfer/trtllm/fmha/kernelParams.h | tail -50

Repository: flashinfer-ai/flashinfer

Length of output: 2785


🏁 Script executed:

# Search for where tmaKSlidingWindowKvPool_ is actually used/read
rg -n "tmaKSlidingWindowKvPool_" --type cpp --type cu

Repository: flashinfer-ai/flashinfer

Length of output: 95


🏁 Script executed:

# Read the full setKernelParams function (roughly lines 634-850)
sed -n '634,850p' include/flashinfer/trtllm/fmha/kernelParams.h

Repository: flashinfer-ai/flashinfer

Length of output: 11415


🏁 Script executed:

# Search for any usage/reads of tmaKSlidingWindowKvPool_ without file type restriction
rg "tmaKSlidingWindowKvPool_" -C 4

Repository: flashinfer-ai/flashinfer

Length of output: 773


🏁 Script executed:

# Look for sparse MLA related code to understand if this field is relevant
rg "sparse.*MLA|SlidingWindow" -C 2

Repository: flashinfer-ai/flashinfer

Length of output: 6475


🏁 Script executed:

# Search for tmaKSlidingWindowKvPool_ in all files, focusing on kernel code
find . -type f \( -name "*.cuh" -o -name "*.cu" \) -exec grep -l "tmaKSlidingWindowKvPool_" {} \;

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Also search in any generated or header files for kernel access
rg "tmaKSlidingWindowKvPool_" . --type-list

Repository: flashinfer-ai/flashinfer

Length of output: 458


🏁 Script executed:

# Look for the sparse MLA kernel to see if it uses this field
rg -n "sparse.*kernel|slidingWindow" --type-list 2>/dev/null | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Search for kernel code more broadly - look for .cuh and .cu files
find . -type f \( -name "*.cuh" -o -name "*.cu" \) | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 668


🏁 Script executed:

# Search all files for any reference to this field with broader patterns
rg "SlidingWindowKvPool" .

Repository: flashinfer-ai/flashinfer

Length of output: 156


🏁 Script executed:

# Check git history to see if this field was recently added
git log --oneline --all -- "include/flashinfer/trtllm/fmha/kernelParams.h" | head -10

Repository: flashinfer-ai/flashinfer

Length of output: 842


🏁 Script executed:

# See the actual commit that added this field
git log -p --all -S "tmaKSlidingWindowKvPool_" -- "include/flashinfer/trtllm/fmha/kernelParams.h" | head -100

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Look at recent commits to kernelParams.h to find when this field was added
git log -p --all -- "include/flashinfer/trtllm/fmha/kernelParams.h" | grep -B 10 -A 10 "tmaKSlidingWindowKvPool_" | head -50

Repository: flashinfer-ai/flashinfer

Length of output: 50


🏁 Script executed:

# Also check if there are any TODOs or comments about this field being unused
rg -B 5 -A 5 "tmaKSlidingWindowKvPool_"

Repository: flashinfer-ai/flashinfer

Length of output: 893


🏁 Script executed:

# Search for any test or reference that might use sparse MLA with sliding window
rg "sparse.*sliding|slidingWindow.*sparse" -C 3

Repository: flashinfer-ai/flashinfer

Length of output: 633


🏁 Script executed:

# Check the fmhaKernels.cuh file that handles sparse MLA selection
cat csrc/trtllm_fmha_kernel_launcher.cu | grep -A 20 -B 5 "sparse_mla"

Repository: flashinfer-ai/flashinfer

Length of output: 8495


Remove tmaKSlidingWindowKvPool_ or implement its initialization for sparse MLA sliding-window support.

The field is declared but never populated via buildNdTmaDescriptor and is unused throughout the codebase. Either remove it as dead code, or if intended for future sparse MLA sliding-window support, add the corresponding buildNdTmaDescriptor call inside setKernelParams (guarded by the sparse-MLA path, similar to how tmaK_ and tmaV_ are constructed).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@include/flashinfer/trtllm/fmha/kernelParams.h` around lines 47 - 48, The
field CUtensorMap tmaKSlidingWindowKvPool_ is declared but never initialized or
used; either remove this dead member or initialize it in setKernelParams by
calling buildNdTmaDescriptor in the same sparse-MLA conditional where tmaK_ and
tmaV_ are built (mirror their construction/guards), ensuring the descriptor
parameters match the DSv4 sparse MLA sliding-window KV pool layout if you intend
to support that path; update any related cleanup/usage accordingly.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I'm missing something, but what's the reason for this never actually being initialized?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for any TODO/FIXME near this field or sparse MLA sliding window
rg -n "TODO|FIXME|tmaKSlidingWindowKvPool_|SlidingWindowKvPool|sparse.*sliding|sliding.*sparse" -C 3

# Also check if the cubin kernels might reference struct offsets that imply this field must exist
rg -n "sparseMla|sparse_mla|mSparseMla" --type cpp --type h -C 3

# Check the full sparse MLA block in setKernelParams
grep -n "sparseMla\|SlidingWindow" include/flashinfer/trtllm/fmha/kernelParams.h

Length of output: 210871


@saltyminty Great question! The root cause is binary ABI compatibility with the precompiled cubins.

KernelParams is the struct passed directly to the TRTLLM-GEN GPU kernels, so its memory layout (field offsets) must exactly match what the precompiled cubins expect. This PR updates the cubin artifacts to a new version (as noted in the PR description), and the new cubins appear to include tmaKSlidingWindowKvPool_ in their expected KernelParams layout. If the field were absent from the host-side struct, every field declared after it would be at the wrong byte offset, causing silent data corruption in the kernel.

The memset zero-initialization acts as a safe "disabled" sentinel — the kernel can check whether the descriptor is valid before using it, and a zeroed CUtensorMap signals that this feature path is inactive.

You can see the same pattern with ptrSparseMlaTopKLens (line 122): it's declared, zero-initialized by memset, and explicitly pinned to nullptr at line 867 — all without a full host-side implementation. Both fields are effectively forward-ported struct layout stubs that maintain cubin ABI compatibility while the actual initialization logic is deferred to a follow-up.

So to summarize: the field must exist to keep the struct binary-compatible with the new cubins, and the zero-initialization from memset is the correct default for the "feature not active" case.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is actually used for a new kind of kernels, see #3269.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

// TMA descriptor for V.
CUtensorMap tmaV_;
// The descriptor for O.
Expand Down Expand Up @@ -117,6 +119,8 @@ struct KernelParams {

// The softmax stats buffer.
float2* ptrSoftmaxStats;
// The variable sparseMla topK lengths with shape of [numTokensQ].
int32_t const* ptrSparseMlaTopKLens;

// The attention window size for sliding window attention.
int32_t mAttentionWindowSize;
Expand Down Expand Up @@ -860,6 +864,7 @@ struct KernelParams {
params.mStartTokenIdxSfO = options.mSfStartTokenIdx;
params.mScaleSfKv = options.mScaleSfKv;
params.ptrSoftmaxStats = options.softmaxStatsPtr;
params.ptrSparseMlaTopKLens = nullptr;
// The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the
// indices.
FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0,
Expand Down
2 changes: 1 addition & 1 deletion tests/attention/test_attention_sink_blackwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_blackwell_trtllm_gen_context_attention_sink(
)

if dtype == torch.float16:
atol, rtol = 1e-3, 1e-3
atol, rtol = 2e-3, 1e-3
elif dtype == torch.bfloat16:
atol, rtol = 1e-2, 1e-2
else:
Expand Down
67 changes: 67 additions & 0 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
workspace_size = 256 * 1024 * 1024


def _skip_if_not_blackwell() -> None:
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] != 10:
pytest.skip(
"Dynamic tokensPerPage tests require SM100 or SM103 Blackwell GPUs."
)


def flip_coin(*args, **kwargs):
# Use any test parameters to deterministically decide branch
# This makes test configurations go through different paths
Expand Down Expand Up @@ -1015,6 +1023,34 @@ def test_trtllm_batch_prefill_bs1(
)


@pytest.mark.parametrize("page_size", [128, 256, 512, 1024])
@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False])
def test_trtllm_batch_prefill_dynamic_page_size_gqa(
page_size: int,
uses_shared_paged_kv_idx: bool,
) -> None:
_skip_if_not_blackwell()
_test_trtllm_batch_prefill(
"HND",
batch_size=4,
page_size=page_size,
num_kv_heads=2,
head_grp_size=5,
causal=True,
window_left=-1,
q_dtype="bf16",
o_dtype="bf16",
kv_dtype="bf16",
enable_pdl=None,
enable_sink=False,
max_q_len=257,
max_kv_len=1024,
device_scale=False,
head_dim=128,
uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
)


def _test_trtllm_batch_decode(
backend: str,
kv_layout: str,
Expand Down Expand Up @@ -1705,6 +1741,37 @@ def test_trtllm_batch_decode_long_sequence_length(
)


@pytest.mark.parametrize("page_size", [128, 256, 512, 1024])
@pytest.mark.parametrize("q_len_per_req", [1, 2])
@pytest.mark.parametrize("window_left", [-1, 127])
@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False])
def test_trtllm_batch_decode_dynamic_page_size_gqa(
page_size: int,
q_len_per_req: int,
window_left: int,
uses_shared_paged_kv_idx: bool,
) -> None:
_skip_if_not_blackwell()
_test_trtllm_batch_decode(
"trtllm-gen",
"HND",
batch_size=4,
q_len_per_req=q_len_per_req,
page_size=page_size,
num_kv_heads=2,
head_grp_size=5,
window_left=window_left,
q_dtype="bf16",
o_dtype="bf16",
kv_dtype="bf16",
enable_pdl=None,
enable_sink=False,
max_in_kv_len=1024,
head_dim=128,
uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
)


@pytest.mark.parametrize("kv_layout", ["HND", "NHD"])
@pytest.mark.parametrize(
"batch_size,page_size,num_kv_heads,head_grp_size",
Expand Down
Loading