diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 57a6c5df97..ad1f861bd9 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -135,7 +135,7 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "1d876ee612888821b168c25ffa75a9dcbb963aaa/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "5d4df6c2647e860992d1cc57ced05204b55f3787/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "c21ddd11585c1eea5764927465d0be15dd957e45/batched_gemm-91e0ba0-da44fdf/" ) @@ -157,7 +157,7 @@ class CheckSumHash: """ TRTLLM_GEN_FMHA: str = ( - "1abeea012a8779c6df5b84332fad43c6cfc3b257fe5ab883c8ea501464010d16" + "681e69c9c0215b4780eaf92f897c2dc94285a9143b90f765085eba83af5afa5b" ) TRTLLM_GEN_BMM: str = ( "4a3ed9c3dc6547ea3eed01ebda75b0e4322f6c01fc40cd2a4978e4deaba2732a" diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index 0920390bcd..a84a9dde30 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -76,6 +76,9 @@ constexpr bool isSMCompatible(int gpuSM, int kernelSM) { //////////////////////////////////////////////////////////////////////////////////////////////////// class TllmGenFmhaKernel { public: + static constexpr int kDynamicNumTokensPerPageThreshold = 128; + static constexpr int kDynamicNumTokensPerPageKernelKey = 128; + // The parameters for launching the kernel. // maxNumCtasQ, maxNumCtasKv, numCtasX, numCtasY, numCtasZ, clusterDimX struct CtaLaunchParams { @@ -160,9 +163,9 @@ class TllmGenFmhaKernel { "Expect (32 <= headDim <= 1024), got headDimPerCtaV=%d, headDimQk=%d, " "headDimV=%d", headDimPerCtaV, headDimQk, headDimV); - // The numTokensPerPage must be power of 2. - FLASHINFER_CHECK((numTokensPerPage & (numTokensPerPage - 1)) == 0, - "The numTokensPerPage must be power of 2."); + // The numTokensPerPage must be 0 (unused for non-paged kernels) or power of 2. + FLASHINFER_CHECK(numTokensPerPage == 0 || ((numTokensPerPage & (numTokensPerPage - 1)) == 0), + "The numTokensPerPage must be 0 or power of 2."); FLASHINFER_CHECK(tileSizeQ <= 128 && tileSizeKv <= 128, "The tileSizeQ and tileSizeKv must be <= 128."); FLASHINFER_CHECK((tileSizeQ & (tileSizeQ - 1)) == 0 && (tileSizeKv & (tileSizeKv - 1)) == 0, @@ -184,14 +187,15 @@ class TllmGenFmhaKernel { // Bit 54 - 54: uses2CtaMma. // Bit 55 - 55: sparseMla. // Bit 56 - 56: skipsSoftmax. + uint64_t const numTokensPerPageLog2 = + numTokensPerPage == 0 ? 0 : static_cast(log2(numTokensPerPage)); return (static_cast(qkvLayout) << 0) | (static_cast(maskType) << 4) | (static_cast(kernelType) << 8) | (static_cast(scheduler) << 12) | (static_cast(multiCtasKvMode) << 16) | (static_cast(headDimPerCtaV >> 3) << 18) | (static_cast(headDimQk >> 3) << 26) | (static_cast(headDimV >> 3) << 34) | - (static_cast(tileSizeKv >> 6) << 42) | - (static_cast(log2(numTokensPerPage)) << 44) | + (static_cast(tileSizeKv >> 6) << 42) | (numTokensPerPageLog2 << 44) | (static_cast(log2(tileSizeQ)) << 49) | (static_cast(reuseSmemKForV) << 53) | (static_cast(uses2CtaMma) << 54) | (static_cast(sparseMla) << 55) | @@ -362,7 +366,34 @@ class TllmGenFmhaKernel { // Is it MLA generation kernel ? inline bool isMlaGenKernel(RunnerParams const& params) const { - return params.mHeadDimQk == 576 && params.mHeadDimV == 512; + return (params.mHeadDimQk == 576 && params.mHeadDimV == 512) || + (params.mHeadDimQk == 320 && params.mHeadDimV == 256); + } + + inline bool useDynamicNumTokensPerPage(RunnerParams const& params) const { + return isPagedKv(params.mQkvLayout) && !params.mSparseMla && params.mNumHeadsQPerKv > 1 && + params.mHeadDimQk == params.mHeadDimV && + params.mNumTokensPerPage >= kDynamicNumTokensPerPageThreshold; + } + + void selectNumTokensPerPage(RunnerParams const& params, + SelectKernelParams& selectKernelParams) const { + selectKernelParams.mDynamicNumTokensPerPage = false; + if (params.mSparseMla) { + // SparseMla kernels use a fixed numTokensPerPage = 1. + selectKernelParams.mNumTokensPerPage = 1; + } else if (!isPagedKv(params.mQkvLayout)) { + // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels. + selectKernelParams.mNumTokensPerPage = 0; + } else if (useDynamicNumTokensPerPage(params)) { + FLASHINFER_CHECK((params.mNumTokensPerPage & (params.mNumTokensPerPage - 1)) == 0, + "Dynamic numTokensPerPage requires a power-of-2 page size, got %d.", + params.mNumTokensPerPage); + selectKernelParams.mDynamicNumTokensPerPage = true; + selectKernelParams.mNumTokensPerPage = kDynamicNumTokensPerPageKernelKey; + } else { + selectKernelParams.mNumTokensPerPage = params.mNumTokensPerPage; + } } // Compute the number of CTAs in X, Y and Z dimension and the cluster size in the X dimension. @@ -813,9 +844,22 @@ class TllmGenFmhaKernel { // Select a kernel based on the heuristic. void selectKernel(RunnerParams const& params, SelectKernelParams& selectKernelParams) const { + // Normalize this before heuristic probing; some GQA-generation heuristics load candidate + // kernels while selecting tileSizeQ. + selectNumTokensPerPage(params, selectKernelParams); + bool const isMlaGeneration = isGenerationKernel(params.mKernelType) && isMlaGenKernel(params); + // Select the kernel based on the kernel type. - if (isGenerationKernel(params.mKernelType) && isMlaGenKernel(params)) { + if (isMlaGeneration) { selectMlaGenerationKernel(params, selectKernelParams); + // TRTLLM-GEN MLA generation kernels are exported with dense mask metadata. Each generation + // CTA processes a bounded query tile, so the runtime sequence lengths/indices provide the + // effective masking. + selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::Dense; + FLASHINFER_CHECK( + params.mMaxSeqLenKv <= params.mAttentionWindowSize && + params.mChunkedAttentionSize == INT_MAX, + "TRTLLM-GEN MLA generation does not support sliding-window or chunked attention."); } else if (isGenerationKernel(params.mKernelType)) { selectGqGenerationKernel(params, selectKernelParams); } @@ -841,14 +885,6 @@ class TllmGenFmhaKernel { "Sliding window attention and chunked attention should not be used together"); selectKernelParams.mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal; } - - // SparseMla kernels use a fixed numTokensPerPage = 1. - if (params.mSparseMla) { - selectKernelParams.mNumTokensPerPage = 1; - } else if (!isPagedKv(params.mQkvLayout)) { - // NumTokensPerPage is set to 0 when not selecting pagedKv-layout kernels. - selectKernelParams.mNumTokensPerPage = 0; - } } std::pair hashFromRunnerParams( @@ -867,6 +903,7 @@ class TllmGenFmhaKernel { ", tileSizeQ=" + std::to_string(selectKernelParams.mTileSizeQ) + ", tileSizeKv=" + std::to_string(selectKernelParams.mTileSizeKv) + ", numTokensPerPage=" + std::to_string(selectKernelParams.mNumTokensPerPage) + + ", dynamicNumTokensPerPage=" + std::to_string(selectKernelParams.mDynamicNumTokensPerPage) + ", reuseSmemKForV=" + std::to_string(selectKernelParams.mReuseSmemKForV) + ", uses2CtaMma=" + std::to_string(selectKernelParams.mUses2CtaMma) + ", sparseMla=" + std::to_string(params.mSparseMla) + diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index 4ed4ee5213..a2687f5d4c 100644 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -372,6 +372,8 @@ struct TllmGenSelectKernelParams { TrtllmGenAttentionMaskType mMaskType; // The number of tokens per page. int mNumTokensPerPage; + // Whether a dynamic tokens-per-page cubin is selected. + bool mDynamicNumTokensPerPage; // Reuse smemK for V or not (only work with MLA generation kernels). bool mReuseSmemKForV; // Do we need to select a new kernel as the parameters have been updated. @@ -398,6 +400,7 @@ struct TllmGenSelectKernelParams { mForceGmemReduction(false), mMaskType(params.mMaskType), mNumTokensPerPage(params.mNumTokensPerPage), + mDynamicNumTokensPerPage(false), mReuseSmemKForV(false), mSelectNewKernel(false), mSkipsSoftmaxWhenPossible(params.mSkipsSoftmaxWhenPossible), diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index ffb6cab0e9..0f41f6c007 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -44,6 +44,8 @@ struct KernelParams { CUtensorMap tmaQ_; // TMA descriptor for K. CUtensorMap tmaK_; + // TMA descriptor for DSv4 sparse MLA sliding-window KV pool. Same format as tmaK_. + CUtensorMap tmaKSlidingWindowKvPool_; // TMA descriptor for V. CUtensorMap tmaV_; // The descriptor for O. @@ -117,6 +119,8 @@ struct KernelParams { // The softmax stats buffer. float2* ptrSoftmaxStats; + // The variable sparseMla topK lengths with shape of [numTokensQ]. + int32_t const* ptrSparseMlaTopKLens; // The attention window size for sliding window attention. int32_t mAttentionWindowSize; @@ -860,6 +864,7 @@ struct KernelParams { params.mStartTokenIdxSfO = options.mSfStartTokenIdx; params.mScaleSfKv = options.mScaleSfKv; params.ptrSoftmaxStats = options.softmaxStatsPtr; + params.ptrSparseMlaTopKLens = nullptr; // The sparseMlaTopK needs to be a multiple of 4 as we use 16B cpAsync instructions for the // indices. FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0, diff --git a/tests/attention/test_attention_sink_blackwell.py b/tests/attention/test_attention_sink_blackwell.py index d9fa320c2c..4cbe0fe4f5 100644 --- a/tests/attention/test_attention_sink_blackwell.py +++ b/tests/attention/test_attention_sink_blackwell.py @@ -227,7 +227,7 @@ def test_blackwell_trtllm_gen_context_attention_sink( ) if dtype == torch.float16: - atol, rtol = 1e-3, 1e-3 + atol, rtol = 2e-3, 1e-3 elif dtype == torch.bfloat16: atol, rtol = 1e-2, 1e-2 else: diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index d3b4bcb669..88dae5b325 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -35,6 +35,14 @@ workspace_size = 256 * 1024 * 1024 +def _skip_if_not_blackwell() -> None: + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] != 10: + pytest.skip( + "Dynamic tokensPerPage tests require SM100 or SM103 Blackwell GPUs." + ) + + def flip_coin(*args, **kwargs): # Use any test parameters to deterministically decide branch # This makes test configurations go through different paths @@ -1015,6 +1023,34 @@ def test_trtllm_batch_prefill_bs1( ) +@pytest.mark.parametrize("page_size", [128, 256, 512, 1024]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) +def test_trtllm_batch_prefill_dynamic_page_size_gqa( + page_size: int, + uses_shared_paged_kv_idx: bool, +) -> None: + _skip_if_not_blackwell() + _test_trtllm_batch_prefill( + "HND", + batch_size=4, + page_size=page_size, + num_kv_heads=2, + head_grp_size=5, + causal=True, + window_left=-1, + q_dtype="bf16", + o_dtype="bf16", + kv_dtype="bf16", + enable_pdl=None, + enable_sink=False, + max_q_len=257, + max_kv_len=1024, + device_scale=False, + head_dim=128, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + ) + + def _test_trtllm_batch_decode( backend: str, kv_layout: str, @@ -1705,6 +1741,37 @@ def test_trtllm_batch_decode_long_sequence_length( ) +@pytest.mark.parametrize("page_size", [128, 256, 512, 1024]) +@pytest.mark.parametrize("q_len_per_req", [1, 2]) +@pytest.mark.parametrize("window_left", [-1, 127]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) +def test_trtllm_batch_decode_dynamic_page_size_gqa( + page_size: int, + q_len_per_req: int, + window_left: int, + uses_shared_paged_kv_idx: bool, +) -> None: + _skip_if_not_blackwell() + _test_trtllm_batch_decode( + "trtllm-gen", + "HND", + batch_size=4, + q_len_per_req=q_len_per_req, + page_size=page_size, + num_kv_heads=2, + head_grp_size=5, + window_left=window_left, + q_dtype="bf16", + o_dtype="bf16", + kv_dtype="bf16", + enable_pdl=None, + enable_sink=False, + max_in_kv_len=1024, + head_dim=128, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + ) + + @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize( "batch_size,page_size,num_kv_heads,head_grp_size",