diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index a2bfdf6727..a0997ddf2d 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -87,7 +87,8 @@ void trtllm_paged_attention_launcher( int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax, bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size, - cudaStream_t stream) { + int64_t k_sf_stride_heads, int64_t k_sf_stride_batch, int64_t v_sf_stride_heads, + int64_t v_sf_stride_batch, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -126,6 +127,10 @@ void trtllm_paged_attention_launcher( runner_params.vStrideKeysValues = kv_stride_keys_values; runner_params.vStrideHeads = kv_stride_heads; runner_params.vStrideBatch = kv_stride_batch; + runner_params.kSfStrideHeads = k_sf_stride_heads; + runner_params.kSfStrideBatch = k_sf_stride_batch; + runner_params.vSfStrideHeads = v_sf_stride_heads; + runner_params.vSfStrideBatch = v_sf_stride_batch; runner_params.mNumPagesInMemPool = num_pages_in_mem_pool; runner_params.stream = stream; // the scaleSoftmaxLog2Ptr and outputScalePtr have higher priority than the scaleSoftmaxLog2 and @@ -299,6 +304,19 @@ void trtllm_paged_attention_decode( const void* v_block_scales_ptr = value_block_scales.has_value() ? value_block_scales.value().data_ptr() : nullptr; + // Read actual scale factor strides from the scale tensors (HND layout: [pages, heads, N, D/16]). + // These are passed separately to the kernel instead of being derived from KV data strides. + int k_sf_stride_heads = 0, k_sf_stride_batch = 0; + int v_sf_stride_heads = 0, v_sf_stride_batch = 0; + if (key_block_scales.has_value()) { + k_sf_stride_heads = key_block_scales.value().stride(-3); + k_sf_stride_batch = key_block_scales.value().stride(0); + } + if (value_block_scales.has_value()) { + v_sf_stride_heads = value_block_scales.value().stride(-3); + v_sf_stride_batch = value_block_scales.value().stride(0); + } + const auto stream = get_stream(query.device()); void* output_sf_ptr = out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr; @@ -345,7 +363,8 @@ void trtllm_paged_attention_decode( max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax, - uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, stream); + uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads, + k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream); } void trtllm_paged_attention_context( @@ -407,6 +426,18 @@ void trtllm_paged_attention_context( const void* v_block_scales_ptr = value_block_scales.has_value() ? value_block_scales.value().data_ptr() : nullptr; + // Read actual scale factor strides from the scale tensors (HND layout: [pages, heads, N, D/16]). + int k_sf_stride_heads = 0, k_sf_stride_batch = 0; + int v_sf_stride_heads = 0, v_sf_stride_batch = 0; + if (key_block_scales.has_value()) { + k_sf_stride_heads = key_block_scales.value().stride(-3); + k_sf_stride_batch = key_block_scales.value().stride(0); + } + if (value_block_scales.has_value()) { + v_sf_stride_heads = value_block_scales.value().stride(-3); + v_sf_stride_batch = value_block_scales.value().stride(0); + } + const auto stream = get_stream(query.device()); void* output_sf_ptr = out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr; @@ -455,7 +486,8 @@ void trtllm_paged_attention_context( kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax, - uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, stream); + uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads, + k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream); } void trtllm_ragged_attention_launcher( diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 8107782382..e8f44d98c7 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2245,6 +2245,11 @@ def trtllm_batch_decode_with_kv_cache( or [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``. The first tensor is the key cache, and the second tensor is the value cache. + **Contiguity requirements (trtllm-gen backend):** + + - The ``head_dim`` (last dim) **must** have stride 1. This is a TMA hardware constraint + - The head and batch/page dims can have arbitrary strides. + workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace @@ -2290,6 +2295,10 @@ def trtllm_batch_decode_with_kv_cache( kv_layout : str = "HND" The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. Defaults to ``HND``. + For the trtllm-gen backend with NVFP4 KV cache, using ``NHD`` will trigger an + automatic transpose and ``.contiguous()`` copy of both the KV data and block scale + tensors to convert them to HND layout. This incurs extra memory allocation and + data copy overhead. Use ``HND`` for better performance. enable_pdl : Optional[bool] = None Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization @@ -2317,6 +2326,20 @@ def trtllm_batch_decode_with_kv_cache( Only supported by trtllm-gen backend. Must be provided together with ``max_q_len``. When None, all requests use uniform query length specified by ``q_len_per_req``. + kv_block_scales : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None + Per-block scale factors for NVFP4 KV cache. Either a tuple of (k_scales, v_scales) or + a single tensor with shape ``[num_pages, 2, ...]`` that will be unbound along dim=1. + Each scale tensor has shape ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` + in HND layout, with dtype ``torch.float8_e4m3fn``. + + **Contiguity requirements (trtllm-gen backend):** + + - The last two dims (``page_size``, ``head_dim // 16``) **must** be contiguous + (i.e., ``stride[-1] == 1`` and ``stride[-2] == head_dim // 16``). This is because + the kernel reshapes them into ``(16, page_size * head_dim / 16 / 16)`` to satisfy + TMA's 16-byte box width minimum. + - The head and batch/page dims can have arbitrary strides. + skip_softmax_threshold_scale_factor: Optional[float] = None threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 5b1b33ea64..9f120aab3e 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3748,6 +3748,11 @@ def trtllm_batch_context_with_kv_cache( If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is "HND", or [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is "NHD". The first tensor is the key cache, the second tensor is the value cache. + + **Contiguity requirements (trtllm-gen backend):** + + - The ``head_dim`` (last dim) **must** have stride 1. This is a TMA hardware constraint + - The head and batch/page dims can have arbitrary strides. workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace block_tables : torch.Tensor @@ -3789,11 +3794,25 @@ def trtllm_batch_context_with_kv_cache( Defaults to ``None``, which means it will be enabled if the device supports PDL. kv_layout : str = "HND" Layout of kv-cache, can be "HND" or "NHD", default is "HND". + For the trtllm-gen backend with NVFP4 KV cache, using ``NHD`` will trigger an + automatic transpose and ``.contiguous()`` copy of both the KV data and block scale + tensors to convert them to HND layout. This incurs extra memory allocation and + data copy overhead. Use ``HND`` for better performance. sinks : Optional[List[torch.Tensor]] = None additional value per head in the denominator of the softmax. kv_block_scales : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None Per-block scale factors for NVFP4 KV cache. Either a tuple of (k_scales, v_scales) or - a single tensor with shape [num_pages, 2, ...] that will be unbound along dim=1. + a single tensor with shape ``[num_pages, 2, ...]`` that will be unbound along dim=1. + Each scale tensor has shape ``[num_pages, num_kv_heads, page_size, head_dim // 16]`` + in HND layout, with dtype ``torch.float8_e4m3fn``. + + **Contiguity requirements (trtllm-gen backend):** + + - The last two dims (``page_size``, ``head_dim // 16``) **must** be contiguous + (i.e., ``stride[-1] == 1`` and ``stride[-2] == head_dim // 16``). This is because + the kernel reshapes them into ``(16, page_size * head_dim / 16 / 16)`` to satisfy + TMA's 16-byte box width minimum. + - The head and batch/page dims can have arbitrary strides. skip_softmax_threshold_scale_factor: Optional[float] = None threshold scale factor for skipping softmax operations. Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index 44c7a749c1..7945b1fab2 100644 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -251,6 +251,15 @@ struct TllmGenFmhaRunnerParams { // The stride between different batches for V. int vStrideBatch; + // The stride between different heads for K scaling factors. + int kSfStrideHeads; + // The stride between different batches for K scaling factors. + int kSfStrideBatch; + // The stride between different heads for V scaling factors. + int vSfStrideHeads; + // The stride between different batches for V scaling factors. + int vSfStrideBatch; + // Head dimension for Q and K. int mHeadDimQk; // Head dimension for V. diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 2bbdb1800a..f38792ce30 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -407,7 +407,13 @@ struct KernelParams { return std::make_tuple(strideKeysVals, strideHeads, strideBatch); } - // Create the TMA shape/stride for K. + // Create the TMA shape/stride for K/V data tensors. + // + // Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim] + // - head_dim (last dim) MUST have stride 1. This is a TMA hardware constraint: + // cuTensorMapEncodeTiled does not accept a stride for dim 0 and implicitly assumes 1. + // - Other dimensions (heads, batch/pages) can have arbitrary strides; the actual + // strides are read from the tensor and passed to the TMA descriptor. template static auto makeTmaShapeStrideKv(FmhaOptions const& options, KernelParams const& params, Data_type dtypeKv, bool isK, bool storeTransformedKvInTmem) { @@ -446,14 +452,23 @@ struct KernelParams { return std::make_tuple(shape, stride); } - // Create the TMA shape/stride for KV scaling factors. + // Create the TMA shape/stride for KV scaling factors (block scales for NVFP4 KV cache). + // + // Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim // 16] + // - The last two dims (page_size, head_dim // 16) MUST be contiguous (stride[-1] = 1, + // stride[-2] = head_dim // 16). This is because we reshape them into + // (16, page_size * head_dim / 16 / 16) with hardcoded stride[1] = 16 to satisfy TMA's + // 16-byte box width requirement. Each scale factor is 1 byte (FP8), and head_dim // 16 + // can be < 16 (e.g., 8 for head_dim=128), so we must merge with page_size to reach 16. + // - The head and batch/page strides are read from the actual scale tensors (kSfStrideHeads, + // kSfStrideBatch) and can differ from the KV data strides. + // - cuTensorMapEncodeTiled requires all non-dim0 strides to be multiples of 16 bytes, so + // sfStrideHeads and sfStrideBatch must each be a multiple of 16. template static auto makeTmaShapeStrideKvSf(FmhaOptions const& options, KernelParams const& params, bool isK) { // The shape elements. auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params); - // The stride elements. - auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK); // The headDim. // Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv. @@ -464,6 +479,10 @@ struct KernelParams { // The number of elements per SF. int32_t NumEltsPerSf = 16; + // Use actual scale factor strides instead of deriving from KV strides. + int32_t sfStrideHeads = isK ? options.kSfStrideHeads : options.vSfStrideHeads; + int32_t sfStrideBatch = isK ? options.kSfStrideBatch : options.vSfStrideBatch; + // The KV shape is: (headDim, numKeys, numHeadsKv, batchSize) // Therefore, the KV SF shape should be (headDim / NumEltsPerSf, numKeys, numHeadsKv, // batchSize). Considering the TMA requires box width to be multiple of 16B, without changing @@ -476,8 +495,8 @@ struct KernelParams { auto shape = std::vector{ 16, static_cast(numKeys * headDim / NumEltsPerSf / 16), static_cast(options.mNumHeadsKv), static_cast(batchSize)}; - auto stride = std::vector{1, 16, static_cast(strideHeads / NumEltsPerSf), - static_cast(strideBatch / NumEltsPerSf)}; + auto stride = std::vector{1, 16, static_cast(sfStrideHeads), + static_cast(sfStrideBatch)}; return std::make_tuple(shape, stride); }