From 1cf52e1af73cf232f12da8e553fad2eca1211560 Mon Sep 17 00:00:00 2001 From: Dom Brown <3886319+DomBrown@users.noreply.github.com> Date: Thu, 12 Feb 2026 16:30:46 +0000 Subject: [PATCH 1/5] Support skip softmax for deepseek prefill --- csrc/trtllm_fmha_kernel_launcher.cu | 18 +++++++++++++++--- flashinfer/prefill.py | 8 ++++++++ tests/attention/test_trtllm_gen_attention.py | 13 ++++++++++--- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index d4737c9aa1..663e8a61cb 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -417,7 +417,8 @@ void trtllm_ragged_attention_launcher( int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal, int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch, int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch, - int64_t workspace_size, cudaStream_t stream) { + float skip_softmax_threshold_scale_factor, bool skips_softmax, int64_t workspace_size, + cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -489,6 +490,9 @@ void trtllm_ragged_attention_launcher( runner_params.multiCtasKvScratchPtr = float_allocator.aligned_alloc(0, 16, "trtllm_gen_scratch_workspace"); + runner_params.mSkipsSoftmaxWhenPossible = skips_softmax; + runner_params.mSkipSoftmaxThresholdScaleFactor = skip_softmax_threshold_scale_factor; + auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params); if (!foundKernels) { std::ostringstream err_msg; @@ -506,7 +510,9 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, bool is_causal, int64_t workspace_size, - Optional attention_sinks, Optional lse) { + Optional attention_sinks, + Optional skip_softmax_threshold_scale_factor, + Optional lse) { float* attention_sinks_ptr = nullptr; if (attention_sinks.has_value()) { TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32) @@ -559,6 +565,11 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value() ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) : nullptr; + + bool const skips_softmax = skip_softmax_threshold_scale_factor.has_value(); + float const skip_softmax_threshold_scale_factor_value = + skip_softmax_threshold_scale_factor.value_or(0.0f); + trtllm_ragged_attention_launcher( out.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(), workspace_buffer.data_ptr(), static_cast(seq_lens.data_ptr()), @@ -567,7 +578,8 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch, - v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream); + v_stride_keys_values, v_stride_heads, v_stride_batch, + skip_softmax_threshold_scale_factor_value, skips_softmax, workspace_size, stream); } namespace trtllm_cubin_loader { diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 618f6fbbf6..1154db998f 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3434,6 +3434,7 @@ def trtllm_ragged_attention_deepseek( is_causal: bool, return_lse: bool, attention_sinks: Optional[torch.Tensor] = None, + skip_softmax_threshold_scale_factor: Optional[float] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -3476,6 +3477,12 @@ def trtllm_ragged_attention_deepseek( is causal attention_sinks : Optional[torch.Tensor] attention sinks + skip_softmax_threshold_scale_factor : Optional[float] + threshold scale factor for skipping softmax operations. + Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 + If no value is provided, then standard attention is used. + Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. + The actual threshold value equals the provided threshold_scale_factor divided by the context length. out : Optional[torch.Tensor] output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]] lse : Optional[torch.Tensor] @@ -3541,6 +3548,7 @@ def trtllm_ragged_attention_deepseek( is_causal, workspace_size, attention_sinks, + skip_softmax_threshold_scale_factor, lse, ) if return_lse: diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index ad3f077676..d3b8a4df98 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1383,8 +1383,9 @@ def test_trtllm_batch_decode_long_sequence_length( @pytest.mark.parametrize("num_kv_heads", [16, 32]) @pytest.mark.parametrize("head_grp_size", [1, 5, 8]) @pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_gen_prefill_deepseek( - batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal + batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: @@ -1473,6 +1474,10 @@ def test_trtllm_gen_prefill_deepseek( bmm1_scale = scale bmm2_scale = 1.0 + + # Using 0.0 threshold should give the same result as normal attention. + skip_softmax_threshold_scale_factor = 0.0 if skips_softmax else None + output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( q, k_cache, @@ -1491,6 +1496,7 @@ def test_trtllm_gen_prefill_deepseek( False, causal, True, + skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, out=output, ) torch.testing.assert_close( @@ -1516,11 +1522,12 @@ def test_trtllm_gen_prefill_deepseek( @pytest.mark.parametrize("num_kv_heads", [128]) @pytest.mark.parametrize("head_grp_size", [1]) @pytest.mark.parametrize("causal", [True, False]) +@pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_gen_prefill_deepseek_bs1( - batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal + batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax ): test_trtllm_gen_prefill_deepseek( - batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal + batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax ) From e077fd564aae5a82a181ee2e300da137fbf4e1d2 Mon Sep 17 00:00:00 2001 From: Dom Brown <3886319+DomBrown@users.noreply.github.com> Date: Thu, 12 Feb 2026 16:49:38 +0000 Subject: [PATCH 2/5] test_trtllm_batch_decode_mla enabled and passing --- flashinfer/mla.py | 8 +++++++- tests/attention/test_trtllm_gen_mla.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 4c7deb9a49..d79fd055b6 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -534,6 +534,7 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm1_scale: Union[float, torch.Tensor] = 1.0, bmm2_scale: Union[float, torch.Tensor] = 1.0, sinks: Optional[List[torch.Tensor]] = None, + skip_softmax_threshold_scale_factor: Optional[float] = None, enable_pdl: bool = None, backend: str = "auto", ) -> torch.Tensor: @@ -556,6 +557,11 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm2_scale: fused scale for mla bmm2 input. when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32. sinks: additional value per head in the denominator of the softmax. + skip_softmax_threshold_scale_factor: threshold scale factor for skipping softmax operations. + Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087 + If no value is provided, then standard attention is used. + Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. + The actual threshold value equals the provided threshold_scale_factor divided by the context length. backend : str = "auto" The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``. When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. @@ -688,7 +694,7 @@ def trtllm_batch_decode_with_kv_cache_mla( workspace_buffer.numel() * workspace_buffer.element_size(), sinks, None, # cum_seq_lens_q - None, # skip_softmax_threshold_scale_factor + skip_softmax_threshold_scale_factor, ) return out diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index d71e8cb386..fec54d311e 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -219,6 +219,7 @@ def trtllm_batch_decode_mla( enable_pdl: bool, backend: str, MAX_SEQ_LEN: int, + skips_softmax: bool, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if backend == "xqa": @@ -234,6 +235,9 @@ def trtllm_batch_decode_mla( if dynamic_scale and dtype != torch.float8_e4m3fn: pytest.skip("Dynamic scale is not supported for non-fp8 dtype") + if skips_softmax and backend != "trtllm-gen": + pytest.skip("skips_softmax is only supported for trtllm-gen backend") + torch.manual_seed(42) device = "cuda:0" @@ -306,6 +310,9 @@ def trtllm_batch_decode_mla( workspace_buffer = global_trtllm_gen_fmha_workspace_buffer workspace_buffer_ref = global_workspace_buffer + # Threshold 0.0 should give the same output as standard attention + skip_softmax_threshold_scale_factor = 0.0 if skips_softmax else None + # Run decode-MLA output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, @@ -319,6 +326,7 @@ def trtllm_batch_decode_mla( max_seq_len=max_seq_len, bmm1_scale=scale / ((128 + 64) ** 0.5), bmm2_scale=1.0, + skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, enable_pdl=enable_pdl, backend=backend, ) @@ -439,6 +447,7 @@ def trtllm_batch_decode_mla( @pytest.mark.parametrize("dynamic_scale", [False]) @pytest.mark.parametrize("enable_pdl", [True, False, None]) @pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) +@pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_decode_mla( batch_size: int, scale: float, @@ -448,6 +457,7 @@ def test_trtllm_batch_decode_mla( dynamic_scale: bool, enable_pdl: bool, backend: str, + skips_softmax: bool, ): trtllm_batch_decode_mla( batch_size, @@ -459,6 +469,7 @@ def test_trtllm_batch_decode_mla( enable_pdl, backend, 1024, + skips_softmax, ) @@ -495,6 +506,7 @@ def test_dsr1_trtllm_mla( enable_pdl, backend, MAX_SEQ_LEN, + False, # skips_softmax ) From df4b9be11318029ebbd1218696ea7a608f5d89fb Mon Sep 17 00:00:00 2001 From: Dom Brown <3886319+DomBrown@users.noreply.github.com> Date: Thu, 12 Feb 2026 17:05:38 +0000 Subject: [PATCH 3/5] test_dsr1_trtllm_mla enabled and passing --- tests/attention/test_trtllm_gen_mla.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index fec54d311e..438d44ccbf 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -485,6 +485,7 @@ def test_trtllm_batch_decode_mla( @pytest.mark.parametrize("enable_pdl", [True, False, None]) @pytest.mark.parametrize("backend", ["trtllm-gen"]) @pytest.mark.parametrize("MAX_SEQ_LEN", [1024, 8960]) +@pytest.mark.parametrize("skips_softmax", [False, True]) def test_dsr1_trtllm_mla( batch_size: int, scale: float, @@ -495,6 +496,7 @@ def test_dsr1_trtllm_mla( enable_pdl: bool, backend: str, MAX_SEQ_LEN: int, + skips_softmax: bool, ): trtllm_batch_decode_mla( batch_size, @@ -506,7 +508,7 @@ def test_dsr1_trtllm_mla( enable_pdl, backend, MAX_SEQ_LEN, - False, # skips_softmax + skips_softmax, ) From 6a88ecc17df8d34f0d7ef566044672dddf0ebe01 Mon Sep 17 00:00:00 2001 From: Dom Brown <3886319+DomBrown@users.noreply.github.com> Date: Thu, 12 Feb 2026 17:41:37 +0000 Subject: [PATCH 4/5] Assertions --- flashinfer/mla.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index d79fd055b6..3adda40789 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -611,6 +611,8 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError( f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" ) + if skip_softmax_threshold_scale_factor is not None: + raise ValueError("skip_softmax is not supported for XQA backend") return xqa_batch_decode_with_kv_cache_mla( query, kv_cache, @@ -641,6 +643,9 @@ def trtllm_batch_decode_with_kv_cache_mla( ): # todo(Yingyi): add support for more block sizes? raise ValueError(f"Supported block_size are 32 and 64, got {block_size}") + if skip_softmax_threshold_scale_factor is not None and sparse_mla_top_k != 0: + raise ValueError("skip_softmax is not supported for sparse MLA") + # Validate and normalize to 4D kv_cache = _check_trtllm_gen_mla_shape( query, From bb7ea749a3f3636637fc15c7e33d363c2504fe49 Mon Sep 17 00:00:00 2001 From: Dom Brown <3886319+DomBrown@users.noreply.github.com> Date: Fri, 13 Feb 2026 16:52:49 +0000 Subject: [PATCH 5/5] Short-circuit to normal attention kernels when threshold is zero to reduce overheads --- csrc/trtllm_fmha_kernel_launcher.cu | 9 ++++++--- tests/attention/test_trtllm_gen_attention.py | 12 ++++++------ tests/attention/test_trtllm_gen_mla.py | 4 ++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 663e8a61cb..6aea76ee58 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -299,9 +299,10 @@ void trtllm_paged_attention_decode( ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) : nullptr; - bool const skips_softmax = skip_softmax_threshold_scale_factor.has_value(); + // If threshold is zero we can fall back to standard attention to reduce overheads. float const skip_softmax_threshold_scale_factor_value = skip_softmax_threshold_scale_factor.value_or(0.0f); + bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), @@ -388,9 +389,10 @@ void trtllm_paged_attention_context( ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) : nullptr; - bool const skips_softmax = skip_softmax_threshold_scale_factor.has_value(); + // If threshold is zero we can fall back to standard attention to reduce overheads. float const skip_softmax_threshold_scale_factor_value = skip_softmax_threshold_scale_factor.value_or(0.0f); + bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), @@ -566,9 +568,10 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T ? static_cast(maybe_bmm2_scale_tensor.value().data_ptr()) : nullptr; - bool const skips_softmax = skip_softmax_threshold_scale_factor.has_value(); + // If threshold is zero we can fall back to standard attention to reduce overheads. float const skip_softmax_threshold_scale_factor_value = skip_softmax_threshold_scale_factor.value_or(0.0f); + bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; trtllm_ragged_attention_launcher( out.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(), diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index d3b8a4df98..50f5bc9d1e 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -544,8 +544,8 @@ def _test_trtllm_batch_prefill( else: q_input = q.contiguous() - # Using 0.0 threshold should give the same result as normal attention. - skip_softmax_threshold_scale_factor = 0.0 if skips_softmax else None + # Using a tiny threshold should give the same result as normal attention. + skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q_input, @@ -959,8 +959,8 @@ def _test_trtllm_batch_decode( else: q_input = q.contiguous() - # Using 0.0 threshold should give the same result as normal attention. - skip_softmax_threshold_scale_factor = 0.0 if skips_softmax else None + # Using a tiny threshold should give the same result as normal attention. + skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( q_input, @@ -1475,8 +1475,8 @@ def test_trtllm_gen_prefill_deepseek( bmm1_scale = scale bmm2_scale = 1.0 - # Using 0.0 threshold should give the same result as normal attention. - skip_softmax_threshold_scale_factor = 0.0 if skips_softmax else None + # Using a tiny threshold should give the same result as normal attention. + skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek( q, diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 438d44ccbf..bd0ee03684 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -310,8 +310,8 @@ def trtllm_batch_decode_mla( workspace_buffer = global_trtllm_gen_fmha_workspace_buffer workspace_buffer_ref = global_workspace_buffer - # Threshold 0.0 should give the same output as standard attention - skip_softmax_threshold_scale_factor = 0.0 if skips_softmax else None + # Using a tiny threshold should give the same output as standard attention + skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None # Run decode-MLA output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(