Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,19 @@ void trtllm_paged_attention_launcher(
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
void* workspace_buffer, int* block_tables, const void* k_block_scales_ptr,
const void* v_block_scales_ptr, int* seq_lens, int* cum_seq_lens_q, int* cum_seq_lens_kv,
float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type,
TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, int64_t max_kv_len,
int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, int64_t q_stride_heads,
int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch,
int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
float* attention_sinks, float* lse, Data_type q_data_type, Data_type kv_data_type,
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens,
int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads,
int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q,
int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax,
bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
int64_t k_sf_stride_heads, int64_t k_sf_stride_batch, int64_t v_sf_stride_heads,
int64_t v_sf_stride_batch, cudaStream_t stream) {
int64_t v_sf_stride_batch, int64_t lse_stride_tokens, int64_t lse_stride_heads,
cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -190,11 +191,22 @@ void trtllm_paged_attention_launcher(
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
size_t num_semaphores =
round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes
// semaphores be at the first 8MB of workspace buffer: counter | scratch
// todo(Yingyi): add softmax buffer later for lse return
// Keep semaphores at the first 8MB of workspace buffer.
// Workspace layout for generation path: counter | softmax | scratch.
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
// scratch takes the rest of the workspace buffer
}

runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
Comment on lines +204 to +206

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The softmaxStatsPtr is now allocated for both Context and Generation modes. In Context mode, runner_params.mSumOfSeqLensQ (the total number of query tokens) can be very large, which may lead to excessive memory allocation in the workspace (e.g., for 128 heads and 128k tokens, this would require ~128MB). Since Context kernels typically do not require this workspace buffer for multi-block reduction (as mMultiCtasKvMode is false), this allocation could cause workspace exhaustion for long sequences. Consider moving this allocation inside the else block (Generation path) or making it conditional on mode == TllmPagedAttentionMode::ForGen.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that trtllm_ragged_attention_launcher on the main branch already allocates softmaxStatsPtr for the Context kernel path (alongside lsePtr), so this PR is consistent with the existing pattern. That said, do you know why the Context kernel needs softmaxStatsPtr even with mMultiCtasKvMode = false? Is it used as an intermediate buffer for computing LSE internally before writing to lsePtr? If so, it might be worth adding a brief comment here explaining that dependency β€” it's not obvious from the code alone.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qsang-nv we need softmaxStatsPtr generally for LSE materialization, not just for multi-CTA reduction. Whenever lsePtr is set, we will call ComputeLSEFromMD. Added a comment in the code.

runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;

// The scratch allocation uses size=0 to consume the rest of the workspace, so it must be the
// last allocation from float_allocator.
if (mode == TllmPagedAttentionMode::ForGen) {
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
}
Expand Down Expand Up @@ -244,7 +256,7 @@ void trtllm_paged_attention_decode(
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q, Optional<TensorView> key_block_scales,
Optional<TensorView> value_block_scales, Optional<float> skip_softmax_threshold_scale_factor,
Optional<bool> uses_shared_paged_kv_idx) {
Optional<bool> uses_shared_paged_kv_idx, Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -352,19 +364,30 @@ void trtllm_paged_attention_decode(
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(lse.value().ndim() - 1);
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr,
v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool,
num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens,
q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type,
o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads,
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream);
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads,
stream);
}

void trtllm_paged_attention_context(
Expand All @@ -376,7 +399,8 @@ void trtllm_paged_attention_context(
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> key_block_scales, Optional<TensorView> value_block_scales,
Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx) {
Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx,
Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand Down Expand Up @@ -474,20 +498,31 @@ void trtllm_paged_attention_context(
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

float* lse_ptr = nullptr;
int lse_stride_tokens = 0;
int lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr,
v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
lse_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values,
kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value,
bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left,
sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads,
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream);
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads,
stream);
}

