Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,10 @@ void trtllm_paged_attention_decode(
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;

bool const skips_softmax = skip_softmax_threshold_scale_factor.has_value();
// If threshold is zero we can fall back to standard attention to reduce overheads.
float const skip_softmax_threshold_scale_factor_value =
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
Expand Down Expand Up @@ -388,9 +389,10 @@ void trtllm_paged_attention_context(
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;

bool const skips_softmax = skip_softmax_threshold_scale_factor.has_value();
// If threshold is zero we can fall back to standard attention to reduce overheads.
float const skip_softmax_threshold_scale_factor_value =
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
Expand All @@ -417,7 +419,8 @@ void trtllm_ragged_attention_launcher(
int64_t batch_size, int64_t window_left, int64_t sm_count, bool enable_pdl, bool is_causal,
int64_t k_stride_keys_values, int64_t k_stride_heads, int64_t k_stride_batch,
int64_t v_stride_keys_values, int64_t v_stride_heads, int64_t v_stride_batch,
int64_t workspace_size, cudaStream_t stream) {
float skip_softmax_threshold_scale_factor, bool skips_softmax, int64_t workspace_size,
cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -489,6 +492,9 @@ void trtllm_ragged_attention_launcher(
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");

runner_params.mSkipsSoftmaxWhenPossible = skips_softmax;
runner_params.mSkipSoftmaxThresholdScaleFactor = skip_softmax_threshold_scale_factor;

auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params);
if (!foundKernels) {
std::ostringstream err_msg;
Expand All @@ -506,7 +512,9 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
int64_t batch_size, int64_t window_left, TensorView cum_seq_lens_q,
TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl,
bool is_causal, int64_t workspace_size,
Optional<TensorView> attention_sinks, Optional<TensorView> lse) {
Optional<TensorView> attention_sinks,
Optional<float> skip_softmax_threshold_scale_factor,
Optional<TensorView> lse) {
float* attention_sinks_ptr = nullptr;
if (attention_sinks.has_value()) {
TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32)
Expand Down Expand Up @@ -559,6 +567,12 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
float* bmm2_scale_ptr = maybe_bmm2_scale_tensor.has_value()
? static_cast<float*>(maybe_bmm2_scale_tensor.value().data_ptr())
: nullptr;

// If threshold is zero we can fall back to standard attention to reduce overheads.
float const skip_softmax_threshold_scale_factor_value =
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

trtllm_ragged_attention_launcher(
out.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(seq_lens.data_ptr()),
Expand All @@ -567,7 +581,8 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T
num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale_value,
bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, batch_size, window_left,
sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, k_stride_batch,
v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream);
v_stride_keys_values, v_stride_heads, v_stride_batch,
skip_softmax_threshold_scale_factor_value, skips_softmax, workspace_size, stream);
}

namespace trtllm_cubin_loader {
Expand Down
13 changes: 12 additions & 1 deletion flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
bmm1_scale: Union[float, torch.Tensor] = 1.0,
bmm2_scale: Union[float, torch.Tensor] = 1.0,
sinks: Optional[List[torch.Tensor]] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
enable_pdl: bool = None,
backend: str = "auto",
) -> torch.Tensor:
Expand All @@ -556,6 +557,11 @@ def trtllm_batch_decode_with_kv_cache_mla(
bmm2_scale: fused scale for mla bmm2 input.
when using trtllm-gen backend, it can be a torch.Tensor with dtype torch.float32.
sinks: additional value per head in the denominator of the softmax.
skip_softmax_threshold_scale_factor: threshold scale factor for skipping softmax operations.
Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087
If no value is provided, then standard attention is used.
Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation.
The actual threshold value equals the provided threshold_scale_factor divided by the context length.
backend : str = "auto"
The implementation backend, could be ``auto``/``xqa`` or ``trtllm-gen``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
Expand Down Expand Up @@ -605,6 +611,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
raise ValueError(
f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}"
)
if skip_softmax_threshold_scale_factor is not None:
raise ValueError("skip_softmax is not supported for XQA backend")
return xqa_batch_decode_with_kv_cache_mla(
query,
kv_cache,
Expand Down Expand Up @@ -635,6 +643,9 @@ def trtllm_batch_decode_with_kv_cache_mla(
): # todo(Yingyi): add support for more block sizes?
raise ValueError(f"Supported block_size are 32 and 64, got {block_size}")

if skip_softmax_threshold_scale_factor is not None and sparse_mla_top_k != 0:
raise ValueError("skip_softmax is not supported for sparse MLA")

# Validate and normalize to 4D
kv_cache = _check_trtllm_gen_mla_shape(
query,
Expand Down Expand Up @@ -688,7 +699,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
None, # cum_seq_lens_q
None, # skip_softmax_threshold_scale_factor
skip_softmax_threshold_scale_factor,
)

return out
Expand Down
8 changes: 8 additions & 0 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3434,6 +3434,7 @@ def trtllm_ragged_attention_deepseek(
is_causal: bool,
return_lse: bool,
attention_sinks: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
out: Optional[torch.Tensor] = None,
lse: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Expand Down Expand Up @@ -3476,6 +3477,12 @@ def trtllm_ragged_attention_deepseek(
is causal
attention_sinks : Optional[torch.Tensor]
attention sinks
skip_softmax_threshold_scale_factor : Optional[float]
threshold scale factor for skipping softmax operations.
Providing a value for this parameter enables skip-softmax sparsity as described in: https://arxiv.org/abs/2512.12087
If no value is provided, then standard attention is used.
Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation.
The actual threshold value equals the provided threshold_scale_factor divided by the context length.
out : Optional[torch.Tensor]
output tensor, if not provided, will be allocated with shape [query.shape[0], query.shape[1], value.shape[2]]
lse : Optional[torch.Tensor]
Expand Down Expand Up @@ -3541,6 +3548,7 @@ def trtllm_ragged_attention_deepseek(
is_causal,
workspace_size,
attention_sinks,
skip_softmax_threshold_scale_factor,
lse,
)
if return_lse:
Expand Down
21 changes: 14 additions & 7 deletions tests/attention/test_trtllm_gen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ def _test_trtllm_batch_prefill(
else:
q_input = q.contiguous()

# Using 0.0 threshold should give the same result as normal attention.
skip_softmax_threshold_scale_factor = 0.0 if skips_softmax else None
# Using a tiny threshold should give the same result as normal attention.
skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None

output = flashinfer.prefill.trtllm_batch_context_with_kv_cache(
q_input,
Expand Down Expand Up @@ -959,8 +959,8 @@ def _test_trtllm_batch_decode(
else:
q_input = q.contiguous()

# Using 0.0 threshold should give the same result as normal attention.
skip_softmax_threshold_scale_factor = 0.0 if skips_softmax else None
# Using a tiny threshold should give the same result as normal attention.
skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None

output = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
q_input,
Expand Down Expand Up @@ -1383,8 +1383,9 @@ def test_trtllm_batch_decode_long_sequence_length(
@pytest.mark.parametrize("num_kv_heads", [16, 32])
@pytest.mark.parametrize("head_grp_size", [1, 5, 8])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("skips_softmax", [False, True])
def test_trtllm_gen_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if compute_capability[0] != 10:
Expand Down Expand Up @@ -1473,6 +1474,10 @@ def test_trtllm_gen_prefill_deepseek(

bmm1_scale = scale
bmm2_scale = 1.0

# Using a tiny threshold should give the same result as normal attention.
skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None

output_trtllm, lse_trtllm = flashinfer.prefill.trtllm_ragged_attention_deepseek(
q,
k_cache,
Expand All @@ -1491,6 +1496,7 @@ def test_trtllm_gen_prefill_deepseek(
False,
causal,
True,
skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor,
out=output,
)
torch.testing.assert_close(
Expand All @@ -1516,11 +1522,12 @@ def test_trtllm_gen_prefill_deepseek(
@pytest.mark.parametrize("num_kv_heads", [128])
@pytest.mark.parametrize("head_grp_size", [1])
@pytest.mark.parametrize("causal", [True, False])
@pytest.mark.parametrize("skips_softmax", [False, True])
def test_trtllm_gen_prefill_deepseek_bs1(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax
):
test_trtllm_gen_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal
batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax
)


Expand Down
14 changes: 14 additions & 0 deletions tests/attention/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def trtllm_batch_decode_mla(
enable_pdl: bool,
backend: str,
MAX_SEQ_LEN: int,
skips_softmax: bool,
):
compute_capability = get_compute_capability(torch.device(device="cuda"))
if backend == "xqa":
Expand All @@ -234,6 +235,9 @@ def trtllm_batch_decode_mla(
if dynamic_scale and dtype != torch.float8_e4m3fn:
pytest.skip("Dynamic scale is not supported for non-fp8 dtype")

if skips_softmax and backend != "trtllm-gen":
pytest.skip("skips_softmax is only supported for trtllm-gen backend")

torch.manual_seed(42)
device = "cuda:0"

Expand Down Expand Up @@ -306,6 +310,9 @@ def trtllm_batch_decode_mla(
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
workspace_buffer_ref = global_workspace_buffer

# Using a tiny threshold should give the same output as standard attention
skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None

# Run decode-MLA
output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
Expand All @@ -319,6 +326,7 @@ def trtllm_batch_decode_mla(
max_seq_len=max_seq_len,
bmm1_scale=scale / ((128 + 64) ** 0.5),
bmm2_scale=1.0,
skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor,
enable_pdl=enable_pdl,
backend=backend,
)
Expand Down Expand Up @@ -439,6 +447,7 @@ def trtllm_batch_decode_mla(
@pytest.mark.parametrize("dynamic_scale", [False])
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"])
@pytest.mark.parametrize("skips_softmax", [False, True])
def test_trtllm_batch_decode_mla(
batch_size: int,
scale: float,
Expand All @@ -448,6 +457,7 @@ def test_trtllm_batch_decode_mla(
dynamic_scale: bool,
enable_pdl: bool,
backend: str,
skips_softmax: bool,
):
trtllm_batch_decode_mla(
batch_size,
Expand All @@ -459,6 +469,7 @@ def test_trtllm_batch_decode_mla(
enable_pdl,
backend,
1024,
skips_softmax,
)


Expand All @@ -474,6 +485,7 @@ def test_trtllm_batch_decode_mla(
@pytest.mark.parametrize("enable_pdl", [True, False, None])
@pytest.mark.parametrize("backend", ["trtllm-gen"])
@pytest.mark.parametrize("MAX_SEQ_LEN", [1024, 8960])
@pytest.mark.parametrize("skips_softmax", [False, True])
def test_dsr1_trtllm_mla(
batch_size: int,
scale: float,
Expand All @@ -484,6 +496,7 @@ def test_dsr1_trtllm_mla(
enable_pdl: bool,
backend: str,
MAX_SEQ_LEN: int,
skips_softmax: bool,
):
trtllm_batch_decode_mla(
batch_size,
Expand All @@ -495,6 +508,7 @@ def test_dsr1_trtllm_mla(
enable_pdl,
backend,
MAX_SEQ_LEN,
skips_softmax,
)


Expand Down
Loading