From 36ebb643a35ec57cd4b3c6b5c99342944b308c46 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 08:32:30 +0000 Subject: [PATCH 1/5] feat: add batch_invariant option to both MLA and non-MLA trtllm decode Add batch_invariant parameter to trtllm_batch_decode_with_kv_cache_mla and trtllm_batch_decode_with_kv_cache that disables multi-CTA optimization in the generation kernel. This ensures output is invariant to batch size, allowing per-request processing without a for loop while maintaining consistent results. Changes: - Updated C++ launcher to accept batch_invariant parameter - Modified generation kernel to use: use_multi_block = !batch_invariant - Added batch_invariant parameter to both Python APIs with documentation - When batch_invariant=true, uses Persistent scheduler instead of Static Co-authored-by: Zihao Ye --- csrc/trtllm_fmha_kernel_launcher.cu | 17 ++++++++++------- flashinfer/decode.py | 8 ++++++++ flashinfer/mla.py | 7 +++++++ 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 3d5e8956e8..a3f0dc24fc 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -84,8 +84,8 @@ void trtllm_paged_attention_launcher( int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, - int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, int64_t workspace_size, - cudaStream_t stream) { + int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, bool batch_invariant, + int64_t workspace_size, cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -165,7 +165,7 @@ void trtllm_paged_attention_launcher( // one tokenQ in those cases, so dense mask works the same as causal mask. runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal; runner_params.mKernelType = FmhaKernelType::Generation; - bool use_multi_block = true; + bool use_multi_block = !batch_invariant; runner_params.mTileScheduler = use_multi_block ? TileScheduler::Static : TileScheduler::Persistent; runner_params.mMultiCtasKvMode = use_multi_block; @@ -226,7 +226,8 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size, int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, - int64_t workspace_size, Optional attention_sinks, + bool batch_invariant, int64_t workspace_size, + Optional attention_sinks, Optional cum_seq_lens_q) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); @@ -306,7 +307,7 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, - sparse_mla_top_k, sm_count, enable_pdl, workspace_size, stream); + sparse_mla_top_k, sm_count, enable_pdl, batch_invariant, workspace_size, stream); } void trtllm_paged_attention_context( @@ -316,7 +317,8 @@ void trtllm_paged_attention_context( Variant bmm1_scale, Variant bmm2_scale, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, - bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { + bool enable_pdl, bool batch_invariant, int64_t workspace_size, + Optional attention_sinks) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); @@ -390,7 +392,8 @@ void trtllm_paged_attention_context( head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, - sum_seq_q, /*sparse_mla_top_k=*/0, sm_count, enable_pdl, workspace_size, stream); + sum_seq_q, /*sparse_mla_top_k=*/0, sm_count, enable_pdl, batch_invariant, workspace_size, + stream); } void trtllm_ragged_attention_launcher( diff --git a/flashinfer/decode.py b/flashinfer/decode.py index a879caf338..137c6b2702 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2118,6 +2118,7 @@ def trtllm_batch_decode_with_kv_cache( sinks: Optional[List[torch.Tensor]] = None, kv_layout: str = "HND", enable_pdl: Optional[bool] = None, + batch_invariant: bool = False, backend: str = "auto", q_len_per_req: Optional[int] = 1, o_scale: Optional[float] = 1.0, @@ -2185,6 +2186,12 @@ def trtllm_batch_decode_with_kv_cache( Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization When set to ``None``, the backend will be chosen based on the device architecture and kernel availability. + batch_invariant : bool = False + When set to True, disables multi-CTA optimization in the generation kernel. + This ensures the output is invariant to batch size, allowing per-request + processing without a for loop while maintaining consistent results. + Only supported by trtllm-gen backend. Defaults to False. + backend : str = "auto" The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. @@ -2384,6 +2391,7 @@ def trtllm_batch_decode_with_kv_cache( 0, # sparse_mla_top_k sm_count, enable_pdl, + batch_invariant, workspace_buffer.numel() * workspace_buffer.element_size(), sinks, cum_seq_lens_q, diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 83415521c3..48c2dcfa7a 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -527,6 +527,7 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm2_scale: Union[float, torch.Tensor] = 1.0, sinks: Optional[List[torch.Tensor]] = None, enable_pdl: bool = None, + batch_invariant: bool = False, backend: str = "auto", ) -> torch.Tensor: """ @@ -548,6 +549,11 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm2_scale: fused scale for mla bmm2 input. when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. sinks: additional value per head in the denominator of the softmax. + batch_invariant : bool = False + When set to True, disables multi-CTA optimization in the generation kernel. + This ensures the output is invariant to batch size, allowing per-request + processing without a for loop while maintaining consistent results. + Only supported by trtllm-gen backend. Defaults to False. backend : str = "auto" The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. @@ -675,6 +681,7 @@ def trtllm_batch_decode_with_kv_cache_mla( sparse_mla_top_k, sm_count, enable_pdl, + batch_invariant, workspace_buffer.numel() * workspace_buffer.element_size(), sinks, None, # cum_seq_lens_q From e5478a30ae023f896df495b49cc8d75d721d6376 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 9 Jan 2026 08:37:19 +0000 Subject: [PATCH 2/5] lint --- csrc/trtllm_fmha_kernel_launcher.cu | 41 ++++++++++++++--------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index a3f0dc24fc..925f7f777b 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -217,18 +217,15 @@ inline Data_type dl_dtype_to_tllm_data_type(const DLDataType dtype) { inline bool is_4bit(Data_type data_type) { return data_type == Data_type::DATA_TYPE_E2M1; } -void trtllm_paged_attention_decode(TensorView out, Optional out_scale_factor, - TensorView query, TensorView key_cache, TensorView value_cache, - TensorView workspace_buffer, TensorView block_tables, - TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, - Variant bmm1_scale, - Variant bmm2_scale, double o_sf_scale, - int64_t o_sf_vec_size, int64_t o_sf_start_index, - int64_t batch_size, int64_t window_left, - int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, - bool batch_invariant, int64_t workspace_size, - Optional attention_sinks, - Optional cum_seq_lens_q) { +void trtllm_paged_attention_decode( + TensorView out, Optional out_scale_factor, TensorView query, TensorView key_cache, + TensorView value_cache, TensorView workspace_buffer, TensorView block_tables, + TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, + Variant bmm1_scale, Variant bmm2_scale, + double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size, + int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, + bool batch_invariant, int64_t workspace_size, Optional attention_sinks, + Optional cum_seq_lens_q) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -310,15 +307,17 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal sparse_mla_top_k, sm_count, enable_pdl, batch_invariant, workspace_size, stream); } -void trtllm_paged_attention_context( - TensorView out, Optional out_scale_factor, TensorView query, TensorView key_cache, - TensorView value_cache, TensorView workspace_buffer, TensorView block_tables, - TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, - Variant bmm1_scale, Variant bmm2_scale, - double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t batch_size, - int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, - bool enable_pdl, bool batch_invariant, int64_t workspace_size, - Optional attention_sinks) { +void trtllm_paged_attention_context(TensorView out, Optional out_scale_factor, + TensorView query, TensorView key_cache, TensorView value_cache, + TensorView workspace_buffer, TensorView block_tables, + TensorView seq_lens, int64_t max_q_len, int64_t max_kv_len, + Variant bmm1_scale, + Variant bmm2_scale, double o_sf_scale, + int64_t o_sf_vec_size, int64_t o_sf_start_index, + int64_t batch_size, int64_t window_left, + TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, + int64_t sm_count, bool enable_pdl, bool batch_invariant, + int64_t workspace_size, Optional attention_sinks) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); From 2bfde00d7c51a0c3da7522a8233b70d1960a7929 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Fri, 9 Jan 2026 08:49:20 +0000 Subject: [PATCH 3/5] Fix Python/C++ API consistency and add batch_invariant tests - Add batch_invariant parameter to low-level wrapper functions - flashinfer/decode.py: TrtllmGenDecodeModule._paged_run - flashinfer/prefill.py: get_trtllm_gen_prefill_module()._paged_run - Add batch_invariant parameter to high-level API - flashinfer/prefill.py: trtllm_batch_context_with_kv_cache - Document that batch_invariant has no effect in context mode - Add unit tests for batch_invariant in tests/attention/test_batch_invariant.py - Test non-MLA decode with batch_invariant - Test MLA decode with batch_invariant Fixes API consistency issues identified in code review Co-authored-by: Zihao Ye --- csrc/trtllm_fmha_kernel_launcher.cu | 2 + flashinfer/decode.py | 2 + flashinfer/prefill.py | 9 + tests/attention/test_batch_invariant.py | 303 ++++++++++++++++++++++++ 4 files changed, 316 insertions(+) create mode 100644 tests/attention/test_batch_invariant.py diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 925f7f777b..b8d897730a 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -151,6 +151,8 @@ void trtllm_paged_attention_launcher( AlignedAllocator float_allocator(workspace_buffer, workspace_size); if (mode == TllmPagedAttentionMode::Context) { + // Note: batch_invariant parameter has no effect in context mode, as context mode + // always uses Persistent scheduler and disables multi-CTA optimization. runner_params.mMaskType = TrtllmGenAttentionMaskType::Causal; runner_params.mKernelType = FmhaKernelType::Context; runner_params.mTileScheduler = TileScheduler::Persistent; diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 137c6b2702..9d66325e0b 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1918,6 +1918,7 @@ def _paged_run( workspace_size: int, window_left: int = -1, enable_pdl: bool = None, + batch_invariant: bool = False, out: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -1958,6 +1959,7 @@ def _paged_run( 0, # sparse_mla_top_k self._sm_count, enable_pdl, + batch_invariant, workspace_size, sinks, None, # cum_seq_lens_q diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index bfdcfd9048..624fc794ff 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -219,6 +219,7 @@ def _paged_run( cum_seq_lens_q: torch.Tensor, cum_seq_lens_kv: torch.Tensor, enable_pdl: bool, + batch_invariant: bool, workspace_size: int, window_left: int = -1, out: Optional[torch.Tensor] = None, @@ -254,6 +255,7 @@ def _paged_run( cum_seq_lens_kv, sm_count, enable_pdl, + batch_invariant, workspace_size, sinks, ) @@ -3486,6 +3488,7 @@ def trtllm_batch_context_with_kv_cache( o_sf_vec_size: Optional[int] = None, kv_layout: str = "HND", enable_pdl: Optional[bool] = None, + batch_invariant: bool = False, sinks: Optional[List[torch.Tensor]] = None, ) -> Union[torch.Tensor, FP4Tensor]: """ @@ -3535,6 +3538,11 @@ def trtllm_batch_context_with_kv_cache( enable_pdl : Optional[bool] = None Whether to enable Programmatic Dependent Launch (PDL). See https://docs.nvidia.com/cuda/cuda-c-programming-guide/#programmatic-dependent-launch-and-synchronization Defaults to ``None``, which means it will be enabled if the device supports PDL. + batch_invariant : bool = False + Whether to disable multi-CTA optimization to ensure output is invariant to batch size. + When True, uses Persistent scheduler instead of Static scheduler. Note that this parameter + has no effect in context mode (context mode always uses Persistent scheduler). + Defaults to ``False``. kv_layout : str = "HND" Layout of kv-cache, can be "HND" or "NHD", default is "HND". sinks : Optional[List[torch.Tensor]] = None @@ -3671,6 +3679,7 @@ def trtllm_batch_context_with_kv_cache( cum_seq_lens_kv, sm_count, enable_pdl, + batch_invariant, workspace_size, sinks, ) diff --git a/tests/attention/test_batch_invariant.py b/tests/attention/test_batch_invariant.py new file mode 100644 index 0000000000..2aed27ea70 --- /dev/null +++ b/tests/attention/test_batch_invariant.py @@ -0,0 +1,303 @@ +"""Tests for batch_invariant parameter in trtllm decode functions.""" +import pytest +import torch + +import flashinfer +from flashinfer.utils import get_compute_capability + +DTYPE_MAP = { + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +GPU_DEVICE = "cuda:0" + +global_trtllm_gen_fmha_workspace_buffer = None +workspace_size = 256 * 1024 * 1024 + + +def create_workspace_buffers(device): + """Create workspace buffers for testing.""" + global global_trtllm_gen_fmha_workspace_buffer + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, dtype=torch.int8, device=device + ) + return global_trtllm_gen_fmha_workspace_buffer + + +@pytest.mark.parametrize("kv_layout", ["HND"]) +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 8, 4), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("max_in_kv_len", [2048]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_batch_decode_batch_invariant( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + max_in_kv_len, + head_dim, +): + """Test that batch_invariant parameter produces consistent results across different batch sizes.""" + compute_capability = get_compute_capability(torch.device(device="cuda")) + + # trtllm-gen backend requires SM100 and SM103 GPUs + if compute_capability[0] != 10: + pytest.skip("trtllm-gen backend requires SM100 and SM103 GPUs.") + + torch.manual_seed(42) # Fixed seed for reproducibility + + num_qo_heads = num_kv_heads * head_grp_size + q_dtype_torch = DTYPE_MAP[q_dtype] + + # Create two simple requests with the same content + seq_len1 = 128 # Fixed KV seq length + + # Single request test + q_single = torch.randn(1, num_qo_heads, head_dim, device=GPU_DEVICE, dtype=q_dtype_torch) + seq_lens_single = torch.tensor([seq_len1], dtype=torch.int32, device=GPU_DEVICE) + + # Create KV cache for single request + num_pages_single = (seq_len1 + page_size - 1) // page_size + kv_cache_single = torch.randn( + num_pages_single, 2, num_kv_heads, page_size, head_dim, + device=GPU_DEVICE, dtype=q_dtype_torch + ) + page_table_single = torch.arange(num_pages_single, dtype=torch.int32, device=GPU_DEVICE).unsqueeze(0) + + workspace_buffer_single = create_workspace_buffers(GPU_DEVICE) + + sm_scale = float(1.0 / (head_dim**0.5)) + bmm1_scale = sm_scale + bmm2_scale = 1.0 + + # Run with batch_invariant=True for single request + output_single = flashinfer.decode.trtllm_batch_decode_with_kv_cache( + q_single, + kv_cache_single, + workspace_buffer_single, + page_table_single, + seq_lens_single, + seq_len1, + bmm1_scale, + bmm2_scale, + window_left, + kv_layout=kv_layout, + enable_pdl=enable_pdl, + backend="trtllm-gen", + batch_invariant=True, + ) + + # Now test with a batch containing the same request replicated + q_batch = q_single.repeat(batch_size, 1, 1) + seq_lens_batch = torch.full((batch_size,), seq_len1, dtype=torch.int32, device=GPU_DEVICE) + + # Create KV cache for batch (replicate the same pages) + kv_cache_batch = kv_cache_single.repeat(batch_size, 1, 1, 1, 1) + + # Create page table for batch + page_table_batch = torch.zeros(batch_size, num_pages_single, dtype=torch.int32, device=GPU_DEVICE) + for i in range(batch_size): + page_table_batch[i] = torch.arange( + i * num_pages_single, (i+1) * num_pages_single, dtype=torch.int32, device=GPU_DEVICE + ) + + workspace_buffer_batch = create_workspace_buffers(GPU_DEVICE) + + # Run with batch_invariant=True for batch + output_batch = flashinfer.decode.trtllm_batch_decode_with_kv_cache( + q_batch, + kv_cache_batch, + workspace_buffer_batch, + page_table_batch, + seq_lens_batch, + seq_len1, + bmm1_scale, + bmm2_scale, + window_left, + kv_layout=kv_layout, + enable_pdl=enable_pdl, + backend="trtllm-gen", + batch_invariant=True, + ) + + # Compare: the first output in the batch should match the single output + rtol, atol = 1e-3, 1e-3 # Tight tolerance since we expect identical results + + torch.testing.assert_close( + output_single[0], + output_batch[0], + rtol=rtol, + atol=atol, + msg="Output with batch_invariant=True should be identical for same request in different batch sizes" + ) + + # Also verify all batch outputs are identical (since we replicated the same request) + for i in range(1, batch_size): + torch.testing.assert_close( + output_batch[0], + output_batch[i], + rtol=rtol, + atol=atol, + msg=f"All outputs in batch should be identical when using same input (batch index {i})" + ) + + +@pytest.mark.parametrize("kv_layout", ["HND"]) +@pytest.mark.parametrize( + "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + [ + (4, 1, 16, 8, 4), + ], +) +@pytest.mark.parametrize("window_left", [-1]) +@pytest.mark.parametrize( + "q_dtype,kv_dtype,o_dtype", + [ + ("bf16", "bf16", "bf16"), + ], +) +@pytest.mark.parametrize("enable_pdl", [None]) +@pytest.mark.parametrize("max_in_kv_len", [2048]) +@pytest.mark.parametrize("head_dim", [128]) +def test_trtllm_mla_batch_decode_batch_invariant( + kv_layout, + batch_size, + q_len_per_req, + page_size, + num_kv_heads, + head_grp_size, + window_left, + q_dtype, + o_dtype, + kv_dtype, + enable_pdl, + max_in_kv_len, + head_dim, +): + """Test that batch_invariant parameter works for MLA decode functions.""" + compute_capability = get_compute_capability(torch.device(device="cuda")) + + # MLA requires SM100+ + if compute_capability[0] < 10: + pytest.skip("MLA attention requires SM100+ GPUs.") + + torch.manual_seed(42) + + num_qo_heads = num_kv_heads * head_grp_size + q_dtype_torch = DTYPE_MAP[q_dtype] + + # MLA uses different head dims + head_dim_qk = 192 + head_dim_vo = 128 + + seq_len1 = 128 + + # Single request + q_single = torch.randn(1, num_qo_heads, head_dim_qk, device=GPU_DEVICE, dtype=q_dtype_torch) + seq_lens_single = torch.tensor([seq_len1], dtype=torch.int32, device=GPU_DEVICE) + + num_pages_single = (seq_len1 + page_size - 1) // page_size + # For MLA: K has head_dim_qk, V has head_dim_vo + k_cache_single = torch.randn( + num_pages_single, num_kv_heads, page_size, head_dim_qk, + device=GPU_DEVICE, dtype=q_dtype_torch + ) + v_cache_single = torch.randn( + num_pages_single, num_kv_heads, page_size, head_dim_vo, + device=GPU_DEVICE, dtype=q_dtype_torch + ) + page_table_single = torch.arange(num_pages_single, dtype=torch.int32, device=GPU_DEVICE).unsqueeze(0) + + workspace_buffer_single = create_workspace_buffers(GPU_DEVICE) + + sm_scale = float(1.0 / (head_dim_qk**0.5)) + bmm1_scale = sm_scale + bmm2_scale = 1.0 + + # Run with batch_invariant=True for single request + output_single = flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla( + q_single, + (k_cache_single, v_cache_single), + workspace_buffer_single, + page_table_single, + seq_lens_single, + seq_len1, + bmm1_scale, + bmm2_scale, + window_left, + kv_layout=kv_layout, + enable_pdl=enable_pdl, + backend="trtllm-gen", + batch_invariant=True, + ) + + # Batch test + q_batch = q_single.repeat(batch_size, 1, 1) + seq_lens_batch = torch.full((batch_size,), seq_len1, dtype=torch.int32, device=GPU_DEVICE) + + k_cache_batch = k_cache_single.repeat(batch_size, 1, 1, 1) + v_cache_batch = v_cache_single.repeat(batch_size, 1, 1, 1) + + page_table_batch = torch.zeros(batch_size, num_pages_single, dtype=torch.int32, device=GPU_DEVICE) + for i in range(batch_size): + page_table_batch[i] = torch.arange( + i * num_pages_single, (i+1) * num_pages_single, dtype=torch.int32, device=GPU_DEVICE + ) + + workspace_buffer_batch = create_workspace_buffers(GPU_DEVICE) + + output_batch = flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla( + q_batch, + (k_cache_batch, v_cache_batch), + workspace_buffer_batch, + page_table_batch, + seq_lens_batch, + seq_len1, + bmm1_scale, + bmm2_scale, + window_left, + kv_layout=kv_layout, + enable_pdl=enable_pdl, + backend="trtllm-gen", + batch_invariant=True, + ) + + rtol, atol = 1e-3, 1e-3 + + torch.testing.assert_close( + output_single[0], + output_batch[0], + rtol=rtol, + atol=atol, + msg="MLA output with batch_invariant=True should be identical for same request in different batch sizes" + ) + + for i in range(1, batch_size): + torch.testing.assert_close( + output_batch[0], + output_batch[i], + rtol=rtol, + atol=atol, + msg=f"All MLA outputs in batch should be identical when using same input (batch index {i})" + ) From 4ef50039780425783b1dfa835c31b23530e5dbf5 Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Sat, 10 Jan 2026 04:23:13 +0000 Subject: [PATCH 4/5] Fix formatting and linting issues in test_batch_invariant.py - Remove unused parametrized arguments (q_len_per_req, o_dtype, kv_dtype, max_in_kv_len) - Fix compute capability device check to use GPU_DEVICE constant - Apply proper code formatting (line breaks, trailing commas, spacing) Co-authored-by: Zihao Ye --- tests/attention/test_batch_invariant.py | 113 +++++++++++++----------- 1 file changed, 63 insertions(+), 50 deletions(-) diff --git a/tests/attention/test_batch_invariant.py b/tests/attention/test_batch_invariant.py index 2aed27ea70..f04bfba1e3 100644 --- a/tests/attention/test_batch_invariant.py +++ b/tests/attention/test_batch_invariant.py @@ -28,38 +28,28 @@ def create_workspace_buffers(device): @pytest.mark.parametrize("kv_layout", ["HND"]) @pytest.mark.parametrize( - "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + "batch_size,page_size,num_kv_heads,head_grp_size", [ - (4, 1, 16, 8, 4), + (4, 16, 8, 4), ], ) @pytest.mark.parametrize("window_left", [-1]) -@pytest.mark.parametrize( - "q_dtype,kv_dtype,o_dtype", - [ - ("bf16", "bf16", "bf16"), - ], -) +@pytest.mark.parametrize("q_dtype", ["bf16"]) @pytest.mark.parametrize("enable_pdl", [None]) -@pytest.mark.parametrize("max_in_kv_len", [2048]) @pytest.mark.parametrize("head_dim", [128]) def test_trtllm_batch_decode_batch_invariant( kv_layout, batch_size, - q_len_per_req, page_size, num_kv_heads, head_grp_size, window_left, q_dtype, - o_dtype, - kv_dtype, enable_pdl, - max_in_kv_len, head_dim, ): """Test that batch_invariant parameter produces consistent results across different batch sizes.""" - compute_capability = get_compute_capability(torch.device(device="cuda")) + compute_capability = get_compute_capability(torch.device(GPU_DEVICE)) # trtllm-gen backend requires SM100 and SM103 GPUs if compute_capability[0] != 10: @@ -74,16 +64,25 @@ def test_trtllm_batch_decode_batch_invariant( seq_len1 = 128 # Fixed KV seq length # Single request test - q_single = torch.randn(1, num_qo_heads, head_dim, device=GPU_DEVICE, dtype=q_dtype_torch) + q_single = torch.randn( + 1, num_qo_heads, head_dim, device=GPU_DEVICE, dtype=q_dtype_torch + ) seq_lens_single = torch.tensor([seq_len1], dtype=torch.int32, device=GPU_DEVICE) # Create KV cache for single request num_pages_single = (seq_len1 + page_size - 1) // page_size kv_cache_single = torch.randn( - num_pages_single, 2, num_kv_heads, page_size, head_dim, - device=GPU_DEVICE, dtype=q_dtype_torch + num_pages_single, + 2, + num_kv_heads, + page_size, + head_dim, + device=GPU_DEVICE, + dtype=q_dtype_torch, ) - page_table_single = torch.arange(num_pages_single, dtype=torch.int32, device=GPU_DEVICE).unsqueeze(0) + page_table_single = torch.arange( + num_pages_single, dtype=torch.int32, device=GPU_DEVICE + ).unsqueeze(0) workspace_buffer_single = create_workspace_buffers(GPU_DEVICE) @@ -110,16 +109,23 @@ def test_trtllm_batch_decode_batch_invariant( # Now test with a batch containing the same request replicated q_batch = q_single.repeat(batch_size, 1, 1) - seq_lens_batch = torch.full((batch_size,), seq_len1, dtype=torch.int32, device=GPU_DEVICE) + seq_lens_batch = torch.full( + (batch_size,), seq_len1, dtype=torch.int32, device=GPU_DEVICE + ) # Create KV cache for batch (replicate the same pages) kv_cache_batch = kv_cache_single.repeat(batch_size, 1, 1, 1, 1) # Create page table for batch - page_table_batch = torch.zeros(batch_size, num_pages_single, dtype=torch.int32, device=GPU_DEVICE) + page_table_batch = torch.zeros( + batch_size, num_pages_single, dtype=torch.int32, device=GPU_DEVICE + ) for i in range(batch_size): page_table_batch[i] = torch.arange( - i * num_pages_single, (i+1) * num_pages_single, dtype=torch.int32, device=GPU_DEVICE + i * num_pages_single, + (i + 1) * num_pages_single, + dtype=torch.int32, + device=GPU_DEVICE, ) workspace_buffer_batch = create_workspace_buffers(GPU_DEVICE) @@ -149,7 +155,7 @@ def test_trtllm_batch_decode_batch_invariant( output_batch[0], rtol=rtol, atol=atol, - msg="Output with batch_invariant=True should be identical for same request in different batch sizes" + msg="Output with batch_invariant=True should be identical for same request in different batch sizes", ) # Also verify all batch outputs are identical (since we replicated the same request) @@ -159,44 +165,32 @@ def test_trtllm_batch_decode_batch_invariant( output_batch[i], rtol=rtol, atol=atol, - msg=f"All outputs in batch should be identical when using same input (batch index {i})" + msg=f"All outputs in batch should be identical when using same input (batch index {i})", ) @pytest.mark.parametrize("kv_layout", ["HND"]) @pytest.mark.parametrize( - "batch_size,q_len_per_req,page_size,num_kv_heads,head_grp_size", + "batch_size,page_size,num_kv_heads,head_grp_size", [ - (4, 1, 16, 8, 4), + (4, 16, 8, 4), ], ) @pytest.mark.parametrize("window_left", [-1]) -@pytest.mark.parametrize( - "q_dtype,kv_dtype,o_dtype", - [ - ("bf16", "bf16", "bf16"), - ], -) +@pytest.mark.parametrize("q_dtype", ["bf16"]) @pytest.mark.parametrize("enable_pdl", [None]) -@pytest.mark.parametrize("max_in_kv_len", [2048]) -@pytest.mark.parametrize("head_dim", [128]) def test_trtllm_mla_batch_decode_batch_invariant( kv_layout, batch_size, - q_len_per_req, page_size, num_kv_heads, head_grp_size, window_left, q_dtype, - o_dtype, - kv_dtype, enable_pdl, - max_in_kv_len, - head_dim, ): """Test that batch_invariant parameter works for MLA decode functions.""" - compute_capability = get_compute_capability(torch.device(device="cuda")) + compute_capability = get_compute_capability(torch.device(GPU_DEVICE)) # MLA requires SM100+ if compute_capability[0] < 10: @@ -214,20 +208,32 @@ def test_trtllm_mla_batch_decode_batch_invariant( seq_len1 = 128 # Single request - q_single = torch.randn(1, num_qo_heads, head_dim_qk, device=GPU_DEVICE, dtype=q_dtype_torch) + q_single = torch.randn( + 1, num_qo_heads, head_dim_qk, device=GPU_DEVICE, dtype=q_dtype_torch + ) seq_lens_single = torch.tensor([seq_len1], dtype=torch.int32, device=GPU_DEVICE) num_pages_single = (seq_len1 + page_size - 1) // page_size # For MLA: K has head_dim_qk, V has head_dim_vo k_cache_single = torch.randn( - num_pages_single, num_kv_heads, page_size, head_dim_qk, - device=GPU_DEVICE, dtype=q_dtype_torch + num_pages_single, + num_kv_heads, + page_size, + head_dim_qk, + device=GPU_DEVICE, + dtype=q_dtype_torch, ) v_cache_single = torch.randn( - num_pages_single, num_kv_heads, page_size, head_dim_vo, - device=GPU_DEVICE, dtype=q_dtype_torch + num_pages_single, + num_kv_heads, + page_size, + head_dim_vo, + device=GPU_DEVICE, + dtype=q_dtype_torch, ) - page_table_single = torch.arange(num_pages_single, dtype=torch.int32, device=GPU_DEVICE).unsqueeze(0) + page_table_single = torch.arange( + num_pages_single, dtype=torch.int32, device=GPU_DEVICE + ).unsqueeze(0) workspace_buffer_single = create_workspace_buffers(GPU_DEVICE) @@ -254,15 +260,22 @@ def test_trtllm_mla_batch_decode_batch_invariant( # Batch test q_batch = q_single.repeat(batch_size, 1, 1) - seq_lens_batch = torch.full((batch_size,), seq_len1, dtype=torch.int32, device=GPU_DEVICE) + seq_lens_batch = torch.full( + (batch_size,), seq_len1, dtype=torch.int32, device=GPU_DEVICE + ) k_cache_batch = k_cache_single.repeat(batch_size, 1, 1, 1) v_cache_batch = v_cache_single.repeat(batch_size, 1, 1, 1) - page_table_batch = torch.zeros(batch_size, num_pages_single, dtype=torch.int32, device=GPU_DEVICE) + page_table_batch = torch.zeros( + batch_size, num_pages_single, dtype=torch.int32, device=GPU_DEVICE + ) for i in range(batch_size): page_table_batch[i] = torch.arange( - i * num_pages_single, (i+1) * num_pages_single, dtype=torch.int32, device=GPU_DEVICE + i * num_pages_single, + (i + 1) * num_pages_single, + dtype=torch.int32, + device=GPU_DEVICE, ) workspace_buffer_batch = create_workspace_buffers(GPU_DEVICE) @@ -290,7 +303,7 @@ def test_trtllm_mla_batch_decode_batch_invariant( output_batch[0], rtol=rtol, atol=atol, - msg="MLA output with batch_invariant=True should be identical for same request in different batch sizes" + msg="MLA output with batch_invariant=True should be identical for same request in different batch sizes", ) for i in range(1, batch_size): @@ -299,5 +312,5 @@ def test_trtllm_mla_batch_decode_batch_invariant( output_batch[i], rtol=rtol, atol=atol, - msg=f"All MLA outputs in batch should be identical when using same input (batch index {i})" + msg=f"All MLA outputs in batch should be identical when using same input (batch index {i})", ) From aec0e1f1d2ba2c4b65e9386136caee6b826f3c7b Mon Sep 17 00:00:00 2001 From: "claude[bot]" <41898282+claude[bot]@users.noreply.github.com> Date: Sun, 25 Jan 2026 23:54:21 +0000 Subject: [PATCH 5/5] docs: clarify MLA batch_invariant limitation due to split-KV reduction The batch_invariant flag disables multi-CTA in the main generation kernel, but MLA attention uses an additional reduction kernel that combines partial results from split-KV optimization. The split count heuristic depends on batch size (split_kv ~ sm_count / batch_size), which means different batch sizes may still produce different numerical results due to different reduction patterns. Added documentation in: - Python docstring for trtllm_batch_decode_with_kv_cache_mla - C++ comments in set_split_kv function Co-authored-by: Zihao Ye --- flashinfer/mla.py | 8 ++++++++ .../flashinfer/attention/blackwell/device/sm100_mla.hpp | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 48c2dcfa7a..3a9e17ecb8 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -554,6 +554,14 @@ def trtllm_batch_decode_with_kv_cache_mla( This ensures the output is invariant to batch size, allowing per-request processing without a for loop while maintaining consistent results. Only supported by trtllm-gen backend. Defaults to False. + + **Important**: For MLA attention, batch invariance may not be fully guaranteed + even with this flag enabled. The MLA implementation uses a reduction kernel + that combines partial results from split-KV optimization, and the number of + splits is determined by a heuristic that depends on batch size + (split_kv ~ sm_count / batch_size). This means different batch sizes may still + produce slightly different numerical results due to different reduction patterns, + even though multi-CTA is disabled in the main generation kernel. backend : str = "auto" The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. diff --git a/include/flashinfer/attention/blackwell/device/sm100_mla.hpp b/include/flashinfer/attention/blackwell/device/sm100_mla.hpp index 9f72482569..7eaa593fef 100644 --- a/include/flashinfer/attention/blackwell/device/sm100_mla.hpp +++ b/include/flashinfer/attention/blackwell/device/sm100_mla.hpp @@ -116,6 +116,10 @@ class MLA { auto [H, K, D, B] = args.problem_shape; int sm_count = args.hw_info.sm_count; int max_splits = ceil_div(K, 128); + // NOTE: This heuristic depends on batch size B, which means the split count + // (and thus the reduction kernel behavior) varies with batch size. + // This is why batch_invariant flag may not guarantee full batch invariance + // for MLA attention - different batch sizes lead to different split counts. int sms_per_batch = max(1, sm_count / B); int split_heur = min(max_splits, sms_per_batch); int waves = ceil_div(B * split_heur, sm_count);