void trtllm_ragged_attention_launcher(
Expand Down
60 changes: 43 additions & 17 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,13 +1365,12 @@ def run(
rope_theta = 1e4

if return_lse:
lse_shape = (q.size(0), q.size(1))
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)
lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
lse, lse_shape, torch.float32, q.device, "lse"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

if out is None:
Expand Down Expand Up @@ -1964,19 +1963,12 @@ def run(
)

if return_lse:
lse_shape = (q_nope.size(0), q_nope.size(1))
if lse is None:
lse = torch.empty(
(q_nope.size(0), q_nope.size(1)),
dtype=torch.float32,
device=device,
)
lse = torch.empty(lse_shape, dtype=torch.float32, device=device)
else:
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
q_nope.dtype,
q_nope.device,
"lse",
lse, lse_shape, q_nope.dtype, q_nope.device, "lse"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
self._cached_module.run(
self._float_workspace_buffer,
Expand Down Expand Up @@ -2036,6 +2028,7 @@ def _paged_run(
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
lse: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(query)
Expand Down Expand Up @@ -2081,6 +2074,7 @@ def _paged_run(
value_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
lse,
)
return out

Expand Down Expand Up @@ -2147,7 +2141,6 @@ def paged_run(
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> None:
assert maybe_lse is None
assert paged_kv_cache is not None
assert num_qo_heads is not None
assert num_kv_heads is not None
Expand Down Expand Up @@ -2176,6 +2169,7 @@ def paged_run(
value_block_scales=value_block_scales,
skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
lse=maybe_lse,
)

@register_fake_op(f"flashinfer::{uri}_paged_run")
Expand Down Expand Up @@ -2259,7 +2253,11 @@ def trtllm_batch_decode_with_kv_cache(
skip_softmax_threshold_scale_factor: Optional[float] = None,
kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
uses_shared_paged_kv_idx: bool = True,
) -> Union[torch.Tensor, FP4Tensor]:
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> Union[
torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
]:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""
Parameters
----------
Expand Down Expand Up @@ -2379,10 +2377,19 @@ def trtllm_batch_decode_with_kv_cache(
True (default) uses vLLM/FlashInfer layout with a 2D page table.
False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``.

lse: Optional[torch.Tensor] = None
The log-sum-exp of attention logits, if not provided, will be allocated internally.
Only supported by trtllm-gen backend.

return_lse: bool = False
Whether to return the logsumexp of attention scores, defaults to ``False``.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
output torch.Tensor or FP4Tensor.
lse: Optional[torch.Tensor]
The log-sum-exp of attention logits, if not provided, will be allocated internally.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

Expand Down Expand Up @@ -2430,6 +2437,11 @@ def trtllm_batch_decode_with_kv_cache(
backend = (
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
)
wants_lse = return_lse or lse is not None
if wants_lse and backend != "trtllm-gen":
raise ValueError(
"lse and return_lse are only supported by the trtllm-gen backend"
)

if backend == "xqa":
# xqa backend doesn't support nvfp4 output
Expand Down Expand Up @@ -2578,6 +2590,15 @@ def trtllm_batch_decode_with_kv_cache(

_check_block_tables_shape(block_tables, uses_shared_paged_kv_idx)

if return_lse:
lse_shape = (query.size(0), query.size(1))
if lse is None:
lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device)
else:
check_shape_dtype_device(
lse, lse_shape, torch.float32, query.device, "lse"
)

run_func(
out,
out_scale_factor,
Expand Down Expand Up @@ -2606,13 +2627,18 @@ def trtllm_batch_decode_with_kv_cache(
v_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
lse,
)

return (
out = (
out
if out_dtype != "nvfp4"
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
)
if return_lse:
return out, lse
else:
return out
else:
raise KeyError(f"Backend {backend} not supported")

Expand Down
34 changes: 30 additions & 4 deletions flashinfer/mla/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,12 @@ def run(
)

if return_lse:
lse_shape = (q_nope.shape[0] * num_heads, q_nope.shape[2])
if lse is None:
lse = torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device)
lse = torch.empty(lse_shape, dtype=torch.float32, device=device)
else:
check_shape_dtype_device(
lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse"
lse, lse_shape, torch.float32, q_nope.device, "lse"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
profiler_args = (profiler_buffer,) if self._use_profiler else ()
self._cached_module.run(
Expand Down Expand Up @@ -609,7 +610,9 @@ def trtllm_batch_decode_with_kv_cache_mla(
backend: str = "auto",
is_var_seq: bool = True,
uses_shared_paged_kv_idx: bool = True,
) -> torch.Tensor:
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Comment on lines +612 to +614

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast when LSE is requested on non-trtllm-gen backends.

The new args are documented as TRTLLM-only, but the xqa and cute-dsl branches still ignore them and return a plain tensor. out, lse = ... will therefore either raise or silently unpack batch slices depending on the output shape.

πŸ› οΈ Suggested guard
     if backend == "auto":
         backend = (
             "trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
         )
+    if backend != "trtllm-gen" and (return_lse or lse is not None):
+        raise ValueError(
+            "lse and return_lse are only supported by the trtllm-gen backend"
+        )

Also applies to: 659-663

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/mla/_core.py` around lines 612 - 614, The code adds TRTLLM-only
args lse and return_lse but the xqa and cute-dsl branches still ignore them and
return a plain tensor; update the function handling to fail fast: at the start
of the method (around the signature with parameters lse and return_lse) add a
guard that checks if (return_lse or lse is not None) and the current backend is
not 'trtllm-gen' (or equivalent backend identifier used in this module), and
raise a clear ValueError/RuntimeError explaining that LSE/return_lse are only
supported for trtllm-gen; also add the same guard inside the xqa and cute-dsl
branch handlers (the code paths around the existing xqa and cute-dsl handling
lines) so they explicitly raise instead of returning plain tensors when LSE is
requested.

"""
Parameters
----------
Expand Down Expand Up @@ -654,6 +657,11 @@ def trtllm_batch_decode_with_kv_cache_mla(
True (default) uses vLLM/FlashInfer layout with a 2D page table.
False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``.
False is only supported for trtllm-gen backend.
lse: Optional[torch.Tensor] = None
The log-sum-exp of attention logits, if not provided, will be allocated internally.
Only supported by trtllm-gen backend.
return_lse: bool = False
Whether to return the logsumexp of attention scores, defaults to ``False``.

Note
----
Expand All @@ -678,6 +686,11 @@ def trtllm_batch_decode_with_kv_cache_mla(
backend = (
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
)
wants_lse = return_lse or lse is not None
if wants_lse and backend != "trtllm-gen":
raise ValueError(
"lse and return_lse are only supported by the trtllm-gen backend"
)
if isinstance(bmm1_scale, torch.Tensor):
assert bmm1_scale.dtype == torch.float32
bmm1_scale = bmm1_scale * log2e
Expand Down Expand Up @@ -765,6 +778,15 @@ def trtllm_batch_decode_with_kv_cache_mla(

batch_size = query.size(0)
max_q_len = query.size(1)
num_qo_heads = query.size(2)
if return_lse:
lse_shape = (batch_size * max_q_len, num_qo_heads)
if lse is None:
lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device)
else:
check_shape_dtype_device(
lse, lse_shape, torch.float32, query.device, "lse"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
query = query.flatten(0, 1) # [B*S, H, D]

run_func(
Expand Down Expand Up @@ -795,9 +817,13 @@ def trtllm_batch_decode_with_kv_cache_mla(
None, # value_block_scales
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
lse,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The lse tensor should be flattened to 2D (num_tokens, num_heads) before being passed to the kernel, as the TRT-LLM launcher expects the first dimension to be the token dimension for stride calculations.

Suggested change
lse,
lse.flatten(0, 1) if lse is not None else None,

)
if return_lse:
return out, lse
else:
return out

return out
elif backend == "cute-dsl":
enable_pdl = (
device_support_pdl(query.device) if enable_pdl is None else enable_pdl
Expand Down
Loading
Loading