From c5b673a637271dd7d258f91a9fe85da8d7905294 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Wed, 24 Dec 2025 14:31:40 +0000 Subject: [PATCH 1/6] update trtllm-gen to support groups tokens and headsQ --- csrc/trtllm_fmha_kernel_launcher.cu | 7 +- flashinfer/artifacts.py | 4 +- .../flashinfer/trtllm/fmha/fmhaKernels.cuh | 539 ++++++++++++------ .../flashinfer/trtllm/fmha/fmhaRunnerParams.h | 6 + include/flashinfer/trtllm/fmha/kernelParams.h | 78 ++- tests/attention/test_trtllm_gen_attention.py | 3 + 6 files changed, 456 insertions(+), 181 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 4daa03ef20..3d5e8956e8 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -159,8 +159,11 @@ void trtllm_paged_attention_launcher( runner_params.cumSeqLensQPtr = cum_seq_lens_q; runner_params.cumSeqLensKvPtr = cum_seq_lens_kv; } else { - // ForGen - runner_params.mMaskType = TrtllmGenAttentionMaskType::Dense; + // Generation. + // Note that kernel names are still labeled as using a dense mask even when maskType is + // specified as causal, this is expected for better performance as each CTA will only process + // one tokenQ in those cases, so dense mask works the same as causal mask. + runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal; runner_params.mKernelType = FmhaKernelType::Generation; bool use_multi_block = true; runner_params.mTileScheduler = diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 717524bc9e..f9f1f719ab 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -87,7 +87,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "9f1b6ddaa1592a8339a82fcab7d27a57eff445fd/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "81d3504ccf84d3ea0ff2ff4e2b15df2b63fb4160/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841" ) @@ -107,7 +107,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "a5a60600a80076317703695f56bbef2f0a44075ef4e24d7b06ba67ff68bc9da2" + "376d4de5a1bbb2a651bfd3c11d62cd55a0fe919c4669671675fc80c9934cd845" ) TRTLLM_GEN_BMM: str = ( "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index 7fb695ed6d..9b9619af2f 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -51,6 +51,16 @@ std::string getCubin(const std::string& kernelName, const std::string& sha256); } // namespace flashinfer::trtllm_cubin_loader using flashinfer::trtllm_cubin_loader::getCubin; +// Check if two SM values are family/specific versions of the same architecture +// Returns true only if one is a family version and the other is a compatible specific version +constexpr bool isFamilySpecificSMPair(int sm1, int sm2) { + if ((sm1 == kSM_100f && (sm2 == kSM_100 || sm2 == kSM_103)) || + (sm2 == kSM_100f && (sm1 == kSM_100 || sm1 == kSM_103))) { + return true; + } + return false; +} + constexpr bool isSMCompatible(int gpuSM, int kernelSM) { if (gpuSM == kSM_103) { return kernelSM == kSM_100f || kernelSM == kSM_103; @@ -63,6 +73,24 @@ constexpr bool isSMCompatible(int gpuSM, int kernelSM) { //////////////////////////////////////////////////////////////////////////////////////////////////// class TllmGenFmhaKernel { + public: + // The parameters for launching the kernel. + // maxNumCtasQ, maxNumCtasKv, numCtasX, numCtasY, numCtasZ, clusterDimX + struct CtaLaunchParams { + // The maximum number of CTAs in Q dimension. + int mMaxNumCtasQ; + // The maximum number of CTAs in Kv dimension. + int mMaxNumCtasKv; + // The number of CTAs in X dimension. + int mNumCtasX; + // The number of CTAs in Y dimension. + int mNumCtasY; + // The number of CTAs in Z dimension. + int mNumCtasZ; + // The cluster size in the X dimension. + int mClusterDimX; + }; + public: using KernelMeta = tensorrt_llm::kernels::TllmGenFmhaKernelMetaInfo; using RunnerParams = TllmGenFmhaRunnerParams; @@ -86,7 +114,23 @@ class TllmGenFmhaKernel { kernelMeta.mDataTypeKv == mDtypeKv && kernelMeta.mDataTypeO == mDtypeOut) { // Store metadata for later use. IKL_LOG_DEBUG("Adding tllmgen attention kernel %s", kernelMeta.mFuncName); - mKernelMetaMap[hashID(kernelMeta)] = i; + // Check for hash conflicts. + uint64_t hash = hashID(kernelMeta); + if (mKernelMetaMap.find(hash) != mKernelMetaMap.end()) { + // The kernelMeta of the existing kernel. + auto const& existingKernelMeta = mKernelMeta[mKernelMetaMap.at(hash)]; + // Allow conflicts only if they are family/specific versions of the same architecture. + FLASHINFER_CHECK(isFamilySpecificSMPair(existingKernelMeta.mSM, kernelMeta.mSM), + "Hash conflicts exist between %s and %s.", existingKernelMeta.mFuncName, + kernelMeta.mFuncName); + + // Prefer specific SM version over family version (replace if existing is family). + if (existingKernelMeta.mSM == kSM_100f) { + mKernelMetaMap[hash] = i; + } + } else { + mKernelMetaMap[hash] = i; + } } } } @@ -95,8 +139,8 @@ class TllmGenFmhaKernel { inline uint64_t hashID(int qkvLayout, int maskType, int kernelType, int scheduler, int multiCtasKvMode, int headDimPerCtaV, int headDimQk, int headDimV, - int tileSizeKv, int numTokensPerPage, int maxNumHeadsQPerKvInCta, - bool reuseSmemKForV, bool uses2CtaMma, bool sparseMla) const { + int tileSizeQ, int tileSizeKv, int numTokensPerPage, bool reuseSmemKForV, + bool uses2CtaMma, bool sparseMla) const { FLASHINFER_CHECK((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) && (headDimPerCtaV <= 1024) && (headDimQk <= 1024) && (headDimV <= 1024), "Expect (32 <= headDim <= 1024), got headDimPerCtaV=%d, headDimQk=%d, " @@ -105,8 +149,10 @@ class TllmGenFmhaKernel { // The numTokensPerPage must be power of 2. FLASHINFER_CHECK((numTokensPerPage & (numTokensPerPage - 1)) == 0, "The numTokensPerPage must be power of 2."); - FLASHINFER_CHECK(maxNumHeadsQPerKvInCta <= 128, - "The maxNumHeadsQPerKvInCta <= 128 is required."); + FLASHINFER_CHECK(tileSizeQ <= 128 && tileSizeKv <= 128, + "The tileSizeQ and tileSizeKv must be <= 128."); + FLASHINFER_CHECK((tileSizeQ & (tileSizeQ - 1)) == 0 && (tileSizeKv & (tileSizeKv - 1)) == 0, + "The tileSizeQ and tileSizeKv must be power of 2."); FLASHINFER_CHECK(tileSizeKv == 64 || tileSizeKv == 128, "The tileSizeKv must be 64 or 128."); // Format of the hash key: // Bit 0 - 3 : qkvLayout. @@ -119,10 +165,10 @@ class TllmGenFmhaKernel { // Bit 34 - 41: (headDimV >> 3). // Bit 42 - 43: (tileSizeKv >> 6). // Bit 44 - 48: (log2(numTokensPerPage)). - // Bit 49 - 56: maxNumHeadsQPerKvInCta. - // Bit 57 - 57: reuseSmemKForV. - // Bit 58 - 58: uses2CtaMma. - // Bit 59 - 59: sparseMla. + // Bit 49 - 52: (log2(tileSizeQ)). + // Bit 53 - 53: reuseSmemKForV. + // Bit 54 - 54: uses2CtaMma. + // Bit 55 - 55: sparseMla. return (static_cast(qkvLayout) << 0) | (static_cast(maskType) << 4) | (static_cast(kernelType) << 8) | (static_cast(scheduler) << 12) | (static_cast(multiCtasKvMode) << 16) | @@ -131,30 +177,35 @@ class TllmGenFmhaKernel { (static_cast(headDimV >> 3) << 34) | (static_cast(tileSizeKv >> 6) << 42) | (static_cast(log2(numTokensPerPage)) << 44) | - (static_cast(maxNumHeadsQPerKvInCta) << 49) | - (static_cast(reuseSmemKForV) << 57) | - (static_cast(uses2CtaMma) << 58) | (static_cast(sparseMla) << 59); + (static_cast(log2(tileSizeQ)) << 49) | + (static_cast(reuseSmemKForV) << 53) | + (static_cast(uses2CtaMma) << 54) | (static_cast(sparseMla) << 55); } uint64_t hashID(KernelMeta const& kernelMeta) const { return hashID(kernelMeta.mQkvLayout, kernelMeta.mMaskType, kernelMeta.mKernelType, kernelMeta.mTileScheduler, kernelMeta.mMultiCtasKvMode, kernelMeta.mHeadDimPerCtaV, kernelMeta.mHeadDimQk, kernelMeta.mHeadDimV, - kernelMeta.mTileSizeKv, kernelMeta.mNumTokensPerPage, - kernelMeta.mMaxNumHeadsQPerKvInCta, kernelMeta.mReuseSmemKForV, - kernelMeta.m2CtaMma, kernelMeta.mSparseMla); + kernelMeta.mTileSizeQ, kernelMeta.mTileSizeKv, kernelMeta.mNumTokensPerPage, + kernelMeta.mReuseSmemKForV, kernelMeta.m2CtaMma, kernelMeta.mSparseMla); } std::pair checkIfKernelExist(RunnerParams const& params) const { // The selectKernelParams that might be updated. SelectKernelParams selectKernelParams{params}; + // Select the kernel. + selectKernel(params, selectKernelParams); + // Hash the runner params. auto [hashId, info] = hashFromRunnerParams(params, selectKernelParams); return std::make_pair(mKernelMetaMap.find(hashId) != mKernelMetaMap.end(), info); } + // start here void run(RunnerParams const& params) const { // The selectKernelParams that might be updated. SelectKernelParams selectKernelParams{params}; + // The parameters for launching the kernel. + CtaLaunchParams ctaLaunchParams; // The iteration index (used to detect a deadlock of selecting new kernels). int selectKernelIter = 0; // While loop. @@ -163,27 +214,15 @@ class TllmGenFmhaKernel { // might have more complicated heuristic in the future. FLASHINFER_CHECK(selectKernelIter < 8, "A deadlock is detected when selecting trtllm-gen kernels."); - auto [hashId, info] = hashFromRunnerParams(params, selectKernelParams); - auto const findMetaIter = mKernelMetaMap.find(hashId); - - // Add debug info when kernels are not found. - FLASHINFER_CHECK(findMetaIter != mKernelMetaMap.end(), - "Trtllm-gen kernels not found: " + info); - - // auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; - auto const findFuncIter = mFunctions.find(hashId); - if (findFuncIter == mFunctions.end()) { - // Load the kernel on-demand. - loadKernel(hashId, findMetaIter->second); - } - // Retrieve the loaded kernel. - auto const& kernelInfo = mFunctions.at(hashId); - auto const& kernelMeta = mKernelMeta[kernelInfo.mMetaInfoIndex]; - CUfunction func = kernelInfo.mDeviceFunction; + + // Select the kernel. + selectKernel(params, selectKernelParams); + // Load the kernel. + auto [func, kernelMeta] = loadKernel(params, selectKernelParams); // Compute the number of CTAs in X, Y and Z dimension and the cluster size in the X dimension. - auto [maxNumCtasQ, maxNumCtasKv, numCtasX, numCtasY, numCtasZ, clusterDimX] = - computeCtaAndClusterConfig(params, kernelMeta, selectKernelParams); + computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParams); + // Need to select a new kernel if mSelectNewKernel is true. if (selectKernelParams.mSelectNewKernel) { selectKernelIter++; @@ -191,8 +230,8 @@ class TllmGenFmhaKernel { } // Prepare the kernel parameters. - auto kernelParams = - KernelParams::setKernelParams(params, kernelMeta, maxNumCtasQ, maxNumCtasKv); + auto kernelParams = KernelParams::setKernelParams( + params, kernelMeta, ctaLaunchParams.mMaxNumCtasQ, ctaLaunchParams.mMaxNumCtasKv); // Prepare kernel parameters list for cuLaunchKernelEx. void* kernelParamsList[] = {&kernelParams}; @@ -200,9 +239,9 @@ class TllmGenFmhaKernel { launch_config.blockDimX = kernelMeta.mThreadsPerCTA; launch_config.blockDimY = 1; launch_config.blockDimZ = 1; - launch_config.gridDimX = numCtasX; - launch_config.gridDimY = numCtasY; - launch_config.gridDimZ = numCtasZ; + launch_config.gridDimX = ctaLaunchParams.mNumCtasX; + launch_config.gridDimY = ctaLaunchParams.mNumCtasY; + launch_config.gridDimZ = ctaLaunchParams.mNumCtasZ; launch_config.hStream = params.stream; launch_config.sharedMemBytes = kernelMeta.mSharedMemBytes; @@ -218,24 +257,25 @@ class TllmGenFmhaKernel { params.mBatchSize, static_cast(params.mKernelType)); IKL_LOG_DEBUG( "TRTLLM-Gen launch info: numCtasX = %d, numCtasY = %d, numCtasZ = %d, clusterDimX = %d", - numCtasX, numCtasY, numCtasZ, clusterDimX); + ctaLaunchParams.mNumCtasX, ctaLaunchParams.mNumCtasY, ctaLaunchParams.mNumCtasZ, + ctaLaunchParams.mClusterDimX); CUlaunchAttribute launch_attribute[3]; launch_attribute[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - launch_attribute[0].value.clusterDim.x = clusterDimX; + launch_attribute[0].value.clusterDim.x = ctaLaunchParams.mClusterDimX; launch_attribute[0].value.clusterDim.y = 1; launch_attribute[0].value.clusterDim.z = 1; launch_attribute[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE; launch_attribute[1].value.clusterSchedulingPolicyPreference = - clusterDimX > 1 ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD - : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; + ctaLaunchParams.mClusterDimX > 1 ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD + : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; launch_attribute[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; launch_attribute[2].value.programmaticStreamSerializationAllowed = params.enable_pdl; launch_config.attrs = launch_attribute; launch_config.numAttrs = 3; // Add setting for non-portable cluster size. - if (clusterDimX > 8) { + if (ctaLaunchParams.mClusterDimX > 8) { cuErrCheck(cuFuncSetAttribute(func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1 // Enable non-portable cluster sizes )); @@ -249,7 +289,8 @@ class TllmGenFmhaKernel { int maxActiveClusters = 1; cuErrCheck(cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &launch_config)); // Use the GmemReduction instead if it needs more than one wave. - if (maxActiveClusters * clusterDimX < (numCtasX * numCtasY * numCtasZ)) { + if (maxActiveClusters * ctaLaunchParams.mClusterDimX < + (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ)) { selectKernelParams.mForceGmemReduction = true; selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReduction; // continue to select a new kernel. @@ -279,11 +320,9 @@ class TllmGenFmhaKernel { } // Compute the number of CTAs in X, Y and Z dimension and the cluster size in the X dimension. - using CtaClusterInfo = std::tuple; - - CtaClusterInfo computeCtaAndClusterConfig(RunnerParams const& params, - KernelMeta const& kernelMeta, - SelectKernelParams& selectKernelParams) const { + void computeCtaAndClusterConfig(CtaLaunchParams& ctaLaunchParams, RunnerParams const& params, + KernelMeta const& kernelMeta, + SelectKernelParams& selectKernelParams) const { bool isDsv3MinLatencyMode = params.mBatchSize == 1 && params.mMaxSeqLenQ >= 1 && params.mMaxSeqLenQ <= 16 && params.mHeadDimQk == 576 && params.mHeadDimV == 512; @@ -292,17 +331,23 @@ class TllmGenFmhaKernel { // The number of Ctas per Q sequence. int numCtasPerSeqQ = (params.mMaxSeqLenQ + kernelMeta.mStepQ - 1) / kernelMeta.mStepQ; - // Each CTA handles one tokenQ by default for spec-decoding generation kernel, which is used to - // emulate causal masking (like MTP or Eagle3). Note this will be changed later when the - // high-throughput spec-decoding generation kernels are integrated. + // The generation-phase kernels might need to group both tokensQ and headsQ into one CTA. if (params.mMaxSeqLenQ > 1 && !isContextKernel(params.mKernelType)) { - numCtasPerSeqQ = params.mMaxSeqLenQ; + // Each CTA handles one tokenQ by default for spec-decoding generation kernel. + if (!kernelMeta.mGroupsTokensHeadsQ) { + numCtasPerSeqQ = params.mMaxSeqLenQ; + } else { + // Compute numTokensPerCtaQ where each CTA must process complete numGroupedHeadsQ. + // Note that each CTA must process complete numHeadsQPerKv. + int numTokensPerCtaQ = kernelMeta.mStepQ / params.mNumHeadsQPerKv; + // Group both headsQ and tokensQ into one CTA. + numCtasPerSeqQ = flashinfer::ceil_div(params.mMaxSeqLenQ, numTokensPerCtaQ); + } } // Compute the grid dimension Y. - int numHeadsPerCta = kernelMeta.mGroupsHeadsQ - ? std::min(params.mNumHeadsQPerKv, kernelMeta.mMaxNumHeadsQPerKvInCta) - : 1; + int numHeadsPerCta = + kernelMeta.mGroupsHeadsQ ? std::min(params.mNumHeadsQPerKv, kernelMeta.mStepQ) : 1; int numCtasForAllHeadsQ = params.mNumHeadsQ / numHeadsPerCta; FLASHINFER_CHECK(numHeadsPerCta * numCtasForAllHeadsQ == params.mNumHeadsQ, "The numHeadsQ/numHeadsKv is not supported."); @@ -351,8 +396,10 @@ class TllmGenFmhaKernel { } // The maximum number Ctas per Kv sequence, which makes sure that each CtaKv has work to do. + // The factor of 2 is applied here to ensure the reduction overhead does not outweigh the + // benefits of a shorter mainloop. int const maxNumCtasPerSeqKv = - (maxAttentionWindow + kernelMeta.mStepKv - 1) / kernelMeta.mStepKv; + (maxAttentionWindow + 2 * kernelMeta.mStepKv - 1) / (2 * kernelMeta.mStepKv); // Compute numCtasPerSeqKv. numCtasPerSeqKv = std::min( maxNumCtasPerSeqKv, @@ -427,9 +474,13 @@ class TllmGenFmhaKernel { } } - // Return the number of CTAs for X, Y and Z dimension and the cluster size in the X dimension. - return std::make_tuple(numCtasPerSeqQ, numCtasPerSeqKv, numCtasX, numCtasY, numCtasZ, - clusterDimX); + // Update the parameters for launching the kernel. + ctaLaunchParams.mMaxNumCtasQ = numCtasPerSeqQ; + ctaLaunchParams.mMaxNumCtasKv = numCtasPerSeqKv; + ctaLaunchParams.mNumCtasX = numCtasX; + ctaLaunchParams.mNumCtasY = numCtasY; + ctaLaunchParams.mNumCtasZ = numCtasZ; + ctaLaunchParams.mClusterDimX = clusterDimX; } // Determine if we should use the SwapsMmaAbForGeneration kernel for MLA generation. @@ -447,7 +498,7 @@ class TllmGenFmhaKernel { ; // The number of Ctas. int const numCtas = static_cast(params.mBatchSize * params.mMaxSeqLenQ * - divUp(params.mNumHeadsQPerKv, 16)); + flashinfer::ceil_div(params.mNumHeadsQPerKv, 16)); // Compute numCtasPerSeqKv. int const numCtasPerSeqKv = std::min(maxNumCtasPerSeqKv, std::max(1, int32_t(params.mMultiProcessorCount / numCtas))); @@ -457,71 +508,203 @@ class TllmGenFmhaKernel { return seqLenPerCtaKv <= 1024 && numCtas <= params.mMultiProcessorCount; } - std::pair hashFromRunnerParams( - RunnerParams const& params, SelectKernelParams& selectKernelParams) const { - // The updated kernel type. + // Select the MLA generation kernel. + void selectMlaGenerationKernel(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + // We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the + // following conditions are met: + // 1. The number of headsQPerKv is <= 32. + // 2. The number of headsQPerKv is < 128 for sparseMla. + // 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned later) + // and + // the numCtas (after splitting the heads across multiple CTAs) <= + // params.mMultiProcessorCount. + // The sparseMla kernel will always use the 2CTA high-throughput kernel. + + // The kernel type. FmhaKernelType& kernelType = selectKernelParams.mKernelType; - // Generation kernelType will use either SwapsMmaAbForGeneration or KeepsMmaAbForGeneration. - if (isGenerationKernel(params.mKernelType) && isMlaGenKernel(params)) { - // We use the low-latency kernel (SwapsMmaAbForGeneration with tileSizeQ = 16) when any of the - // following conditions are met: - // 1. The number of headsQPerKv is <= 32. - // 2. The number of headsQPerKv is < 128 for sparseMla. - // 3. The seqLenPerCtaKv <= 1024 based on the benchmark results (this might be fine-tuned - // later) and - // the numCtas (after splitting the heads across multiple CTAs) <= - // params.mMultiProcessorCount. - - // Check the conditions. - if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) || - useSwapsMmaAbMlaGenKernel(params)) { - kernelType = FmhaKernelType::SwapsMmaAbForGeneration; - } else { - // Otherwise, we use the high-throughput kernel. - kernelType = FmhaKernelType::KeepsMmaAbForGeneration; - // Always use the separate reduction kernel. - if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { - selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; - } - // The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128. - FLASHINFER_CHECK( - !params.mSparseMla || params.mNumHeadsQPerKv == 128, - "The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128"); - // The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128. - if (params.mNumHeadsQPerKv == 128) { - selectKernelParams.mUses2CtaMma = true; - // Each Cta only handles 256 headDimV. - selectKernelParams.mHeadDimPerCtaV = 256; - } + // The tile size for Q. + int& tileSizeQ = selectKernelParams.mTileSizeQ; + + // Check the conditions. + if (params.mNumHeadsQPerKv <= 32 || (params.mSparseMla && params.mNumHeadsQPerKv < 128) || + useSwapsMmaAbMlaGenKernel(params)) { + kernelType = FmhaKernelType::SwapsMmaAbForGeneration; + // Currently, only tileSizeQ = 8 or 16 are supported. + tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16; + } else { + // Otherwise, we use the high-throughput kernel. + kernelType = FmhaKernelType::KeepsMmaAbForGeneration; + // Use the tileSizeQ = 64 for MLA high-throughput generation kernels. + tileSizeQ = 64; + // Always use the separate reduction kernel. + if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { + selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; } - } else if (isGenerationKernel(params.mKernelType)) { - kernelType = (params.mNumHeadsQPerKv <= 16 && params.mHeadDimQk != 32) - ? FmhaKernelType::SwapsMmaAbForGeneration - : FmhaKernelType::KeepsMmaAbForGeneration; + // The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128. + FLASHINFER_CHECK( + !params.mSparseMla || params.mNumHeadsQPerKv == 128, + "The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128, got %d", + params.mNumHeadsQPerKv); + // The 2CTA keepsMmaAbForGeneration kernel is used when the numHeadsQPerKv is 128. + if (params.mNumHeadsQPerKv == 128) { + selectKernelParams.mUses2CtaMma = true; + // Each Cta only handles 256 headDimV. + selectKernelParams.mHeadDimPerCtaV = 256; + } + } + } + + // Selects a heuristic tileSizeQ if groupsTokensHeadsQ is true. + void selectTileSizeQForGqaGeneration(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + // Define the per-tile mainloop cost model for different tileSizeQ choices. + std::unordered_map kernelMainloopCost = { + {128, 2.2}, // Cost factor when tileSizeQ = 128 + {64, 1.68}, // Cost factor when tileSizeQ = 64 + {32, 1.48}, // Cost factor when tileSizeQ = 32 + {16, 1.2}, // Cost factor when tileSizeQ = 16 + {8, 1.0} // Cost factor when tileSizeQ = 8 + }; + + // Define the per-tile reduction cost model for different tileSizeQ choices. + std::unordered_map kernelReductionCost = { + {128, 1.32}, // Reduction cost factor when tileSizeQ = 128 + {64, 1.2}, // Reduction cost factor when tileSizeQ = 64 + {32, 1.08}, // Reduction cost factor when tileSizeQ = 32 + {16, 1.03}, // Reduction cost factor when tileSizeQ = 16 + {8, 1.0} // Reduction cost factor when tileSizeQ = 8 + }; + + // The reduction cost emulated as a sequence length factor. + float const kernelReductionSeqLenFactor = 128.0f; + + // The parameters for launching the kernel. + CtaLaunchParams ctaLaunchParams; + // The copy of the selectKernelParams, which makes sure it won't modify the original + // selectKernelParams when computing the number of CTAs. + SelectKernelParams selectKernelParamsCopy = selectKernelParams; + // Load the kernel. + auto [func, kernelMeta] = loadKernel(params, selectKernelParamsCopy); + // Compute numCtasX, numCtasY and numCtasZ. + computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy); + + // If there are no free SMs or tileSizeQ is already the smallest one, skip the heuristic + // selection. + if (ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ * 2 > + params.mMultiProcessorCount || + selectKernelParamsCopy.mTileSizeQ <= 8) { + // No need to select the kernel further. + return; + } + + // Candidate tile sizes for tileSizeQ to explore. + int const candidateTileSizesQ[] = {128, 64, 32, 16, 8}; + + // The default tileSizeQ. + int defaultTileSizeQ = selectKernelParamsCopy.mTileSizeQ; + // The selected tileSizeQ. + int selectedTileSizeQ = selectKernelParamsCopy.mTileSizeQ; + + // The minimum modeling kernel time. + float globalModelingKernelTime = FLT_MAX; + // Loop over each candidate tile size. + for (int tileSizeQ : candidateTileSizesQ) { + // Only consider candidates <= default tileSizeQ. + if (tileSizeQ > defaultTileSizeQ) { + continue; + } + + // Compute the number of CTAs. + computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy); + + // Compute the seqLenPerCtaKv. + int32_t seqLenPerCtaKv = + flashinfer::ceil_div(flashinfer::ceil_div(params.mMaxSeqLenKv, kernelMeta.mStepKv), + ctaLaunchParams.mMaxNumCtasKv) * + kernelMeta.mStepKv; + + // Compute the modeling kernel time = mainloop cost + reduction cost. + float modelingKernelTime = kernelMainloopCost[tileSizeQ] * seqLenPerCtaKv + + kernelReductionCost[tileSizeQ] * kernelReductionSeqLenFactor * + ctaLaunchParams.mMaxNumCtasKv; + + // Compute the total number of CTAs. + int32_t numCtas = + ctaLaunchParams.mNumCtasX * ctaLaunchParams.mNumCtasY * ctaLaunchParams.mNumCtasZ; + // Compute the number of waves. + int32_t numWaves = flashinfer::ceil_div(numCtas, params.mMultiProcessorCount); + // Compute the total modeling kernel time. + modelingKernelTime *= numWaves; + + // If this candidate has a lower time than the global minimum, update the global minimum. + if (modelingKernelTime < globalModelingKernelTime) { + globalModelingKernelTime = modelingKernelTime; + selectedTileSizeQ = tileSizeQ; + } + } + + // Update the tileSizeQ. + selectKernelParams.mTileSizeQ = selectedTileSizeQ; + // Update the kernel type. + if (selectKernelParams.mTileSizeQ >= 64) { + selectKernelParams.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration; + } else { + selectKernelParams.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration; } + } - // The maximum number of headsQPerKv that the kernel can support in one Cta. - int maxNumHeadsQPerKvInCta = 1; - if (isSwapsMmaAbForGenerationKernel(kernelType)) { - // Set the corresponding maxNumHeadsQPerKvInCta (tileSizeQ) for low-latency generation - // kernels. - maxNumHeadsQPerKvInCta = (params.mNumHeadsQPerKv <= 8) ? 8 : 16; - FLASHINFER_CHECK((maxNumHeadsQPerKvInCta == 8 || maxNumHeadsQPerKvInCta == 16) && - (params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta || - params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0), - "Not supported"); - } else if (isKeepsMmaAbForGenerationKernel(kernelType)) { - // Use the maxNumHeadsQPerKvInCta (tileSizeQ) = 64 for MLA high-throughput generation kernels. - maxNumHeadsQPerKvInCta = isMlaGenKernel(params) ? 64 : 32; - FLASHINFER_CHECK((params.mNumHeadsQPerKv < maxNumHeadsQPerKvInCta || - params.mNumHeadsQPerKv % maxNumHeadsQPerKvInCta == 0), - "Not supported"); - } else if (isContextKernel(kernelType)) { - FLASHINFER_CHECK(maxNumHeadsQPerKvInCta == 1, "Not supported"); + // Selects a heuristic kernel for GQA generation. + void selectGqGenerationKernel(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + // The kernel type. + FmhaKernelType& kernelType = selectKernelParams.mKernelType; + // The tile size for Q. + int& tileSizeQ = selectKernelParams.mTileSizeQ; + + // Mixed precision kernels don't work with groupsTokensHeadsQ = true for now. + if (mDtypeQ != mDtypeKv || mDtypeOut == DATA_TYPE_E2M1) { + tileSizeQ = params.mNumHeadsQPerKv <= 8 ? 8 : 16; + kernelType = FmhaKernelType::SwapsMmaAbForGeneration; + return; + } + + // The number of tokensQ and headsQ that can be grouped into one CTA. + int numTokensHeadsQ = params.mNumHeadsQPerKv * params.mMaxSeqLenQ; + // When numHeadsQPerKv >= 64, use KeepsMmaAbForGeneration kernel. + if (numTokensHeadsQ <= 8) { + tileSizeQ = 8; + kernelType = FmhaKernelType::SwapsMmaAbForGeneration; + } else if (numTokensHeadsQ <= 16) { + tileSizeQ = 16; + kernelType = FmhaKernelType::SwapsMmaAbForGeneration; + } else if (numTokensHeadsQ <= 32) { + tileSizeQ = 32; + kernelType = FmhaKernelType::SwapsMmaAbForGeneration; + } else if (numTokensHeadsQ <= 64) { + tileSizeQ = 64; + kernelType = FmhaKernelType::KeepsMmaAbForGeneration; + } else { + tileSizeQ = 128; + kernelType = FmhaKernelType::KeepsMmaAbForGeneration; + } + + // When maxSeqLenQ > 1, use an experimental kernel-timing model to select the best kernel that + // groups both tokensQ and headsQ into one CTA. + if (params.mMaxSeqLenQ > 1) { + selectTileSizeQForGqaGeneration(params, selectKernelParams); + } + } + + // Select a kernel based on the heuristic. + void selectKernel(RunnerParams const& params, SelectKernelParams& selectKernelParams) const { + // Select the kernel based on the kernel type. + if (isGenerationKernel(params.mKernelType) && isMlaGenKernel(params)) { + selectMlaGenerationKernel(params, selectKernelParams); + } else if (isGenerationKernel(params.mKernelType)) { + selectGqGenerationKernel(params, selectKernelParams); } - // The mask type. - selectKernelParams.mMaskType = params.mMaskType; // Enable sliding window or chunked causal if the max kv sequence length exceeds attention // window size or chunked attention size. This is supported by causal-mask context kernels and // generation-phase kernels. @@ -536,30 +719,31 @@ class TllmGenFmhaKernel { selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal; } - // The number of tokens per page. - int numTokensPerPage = params.mNumTokensPerPage; // SparseMla kernels use a fixed numTokensPerPage = 1. if (params.mSparseMla) { - numTokensPerPage = 1; + selectKernelParams.mNumTokensPerPage = 1; } else if (!isPagedKv(params.mQkvLayout)) { // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels. - numTokensPerPage = 0; + selectKernelParams.mNumTokensPerPage = 0; } + } + std::pair hashFromRunnerParams( + RunnerParams const& params, SelectKernelParams& selectKernelParams) const { // Debug info. std::string info = "qkvLayout=" + std::to_string(static_cast(params.mQkvLayout)) + ", maskType=" + std::to_string(static_cast(selectKernelParams.mMaskType)) + - ", kernelType=" + std::to_string(static_cast(kernelType)) + + ", kernelType=" + std::to_string(static_cast(selectKernelParams.mKernelType)) + ", tileScheduler=" + std::to_string(static_cast(selectKernelParams.mTileScheduler)) + ", multiCtasKvMode=" + std::to_string(static_cast(selectKernelParams.mMultiCtasKvMode)) + ", headDimPerCtaV=" + std::to_string(selectKernelParams.mHeadDimPerCtaV) + ", headDimQk=" + std::to_string(params.mHeadDimQk) + ", headDimV=" + std::to_string(params.mHeadDimV) + + ", tileSizeQ=" + std::to_string(selectKernelParams.mTileSizeQ) + ", tileSizeKv=" + std::to_string(selectKernelParams.mTileSizeKv) + - ", numTokensPerPage=" + std::to_string(numTokensPerPage) + - ", maxNumHeadsQPerKvInCta=" + std::to_string(maxNumHeadsQPerKvInCta) + + ", numTokensPerPage=" + std::to_string(selectKernelParams.mNumTokensPerPage) + ", reuseSmemKForV=" + std::to_string(selectKernelParams.mReuseSmemKForV) + ", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma) + ", sparseMla=" + std::to_string(params.mSparseMla); @@ -570,55 +754,78 @@ class TllmGenFmhaKernel { return std::make_pair( hashID(static_cast(params.mQkvLayout), static_cast(selectKernelParams.mMaskType), - static_cast(kernelType), static_cast(selectKernelParams.mTileScheduler), + static_cast(selectKernelParams.mKernelType), + static_cast(selectKernelParams.mTileScheduler), static_cast(selectKernelParams.mMultiCtasKvMode), selectKernelParams.mHeadDimPerCtaV, params.mHeadDimQk, params.mHeadDimV, - selectKernelParams.mTileSizeKv, numTokensPerPage, maxNumHeadsQPerKvInCta, - selectKernelParams.mReuseSmemKForV, selectKernelParams.mUses2CtaMma, - params.mSparseMla), + selectKernelParams.mTileSizeQ, selectKernelParams.mTileSizeKv, + selectKernelParams.mNumTokensPerPage, selectKernelParams.mReuseSmemKForV, + selectKernelParams.mUses2CtaMma, params.mSparseMla), info); } // Load a single kernel (called by `run()` when needed). - void loadKernel(uint64_t hashId, unsigned int metaIndex) const { - auto const& kernelMeta = mKernelMeta[metaIndex]; - CUmodule hmod{0}; - std::string kernelName(kernelMeta.mFuncName); - - // Check if the module is already loaded. - auto findModuleIter = mModules.find(kernelMeta.mFuncName); - auto capitalizeFirst = [](std::string str) { - if (!str.empty()) { - str[0] = std::toupper(str[0]); - } - return str; - }; - if (findModuleIter == mModules.end()) { - // Load the module. - std::string cubin_path = tllm_gen_fmha_cubin_path + "/" + kernelMeta.mFuncName + ".cubin"; - std::string cubin = getCubin(cubin_path, kernelMeta.sha256); - if (cubin.empty()) { - throw std::runtime_error("Failed to load cubin for " + kernelName); + std::pair loadKernel(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + // Hash the runner params. + auto [hashId, info] = hashFromRunnerParams(params, selectKernelParams); + auto const findMetaIter = mKernelMetaMap.find(hashId); + // The meta index. + auto const metaIndex = findMetaIter->second; + + // Add debug info when kernels are not found. + FLASHINFER_CHECK(findMetaIter != mKernelMetaMap.end(), "Trtllm-gen kernels not found: " + info); + + // Load the function if not found. + if (mFunctions.find(hashId) == mFunctions.end()) { + // Load the kernel on-demand. + auto const& kernelMeta = mKernelMeta[metaIndex]; + CUmodule hmod{0}; + std::string kernelName(kernelMeta.mFuncName); + + // Check if the module is already loaded. + auto findModuleIter = mModules.find(kernelMeta.mFuncName); + auto capitalizeFirst = [](std::string str) { + if (!str.empty()) { + str[0] = std::toupper(str[0]); + } + return str; + }; + if (findModuleIter == mModules.end()) { + // Load the module. + std::string cubin_path = tllm_gen_fmha_cubin_path + "/" + kernelMeta.mFuncName + ".cubin"; + std::string cubin = getCubin(cubin_path, kernelMeta.sha256); + if (cubin.empty()) { + throw std::runtime_error("Failed to load cubin for " + kernelName); + } + cuErrCheck(cuModuleLoadData(&hmod, cubin.data())); + mModules[kernelName] = hmod; + } else { + hmod = findModuleIter->second; } - cuErrCheck(cuModuleLoadData(&hmod, cubin.data())); - mModules[kernelName] = hmod; - } else { - hmod = findModuleIter->second; - } - // Load the function. - KernelInfo funcInfo; - funcInfo.mMetaInfoIndex = metaIndex; - cuErrCheck(cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName)); + // Load the function. + KernelInfo funcInfo; + funcInfo.mMetaInfoIndex = metaIndex; + cuErrCheck(cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName)); + + if (kernelMeta.mSharedMemBytes >= 48 * 1024) { + cuErrCheck(cuFuncSetAttribute(funcInfo.mDeviceFunction, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + kernelMeta.mSharedMemBytes)); + } - if (kernelMeta.mSharedMemBytes >= 48 * 1024) { - cuErrCheck(cuFuncSetAttribute(funcInfo.mDeviceFunction, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - kernelMeta.mSharedMemBytes)); + // Cache the loaded function. + mFunctions[hashId] = funcInfo; } - // Cache the loaded function. - mFunctions[hashId] = funcInfo; + // Retrieve the loaded kernel. + auto const& kernelInfo = mFunctions.at(hashId); + auto const& kernelMeta = mKernelMeta[kernelInfo.mMetaInfoIndex]; + CUfunction func = kernelInfo.mDeviceFunction; + + // Return the function and kernelMeta. + return std::make_pair(func, kernelMeta); } Data_type mDtypeQ, mDtypeKv, mDtypeOut; diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index 693240306d..576fae4d4d 100755 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -348,12 +348,16 @@ struct TllmGenSelectKernelParams { bool mForceGmemReduction; // The mask type. TrtllmGenAttentionMaskType mMaskType; + // The number of tokens per page. + int mNumTokensPerPage; // Reuse smemK for V or not (only work with MLA generation kernels). bool mReuseSmemKForV; // Do we need to select a new kernel as the parameters have been updated. bool mSelectNewKernel; // The tile scheduler. TileScheduler mTileScheduler; + // The tile size for Q. + int mTileSizeQ; // The tile size for Kv. int mTileSizeKv; // Use 2 CTA MMA or not. @@ -369,9 +373,11 @@ struct TllmGenSelectKernelParams { : MultiCtasKvMode::Disabled), mForceGmemReduction(false), mMaskType(params.mMaskType), + mNumTokensPerPage(params.mNumTokensPerPage), mReuseSmemKForV(false), mSelectNewKernel(false), mTileScheduler(params.mTileScheduler), + mTileSizeQ(128), mTileSizeKv(128), mUses2CtaMma(false) {}; }; diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index fb8c4c1482..9b48d15d14 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -31,6 +31,41 @@ #include "../common.h" #include "fmhaRunnerParams.h" +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// CCCL >= 3.1.0 (CUDA CTK 13.1) introduces the fast_mod_div math operations. +// The following code makes sure that the host initialization works with older CUDA CTK versions. +// + +// Refer to +// https://github.com/NVIDIA/cccl/blob/main/libcudacxx/include/cuda/__cmath/fast_modulo_division.h#L76-L81 +// about how to compute the fast modulo division. +struct FastModDivInt32 { + public: + FastModDivInt32(int32_t divisor) : mDivisor(divisor) { + mShift = ceilLog2(mDivisor) - 1; + mMultiplier = static_cast( + ceilDiv(uint64_t(1) << (32 + mShift), static_cast(mDivisor))); + } + + private: + template + T ceilDiv(T a, T b) { + return (a + b - 1) / b; + } + + int32_t ceilLog2(int32_t value) const { + return static_cast(std::ceil(std::log2(value))); + } + + private: + int32_t mDivisor = 1; + uint32_t mMultiplier = 0; + uint32_t mAdd = 0; + int32_t mShift = 0; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// using Dtype = Data_type; @@ -115,6 +150,9 @@ struct KernelParams { int32_t mBatchSize; // The chunked attention size in log2. int32_t mChunkedAttentionSizeLog2; + // The factor to add to the maximum value to increase the probability + // of skip correction during next iterations. + float mInflateMax; // The log of the Sage Attention block size for K. int32_t mLogNumEltsPerSageAttnBlkK; // The log of the Sage Attention block size for P. @@ -137,10 +175,14 @@ struct KernelParams { int32_t mNumHeadsQ; // The number of Q heads per K/V head (i.e. mNumHeadsQ / mNumHeadsKv). int32_t mNumHeadsQPerKv; + // The number of headsQ per K/V head as a fast_mod_div divisor. + FastModDivInt32 mNumHeadsQPerKvDivisor{1}; // The hidden size of O. int64_t mNumHiddenEltsO; // The total number of pages in the paged-kv memory pool. int32_t mNumPagesInMemPool; + // The number of tokensQ per CTA (used for groupsHeadsTokensQ generation kernel). + int32_t mNumTokensPerCtaQ; // The number of tokens per page (used if dynamic numTokensPerPage is enabled). int32_t mNumTokensPerPageLog2; // The output scale for FP8 quantization. @@ -165,7 +207,8 @@ struct KernelParams { // Create the TMA shape/stride for Q. template - static auto makeTmaShapeStrideQ(FmhaOptions const& options, bool groupsHeadsQ, int32_t tileSizeQ, + static auto makeTmaShapeStrideQ(FmhaOptions const& options, bool groupsHeadsQ, + bool groupsTokensHeadsQ, int32_t tileSizeQ, int32_t numEltsInClampedHeadDimQ) { // // The Q has shape of [numTokens * numHeadsQPerKv, numHeadsKv * 1, headDim] @@ -236,19 +279,26 @@ struct KernelParams { // The tile shape for TMA. auto tileShapes = std::vector{static_cast(numEltsInClampedHeadDimQ), 1, 1, static_cast(tileSizeQ)}; + // The number of tokensQ per CTA. + int32_t numTokensPerCtaQ{tileSizeQ}; + // Re-compute the number of tokensQ per CTA if groupsHeadsQ is enabled. if (groupsHeadsQ) { - if (isSpecDecodingGenerationKernel(options.mKernelType)) { - FLASHINFER_CHECK((tileSizeQ % numGroupedHeads == 0), "internal error"); - tileShapes = std::vector{static_cast(numEltsInClampedHeadDimQ), - static_cast(numGroupedHeads), 1, - static_cast(tileSizeQ / numGroupedHeads)}; + if (groupsTokensHeadsQ) { + // Currently, it requires each CTA to process complete headsQ (i.e. numGroupedHeads) at a + // time, so it allows paddings in the end. Removing paddings needs re-organizing the Q + // tensor to [numTokensQ, numGroupedHeads, numHeads, headDimQ] and we might want to revisit + // this in the future. + numTokensPerCtaQ = static_cast(numTokensPerCtaQ / numGroupedHeads); } else { - tileShapes = std::vector{static_cast(numEltsInClampedHeadDimQ), - static_cast(tileSizeQ), 1, 1}; + numGroupedHeads = tileSizeQ; + numTokensPerCtaQ = 1; } + tileShapes = std::vector{static_cast(numEltsInClampedHeadDimQ), + static_cast(numGroupedHeads), 1, + static_cast(numTokensPerCtaQ)}; } - return std::make_tuple(shape, stride, tileShapes); + return std::make_tuple(shape, stride, tileShapes, numTokensPerCtaQ); } // Create the TMA shape/stride for O. @@ -541,6 +591,9 @@ struct KernelParams { if (result != CUDA_SUCCESS) { char const* err_str; cuGetErrorString(result, &err_str); + // Note that the error is thrown out before launching fmha kernels, so it is highly possible + // that the errors are broadcasted by previous kernels. Please enable CUDA_LAUNCH_BLOCKING or + // use cuda-gdb for more details. std::cerr << "Error: Failed to initialize the TMA descriptor due to " << err_str << std::endl; std::cerr << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; @@ -588,8 +641,9 @@ struct KernelParams { int32_t numEltsInClampedHeadDimQ = std::min(numEltsIn128BQ, options.mHeadDimQk); // Shape/stride for gmem tensor Q. - auto [shapeQ, strideQ, tileShapeQ] = makeTmaShapeStrideQ( - options, kernelMeta.mGroupsHeadsQ, kernelMeta.mTileSizeQ, numEltsInClampedHeadDimQ); + auto [shapeQ, strideQ, tileShapeQ, numTokensPerCtaQ] = + makeTmaShapeStrideQ(options, kernelMeta.mGroupsHeadsQ, kernelMeta.mGroupsTokensHeadsQ, + kernelMeta.mTileSizeQ, numEltsInClampedHeadDimQ); // Build tma descriptor for Q. params.tmaQ_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeQ, shapeQ, strideQ, tileShapeQ, const_cast(qPtr)); @@ -744,7 +798,9 @@ struct KernelParams { params.mNumHeadsQ = options.mNumHeadsQ; params.mNumHeadsKv = options.mNumHeadsKv; params.mNumHeadsQPerKv = options.mNumHeadsQPerKv; + params.mNumHeadsQPerKvDivisor = FastModDivInt32(options.mNumHeadsQPerKv); params.mNumHiddenEltsO = options.mNumHeadsQ * options.mHeadDimQk; + params.mNumTokensPerCtaQ = numTokensPerCtaQ; params.mOutputScale = options.outputScale; params.mScaleSoftmaxLog2 = options.scaleSoftmaxLog2; params.mStartTokenIdxSfO = options.mSfStartTokenIdx; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 39b46423e0..c6e8920c26 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1279,6 +1279,9 @@ def test_trtllm_batch_decode_head_dim_256( (1, 1, 32, 2, 5), (1, 3, 64, 2, 1), (1, 4, 64, 4, 1), + (32, 4, 16, 2, 8), + (32, 8, 16, 2, 8), + (32, 16, 16, 2, 8), ], ) @pytest.mark.parametrize("window_left", [-1]) From e7f9ba05557ddb175d5da39fa08c799444731306 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Wed, 24 Dec 2025 15:06:23 +0000 Subject: [PATCH 2/6] address comments --- include/flashinfer/trtllm/fmha/fmhaKernels.cuh | 14 ++++++++++++-- include/flashinfer/trtllm/fmha/kernelParams.h | 7 +------ 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index 9b9619af2f..d53e318d91 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -615,6 +615,16 @@ class TllmGenFmhaKernel { continue; } + selectKernelParamsCopy.mTileSizeQ = tileSizeQ; + if (tileSizeQ >= 64) { + selectKernelParamsCopy.mKernelType = FmhaKernelType::KeepsMmaAbForGeneration; + } else { + selectKernelParamsCopy.mKernelType = FmhaKernelType::SwapsMmaAbForGeneration; + } + + // Load the kernel. + std::tie(func, kernelMeta) = loadKernel(params, selectKernelParamsCopy); + // Compute the number of CTAs. computeCtaAndClusterConfig(ctaLaunchParams, params, kernelMeta, selectKernelParamsCopy); @@ -625,8 +635,8 @@ class TllmGenFmhaKernel { kernelMeta.mStepKv; // Compute the modeling kernel time = mainloop cost + reduction cost. - float modelingKernelTime = kernelMainloopCost[tileSizeQ] * seqLenPerCtaKv + - kernelReductionCost[tileSizeQ] * kernelReductionSeqLenFactor * + float modelingKernelTime = kernelMainloopCost.at(tileSizeQ) * seqLenPerCtaKv + + kernelReductionCost.at(tileSizeQ) * kernelReductionSeqLenFactor * ctaLaunchParams.mMaxNumCtasKv; // Compute the total number of CTAs. diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 9b48d15d14..ccb339aba6 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -46,15 +46,10 @@ struct FastModDivInt32 { FastModDivInt32(int32_t divisor) : mDivisor(divisor) { mShift = ceilLog2(mDivisor) - 1; mMultiplier = static_cast( - ceilDiv(uint64_t(1) << (32 + mShift), static_cast(mDivisor))); + flashinfer::ceil_div(uint64_t(1) << (32 + mShift), static_cast(mDivisor))); } private: - template - T ceilDiv(T a, T b) { - return (a + b - 1) / b; - } - int32_t ceilLog2(int32_t value) const { return static_cast(std::ceil(std::log2(value))); } From e2734cd79cada16216a7a45191ccd1ff5c9c7ab1 Mon Sep 17 00:00:00 2001 From: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com> Date: Wed, 24 Dec 2025 15:18:51 +0000 Subject: [PATCH 3/6] small fix --- include/flashinfer/trtllm/fmha/kernelParams.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index ccb339aba6..295647245a 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -44,7 +44,7 @@ struct FastModDivInt32 { public: FastModDivInt32(int32_t divisor) : mDivisor(divisor) { - mShift = ceilLog2(mDivisor) - 1; + mShift = std::max(ceilLog2(mDivisor) - 1, 0); mMultiplier = static_cast( flashinfer::ceil_div(uint64_t(1) << (32 + mShift), static_cast(mDivisor))); } From b0a6e2c466aca843a9a30f4b19c38c6e5b7624b2 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Wed, 7 Jan 2026 00:30:53 -0800 Subject: [PATCH 4/6] fix --- flashinfer/artifacts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index f9f1f719ab..4cda56e438 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -87,7 +87,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "81d3504ccf84d3ea0ff2ff4e2b15df2b63fb4160/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "75d477a640f268ea9ad117cc596eb39245713b9e/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841" ) @@ -107,7 +107,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "376d4de5a1bbb2a651bfd3c11d62cd55a0fe919c4669671675fc80c9934cd845" + "7406291bfe582f1bfa4f9c02f13960187fba0d4915dacd2c38c9e51dda7bf9b0" ) TRTLLM_GEN_BMM: str = ( "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" From 45075988583d4b3fcfc97477bb8573b6f0456daf Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 7 Jan 2026 03:33:00 -0500 Subject: [PATCH 5/6] fix --- flashinfer/artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 4cda56e438..292a35e3b2 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -107,7 +107,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "7406291bfe582f1bfa4f9c02f13960187fba0d4915dacd2c38c9e51dda7bf9b0" + "e014d7a54c396733ef012b223603c1be2861019f88faa5dcc882ed1ecfe5c2d9" ) TRTLLM_GEN_BMM: str = ( "b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc" From 78448ee138312a65c04c64655593288246ab9fac Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 7 Jan 2026 12:09:01 -0500 Subject: [PATCH 6/6] add missing header --- include/flashinfer/trtllm/fmha/fmhaKernels.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index d53e318d91..ca71bb0b88 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -19,6 +19,7 @@ #include #include +#include #include #include #include