From 806e5c684cde55534abeee648c3df8c58714310c Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Tue, 7 Apr 2026 23:55:22 +0000 Subject: [PATCH 1/7] Support lse in trtllm paged attn kernels --- csrc/trtllm_fmha_kernel_launcher.cu | 73 ++++++++++++++----- flashinfer/decode.py | 42 ++++++++++- flashinfer/mla/_core.py | 30 +++++++- flashinfer/prefill.py | 29 +++++++- .../flashinfer/trtllm/fmha/fmhaRunnerParams.h | 4 + tests/attention/test_trtllm_gen_attention.py | 5 +- tests/attention/test_trtllm_gen_mla.py | 2 + 7 files changed, 158 insertions(+), 27 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index a0997ddf2d..888a149e9f 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -77,18 +77,19 @@ void trtllm_paged_attention_launcher( void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache, void* workspace_buffer, int* block_tables, const void* k_block_scales_ptr, const void* v_block_scales_ptr, int* seq_lens, int* cum_seq_lens_q, int* cum_seq_lens_kv, - float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, - TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, int64_t max_kv_len, - int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, - int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, int64_t q_stride_heads, - int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch, - int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, + float* attention_sinks, float* lse, Data_type q_data_type, Data_type kv_data_type, + Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, + int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, + int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads, + int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax, bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size, int64_t k_sf_stride_heads, int64_t k_sf_stride_batch, int64_t v_sf_stride_heads, - int64_t v_sf_stride_batch, cudaStream_t stream) { + int64_t v_sf_stride_batch, int64_t lse_stride_tokens, int64_t lse_stride_heads, + cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -190,11 +191,22 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the first 8MB of workspace buffer: counter | scratch - // todo(Yingyi): add softmax buffer later for lse return + // Keep semaphores at the first 8MB of workspace buffer. + // Workspace layout for generation path: counter | softmax | scratch. runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc( num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); - // scratch takes the rest of the workspace buffer + } + + runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, + "trtllm_gen_softmax_workspace"); + runner_params.lsePtr = lse; + runner_params.lseStrideTokens = lse_stride_tokens; + runner_params.lseStrideHeads = lse_stride_heads; + + // The scratch allocation uses size=0 to consume the rest of the workspace, so it must be the + // last allocation from float_allocator. + if (mode == TllmPagedAttentionMode::ForGen) { runner_params.multiCtasKvScratchPtr = float_allocator.aligned_alloc(0, 16, "trtllm_gen_scratch_workspace"); } @@ -244,7 +256,7 @@ void trtllm_paged_attention_decode( int64_t workspace_size, Optional attention_sinks, Optional cum_seq_lens_q, Optional key_block_scales, Optional value_block_scales, Optional skip_softmax_threshold_scale_factor, - Optional uses_shared_paged_kv_idx) { + Optional uses_shared_paged_kv_idx, Optional lse) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -352,19 +364,30 @@ void trtllm_paged_attention_decode( skip_softmax_threshold_scale_factor.value_or(0.0f); bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; + float* lse_ptr = nullptr; + int lse_stride_tokens = 0; + int lse_stride_heads = 0; + if (lse.has_value()) { + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); + lse_stride_tokens = lse.value().stride(0); + lse_stride_heads = lse.value().stride(lse.value().ndim() - 1); + } + trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), k_block_scales_ptr, v_block_scales_ptr, static_cast(seq_lens.data_ptr()), cum_seq_lens_q_ptr, - /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, - TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, - num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens, - q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, + /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, + o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, + num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, + q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax, uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads, - k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream); + k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads, + stream); } void trtllm_paged_attention_context( @@ -376,7 +399,8 @@ void trtllm_paged_attention_context( int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, int64_t workspace_size, Optional attention_sinks, Optional key_block_scales, Optional value_block_scales, - Optional skip_softmax_threshold_scale_factor, Optional uses_shared_paged_kv_idx) { + Optional skip_softmax_threshold_scale_factor, Optional uses_shared_paged_kv_idx, + Optional lse) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); @@ -474,20 +498,31 @@ void trtllm_paged_attention_context( skip_softmax_threshold_scale_factor.value_or(0.0f); bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; + float* lse_ptr = nullptr; + int lse_stride_tokens = 0; + int lse_stride_heads = 0; + if (lse.has_value()) { + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); + lse_stride_tokens = lse.value().stride(0); + lse_stride_heads = lse.value().stride(1); + } + trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), k_block_scales_ptr, v_block_scales_ptr, static_cast(seq_lens.data_ptr()), /*cum_seq_lens_q=*/static_cast(cum_seq_lens_q.data_ptr()), /*cum_seq_lens_kv=*/static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, - q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, + lse_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax, uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads, - k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream); + k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads, + stream); } void trtllm_ragged_attention_launcher( diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 822aca407c..0f61a30f66 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -2036,6 +2036,7 @@ def _paged_run( value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, + lse: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out is None: out = torch.empty_like(query) @@ -2081,6 +2082,7 @@ def _paged_run( value_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) return out @@ -2147,7 +2149,6 @@ def paged_run( skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, ) -> None: - assert maybe_lse is None assert paged_kv_cache is not None assert num_qo_heads is not None assert num_kv_heads is not None @@ -2176,6 +2177,7 @@ def paged_run( value_block_scales=value_block_scales, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + lse=maybe_lse, ) @register_fake_op(f"flashinfer::{uri}_paged_run") @@ -2259,7 +2261,11 @@ def trtllm_batch_decode_with_kv_cache( skip_softmax_threshold_scale_factor: Optional[float] = None, kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, uses_shared_paged_kv_idx: bool = True, -) -> Union[torch.Tensor, FP4Tensor]: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[ + torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] +]: """ Parameters ---------- @@ -2379,10 +2385,19 @@ def trtllm_batch_decode_with_kv_cache( True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. + lse: Optional[torch.Tensor] = None + The log-sum-exp of attention logits, if not provided, will be allocated internally. + Only supported by trtllm-gen backend. + + return_lse: bool = False + Whether to return the logsumexp of attention scores, defaults to ``False``. + Returns ------- out : Union[torch.Tensor, FP4Tensor] output torch.Tensor or FP4Tensor. + lse: Optional[torch.Tensor] + The log-sum-exp of attention logits, if not provided, will be allocated internally. """ enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl @@ -2578,6 +2593,22 @@ def trtllm_batch_decode_with_kv_cache( _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) + if return_lse: + if lse is None: + lse = torch.empty( + (query.size(0), query.size(1)), + dtype=torch.float32, + device=query.device, + ) + else: + check_shape_dtype_device( + lse, + (query.size(0), query.size(1)), + torch.float32, + query.device, + "lse", + ) + run_func( out, out_scale_factor, @@ -2606,13 +2637,18 @@ def trtllm_batch_decode_with_kv_cache( v_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) - return ( + out = ( out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) ) + if return_lse: + return out, lse + else: + return out else: raise KeyError(f"Backend {backend} not supported") diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 4e8bdd7212..4feb3d32ac 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -609,7 +609,9 @@ def trtllm_batch_decode_with_kv_cache_mla( backend: str = "auto", is_var_seq: bool = True, uses_shared_paged_kv_idx: bool = True, -) -> torch.Tensor: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Parameters ---------- @@ -654,6 +656,11 @@ def trtllm_batch_decode_with_kv_cache_mla( True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. False is only supported for trtllm-gen backend. + lse: Optional[torch.Tensor] = None + The log-sum-exp of attention logits, if not provided, will be allocated internally. + Only supported by trtllm-gen backend. + return_lse: bool = False + Whether to return the logsumexp of attention scores, defaults to ``False``. Note ---- @@ -765,6 +772,21 @@ def trtllm_batch_decode_with_kv_cache_mla( batch_size = query.size(0) max_q_len = query.size(1) + if return_lse: + if lse is None: + lse = torch.empty( + (query.size(0), query.size(1)), + dtype=torch.float32, + device=query.device, + ) + else: + check_shape_dtype_device( + lse, + (query.size(0), query.size(1)), + torch.float32, + query.device, + "lse", + ) query = query.flatten(0, 1) # [B*S, H, D] run_func( @@ -795,9 +817,13 @@ def trtllm_batch_decode_with_kv_cache_mla( None, # value_block_scales skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) + if return_lse: + return out, lse + else: + return out - return out elif backend == "cute-dsl": enable_pdl = ( device_support_pdl(query.device) if enable_pdl is None else enable_pdl diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index e64d4a73b6..cfe4c68a04 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3864,7 +3864,11 @@ def trtllm_batch_context_with_kv_cache( kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, -) -> Union[torch.Tensor, FP4Tensor]: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[ + torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] +]: """ Parameters ---------- @@ -3961,10 +3965,17 @@ def trtllm_batch_context_with_kv_cache( Whether the K and V page indices are shared as a unified index. True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. + lse : Optional[torch.Tensor] = None + The log-sum-exp of attention logits, if not provided, will be allocated internally. + Only supported by trtllm-gen backend. + return_lse : bool = False + Whether to return the logsumexp of attention scores, defaults to ``False``. Returns ------- out: Union[torch.Tensor, FP4Tensor] output torch.Tensor or FP4Tensor. + lse: Optional[torch.Tensor] + The log-sum-exp of attention logits, if not provided, will be allocated internally. """ if enable_pdl is None: @@ -4097,6 +4108,15 @@ def trtllm_batch_context_with_kv_cache( assert bmm2_scale.dtype == torch.float32 _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() + if return_lse: + if lse is None: + lse = torch.empty( + query.size(0), query.size(1), dtype=torch.float32, device=query.device + ) + else: + check_shape_dtype_device( + lse, (query.size(0), query.size(1)), torch.float32, query.device, "lse" + ) run_func( out, out_scale_factor, @@ -4125,12 +4145,17 @@ def trtllm_batch_context_with_kv_cache( value_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) - return ( + out = ( out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) ) + if return_lse: + return out, lse + else: + return out @functools.cache diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index 7945b1fab2..156233fb1f 100644 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -224,6 +224,10 @@ struct TllmGenFmhaRunnerParams { float2* softmaxStatsPtr; // The LSE buffer. float* lsePtr; + // The stride between different tokens for LSE. + int lseStrideTokens; + // The stride between different heads for LSE. + int lseStrideHeads; // Attention sink float const* ptrAttentionSinks{nullptr}; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 186f1e5072..c6ca42db8f 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -636,7 +636,7 @@ def _test_trtllm_batch_prefill( # Using a tiny threshold should give the same result as normal attention. skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None - output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( + output, lse = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q_input, kv_cache_kernel, workspace_buffer, @@ -660,6 +660,7 @@ def _test_trtllm_batch_prefill( kv_cache_sf=kv_cache_sf_kernel, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + return_lse=True, ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future @@ -1118,8 +1119,10 @@ def _test_trtllm_batch_decode( skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, kv_cache_sf=kv_cache_sf_kernel, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + return_lse=True, ) if backend == "trtllm-gen": + output, lse = output # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index c1cf3d8a50..c80d991c29 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -407,10 +407,12 @@ def trtllm_batch_decode_mla( enable_pdl=enable_pdl, backend=backend, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + return_lse=True, ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future if backend == "trtllm-gen": + output, lse = output assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() # Run reference attention and align output From 88e36932e521edba2067c923f268a5785c1aae61 Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Tue, 7 Apr 2026 23:55:22 +0000 Subject: [PATCH 2/7] Support lse in trtllm paged attn kernels --- flashinfer/decode.py | 38 ++++++++------------ flashinfer/mla/_core.py | 24 ++++++------- flashinfer/prefill.py | 17 ++++----- tests/attention/test_trtllm_gen_attention.py | 33 ++++++++++++++--- tests/attention/test_trtllm_gen_mla.py | 16 +++++++-- 5 files changed, 76 insertions(+), 52 deletions(-) diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 0f61a30f66..b92802ed7d 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1365,13 +1365,12 @@ def run( rope_theta = 1e4 if return_lse: + lse_shape = (q.size(0), q.size(1)) if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) + lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device) else: check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + lse, lse_shape, torch.float32, q.device, "lse" ) if out is None: @@ -1964,19 +1963,12 @@ def run( ) if return_lse: + lse_shape = (q_nope.size(0), q_nope.size(1)) if lse is None: - lse = torch.empty( - (q_nope.size(0), q_nope.size(1)), - dtype=torch.float32, - device=device, - ) + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) else: check_shape_dtype_device( - lse, - (q_nope.size(0), q_nope.size(1)), - q_nope.dtype, - q_nope.device, - "lse", + lse, lse_shape, q_nope.dtype, q_nope.device, "lse" ) self._cached_module.run( self._float_workspace_buffer, @@ -2445,6 +2437,11 @@ def trtllm_batch_decode_with_kv_cache( backend = ( "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) + wants_lse = return_lse or lse is not None + if wants_lse and backend != "trtllm-gen": + raise ValueError( + "lse and return_lse are only supported by the trtllm-gen backend" + ) if backend == "xqa": # xqa backend doesn't support nvfp4 output @@ -2594,19 +2591,12 @@ def trtllm_batch_decode_with_kv_cache( _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) if return_lse: + lse_shape = (query.size(0), query.size(1)) if lse is None: - lse = torch.empty( - (query.size(0), query.size(1)), - dtype=torch.float32, - device=query.device, - ) + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) else: check_shape_dtype_device( - lse, - (query.size(0), query.size(1)), - torch.float32, - query.device, - "lse", + lse, lse_shape, torch.float32, query.device, "lse" ) run_func( diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 4feb3d32ac..6036d53b6f 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -559,11 +559,12 @@ def run( ) if return_lse: + lse_shape = (q_nope.shape[0] * num_heads, q_nope.shape[2]) if lse is None: - lse = torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device) + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) else: check_shape_dtype_device( - lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse" + lse, lse_shape, torch.float32, q_nope.device, "lse" ) profiler_args = (profiler_buffer,) if self._use_profiler else () self._cached_module.run( @@ -685,6 +686,11 @@ def trtllm_batch_decode_with_kv_cache_mla( backend = ( "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) + wants_lse = return_lse or lse is not None + if wants_lse and backend != "trtllm-gen": + raise ValueError( + "lse and return_lse are only supported by the trtllm-gen backend" + ) if isinstance(bmm1_scale, torch.Tensor): assert bmm1_scale.dtype == torch.float32 bmm1_scale = bmm1_scale * log2e @@ -772,20 +778,14 @@ def trtllm_batch_decode_with_kv_cache_mla( batch_size = query.size(0) max_q_len = query.size(1) + num_qo_heads = query.size(2) if return_lse: + lse_shape = (batch_size * max_q_len, num_qo_heads) if lse is None: - lse = torch.empty( - (query.size(0), query.size(1)), - dtype=torch.float32, - device=query.device, - ) + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) else: check_shape_dtype_device( - lse, - (query.size(0), query.size(1)), - torch.float32, - query.device, - "lse", + lse, lse_shape, torch.float32, query.device, "lse" ) query = query.flatten(0, 1) # [B*S, H, D] diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index cfe4c68a04..f7b669fb3c 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -679,7 +679,6 @@ def paged_run( uses_shared_paged_kv_idx: bool = True, ) -> None: if backend == "trtllm-gen": - assert maybe_lse is None assert num_qo_heads is not None assert num_kv_heads is not None assert block_tables is not None @@ -714,6 +713,7 @@ def paged_run( value_block_scales=value_block_scales, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + lse=maybe_lse, ) elif backend == "fa2": assert not is_float8(q) @@ -4108,15 +4108,12 @@ def trtllm_batch_context_with_kv_cache( assert bmm2_scale.dtype == torch.float32 _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() - if return_lse: - if lse is None: - lse = torch.empty( - query.size(0), query.size(1), dtype=torch.float32, device=query.device - ) - else: - check_shape_dtype_device( - lse, (query.size(0), query.size(1)), torch.float32, query.device, "lse" - ) + lse_shape = (query.size(0), query.size(1)) + if lse is None: + if return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) + else: + check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") run_func( out, out_scale_factor, diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index c6ca42db8f..094d4caa8f 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -566,6 +566,7 @@ def _test_trtllm_batch_prefill( ) sm_scale = float(1.0 / (head_dim**0.5)) + lse_ref = None # Build reference output plan_params = { @@ -589,7 +590,7 @@ def _test_trtllm_batch_prefill( workspace_buffer_ref, kv_layout ) wrapper_ref.plan(**plan_params) - output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + output_ref, lse_ref = wrapper_ref.run(ref_q, ref_kv_cache, return_lse=True) else: # Construct flat K/V via helper k_flat, v_flat, kv_indptr_tokens = flatten_paged_kv( @@ -713,6 +714,13 @@ def _test_trtllm_batch_prefill( f"Block scale factors may be mismatched to FP4 data blocks." ) + expected_lse_shape = (q_input.size(0), q_input.size(1)) + assert lse.shape == expected_lse_shape + assert lse.dtype == torch.float32 + assert torch.isfinite(lse).all() + if lse_ref is not None: + torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + if ( o_dtype != "nvfp4" and kv_dtype != "nvfp4" and uses_shared_paged_kv_idx ): # wrapper api does not support fp4 output/kv or separate KV page indices yet. @@ -999,6 +1007,8 @@ def _test_trtllm_batch_decode( ) sm_scale = float(1.0 / (head_dim**0.5)) + should_check_lse = backend == "trtllm-gen" + lse_ref = None # Build reference output plan_params = { @@ -1020,7 +1030,12 @@ def _test_trtllm_batch_decode( workspace_buffer_ref, kv_layout, use_tensor_cores=True ) wrapper_ref.plan(**plan_params) - output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + if should_check_lse: + output_ref, lse_ref = wrapper_ref.run( + ref_q, ref_kv_cache, return_lse=True + ) + else: + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) else: # speculative decoding test @@ -1040,7 +1055,12 @@ def _test_trtllm_batch_decode( } ) wrapper_ref.plan(**plan_params_prefill) - output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + if should_check_lse: + output_ref, lse_ref = wrapper_ref.run( + ref_q, ref_kv_cache, return_lse=True + ) + else: + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) else: # Construct flat K/V via helper k_flat, v_flat, kv_indptr_tokens = flatten_paged_kv( @@ -1119,8 +1139,10 @@ def _test_trtllm_batch_decode( skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, kv_cache_sf=kv_cache_sf_kernel, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, - return_lse=True, + return_lse=should_check_lse, ) + if should_check_lse: + output, lse = output if backend == "trtllm-gen": output, lse = output # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero @@ -1184,6 +1206,9 @@ def _test_trtllm_batch_decode( f"NVFP4 KV cache attention: cosine similarity {cos:.4f} < 0.86. " f"Block scale factors may be mismatched to FP4 data blocks." ) + if lse_ref is not None: + assert lse is not None, "LSE should be returned when return_lse=True" + torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) # Only test wrapper with trtllm-gen backend if ( diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index c80d991c29..bb6ec9206a 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -389,6 +389,7 @@ def trtllm_batch_decode_mla( # Using a tiny threshold should give the same output as standard attention skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None + should_check_lse = backend == "trtllm-gen" # Run decode-MLA output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( @@ -407,10 +408,15 @@ def trtllm_batch_decode_mla( enable_pdl=enable_pdl, backend=backend, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, - return_lse=True, + return_lse=should_check_lse, ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + if should_check_lse: + output, lse = output + assert lse.shape == (batch_size * q_len_per_request, layer_dimensions.num_heads) + assert lse.dtype == torch.float32 + assert torch.isfinite(lse).all() if backend == "trtllm-gen": output, lse = output assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() @@ -466,7 +472,11 @@ def trtllm_batch_decode_mla( ckv = kv_cache[..., : layer_dimensions.head_dimensions.kv_lora_rank] kpe = kv_cache[..., layer_dimensions.head_dimensions.kv_lora_rank :] - o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) + lse_ref = None + if should_check_lse: + o_ref, lse_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) + else: + o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) # cute-dsl fp8 kernel outputs fp8; cast to bf16 to match trtllm-gen / reference if backend == "cute-dsl" and output.dtype == torch.float8_e4m3fn: @@ -488,6 +498,8 @@ def trtllm_batch_decode_mla( try: torch.testing.assert_close(output, o_ref_view, rtol=rtol, atol=atol) + if lse_ref is not None: + torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) except AssertionError as fa2_err: if backend == "cute-dsl": # fa2 reference may diverge from cute-dsl in some configs; From a8923616abfa73fa5a192722df2de93d02484cbe Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Tue, 14 Apr 2026 18:19:39 +0000 Subject: [PATCH 3/7] fix test case workspace zero check --- tests/attention/test_trtllm_gen_attention.py | 35 ++++++++++++++++---- tests/attention/test_trtllm_gen_mla.py | 1 - 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 094d4caa8f..5757e5eb06 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -33,6 +33,7 @@ global_workspace_buffer = None # can.be empty initialized global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized workspace_size = 256 * 1024 * 1024 +TRTLLM_GEN_WORKSPACE_CHECK_BYTES = 8192 * 256 * 4 def flip_coin(*args, **kwargs): @@ -481,6 +482,16 @@ def generate_causal_mask( return mask_uint16 +def trtllm_gen_workspace_softmax_end_bytes_context( + workspace_buffer: torch.Tensor, *, num_qo_heads: int, sum_seq_q: int +) -> int: + """Context/prefill uses softmax stats as the first workspace slab.""" + softmax_stats_nbytes = 8 * num_qo_heads * sum_seq_q + base_addr = workspace_buffer.data_ptr() + aligned_addr = (base_addr + 15) // 16 * 16 + return (aligned_addr - base_addr) + softmax_stats_nbytes + + def _test_trtllm_batch_prefill( kv_layout: str, batch_size: int, @@ -636,6 +647,15 @@ def _test_trtllm_batch_prefill( # Using a tiny threshold should give the same result as normal attention. skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None + softmax_end_bytes = trtllm_gen_workspace_softmax_end_bytes_context( + workspace_buffer, + num_qo_heads=q_input.size(1), + sum_seq_q=q_input.size(0), + ) + workspace_check_end_bytes = min( + softmax_end_bytes + TRTLLM_GEN_WORKSPACE_CHECK_BYTES, workspace_buffer.numel() + ) + workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].zero_() output, lse = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q_input, @@ -663,9 +683,9 @@ def _test_trtllm_batch_prefill( uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, return_lse=True, ) - # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero - # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future - assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + assert ( + workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].cpu().numpy() == 0 + ).all() if o_dtype == "nvfp4": output, output_ref = unpack_compare_nvfp4( @@ -732,6 +752,7 @@ def _test_trtllm_batch_prefill( plan_params["kv_data_type"] = kv_cache.dtype plan_params["o_data_type"] = DTYPE_MAP[o_dtype] wrapper_trtllm_gen.plan(**plan_params) + workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].zero_() output_wrapper = wrapper_trtllm_gen.run( q_input, kv_cache, @@ -748,9 +769,10 @@ def _test_trtllm_batch_prefill( torch.testing.assert_close( output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1 ) - # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero - # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future - assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + assert ( + workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].cpu().numpy() + == 0 + ).all() @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @@ -1144,7 +1166,6 @@ def _test_trtllm_batch_decode( if should_check_lse: output, lse = output if backend == "trtllm-gen": - output, lse = output # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index bb6ec9206a..b3aef3b5bc 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -418,7 +418,6 @@ def trtllm_batch_decode_mla( assert lse.dtype == torch.float32 assert torch.isfinite(lse).all() if backend == "trtllm-gen": - output, lse = output assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() # Run reference attention and align output From c6f46843e0d14913dc9767760c92de3e253cff71 Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Tue, 14 Apr 2026 22:46:30 +0000 Subject: [PATCH 4/7] test fix --- csrc/trtllm_fmha_kernel_launcher.cu | 8 ++-- flashinfer/decode.py | 41 ++++++++----------- flashinfer/mla/_core.py | 28 ++++++------- flashinfer/prefill.py | 2 + .../flashinfer/trtllm/fmha/fmhaRunnerParams.h | 4 +- tests/attention/test_trtllm_gen_attention.py | 37 ++++++++++++++++- tests/attention/test_trtllm_gen_mla.py | 22 ++++++++-- 7 files changed, 90 insertions(+), 52 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 888a149e9f..8403e712ff 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -365,8 +365,8 @@ void trtllm_paged_attention_decode( bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; float* lse_ptr = nullptr; - int lse_stride_tokens = 0; - int lse_stride_heads = 0; + int64_t lse_stride_tokens = 0; + int64_t lse_stride_heads = 0; if (lse.has_value()) { TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; lse_ptr = static_cast(lse.value().data_ptr()); @@ -499,8 +499,8 @@ void trtllm_paged_attention_context( bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; float* lse_ptr = nullptr; - int lse_stride_tokens = 0; - int lse_stride_heads = 0; + int64_t lse_stride_tokens = 0; + int64_t lse_stride_heads = 0; if (lse.has_value()) { TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; lse_ptr = static_cast(lse.value().data_ptr()); diff --git a/flashinfer/decode.py b/flashinfer/decode.py index b92802ed7d..70dc62f1b7 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1364,14 +1364,11 @@ def run( if rope_theta is None: rope_theta = 1e4 - if return_lse: - lse_shape = (q.size(0), q.size(1)) - if lse is None: - lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device) - else: - check_shape_dtype_device( - lse, lse_shape, torch.float32, q.device, "lse" - ) + lse_shape = (q.size(0), q.size(1)) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, q.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device) if out is None: out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype @@ -1962,14 +1959,13 @@ def run( out, q_nope.shape, q_nope.dtype, q_nope.device, "out" ) - if return_lse: - lse_shape = (q_nope.size(0), q_nope.size(1)) - if lse is None: - lse = torch.empty(lse_shape, dtype=torch.float32, device=device) - else: - check_shape_dtype_device( - lse, lse_shape, q_nope.dtype, q_nope.device, "lse" - ) + lse_shape = (q_nope.size(0), q_nope.size(1)) + if lse is not None: + check_shape_dtype_device( + lse, lse_shape, torch.float32, q_nope.device, "lse" + ) + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) self._cached_module.run( self._float_workspace_buffer, self._int_workspace_buffer, @@ -2590,14 +2586,11 @@ def trtllm_batch_decode_with_kv_cache( _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) - if return_lse: - lse_shape = (query.size(0), query.size(1)) - if lse is None: - lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) - else: - check_shape_dtype_device( - lse, lse_shape, torch.float32, query.device, "lse" - ) + lse_shape = (query.size(0), query.size(1)) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) run_func( out, diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 6036d53b6f..b471862c7c 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -558,14 +558,13 @@ def run( out, q_nope.shape, q_nope.dtype, q_nope.device, "out" ) - if return_lse: - lse_shape = (q_nope.shape[0] * num_heads, q_nope.shape[2]) - if lse is None: - lse = torch.empty(lse_shape, dtype=torch.float32, device=device) - else: - check_shape_dtype_device( - lse, lse_shape, torch.float32, q_nope.device, "lse" - ) + lse_shape = (q_nope.shape[0], num_heads) + if lse is not None: + check_shape_dtype_device( + lse, lse_shape, torch.float32, q_nope.device, "lse" + ) + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) profiler_args = (profiler_buffer,) if self._use_profiler else () self._cached_module.run( self._float_workspace_buffer, @@ -779,14 +778,11 @@ def trtllm_batch_decode_with_kv_cache_mla( batch_size = query.size(0) max_q_len = query.size(1) num_qo_heads = query.size(2) - if return_lse: - lse_shape = (batch_size * max_q_len, num_qo_heads) - if lse is None: - lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) - else: - check_shape_dtype_device( - lse, lse_shape, torch.float32, query.device, "lse" - ) + lse_shape = (batch_size * max_q_len, num_qo_heads) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) query = query.flatten(0, 1) # [B*S, H, D] run_func( diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index f7b669fb3c..cd611400ee 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -269,6 +269,7 @@ def _paged_run( value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, + lse: Optional[torch.Tensor] = None, ) -> torch.Tensor: sm_count = get_device_sm_count(query.device) if out is None: @@ -306,6 +307,7 @@ def _paged_run( value_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) return out diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index 156233fb1f..cee2b39de6 100644 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -225,9 +225,9 @@ struct TllmGenFmhaRunnerParams { // The LSE buffer. float* lsePtr; // The stride between different tokens for LSE. - int lseStrideTokens; + int64_t lseStrideTokens; // The stride between different heads for LSE. - int lseStrideHeads; + int64_t lseStrideHeads; // Attention sink float const* ptrAttentionSinks{nullptr}; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 5757e5eb06..7a0e6f1e1d 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -36,6 +36,16 @@ TRTLLM_GEN_WORKSPACE_CHECK_BYTES = 8192 * 256 * 4 +def get_lse_test_tolerances(q_dtype: str, kv_dtype: str) -> tuple[float, float]: + # TRT-LLM's FP8 prefill/decode LSE is noisier than the high-precision wrapper + # reference even when the output tensors still agree within the existing tolerances. + if kv_dtype == "nvfp4": + return 3e-1, 3e-1 + if q_dtype == "fp8" or kv_dtype == "fp8": + return 2e-2, 2e-2 + return 1e-3, 1e-3 + + def flip_coin(*args, **kwargs): # Use any test parameters to deterministically decide branch # This makes test configurations go through different paths @@ -656,6 +666,11 @@ def _test_trtllm_batch_prefill( softmax_end_bytes + TRTLLM_GEN_WORKSPACE_CHECK_BYTES, workspace_buffer.numel() ) workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].zero_() + provided_lse = torch.empty( + (q_input.size(0), q_input.size(1)), + device=GPU_DEVICE, + dtype=torch.float32, + ) output, lse = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q_input, @@ -681,6 +696,7 @@ def _test_trtllm_batch_prefill( kv_cache_sf=kv_cache_sf_kernel, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + lse=provided_lse, return_lse=True, ) assert ( @@ -735,11 +751,13 @@ def _test_trtllm_batch_prefill( ) expected_lse_shape = (q_input.size(0), q_input.size(1)) + assert lse is provided_lse assert lse.shape == expected_lse_shape assert lse.dtype == torch.float32 assert torch.isfinite(lse).all() if lse_ref is not None: - torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + lse_rtol, lse_atol = get_lse_test_tolerances(q_dtype, kv_dtype) + torch.testing.assert_close(lse, lse_ref, rtol=lse_rtol, atol=lse_atol) if ( o_dtype != "nvfp4" and kv_dtype != "nvfp4" and uses_shared_paged_kv_idx @@ -1131,6 +1149,15 @@ def _test_trtllm_batch_decode( q_input = make_query_non_contiguous(q, num_qo_heads, head_dim) else: q_input = q.contiguous() + provided_lse = ( + torch.empty( + (q_input.size(0), q_input.size(1)), + device=GPU_DEVICE, + dtype=torch.float32, + ) + if should_check_lse + else None + ) # Using a tiny threshold should give the same result as normal attention. skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None @@ -1161,10 +1188,15 @@ def _test_trtllm_batch_decode( skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, kv_cache_sf=kv_cache_sf_kernel, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + lse=provided_lse, return_lse=should_check_lse, ) if should_check_lse: output, lse = output + assert lse is provided_lse + assert lse.shape == (q_input.size(0), q_input.size(1)) + assert lse.dtype == torch.float32 + assert torch.isfinite(lse).all() if backend == "trtllm-gen": # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future @@ -1229,7 +1261,8 @@ def _test_trtllm_batch_decode( ) if lse_ref is not None: assert lse is not None, "LSE should be returned when return_lse=True" - torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) + lse_rtol, lse_atol = get_lse_test_tolerances(q_dtype, kv_dtype) + torch.testing.assert_close(lse, lse_ref, rtol=lse_rtol, atol=lse_atol) # Only test wrapper with trtllm-gen backend if ( diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index b3aef3b5bc..74208acc09 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -1,7 +1,9 @@ +import random +import math + import pytest import torch import torch.nn.functional as F -import random import flashinfer from flashinfer.mla import ( @@ -471,11 +473,23 @@ def trtllm_batch_decode_mla( ckv = kv_cache[..., : layer_dimensions.head_dimensions.kv_lora_rank] kpe = kv_cache[..., layer_dimensions.head_dimensions.kv_lora_rank :] + o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) lse_ref = None if should_check_lse: - o_ref, lse_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) - else: - o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) + _, lse_ref = sparse_mla_reference_torch( + cache_seqlens=seq_lens_tensor, + block_table=block_tables, + q=query, + blocked_k=kv_cache, + blocked_v=ckv, + page_size=page_size, + is_causal=True, + sm_scale=sm_scale, + ) + # TRT-LLM returns log2(LSE); the torch helper returns natural-log LSE. + lse_ref = lse_ref.permute(0, 2, 1).contiguous().view( + batch_size * q_len_per_request, layer_dimensions.num_heads + ).to(lse.device) / math.log(2.0) # cute-dsl fp8 kernel outputs fp8; cast to bf16 to match trtllm-gen / reference if backend == "cute-dsl" and output.dtype == torch.float8_e4m3fn: From 903d007039cded36ad43a7a6e299d8bad5863580 Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Tue, 14 Apr 2026 22:50:49 +0000 Subject: [PATCH 5/7] consistency --- flashinfer/prefill.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index cd611400ee..add2fb3b25 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -4111,11 +4111,10 @@ def trtllm_batch_context_with_kv_cache( _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() lse_shape = (query.size(0), query.size(1)) - if lse is None: - if return_lse: - lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) - else: + if lse is not None: check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) run_func( out, out_scale_factor, From 3a0256cc146c5891ca34a50d50c782d57e3f424d Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Tue, 14 Apr 2026 23:06:42 +0000 Subject: [PATCH 6/7] fix mla test file mode Clear the executable bit on the TRT-LLM MLA pytest module so Ruff EXE002 stops flagging it as a script without a shebang. Made-with: Cursor --- tests/attention/test_trtllm_gen_mla.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 tests/attention/test_trtllm_gen_mla.py diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py old mode 100755 new mode 100644 From a8528be2d698a597ee61d66ef466542ae8a726f6 Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Wed, 15 Apr 2026 05:49:26 +0000 Subject: [PATCH 7/7] address comments --- csrc/trtllm_fmha_kernel_launcher.cu | 8 ++++++++ flashinfer/prefill.py | 14 +++++--------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 8403e712ff..ac348b3633 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -197,6 +197,10 @@ void trtllm_paged_attention_launcher( num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); } + // Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and + // ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction + // kernel only participates for special multi-CTA modes; single-CTA context kernels still write + // the final stats directly into this buffer. runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, "trtllm_gen_softmax_workspace"); @@ -601,6 +605,10 @@ void trtllm_ragged_attention_launcher( // semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc( num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); + // Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and + // ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction + // kernel only participates for special multi-CTA modes; single-CTA context kernels still write + // the final stats directly into this buffer. runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, "trtllm_gen_softmax_workspace"); diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index add2fb3b25..b907edb0b0 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -2284,15 +2284,11 @@ def run( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - if return_lse: - if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) - else: - check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" - ) + lse_shape = (q.size(0), q.size(1)) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, q.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device) # For NVFP4 KV (uint8 packed), v_cache last dim is head_dim//2; # use q's head_dim for output instead