Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ void trtllm_paged_attention_launcher(
int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q,
int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax,
bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
cudaStream_t stream) {
int64_t k_sf_stride_heads, int64_t k_sf_stride_batch, int64_t v_sf_stride_heads,
int64_t v_sf_stride_batch, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -126,6 +127,10 @@ void trtllm_paged_attention_launcher(
runner_params.vStrideKeysValues = kv_stride_keys_values;
runner_params.vStrideHeads = kv_stride_heads;
runner_params.vStrideBatch = kv_stride_batch;
runner_params.kSfStrideHeads = k_sf_stride_heads;
runner_params.kSfStrideBatch = k_sf_stride_batch;
runner_params.vSfStrideHeads = v_sf_stride_heads;
runner_params.vSfStrideBatch = v_sf_stride_batch;
runner_params.mNumPagesInMemPool = num_pages_in_mem_pool;
runner_params.stream = stream;
// the scaleSoftmaxLog2Ptr and outputScalePtr have higher priority than the scaleSoftmaxLog2 and
Expand Down Expand Up @@ -299,6 +304,19 @@ void trtllm_paged_attention_decode(
const void* v_block_scales_ptr =
value_block_scales.has_value() ? value_block_scales.value().data_ptr() : nullptr;

// Read actual scale factor strides from the scale tensors (HND layout: [pages, heads, N, D/16]).
// These are passed separately to the kernel instead of being derived from KV data strides.
int k_sf_stride_heads = 0, k_sf_stride_batch = 0;
int v_sf_stride_heads = 0, v_sf_stride_batch = 0;
if (key_block_scales.has_value()) {
k_sf_stride_heads = key_block_scales.value().stride(-3);
k_sf_stride_batch = key_block_scales.value().stride(0);
}
if (value_block_scales.has_value()) {
v_sf_stride_heads = value_block_scales.value().stride(-3);
v_sf_stride_batch = value_block_scales.value().stride(0);
}

const auto stream = get_stream(query.device());
void* output_sf_ptr =
out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr;
Expand Down Expand Up @@ -345,7 +363,8 @@ void trtllm_paged_attention_decode(
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, stream);
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads,
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream);
}

void trtllm_paged_attention_context(
Expand Down Expand Up @@ -407,6 +426,18 @@ void trtllm_paged_attention_context(
const void* v_block_scales_ptr =
value_block_scales.has_value() ? value_block_scales.value().data_ptr() : nullptr;

// Read actual scale factor strides from the scale tensors (HND layout: [pages, heads, N, D/16]).
int k_sf_stride_heads = 0, k_sf_stride_batch = 0;
int v_sf_stride_heads = 0, v_sf_stride_batch = 0;
if (key_block_scales.has_value()) {
k_sf_stride_heads = key_block_scales.value().stride(-3);
k_sf_stride_batch = key_block_scales.value().stride(0);
}
if (value_block_scales.has_value()) {
v_sf_stride_heads = value_block_scales.value().stride(-3);
v_sf_stride_batch = value_block_scales.value().stride(0);
}

const auto stream = get_stream(query.device());
void* output_sf_ptr =
out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr;
Expand Down Expand Up @@ -455,7 +486,8 @@ void trtllm_paged_attention_context(
kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value,
bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left,
sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, stream);
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads,
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream);
}

void trtllm_ragged_attention_launcher(
Expand Down
23 changes: 23 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2245,6 +2245,11 @@ def trtllm_batch_decode_with_kv_cache(
or [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is ``NHD``.
The first tensor is the key cache, and the second tensor is the value cache.

**Contiguity requirements (trtllm-gen backend):**

- The ``head_dim`` (last dim) **must** have stride 1. This is a TMA hardware constraint
- The head and batch/page dims can have arbitrary strides.

workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
workspace

Expand Down Expand Up @@ -2290,6 +2295,10 @@ def trtllm_batch_decode_with_kv_cache(
kv_layout : str = "HND"
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
Defaults to ``HND``.
For the trtllm-gen backend with NVFP4 KV cache, using ``NHD`` will trigger an
automatic transpose and ``.contiguous()`` copy of both the KV data and block scale
tensors to convert them to HND layout. This incurs extra memory allocation and
data copy overhead. Use ``HND`` for better performance.

enable_pdl : Optional[bool] = None
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
Expand Down Expand Up @@ -2317,6 +2326,20 @@ def trtllm_batch_decode_with_kv_cache(
Only supported by trtllm-gen backend. Must be provided together with ``max_q_len``.
When None, all requests use uniform query length specified by ``q_len_per_req``.

kv_block_scales : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None
Per-block scale factors for NVFP4 KV cache. Either a tuple of (k_scales, v_scales) or
a single tensor with shape ``[num_pages, 2, ...]`` that will be unbound along dim=1.
Each scale tensor has shape ``[num_pages, num_kv_heads, page_size, head_dim // 16]``
in HND layout, with dtype ``torch.float8_e4m3fn``.

**Contiguity requirements (trtllm-gen backend):**

- The last two dims (``page_size``, ``head_dim // 16``) **must** be contiguous
(i.e., ``stride[-1] == 1`` and ``stride[-2] == head_dim // 16``). This is because
the kernel reshapes them into ``(16, page_size * head_dim / 16 / 16)`` to satisfy
TMA's 16-byte box width minimum.
- The head and batch/page dims can have arbitrary strides.

skip_softmax_threshold_scale_factor: Optional[float] = None
threshold scale factor for skipping softmax operations.
Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087
Expand Down
21 changes: 20 additions & 1 deletion flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3748,6 +3748,11 @@ def trtllm_batch_context_with_kv_cache(
If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim] if :attr:`kv_layout` is "HND",
or [num_pages, page_size, num_kv_heads, head_dim] if :attr:`kv_layout` is "NHD".
The first tensor is the key cache, the second tensor is the value cache.

**Contiguity requirements (trtllm-gen backend):**

- The ``head_dim`` (last dim) **must** have stride 1. This is a TMA hardware constraint
- The head and batch/page dims can have arbitrary strides.
workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
workspace
block_tables : torch.Tensor
Expand Down Expand Up @@ -3789,11 +3794,25 @@ def trtllm_batch_context_with_kv_cache(
Defaults to ``None``, which means it will be enabled if the device supports PDL.
kv_layout : str = "HND"
Layout of kv-cache, can be "HND" or "NHD", default is "HND".
For the trtllm-gen backend with NVFP4 KV cache, using ``NHD`` will trigger an
automatic transpose and ``.contiguous()`` copy of both the KV data and block scale
tensors to convert them to HND layout. This incurs extra memory allocation and
data copy overhead. Use ``HND`` for better performance.
sinks : Optional[List[torch.Tensor]] = None
additional value per head in the denominator of the softmax.
kv_block_scales : Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None
Per-block scale factors for NVFP4 KV cache. Either a tuple of (k_scales, v_scales) or
a single tensor with shape [num_pages, 2, ...] that will be unbound along dim=1.
a single tensor with shape ``[num_pages, 2, ...]`` that will be unbound along dim=1.
Each scale tensor has shape ``[num_pages, num_kv_heads, page_size, head_dim // 16]``
in HND layout, with dtype ``torch.float8_e4m3fn``.

**Contiguity requirements (trtllm-gen backend):**

- The last two dims (``page_size``, ``head_dim // 16``) **must** be contiguous
(i.e., ``stride[-1] == 1`` and ``stride[-2] == head_dim // 16``). This is because
the kernel reshapes them into ``(16, page_size * head_dim / 16 / 16)`` to satisfy
TMA's 16-byte box width minimum.
- The head and batch/page dims can have arbitrary strides.
skip_softmax_threshold_scale_factor: Optional[float] = None
threshold scale factor for skipping softmax operations.
Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087
Expand Down
9 changes: 9 additions & 0 deletions include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,15 @@ struct TllmGenFmhaRunnerParams {
// The stride between different batches for V.
int vStrideBatch;

// The stride between different heads for K scaling factors.
int kSfStrideHeads;
// The stride between different batches for K scaling factors.
int kSfStrideBatch;
// The stride between different heads for V scaling factors.
int vSfStrideHeads;
// The stride between different batches for V scaling factors.
int vSfStrideBatch;

// Head dimension for Q and K.
int mHeadDimQk;
// Head dimension for V.
Expand Down
31 changes: 25 additions & 6 deletions include/flashinfer/trtllm/fmha/kernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,13 @@ struct KernelParams {
return std::make_tuple(strideKeysVals, strideHeads, strideBatch);
}

// Create the TMA shape/stride for K.
// Create the TMA shape/stride for K/V data tensors.
//
// Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim]
// - head_dim (last dim) MUST have stride 1. This is a TMA hardware constraint:
// cuTensorMapEncodeTiled does not accept a stride for dim 0 and implicitly assumes 1.
// - Other dimensions (heads, batch/pages) can have arbitrary strides; the actual
// strides are read from the tensor and passed to the TMA descriptor.
template <class FmhaOptions>
static auto makeTmaShapeStrideKv(FmhaOptions const& options, KernelParams const& params,
Data_type dtypeKv, bool isK, bool storeTransformedKvInTmem) {
Expand Down Expand Up @@ -446,14 +452,23 @@ struct KernelParams {
return std::make_tuple(shape, stride);
}

// Create the TMA shape/stride for KV scaling factors.
// Create the TMA shape/stride for KV scaling factors (block scales for NVFP4 KV cache).
//
// Layout requirement (HND): [num_pages, num_kv_heads, page_size, head_dim // 16]
// - The last two dims (page_size, head_dim // 16) MUST be contiguous (stride[-1] = 1,
// stride[-2] = head_dim // 16). This is because we reshape them into
// (16, page_size * head_dim / 16 / 16) with hardcoded stride[1] = 16 to satisfy TMA's
// 16-byte box width requirement. Each scale factor is 1 byte (FP8), and head_dim // 16
// can be < 16 (e.g., 8 for head_dim=128), so we must merge with page_size to reach 16.
// - The head and batch/page strides are read from the actual scale tensors (kSfStrideHeads,
// kSfStrideBatch) and can differ from the KV data strides.
// - cuTensorMapEncodeTiled requires all non-dim0 strides to be multiples of 16 bytes, so
// sfStrideHeads and sfStrideBatch must each be a multiple of 16.
template <class FmhaOptions>
static auto makeTmaShapeStrideKvSf(FmhaOptions const& options, KernelParams const& params,
bool isK) {
// The shape elements.
auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params);
// The stride elements.
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK);

// The headDim.
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
Expand All @@ -464,6 +479,10 @@ struct KernelParams {
// The number of elements per SF.
int32_t NumEltsPerSf = 16;

// Use actual scale factor strides instead of deriving from KV strides.
int32_t sfStrideHeads = isK ? options.kSfStrideHeads : options.vSfStrideHeads;
int32_t sfStrideBatch = isK ? options.kSfStrideBatch : options.vSfStrideBatch;

// The KV shape is: (headDim, numKeys, numHeadsKv, batchSize)
// Therefore, the KV SF shape should be (headDim / NumEltsPerSf, numKeys, numHeadsKv,
// batchSize). Considering the TMA requires box width to be multiple of 16B, without changing
Expand All @@ -476,8 +495,8 @@ struct KernelParams {
auto shape = std::vector<uint64_t>{
16, static_cast<uint64_t>(numKeys * headDim / NumEltsPerSf / 16),
static_cast<uint64_t>(options.mNumHeadsKv), static_cast<uint64_t>(batchSize)};
auto stride = std::vector<uint64_t>{1, 16, static_cast<uint64_t>(strideHeads / NumEltsPerSf),
static_cast<uint64_t>(strideBatch / NumEltsPerSf)};
auto stride = std::vector<uint64_t>{1, 16, static_cast<uint64_t>(sfStrideHeads),
static_cast<uint64_t>(sfStrideBatch)};

return std::make_tuple(shape, stride);
}
Expand Down
Loading