diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index a0997ddf2d..ac348b3633 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -77,18 +77,19 @@ void trtllm_paged_attention_launcher( void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache, void* workspace_buffer, int* block_tables, const void* k_block_scales_ptr, const void* v_block_scales_ptr, int* seq_lens, int* cum_seq_lens_q, int* cum_seq_lens_kv, - float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, - TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, int64_t max_kv_len, - int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, - int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, int64_t q_stride_heads, - int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch, - int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, + float* attention_sinks, float* lse, Data_type q_data_type, Data_type kv_data_type, + Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, + int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, + int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads, + int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax, bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size, int64_t k_sf_stride_heads, int64_t k_sf_stride_batch, int64_t v_sf_stride_heads, - int64_t v_sf_stride_batch, cudaStream_t stream) { + int64_t v_sf_stride_batch, int64_t lse_stride_tokens, int64_t lse_stride_heads, + cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -190,11 +191,26 @@ void trtllm_paged_attention_launcher( size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB size_t num_semaphores = round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes - // semaphores be at the first 8MB of workspace buffer: counter | scratch - // todo(Yingyi): add softmax buffer later for lse return + // Keep semaphores at the first 8MB of workspace buffer. + // Workspace layout for generation path: counter | softmax | scratch. runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc( num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); - // scratch takes the rest of the workspace buffer + } + + // Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and + // ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction + // kernel only participates for special multi-CTA modes; single-CTA context kernels still write + // the final stats directly into this buffer. + runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( + sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, + "trtllm_gen_softmax_workspace"); + runner_params.lsePtr = lse; + runner_params.lseStrideTokens = lse_stride_tokens; + runner_params.lseStrideHeads = lse_stride_heads; + + // The scratch allocation uses size=0 to consume the rest of the workspace, so it must be the + // last allocation from float_allocator. + if (mode == TllmPagedAttentionMode::ForGen) { runner_params.multiCtasKvScratchPtr = float_allocator.aligned_alloc(0, 16, "trtllm_gen_scratch_workspace"); } @@ -244,7 +260,7 @@ void trtllm_paged_attention_decode( int64_t workspace_size, Optional attention_sinks, Optional cum_seq_lens_q, Optional key_block_scales, Optional value_block_scales, Optional skip_softmax_threshold_scale_factor, - Optional uses_shared_paged_kv_idx) { + Optional uses_shared_paged_kv_idx, Optional lse) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -352,19 +368,30 @@ void trtllm_paged_attention_decode( skip_softmax_threshold_scale_factor.value_or(0.0f); bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; + float* lse_ptr = nullptr; + int64_t lse_stride_tokens = 0; + int64_t lse_stride_heads = 0; + if (lse.has_value()) { + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); + lse_stride_tokens = lse.value().stride(0); + lse_stride_heads = lse.value().stride(lse.value().ndim() - 1); + } + trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), k_block_scales_ptr, v_block_scales_ptr, static_cast(seq_lens.data_ptr()), cum_seq_lens_q_ptr, - /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, - TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, - num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens, - q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, + /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, + o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, + num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, + q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax, uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads, - k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream); + k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads, + stream); } void trtllm_paged_attention_context( @@ -376,7 +403,8 @@ void trtllm_paged_attention_context( int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, int64_t workspace_size, Optional attention_sinks, Optional key_block_scales, Optional value_block_scales, - Optional skip_softmax_threshold_scale_factor, Optional uses_shared_paged_kv_idx) { + Optional skip_softmax_threshold_scale_factor, Optional uses_shared_paged_kv_idx, + Optional lse) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); @@ -474,20 +502,31 @@ void trtllm_paged_attention_context( skip_softmax_threshold_scale_factor.value_or(0.0f); bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; + float* lse_ptr = nullptr; + int64_t lse_stride_tokens = 0; + int64_t lse_stride_heads = 0; + if (lse.has_value()) { + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); + lse_stride_tokens = lse.value().stride(0); + lse_stride_heads = lse.value().stride(1); + } + trtllm_paged_attention_launcher( out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), k_block_scales_ptr, v_block_scales_ptr, static_cast(seq_lens.data_ptr()), /*cum_seq_lens_q=*/static_cast(cum_seq_lens_q.data_ptr()), /*cum_seq_lens_kv=*/static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, - q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, + lse_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax, uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads, - k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream); + k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads, + stream); } void trtllm_ragged_attention_launcher( @@ -566,6 +605,10 @@ void trtllm_ragged_attention_launcher( // semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc( num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); + // Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and + // ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction + // kernel only participates for special multi-CTA modes; single-CTA context kernels still write + // the final stats directly into this buffer. runner_params.softmaxStatsPtr = float_allocator.aligned_alloc( sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, "trtllm_gen_softmax_workspace"); diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 822aca407c..70dc62f1b7 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -1364,15 +1364,11 @@ def run( if rope_theta is None: rope_theta = 1e4 - if return_lse: - if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) - else: - check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" - ) + lse_shape = (q.size(0), q.size(1)) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, q.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device) if out is None: out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype @@ -1963,21 +1959,13 @@ def run( out, q_nope.shape, q_nope.dtype, q_nope.device, "out" ) - if return_lse: - if lse is None: - lse = torch.empty( - (q_nope.size(0), q_nope.size(1)), - dtype=torch.float32, - device=device, - ) - else: - check_shape_dtype_device( - lse, - (q_nope.size(0), q_nope.size(1)), - q_nope.dtype, - q_nope.device, - "lse", - ) + lse_shape = (q_nope.size(0), q_nope.size(1)) + if lse is not None: + check_shape_dtype_device( + lse, lse_shape, torch.float32, q_nope.device, "lse" + ) + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) self._cached_module.run( self._float_workspace_buffer, self._int_workspace_buffer, @@ -2036,6 +2024,7 @@ def _paged_run( value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, + lse: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out is None: out = torch.empty_like(query) @@ -2081,6 +2070,7 @@ def _paged_run( value_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) return out @@ -2147,7 +2137,6 @@ def paged_run( skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, ) -> None: - assert maybe_lse is None assert paged_kv_cache is not None assert num_qo_heads is not None assert num_kv_heads is not None @@ -2176,6 +2165,7 @@ def paged_run( value_block_scales=value_block_scales, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + lse=maybe_lse, ) @register_fake_op(f"flashinfer::{uri}_paged_run") @@ -2259,7 +2249,11 @@ def trtllm_batch_decode_with_kv_cache( skip_softmax_threshold_scale_factor: Optional[float] = None, kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, uses_shared_paged_kv_idx: bool = True, -) -> Union[torch.Tensor, FP4Tensor]: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[ + torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] +]: """ Parameters ---------- @@ -2379,10 +2373,19 @@ def trtllm_batch_decode_with_kv_cache( True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. + lse: Optional[torch.Tensor] = None + The log-sum-exp of attention logits, if not provided, will be allocated internally. + Only supported by trtllm-gen backend. + + return_lse: bool = False + Whether to return the logsumexp of attention scores, defaults to ``False``. + Returns ------- out : Union[torch.Tensor, FP4Tensor] output torch.Tensor or FP4Tensor. + lse: Optional[torch.Tensor] + The log-sum-exp of attention logits, if not provided, will be allocated internally. """ enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl @@ -2430,6 +2433,11 @@ def trtllm_batch_decode_with_kv_cache( backend = ( "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) + wants_lse = return_lse or lse is not None + if wants_lse and backend != "trtllm-gen": + raise ValueError( + "lse and return_lse are only supported by the trtllm-gen backend" + ) if backend == "xqa": # xqa backend doesn't support nvfp4 output @@ -2578,6 +2586,12 @@ def trtllm_batch_decode_with_kv_cache( _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) + lse_shape = (query.size(0), query.size(1)) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) + run_func( out, out_scale_factor, @@ -2606,13 +2620,18 @@ def trtllm_batch_decode_with_kv_cache( v_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) - return ( + out = ( out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) ) + if return_lse: + return out, lse + else: + return out else: raise KeyError(f"Backend {backend} not supported") diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 4e8bdd7212..b471862c7c 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -558,13 +558,13 @@ def run( out, q_nope.shape, q_nope.dtype, q_nope.device, "out" ) - if return_lse: - if lse is None: - lse = torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device) - else: - check_shape_dtype_device( - lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse" - ) + lse_shape = (q_nope.shape[0], num_heads) + if lse is not None: + check_shape_dtype_device( + lse, lse_shape, torch.float32, q_nope.device, "lse" + ) + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=device) profiler_args = (profiler_buffer,) if self._use_profiler else () self._cached_module.run( self._float_workspace_buffer, @@ -609,7 +609,9 @@ def trtllm_batch_decode_with_kv_cache_mla( backend: str = "auto", is_var_seq: bool = True, uses_shared_paged_kv_idx: bool = True, -) -> torch.Tensor: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Parameters ---------- @@ -654,6 +656,11 @@ def trtllm_batch_decode_with_kv_cache_mla( True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. False is only supported for trtllm-gen backend. + lse: Optional[torch.Tensor] = None + The log-sum-exp of attention logits, if not provided, will be allocated internally. + Only supported by trtllm-gen backend. + return_lse: bool = False + Whether to return the logsumexp of attention scores, defaults to ``False``. Note ---- @@ -678,6 +685,11 @@ def trtllm_batch_decode_with_kv_cache_mla( backend = ( "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" ) + wants_lse = return_lse or lse is not None + if wants_lse and backend != "trtllm-gen": + raise ValueError( + "lse and return_lse are only supported by the trtllm-gen backend" + ) if isinstance(bmm1_scale, torch.Tensor): assert bmm1_scale.dtype == torch.float32 bmm1_scale = bmm1_scale * log2e @@ -765,6 +777,12 @@ def trtllm_batch_decode_with_kv_cache_mla( batch_size = query.size(0) max_q_len = query.size(1) + num_qo_heads = query.size(2) + lse_shape = (batch_size * max_q_len, num_qo_heads) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) query = query.flatten(0, 1) # [B*S, H, D] run_func( @@ -795,9 +813,13 @@ def trtllm_batch_decode_with_kv_cache_mla( None, # value_block_scales skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) + if return_lse: + return out, lse + else: + return out - return out elif backend == "cute-dsl": enable_pdl = ( device_support_pdl(query.device) if enable_pdl is None else enable_pdl diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index e64d4a73b6..b907edb0b0 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -269,6 +269,7 @@ def _paged_run( value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, + lse: Optional[torch.Tensor] = None, ) -> torch.Tensor: sm_count = get_device_sm_count(query.device) if out is None: @@ -306,6 +307,7 @@ def _paged_run( value_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) return out @@ -679,7 +681,6 @@ def paged_run( uses_shared_paged_kv_idx: bool = True, ) -> None: if backend == "trtllm-gen": - assert maybe_lse is None assert num_qo_heads is not None assert num_kv_heads is not None assert block_tables is not None @@ -714,6 +715,7 @@ def paged_run( value_block_scales=value_block_scales, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + lse=maybe_lse, ) elif backend == "fa2": assert not is_float8(q) @@ -2282,15 +2284,11 @@ def run( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - if return_lse: - if lse is None: - lse = torch.empty( - (q.size(0), q.size(1)), dtype=torch.float32, device=q.device - ) - else: - check_shape_dtype_device( - lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" - ) + lse_shape = (q.size(0), q.size(1)) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, q.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device) # For NVFP4 KV (uint8 packed), v_cache last dim is head_dim//2; # use q's head_dim for output instead @@ -3864,7 +3862,11 @@ def trtllm_batch_context_with_kv_cache( kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, uses_shared_paged_kv_idx: bool = True, -) -> Union[torch.Tensor, FP4Tensor]: + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, +) -> Union[ + torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] +]: """ Parameters ---------- @@ -3961,10 +3963,17 @@ def trtllm_batch_context_with_kv_cache( Whether the K and V page indices are shared as a unified index. True (default) uses vLLM/FlashInfer layout with a 2D page table. False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. + lse : Optional[torch.Tensor] = None + The log-sum-exp of attention logits, if not provided, will be allocated internally. + Only supported by trtllm-gen backend. + return_lse : bool = False + Whether to return the logsumexp of attention scores, defaults to ``False``. Returns ------- out: Union[torch.Tensor, FP4Tensor] output torch.Tensor or FP4Tensor. + lse: Optional[torch.Tensor] + The log-sum-exp of attention logits, if not provided, will be allocated internally. """ if enable_pdl is None: @@ -4097,6 +4106,11 @@ def trtllm_batch_context_with_kv_cache( assert bmm2_scale.dtype == torch.float32 _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() + lse_shape = (query.size(0), query.size(1)) + if lse is not None: + check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") + elif return_lse: + lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) run_func( out, out_scale_factor, @@ -4125,12 +4139,17 @@ def trtllm_batch_context_with_kv_cache( value_block_scales, skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx, + lse, ) - return ( + out = ( out if out_dtype != "nvfp4" else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) ) + if return_lse: + return out, lse + else: + return out @functools.cache diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index 7945b1fab2..cee2b39de6 100644 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -224,6 +224,10 @@ struct TllmGenFmhaRunnerParams { float2* softmaxStatsPtr; // The LSE buffer. float* lsePtr; + // The stride between different tokens for LSE. + int64_t lseStrideTokens; + // The stride between different heads for LSE. + int64_t lseStrideHeads; // Attention sink float const* ptrAttentionSinks{nullptr}; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 186f1e5072..7a0e6f1e1d 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -33,6 +33,17 @@ global_workspace_buffer = None # can.be empty initialized global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized workspace_size = 256 * 1024 * 1024 +TRTLLM_GEN_WORKSPACE_CHECK_BYTES = 8192 * 256 * 4 + + +def get_lse_test_tolerances(q_dtype: str, kv_dtype: str) -> tuple[float, float]: + # TRT-LLM's FP8 prefill/decode LSE is noisier than the high-precision wrapper + # reference even when the output tensors still agree within the existing tolerances. + if kv_dtype == "nvfp4": + return 3e-1, 3e-1 + if q_dtype == "fp8" or kv_dtype == "fp8": + return 2e-2, 2e-2 + return 1e-3, 1e-3 def flip_coin(*args, **kwargs): @@ -481,6 +492,16 @@ def generate_causal_mask( return mask_uint16 +def trtllm_gen_workspace_softmax_end_bytes_context( + workspace_buffer: torch.Tensor, *, num_qo_heads: int, sum_seq_q: int +) -> int: + """Context/prefill uses softmax stats as the first workspace slab.""" + softmax_stats_nbytes = 8 * num_qo_heads * sum_seq_q + base_addr = workspace_buffer.data_ptr() + aligned_addr = (base_addr + 15) // 16 * 16 + return (aligned_addr - base_addr) + softmax_stats_nbytes + + def _test_trtllm_batch_prefill( kv_layout: str, batch_size: int, @@ -566,6 +587,7 @@ def _test_trtllm_batch_prefill( ) sm_scale = float(1.0 / (head_dim**0.5)) + lse_ref = None # Build reference output plan_params = { @@ -589,7 +611,7 @@ def _test_trtllm_batch_prefill( workspace_buffer_ref, kv_layout ) wrapper_ref.plan(**plan_params) - output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + output_ref, lse_ref = wrapper_ref.run(ref_q, ref_kv_cache, return_lse=True) else: # Construct flat K/V via helper k_flat, v_flat, kv_indptr_tokens = flatten_paged_kv( @@ -635,8 +657,22 @@ def _test_trtllm_batch_prefill( # Using a tiny threshold should give the same result as normal attention. skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None + softmax_end_bytes = trtllm_gen_workspace_softmax_end_bytes_context( + workspace_buffer, + num_qo_heads=q_input.size(1), + sum_seq_q=q_input.size(0), + ) + workspace_check_end_bytes = min( + softmax_end_bytes + TRTLLM_GEN_WORKSPACE_CHECK_BYTES, workspace_buffer.numel() + ) + workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].zero_() + provided_lse = torch.empty( + (q_input.size(0), q_input.size(1)), + device=GPU_DEVICE, + dtype=torch.float32, + ) - output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( + output, lse = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q_input, kv_cache_kernel, workspace_buffer, @@ -660,10 +696,12 @@ def _test_trtllm_batch_prefill( kv_cache_sf=kv_cache_sf_kernel, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + lse=provided_lse, + return_lse=True, ) - # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero - # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future - assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + assert ( + workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].cpu().numpy() == 0 + ).all() if o_dtype == "nvfp4": output, output_ref = unpack_compare_nvfp4( @@ -712,6 +750,15 @@ def _test_trtllm_batch_prefill( f"Block scale factors may be mismatched to FP4 data blocks." ) + expected_lse_shape = (q_input.size(0), q_input.size(1)) + assert lse is provided_lse + assert lse.shape == expected_lse_shape + assert lse.dtype == torch.float32 + assert torch.isfinite(lse).all() + if lse_ref is not None: + lse_rtol, lse_atol = get_lse_test_tolerances(q_dtype, kv_dtype) + torch.testing.assert_close(lse, lse_ref, rtol=lse_rtol, atol=lse_atol) + if ( o_dtype != "nvfp4" and kv_dtype != "nvfp4" and uses_shared_paged_kv_idx ): # wrapper api does not support fp4 output/kv or separate KV page indices yet. @@ -723,6 +770,7 @@ def _test_trtllm_batch_prefill( plan_params["kv_data_type"] = kv_cache.dtype plan_params["o_data_type"] = DTYPE_MAP[o_dtype] wrapper_trtllm_gen.plan(**plan_params) + workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].zero_() output_wrapper = wrapper_trtllm_gen.run( q_input, kv_cache, @@ -739,9 +787,10 @@ def _test_trtllm_batch_prefill( torch.testing.assert_close( output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1 ) - # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero - # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future - assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() + assert ( + workspace_buffer[softmax_end_bytes:workspace_check_end_bytes].cpu().numpy() + == 0 + ).all() @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @@ -998,6 +1047,8 @@ def _test_trtllm_batch_decode( ) sm_scale = float(1.0 / (head_dim**0.5)) + should_check_lse = backend == "trtllm-gen" + lse_ref = None # Build reference output plan_params = { @@ -1019,7 +1070,12 @@ def _test_trtllm_batch_decode( workspace_buffer_ref, kv_layout, use_tensor_cores=True ) wrapper_ref.plan(**plan_params) - output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + if should_check_lse: + output_ref, lse_ref = wrapper_ref.run( + ref_q, ref_kv_cache, return_lse=True + ) + else: + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) else: # speculative decoding test @@ -1039,7 +1095,12 @@ def _test_trtllm_batch_decode( } ) wrapper_ref.plan(**plan_params_prefill) - output_ref = wrapper_ref.run(ref_q, ref_kv_cache) + if should_check_lse: + output_ref, lse_ref = wrapper_ref.run( + ref_q, ref_kv_cache, return_lse=True + ) + else: + output_ref = wrapper_ref.run(ref_q, ref_kv_cache) else: # Construct flat K/V via helper k_flat, v_flat, kv_indptr_tokens = flatten_paged_kv( @@ -1088,6 +1149,15 @@ def _test_trtllm_batch_decode( q_input = make_query_non_contiguous(q, num_qo_heads, head_dim) else: q_input = q.contiguous() + provided_lse = ( + torch.empty( + (q_input.size(0), q_input.size(1)), + device=GPU_DEVICE, + dtype=torch.float32, + ) + if should_check_lse + else None + ) # Using a tiny threshold should give the same result as normal attention. skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None @@ -1118,7 +1188,15 @@ def _test_trtllm_batch_decode( skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, kv_cache_sf=kv_cache_sf_kernel, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + lse=provided_lse, + return_lse=should_check_lse, ) + if should_check_lse: + output, lse = output + assert lse is provided_lse + assert lse.shape == (q_input.size(0), q_input.size(1)) + assert lse.dtype == torch.float32 + assert torch.isfinite(lse).all() if backend == "trtllm-gen": # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future @@ -1181,6 +1259,10 @@ def _test_trtllm_batch_decode( f"NVFP4 KV cache attention: cosine similarity {cos:.4f} < 0.86. " f"Block scale factors may be mismatched to FP4 data blocks." ) + if lse_ref is not None: + assert lse is not None, "LSE should be returned when return_lse=True" + lse_rtol, lse_atol = get_lse_test_tolerances(q_dtype, kv_dtype) + torch.testing.assert_close(lse, lse_ref, rtol=lse_rtol, atol=lse_atol) # Only test wrapper with trtllm-gen backend if ( diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py old mode 100755 new mode 100644 index c1cf3d8a50..74208acc09 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -1,7 +1,9 @@ +import random +import math + import pytest import torch import torch.nn.functional as F -import random import flashinfer from flashinfer.mla import ( @@ -389,6 +391,7 @@ def trtllm_batch_decode_mla( # Using a tiny threshold should give the same output as standard attention skip_softmax_threshold_scale_factor = 1e-30 if skips_softmax else None + should_check_lse = backend == "trtllm-gen" # Run decode-MLA output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( @@ -407,9 +410,15 @@ def trtllm_batch_decode_mla( enable_pdl=enable_pdl, backend=backend, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + return_lse=should_check_lse, ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future + if should_check_lse: + output, lse = output + assert lse.shape == (batch_size * q_len_per_request, layer_dimensions.num_heads) + assert lse.dtype == torch.float32 + assert torch.isfinite(lse).all() if backend == "trtllm-gen": assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() @@ -465,6 +474,22 @@ def trtllm_batch_decode_mla( kpe = kv_cache[..., layer_dimensions.head_dimensions.kv_lora_rank :] o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) + lse_ref = None + if should_check_lse: + _, lse_ref = sparse_mla_reference_torch( + cache_seqlens=seq_lens_tensor, + block_table=block_tables, + q=query, + blocked_k=kv_cache, + blocked_v=ckv, + page_size=page_size, + is_causal=True, + sm_scale=sm_scale, + ) + # TRT-LLM returns log2(LSE); the torch helper returns natural-log LSE. + lse_ref = lse_ref.permute(0, 2, 1).contiguous().view( + batch_size * q_len_per_request, layer_dimensions.num_heads + ).to(lse.device) / math.log(2.0) # cute-dsl fp8 kernel outputs fp8; cast to bf16 to match trtllm-gen / reference if backend == "cute-dsl" and output.dtype == torch.float8_e4m3fn: @@ -486,6 +511,8 @@ def trtllm_batch_decode_mla( try: torch.testing.assert_close(output, o_ref_view, rtol=rtol, atol=atol) + if lse_ref is not None: + torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) except AssertionError as fa2_err: if backend == "cute-dsl": # fa2 reference may diverge from cute-dsl in some configs;