Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ void trtllm_paged_attention_launcher(
int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q,
int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
cudaStream_t stream) {
int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, bool batch_invariant,
int64_t workspace_size, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -151,6 +151,8 @@ void trtllm_paged_attention_launcher(

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
if (mode == TllmPagedAttentionMode::Context) {
// Note: batch_invariant parameter has no effect in context mode, as context mode
// always uses Persistent scheduler and disables multi-CTA optimization.
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
runner_params.mKernelType = FmhaKernelType::Context;
runner_params.mTileScheduler = TileScheduler::Persistent;
Expand All @@ -165,7 +167,7 @@ void trtllm_paged_attention_launcher(
// one tokenQ in those cases, so dense mask works the same as causal mask.
runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal;
runner_params.mKernelType = FmhaKernelType::Generation;
bool use_multi_block = true;
bool use_multi_block = !batch_invariant;
runner_params.mTileScheduler =
use_multi_block ? TileScheduler::Static : TileScheduler::Persistent;
runner_params.mMultiCtasKvMode = use_multi_block;
Expand Down Expand Up @@ -217,17 +219,15 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) {

inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; }

void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scale_factor,
TensorView query, TensorView key_cache, TensorView value_cache,
TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t batch_size, int64_t window_left,
int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q) {
void trtllm_paged_attention_decode(
TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache,
TensorView value_cache, TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
bool batch_invariant, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -306,17 +306,20 @@ void trtllm_paged_attention_decode(TensorView out, Optional<TensorView> out_scal
q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
sparse_mla_top_k, sm_count, enable_pdl, workspace_size, stream);
sparse_mla_top_k, sm_count, enable_pdl, batch_invariant, workspace_size, stream);
}

void trtllm_paged_attention_context(
TensorView out, Optional<TensorView> out_scale_factor, TensorView query, TensorView key_cache,
TensorView value_cache, TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale, Variant<double, ffi::Tensor> bmm2_scale,
double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size,
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks) {
void trtllm_paged_attention_context(TensorView out, Optional<TensorView> out_scale_factor,
TensorView query, TensorView key_cache, TensorView value_cache,
TensorView workspace_buffer, TensorView block_tables,
TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len,
Variant<double, ffi::Tensor> bmm1_scale,
Variant<double, ffi::Tensor> bmm2_scale, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index,
int64_t batch_size, int64_t window_left,
TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv,
int64_t sm_count, bool enable_pdl, bool batch_invariant,
int64_t workspace_size, Optional<TensorView> attention_sinks) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand Down Expand Up @@ -390,7 +393,8 @@ void trtllm_paged_attention_context(
head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values,
kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value,
bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left,
sum_seq_q, /*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream);
sum_seq_q, /*sparse_mla_top_k=*/0, sm_count, enable_pdl, batch_invariant, workspace_size,
stream);
Comment on lines +396 to +397
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Following the removal of batch_invariant from trtllm_paged_attention_context's signature, this call should be updated to pass false. The batch_invariant flag does not affect context-phase kernels, so false is a safe default.

      sum_seq_q, /*sparse_mla_top_k=*/0, sm_count, enable_pdl, /*batch_invariant=*/false, workspace_size,
      stream);

}

void trtllm_ragged_attention_launcher(
Expand Down
10 changes: 10 additions & 0 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,7 @@ def _paged_run(
workspace_size: int,
window_left: int = -1,
enable_pdl: bool = None,
batch_invariant: bool = False,
out: Optional[torch.Tensor] = None,
sinks: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -1958,6 +1959,7 @@ def _paged_run(
0, # sparse_mla_top_k
self._sm_count,
enable_pdl,
batch_invariant,
workspace_size,
sinks,
None, # cum_seq_lens_q
Expand Down Expand Up @@ -2118,6 +2120,7 @@ def trtllm_batch_decode_with_kv_cache(
sinks: Optional[List[torch.Tensor]] = None,
kv_layout: str = "HND",
enable_pdl: Optional[bool] = None,
batch_invariant: bool = False,
backend: str = "auto",
q_len_per_req: Optional[int] = 1,
o_scale: Optional[float] = 1.0,
Expand Down Expand Up @@ -2185,6 +2188,12 @@ def trtllm_batch_decode_with_kv_cache(
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
When set to ``None``, the backend will be chosen based on the device architecture and kernel availability.

batch_invariant : bool = False
When set to True, disables multi-CTA optimization in the generation kernel.
This ensures the output is invariant to batch size, allowing per-request
processing without a for loop while maintaining consistent results.
Only supported by trtllm-gen backend. Defaults to False.

backend : str = "auto"
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
Expand Down Expand Up @@ -2384,6 +2393,7 @@ def trtllm_batch_decode_with_kv_cache(
0, # sparse_mla_top_k
sm_count,
enable_pdl,
batch_invariant,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
cum_seq_lens_q,
Expand Down
15 changes: 15 additions & 0 deletions flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
bmm2_scale: Union[float, torch.Tensor] = 1.0,
sinks: Optional[List[torch.Tensor]] = None,
enable_pdl: bool = None,
batch_invariant: bool = False,
backend: str = "auto",
) -> torch.Tensor:
"""
Expand All @@ -548,6 +549,19 @@ def trtllm_batch_decode_with_kv_cache_mla(
bmm2_scale: fused scale for mla bmm2 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
sinks: additional value per head in the denominator of the softmax.
batch_invariant : bool = False
When set to True, disables multi-CTA optimization in the generation kernel.
This ensures the output is invariant to batch size, allowing per-request
processing without a for loop while maintaining consistent results.
Only supported by trtllm-gen backend. Defaults to False.

**Important**: For MLA attention, batch invariance may not be fully guaranteed
even with this flag enabled. The MLA implementation uses a reduction kernel
that combines partial results from split-KV optimization, and the number of
splits is determined by a heuristic that depends on batch size
(split_kv ~ sm_count / batch_size). This means different batch sizes may still
produce slightly different numerical results due to different reduction patterns,
even though multi-CTA is disabled in the main generation kernel.
backend : str = "auto"
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
Expand Down Expand Up @@ -675,6 +689,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
sparse_mla_top_k,
sm_count,
enable_pdl,
batch_invariant,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
None, # cum_seq_lens_q
Expand Down
9 changes: 9 additions & 0 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def _paged_run(
cum_seq_lens_q: torch.Tensor,
cum_seq_lens_kv: torch.Tensor,
enable_pdl: bool,
batch_invariant: bool,
workspace_size: int,
window_left: int = -1,
out: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -254,6 +255,7 @@ def _paged_run(
cum_seq_lens_kv,
sm_count,
enable_pdl,
batch_invariant,
workspace_size,
sinks,
)
Expand Down Expand Up @@ -3486,6 +3488,7 @@ def trtllm_batch_context_with_kv_cache(
o_sf_vec_size: Optional[int] = None,
kv_layout: str = "HND",
enable_pdl: Optional[bool] = None,
batch_invariant: bool = False,
sinks: Optional[List[torch.Tensor]] = None,
) -> Union[torch.Tensor, FP4Tensor]:
"""
Expand Down Expand Up @@ -3535,6 +3538,11 @@ def trtllm_batch_context_with_kv_cache(
enable_pdl : Optional[bool] = None
Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization
Defaults to ``None``, which means it will be enabled if the device supports PDL.
batch_invariant : bool = False
Whether to disable multi-CTA optimization to ensure output is invariant to batch size.
When True, uses Persistent scheduler instead of Static scheduler. Note that this parameter
has no effect in context mode (context mode always uses Persistent scheduler).
Defaults to ``False``.
kv_layout : str = "HND"
Layout of kv-cache, can be "HND" or "NHD", default is "HND".
sinks : Optional[List[torch.Tensor]] = None
Expand Down Expand Up @@ -3671,6 +3679,7 @@ def trtllm_batch_context_with_kv_cache(
cum_seq_lens_kv,
sm_count,
enable_pdl,
batch_invariant,
workspace_size,
sinks,
)
Expand Down
4 changes: 4 additions & 0 deletions include/flashinfer/attention/blackwell/device/sm100_mla.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class MLA {
auto [H, K, D, B] = args.problem_shape;
int sm_count = args.hw_info.sm_count;
int max_splits = ceil_div(K, 128);
// NOTE: This heuristic depends on batch size B, which means the split count
// (and thus the reduction kernel behavior) varies with batch size.
// This is why batch_invariant flag may not guarantee full batch invariance
// for MLA attention - different batch sizes lead to different split counts.
int sms_per_batch = max(1, sm_count / B);
int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count);
Expand Down
Loading
Loading