From a6b90879d8523828682b04583d6a142aed0a27c2 Mon Sep 17 00:00:00 2001 From: Duncan Moss Date: Wed, 13 May 2026 22:16:16 +0000 Subject: [PATCH] Update trtllm FMHA public cubins Point FMHA at the newer public trtllm cubin publish, align the FMHA parameter ABI, and use dense mask selection for MLA decode kernels. --- csrc/trtllm_fmha_kernel_launcher.cu | 9 ++--- flashinfer/artifacts.py | 4 +-- include/flashinfer/trtllm/fmha/kernelParams.h | 35 +++++++++++++------ 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index cadb51dee3..fbb7db5d9e 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -168,9 +168,9 @@ void trtllm_paged_attention_launcher( // The sparse MLA parameters. runner_params.mSparseMla = sparse_mla_top_k > 0; runner_params.mSparseMlaTopK = sparse_mla_top_k; - TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || - (head_dim_qk == 320 && head_dim_vo == 256) || sparse_mla_top_k <= 0) - << "Only decode MLA supports sparse MLA"; + bool const is_mla_decode = + (head_dim_qk == 576 && head_dim_vo == 512) || (head_dim_qk == 320 && head_dim_vo == 256); + TVM_FFI_ICHECK(is_mla_decode || sparse_mla_top_k <= 0) << "Only decode MLA supports sparse MLA"; AlignedAllocator float_allocator(workspace_buffer, workspace_size); if (mode == TllmPagedAttentionMode::Context) { @@ -187,7 +187,8 @@ void trtllm_paged_attention_launcher( // Note that kernel names are still labeled as using a dense mask even when maskType is // specified as causal, this is expected for better performance as each CTA will only process // one tokenQ in those cases, so dense mask works the same as causal mask. - runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal; + runner_params.mMaskType = + is_mla_decode ? TrtllmGenAttentionMaskType::Dense : TrtllmGenAttentionMaskType::Causal; runner_params.mKernelType = FmhaKernelType::Generation; bool use_multi_block = true; runner_params.mTileScheduler = diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index ad1f861bd9..d43a7e6cb6 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -135,7 +135,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "5d4df6c2647e860992d1cc57ced05204b55f3787/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "158f6fa11ef139a098cfddcdddce73ca99d164ad/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "c21ddd11585c1eea5764927465d0be15dd957e45/batched_gemm-91e0ba0-da44fdf/" ) @@ -157,7 +157,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "681e69c9c0215b4780eaf92f897c2dc94285a9143b90f765085eba83af5afa5b" + "c2d9399b2537be785882354a4f9902ed6c03136c0ea341e201eac40c3923e1dc" ) TRTLLM_GEN_BMM: str = ( "4a3ed9c3dc6547ea3eed01ebda75b0e4322f6c01fc40cd2a4978e4deaba2732a" diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 0f41f6c007..7067243f29 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -44,14 +44,12 @@ struct KernelParams { CUtensorMap tmaQ_; // TMA descriptor for K. CUtensorMap tmaK_; - // TMA descriptor for DSv4 sparse MLA sliding-window KV pool. Same format as tmaK_. - CUtensorMap tmaKSlidingWindowKvPool_; - // TMA descriptor for V. - CUtensorMap tmaV_; // The descriptor for O. CUtensorMap tmaO_; - - // For FP4 KV cache, additional scaling factors are needed. + // TMA descriptor for V. + CUtensorMap tmaV_; + // TMA descriptor for output scaling factor. + CUtensorMap tmaOSf_; // TMA descriptor for K scaling factor. CUtensorMap tmaKSf_; // TMA descriptor for V scaling factor. @@ -119,9 +117,9 @@ struct KernelParams { // The softmax stats buffer. float2* ptrSoftmaxStats; - // The variable sparseMla topK lengths with shape of [numTokensQ]. - int32_t const* ptrSparseMlaTopKLens; + // Reserved scalar ABI state expected by newer trtllm-gen cubins. + int32_t mReservedAttentionWindowState[2]{}; // The attention window size for sliding window attention. int32_t mAttentionWindowSize; // The batch size @@ -163,6 +161,8 @@ struct KernelParams { int32_t mNumTokensPerCtaQ; // The number of tokens per page (used if dynamic numTokensPerPage is enabled). int32_t mNumTokensPerPageLog2; + // The runtime K/V TMA box reshape factor selected by host descriptor setup. + int32_t mReshapeFactorKv{}; // The output scale for FP8 quantization. float mOutputScale; // The scaling factor for softmax (multiplied by log2 to use faster exp2). @@ -440,6 +440,19 @@ struct KernelParams { return std::make_tuple(shape, stride); } + // Check whether reshaping the K/V TMA box can merge consecutive token rows without changing + // which elements are loaded. This requires the token stride, in descriptor element units, to be + // exactly one descriptor head row. NHD paged-cache views fail this check because the next + // contiguous row is the next head at the same token, not the next token for the same head. + template + static bool canUseTmaKvReshape(FmhaOptions const& options, Data_type dtypeKv, bool isK) { + int32_t const strideKeys = std::get<0>(makeStrideKv(options, isK)); + int32_t const headDim = isK ? options.mHeadDimQk : options.mHeadDimV; + int32_t const colIdxDivisor = dtypeKv == DATA_TYPE_E2M1 ? 2 : 1; + int32_t const physicalHeadDim = headDim / colIdxDivisor; + return strideKeys / colIdxDivisor == physicalHeadDim; + } + // Create the TMA shape/stride for KV scaling factors (block scales for NVFP4 KV cache). // // Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim // 16] @@ -688,7 +701,9 @@ struct KernelParams { bool const swizzleKv{storeTransformedKvInTmem || !transformsKv}; // Whether we can reshape the TMA box for K/V to widen it to 128B. bool const canReshapeTmaKv{isPagedKv(options.mQkvLayout) && - options.mHeadDimQk == options.mHeadDimV && !swizzleKv}; + options.mHeadDimQk == options.mHeadDimV && !swizzleKv && + canUseTmaKvReshape(options, kernelMeta.mDataTypeK, /*isK*/ true) && + canUseTmaKvReshape(options, kernelMeta.mDataTypeV, /*isK*/ false)}; // The reshape factor for K/V TMA box: aim for 128B box width. // - 128 / maxHeadDimKv: keeps first-dim tile <= 128 elts (CU_TENSOR_MAP_SWIZZLE_128B limit). // - 128 / (maxHeadDimKv * bytesPerElt): factor needed to reach 128B box width. @@ -702,6 +717,7 @@ struct KernelParams { static_cast(get_size_in_bits(kernelMeta.mDataTypeK)) / 8), numKeysPerTile})) : 1}; + params.mReshapeFactorKv = reshapeFactorKv; // Shape/stride for gmem tensor Kv. auto [shapeK, strideK] = makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeK, @@ -864,7 +880,6 @@ struct KernelParams { params.mStartTokenIdxSfO = options.mSfStartTokenIdx; params.mScaleSfKv = options.mScaleSfKv; params.ptrSoftmaxStats = options.softmaxStatsPtr; - params.ptrSparseMlaTopKLens = nullptr; // The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the // indices. FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0,