-
Notifications
You must be signed in to change notification settings - Fork 1k
Support lse in trtllm paged attn kernels #3058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
806e5c6
88e3693
a892361
c6f4684
903d007
3a0256c
a8528be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -77,18 +77,19 @@ void trtllm_paged_attention_launcher( | |
| void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache, | ||
| void* workspace_buffer, int* block_tables, const void* k_block_scales_ptr, | ||
| const void* v_block_scales_ptr, int* seq_lens, int* cum_seq_lens_q, int* cum_seq_lens_kv, | ||
| float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type, | ||
| TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, int64_t max_kv_len, | ||
| int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, | ||
| int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, int64_t q_stride_heads, | ||
| int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch, | ||
| int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, | ||
| float* attention_sinks, float* lse, Data_type q_data_type, Data_type kv_data_type, | ||
| Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, | ||
| int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, | ||
| int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, | ||
| int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads, | ||
| int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale, | ||
| const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, | ||
| int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, | ||
| int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax, | ||
| bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size, | ||
| int64_t k_sf_stride_heads, int64_t k_sf_stride_batch, int64_t v_sf_stride_heads, | ||
| int64_t v_sf_stride_batch, cudaStream_t stream) { | ||
| int64_t v_sf_stride_batch, int64_t lse_stride_tokens, int64_t lse_stride_heads, | ||
| cudaStream_t stream) { | ||
| if (num_qo_heads % num_kv_heads != 0) { | ||
| std::ostringstream err_msg; | ||
| err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads | ||
|
|
@@ -190,11 +191,26 @@ void trtllm_paged_attention_launcher( | |
| size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB | ||
| size_t num_semaphores = | ||
| round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes | ||
| // semaphores be at the first 8MB of workspace buffer: counter | scratch | ||
| // todo(Yingyi): add softmax buffer later for lse return | ||
| // Keep semaphores at the first 8MB of workspace buffer. | ||
| // Workspace layout for generation path: counter | softmax | scratch. | ||
| runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>( | ||
| num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); | ||
| // scratch takes the rest of the workspace buffer | ||
| } | ||
|
|
||
| // Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and | ||
| // ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction | ||
| // kernel only participates for special multi-CTA modes; single-CTA context kernels still write | ||
| // the final stats directly into this buffer. | ||
| runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>( | ||
| sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, | ||
| "trtllm_gen_softmax_workspace"); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I noticed that trtllm_ragged_attention_launcher on the main branch already allocates softmaxStatsPtr for the Context kernel path (alongside lsePtr), so this PR is consistent with the existing pattern. That said, do you know why the Context kernel needs softmaxStatsPtr even with mMultiCtasKvMode = false? Is it used as an intermediate buffer for computing LSE internally before writing to lsePtr? If so, it might be worth adding a brief comment here explaining that dependency β it's not obvious from the code alone.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qsang-nv we need |
||
| runner_params.lsePtr = lse; | ||
| runner_params.lseStrideTokens = lse_stride_tokens; | ||
| runner_params.lseStrideHeads = lse_stride_heads; | ||
|
|
||
| // The scratch allocation uses size=0 to consume the rest of the workspace, so it must be the | ||
| // last allocation from float_allocator. | ||
| if (mode == TllmPagedAttentionMode::ForGen) { | ||
| runner_params.multiCtasKvScratchPtr = | ||
| float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace"); | ||
| } | ||
|
|
@@ -244,7 +260,7 @@ void trtllm_paged_attention_decode( | |
| int64_t workspace_size, Optional<TensorView> attention_sinks, | ||
| Optional<TensorView> cum_seq_lens_q, Optional<TensorView> key_block_scales, | ||
| Optional<TensorView> value_block_scales, Optional<float> skip_softmax_threshold_scale_factor, | ||
| Optional<bool> uses_shared_paged_kv_idx) { | ||
| Optional<bool> uses_shared_paged_kv_idx, Optional<TensorView> lse) { | ||
| auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); | ||
| auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); | ||
| TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); | ||
|
|
@@ -352,19 +368,30 @@ void trtllm_paged_attention_decode( | |
| skip_softmax_threshold_scale_factor.value_or(0.0f); | ||
| bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; | ||
|
|
||
| float* lse_ptr = nullptr; | ||
| int64_t lse_stride_tokens = 0; | ||
| int64_t lse_stride_heads = 0; | ||
| if (lse.has_value()) { | ||
| TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; | ||
| lse_ptr = static_cast<float*>(lse.value().data_ptr()); | ||
| lse_stride_tokens = lse.value().stride(0); | ||
| lse_stride_heads = lse.value().stride(lse.value().ndim() - 1); | ||
| } | ||
|
|
||
| trtllm_paged_attention_launcher( | ||
| out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), | ||
| workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr, | ||
| v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr, | ||
| /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, | ||
| TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, | ||
| num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens, | ||
| q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, | ||
| /*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, | ||
| o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, | ||
| num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, | ||
| q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, | ||
| max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, | ||
| bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, | ||
| sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax, | ||
| uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads, | ||
| k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream); | ||
| k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads, | ||
| stream); | ||
| } | ||
|
|
||
| void trtllm_paged_attention_context( | ||
|
|
@@ -376,7 +403,8 @@ void trtllm_paged_attention_context( | |
| int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, | ||
| bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks, | ||
| Optional<TensorView> key_block_scales, Optional<TensorView> value_block_scales, | ||
| Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx) { | ||
| Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx, | ||
| Optional<TensorView> lse) { | ||
| auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); | ||
| auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); | ||
| auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); | ||
|
|
@@ -474,20 +502,31 @@ void trtllm_paged_attention_context( | |
| skip_softmax_threshold_scale_factor.value_or(0.0f); | ||
| bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f; | ||
|
|
||
| float* lse_ptr = nullptr; | ||
| int64_t lse_stride_tokens = 0; | ||
| int64_t lse_stride_heads = 0; | ||
| if (lse.has_value()) { | ||
| TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; | ||
| lse_ptr = static_cast<float*>(lse.value().data_ptr()); | ||
| lse_stride_tokens = lse.value().stride(0); | ||
| lse_stride_heads = lse.value().stride(1); | ||
| } | ||
|
|
||
| trtllm_paged_attention_launcher( | ||
| out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), | ||
| workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr, | ||
| v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()), | ||
| /*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()), | ||
| /*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, | ||
| q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, | ||
| lse_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, | ||
| max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, | ||
| head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values, | ||
| kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, | ||
| bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, | ||
| sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax, | ||
| uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads, | ||
| k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream); | ||
| k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads, | ||
| stream); | ||
| } | ||
|
|
||
| void trtllm_ragged_attention_launcher( | ||
|
|
@@ -566,6 +605,10 @@ void trtllm_ragged_attention_launcher( | |
| // semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch | ||
| runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>( | ||
| num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace"); | ||
| // Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and | ||
| // ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction | ||
| // kernel only participates for special multi-CTA modes; single-CTA context kernels still write | ||
| // the final stats directly into this buffer. | ||
| runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>( | ||
| sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16, | ||
| "trtllm_gen_softmax_workspace"); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1364,15 +1364,11 @@ def run( | |
| if rope_theta is None: | ||
| rope_theta = 1e4 | ||
|
|
||
| if return_lse: | ||
| if lse is None: | ||
| lse = torch.empty( | ||
| (q.size(0), q.size(1)), dtype=torch.float32, device=q.device | ||
| ) | ||
| else: | ||
| check_shape_dtype_device( | ||
| lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" | ||
| ) | ||
| lse_shape = (q.size(0), q.size(1)) | ||
| if lse is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update |
||
| check_shape_dtype_device(lse, lse_shape, torch.float32, q.device, "lse") | ||
| elif return_lse: | ||
| lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device) | ||
|
|
||
| if out is None: | ||
| out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype | ||
|
|
@@ -1963,21 +1959,13 @@ def run( | |
| out, q_nope.shape, q_nope.dtype, q_nope.device, "out" | ||
| ) | ||
|
|
||
| if return_lse: | ||
| if lse is None: | ||
| lse = torch.empty( | ||
| (q_nope.size(0), q_nope.size(1)), | ||
| dtype=torch.float32, | ||
| device=device, | ||
| ) | ||
| else: | ||
| check_shape_dtype_device( | ||
| lse, | ||
| (q_nope.size(0), q_nope.size(1)), | ||
| q_nope.dtype, | ||
| q_nope.device, | ||
| "lse", | ||
| ) | ||
| lse_shape = (q_nope.size(0), q_nope.size(1)) | ||
| if lse is not None: | ||
| check_shape_dtype_device( | ||
| lse, lse_shape, torch.float32, q_nope.device, "lse" | ||
| ) | ||
| elif return_lse: | ||
| lse = torch.empty(lse_shape, dtype=torch.float32, device=device) | ||
| self._cached_module.run( | ||
| self._float_workspace_buffer, | ||
| self._int_workspace_buffer, | ||
|
|
@@ -2036,6 +2024,7 @@ def _paged_run( | |
| value_block_scales: Optional[torch.Tensor] = None, | ||
| skip_softmax_threshold_scale_factor: Optional[float] = None, | ||
| uses_shared_paged_kv_idx: bool = True, | ||
| lse: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| if out is None: | ||
| out = torch.empty_like(query) | ||
|
|
@@ -2081,6 +2070,7 @@ def _paged_run( | |
| value_block_scales, | ||
| skip_softmax_threshold_scale_factor, | ||
| uses_shared_paged_kv_idx, | ||
| lse, | ||
| ) | ||
| return out | ||
|
|
||
|
|
@@ -2147,7 +2137,6 @@ def paged_run( | |
| skip_softmax_threshold_scale_factor: Optional[float] = None, | ||
| uses_shared_paged_kv_idx: bool = True, | ||
| ) -> None: | ||
| assert maybe_lse is None | ||
| assert paged_kv_cache is not None | ||
| assert num_qo_heads is not None | ||
| assert num_kv_heads is not None | ||
|
|
@@ -2176,6 +2165,7 @@ def paged_run( | |
| value_block_scales=value_block_scales, | ||
| skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, | ||
| uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, | ||
| lse=maybe_lse, | ||
| ) | ||
|
|
||
| @register_fake_op(f"flashinfer::{uri}_paged_run") | ||
|
|
@@ -2259,7 +2249,11 @@ def trtllm_batch_decode_with_kv_cache( | |
| skip_softmax_threshold_scale_factor: Optional[float] = None, | ||
| kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | ||
| uses_shared_paged_kv_idx: bool = True, | ||
| ) -> Union[torch.Tensor, FP4Tensor]: | ||
| lse: Optional[torch.Tensor] = None, | ||
| return_lse: bool = False, | ||
| ) -> Union[ | ||
| torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor] | ||
| ]: | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| """ | ||
| Parameters | ||
| ---------- | ||
|
|
@@ -2379,10 +2373,19 @@ def trtllm_batch_decode_with_kv_cache( | |
| True (default) uses vLLM/FlashInfer layout with a 2D page table. | ||
| False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. | ||
|
|
||
| lse: Optional[torch.Tensor] = None | ||
| The log-sum-exp of attention logits, if not provided, will be allocated internally. | ||
| Only supported by trtllm-gen backend. | ||
|
|
||
| return_lse: bool = False | ||
| Whether to return the logsumexp of attention scores, defaults to ``False``. | ||
|
|
||
| Returns | ||
| ------- | ||
| out : Union[torch.Tensor, FP4Tensor] | ||
| output torch.Tensor or FP4Tensor. | ||
| lse: Optional[torch.Tensor] | ||
| The log-sum-exp of attention logits, if not provided, will be allocated internally. | ||
| """ | ||
| enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl | ||
|
|
||
|
|
@@ -2430,6 +2433,11 @@ def trtllm_batch_decode_with_kv_cache( | |
| backend = ( | ||
| "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa" | ||
| ) | ||
| wants_lse = return_lse or lse is not None | ||
| if wants_lse and backend != "trtllm-gen": | ||
| raise ValueError( | ||
| "lse and return_lse are only supported by the trtllm-gen backend" | ||
| ) | ||
|
|
||
| if backend == "xqa": | ||
| # xqa backend doesn't support nvfp4 output | ||
|
|
@@ -2578,6 +2586,12 @@ def trtllm_batch_decode_with_kv_cache( | |
|
|
||
| _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) | ||
|
|
||
| lse_shape = (query.size(0), query.size(1)) | ||
| if lse is not None: | ||
| check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse") | ||
| elif return_lse: | ||
| lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device) | ||
|
|
||
| run_func( | ||
| out, | ||
| out_scale_factor, | ||
|
|
@@ -2606,13 +2620,18 @@ def trtllm_batch_decode_with_kv_cache( | |
| v_block_scales, | ||
| skip_softmax_threshold_scale_factor, | ||
| uses_shared_paged_kv_idx, | ||
| lse, | ||
| ) | ||
|
|
||
| return ( | ||
| out = ( | ||
| out | ||
| if out_dtype != "nvfp4" | ||
| else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape) | ||
| ) | ||
| if return_lse: | ||
| return out, lse | ||
| else: | ||
| return out | ||
| else: | ||
| raise KeyError(f"Backend {backend} not supported") | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
softmaxStatsPtris now allocated for bothContextandGenerationmodes. InContextmode,runner_params.mSumOfSeqLensQ(the total number of query tokens) can be very large, which may lead to excessive memory allocation in the workspace (e.g., for 128 heads and 128k tokens, this would require ~128MB). SinceContextkernels typically do not require this workspace buffer for multi-block reduction (asmMultiCtasKvModeisfalse), this allocation could cause workspace exhaustion for long sequences. Consider moving this allocation inside theelseblock (Generation path) or making it conditional onmode == TllmPagedAttentionMode::ForGen.