Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ void trtllm_paged_attention_launcher(
// The sparse MLA parameters.
runner_params.mSparseMla = sparse_mla_top_k > 0;
runner_params.mSparseMlaTopK = sparse_mla_top_k;
TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) ||
(head_dim_qk == 320 && head_dim_vo == 256) || sparse_mla_top_k <= 0)
<< "Only decode MLA supports sparse MLA";
bool const is_mla_decode =
(head_dim_qk == 576 && head_dim_vo == 512) || (head_dim_qk == 320 && head_dim_vo == 256);
TVM_FFI_ICHECK(is_mla_decode || sparse_mla_top_k <= 0) << "Only decode MLA supports sparse MLA";

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
if (mode == TllmPagedAttentionMode::Context) {
Expand All @@ -187,7 +187,8 @@ void trtllm_paged_attention_launcher(
// Note that kernel names are still labeled as using a dense mask even when maskType is
// specified as causal, this is expected for better performance as each CTA will only process
// one tokenQ in those cases, so dense mask works the same as causal mask.
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
runner_params.mMaskType =
is_mla_decode ? TrtllmGenAttentionMaskType::Dense : TrtllmGenAttentionMaskType::Causal;
runner_params.mKernelType = FmhaKernelType::Generation;
bool use_multi_block = true;
runner_params.mTileScheduler =
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class ArtifactPath:
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "5d4df6c2647e860992d1cc57ced05204b55f3787/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "158f6fa11ef139a098cfddcdddce73ca99d164ad/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"c21ddd11585c1eea5764927465d0be15dd957e45/batched_gemm-91e0ba0-da44fdf/"
)
Expand All @@ -157,7 +157,7 @@ class CheckSumHash:
"""

TRTLLM_GEN_FMHA: str = (
"681e69c9c0215b4780eaf92f897c2dc94285a9143b90f765085eba83af5afa5b"
"c2d9399b2537be785882354a4f9902ed6c03136c0ea341e201eac40c3923e1dc"
)
TRTLLM_GEN_BMM: str = (
"4a3ed9c3dc6547ea3eed01ebda75b0e4322f6c01fc40cd2a4978e4deaba2732a"
Expand Down
35 changes: 25 additions & 10 deletions include/flashinfer/trtllm/fmha/kernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,12 @@ struct KernelParams {
CUtensorMap tmaQ_;
// TMA descriptor for K.
CUtensorMap tmaK_;
// TMA descriptor for DSv4 sparse MLA sliding-window KV pool. Same format as tmaK_.
CUtensorMap tmaKSlidingWindowKvPool_;
// TMA descriptor for V.
CUtensorMap tmaV_;
// The descriptor for O.
CUtensorMap tmaO_;

// For FP4 KV cache, additional scaling factors are needed.
// TMA descriptor for V.
CUtensorMap tmaV_;
// TMA descriptor for output scaling factor.
CUtensorMap tmaOSf_;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new tmaOSf_ member is added to the KernelParams struct to align with the newer cubin ABI, but it is not initialized in the setKernelParams function. If the newer cubins expect a valid TMA descriptor for output scaling factors (e.g., when performing FP4 quantization on output), this will lead to undefined behavior or crashes as the descriptor will be all zeros. Please add the necessary logic in setKernelParams to build the TMA descriptor for tmaOSf_ when options.oSfPtr is provided, similar to how tmaKSf_ and tmaVSf_ are handled.

Copy link
Copy Markdown
Contributor

@PerkzZheng PerkzZheng May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djmmoss I am not quite sure why those changes are made. please add me to review next time. Thanks!
And no worries about that. I will revert them in my MR.

// TMA descriptor for K scaling factor.
CUtensorMap tmaKSf_;
// TMA descriptor for V scaling factor.
Expand Down Expand Up @@ -119,9 +117,9 @@ struct KernelParams {

// The softmax stats buffer.
float2* ptrSoftmaxStats;
// The variable sparseMla topK lengths with shape of [numTokensQ].
int32_t const* ptrSparseMlaTopKLens;

// Reserved scalar ABI state expected by newer trtllm-gen cubins.
int32_t mReservedAttentionWindowState[2]{};
// The attention window size for sliding window attention.
int32_t mAttentionWindowSize;
// The batch size
Expand Down Expand Up @@ -163,6 +161,8 @@ struct KernelParams {
int32_t mNumTokensPerCtaQ;
// The number of tokens per page (used if dynamic numTokensPerPage is enabled).
int32_t mNumTokensPerPageLog2;
// The runtime K/V TMA box reshape factor selected by host descriptor setup.
int32_t mReshapeFactorKv{};
// The output scale for FP8 quantization.
float mOutputScale;
// The scaling factor for softmax (multiplied by log2 to use faster exp2).
Expand Down Expand Up @@ -440,6 +440,19 @@ struct KernelParams {
return std::make_tuple(shape, stride);
}

// Check whether reshaping the K/V TMA box can merge consecutive token rows without changing
// which elements are loaded. This requires the token stride, in descriptor element units, to be
// exactly one descriptor head row. NHD paged-cache views fail this check because the next
// contiguous row is the next head at the same token, not the next token for the same head.
template <class FmhaOptions>
static bool canUseTmaKvReshape(FmhaOptions const& options, Data_type dtypeKv, bool isK) {
int32_t const strideKeys = std::get<0>(makeStrideKv(options, isK));
int32_t const headDim = isK ? options.mHeadDimQk : options.mHeadDimV;
int32_t const colIdxDivisor = dtypeKv == DATA_TYPE_E2M1 ? 2 : 1;
int32_t const physicalHeadDim = headDim / colIdxDivisor;
return strideKeys / colIdxDivisor == physicalHeadDim;
}

// Create the TMA shape/stride for KV scaling factors (block scales for NVFP4 KV cache).
//
// Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim // 16]
Expand Down Expand Up @@ -688,7 +701,9 @@ struct KernelParams {
bool const swizzleKv{storeTransformedKvInTmem || !transformsKv};
// Whether we can reshape the TMA box for K/V to widen it to 128B.
bool const canReshapeTmaKv{isPagedKv(options.mQkvLayout) &&
options.mHeadDimQk == options.mHeadDimV && !swizzleKv};
options.mHeadDimQk == options.mHeadDimV && !swizzleKv &&
canUseTmaKvReshape(options, kernelMeta.mDataTypeK, /*isK*/ true) &&
canUseTmaKvReshape(options, kernelMeta.mDataTypeV, /*isK*/ false)};
// The reshape factor for K/V TMA box: aim for 128B box width.
// - 128 / maxHeadDimKv: keeps first-dim tile <= 128 elts (CU_TENSOR_MAP_SWIZZLE_128B limit).
// - 128 / (maxHeadDimKv * bytesPerElt): factor needed to reach 128B box width.
Expand All @@ -702,6 +717,7 @@ struct KernelParams {
static_cast<int32_t>(get_size_in_bits(kernelMeta.mDataTypeK)) / 8),
numKeysPerTile}))
: 1};
params.mReshapeFactorKv = reshapeFactorKv;
// Shape/stride for gmem tensor Kv.
auto [shapeK, strideK] =
makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeK,
Expand Down Expand Up @@ -864,7 +880,6 @@ struct KernelParams {
params.mStartTokenIdxSfO = options.mSfStartTokenIdx;
params.mScaleSfKv = options.mScaleSfKv;
params.ptrSoftmaxStats = options.softmaxStatsPtr;
params.ptrSparseMlaTopKLens = nullptr;
// The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the
// indices.
FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0,
Expand Down
Loading