Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 19 additions & 62 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,18 @@ void trtllm_paged_attention_launcher(
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
void* workspace_buffer, int* block_tables, const void* k_block_scales_ptr,
const void* v_block_scales_ptr, int* seq_lens, int* cum_seq_lens_q, int* cum_seq_lens_kv,
float* attention_sinks, float* lse, Data_type q_data_type, Data_type kv_data_type,
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens,
int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads,
int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type,
TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, int64_t max_kv_len,
int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, int64_t q_stride_heads,
int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch,
int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q,
int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax,
bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
int64_t k_sf_stride_heads, int64_t k_sf_stride_batch, int64_t v_sf_stride_heads,
int64_t v_sf_stride_batch, int64_t lse_stride_tokens, int64_t lse_stride_heads,
cudaStream_t stream) {
int64_t v_sf_stride_batch, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -191,26 +190,11 @@ void trtllm_paged_attention_launcher(
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
size_t num_semaphores =
round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes
// Keep semaphores at the first 8MB of workspace buffer.
// Workspace layout for generation path: counter | softmax | scratch.
// semaphores be at the first 8MB of workspace buffer: counter | scratch
// todo(Yingyi): add softmax buffer later for lse return
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
}

// Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and
// ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction
// kernel only participates for special multi-CTA modes; single-CTA context kernels still write
// the final stats directly into this buffer.
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;

// The scratch allocation uses size=0 to consume the rest of the workspace, so it must be the
// last allocation from float_allocator.
if (mode == TllmPagedAttentionMode::ForGen) {
// scratch takes the rest of the workspace buffer
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
}
Expand Down Expand Up @@ -260,7 +244,7 @@ void trtllm_paged_attention_decode(
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q, Optional<TensorView> key_block_scales,
Optional<TensorView> value_block_scales, Optional<float> skip_softmax_threshold_scale_factor,
Optional<bool> uses_shared_paged_kv_idx, Optional<TensorView> lse) {
Optional<bool> uses_shared_paged_kv_idx) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -368,30 +352,19 @@ void trtllm_paged_attention_decode(
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

float* lse_ptr = nullptr;
int64_t lse_stride_tokens = 0;
int64_t lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(lse.value().ndim() - 1);
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr,
v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type,
o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool,
num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens,
q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads,
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads,
stream);
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream);
}

void trtllm_paged_attention_context(
Expand All @@ -403,8 +376,7 @@ void trtllm_paged_attention_context(
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> key_block_scales, Optional<TensorView> value_block_scales,
Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx,
Optional<TensorView> lse) {
Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand Down Expand Up @@ -502,31 +474,20 @@ void trtllm_paged_attention_context(
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

float* lse_ptr = nullptr;
int64_t lse_stride_tokens = 0;
int64_t lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr,
v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
lse_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values,
kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value,
bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left,
sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads,
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads,
stream);
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream);
}

void trtllm_ragged_attention_launcher(
Expand Down Expand Up @@ -605,10 +566,6 @@ void trtllm_ragged_attention_launcher(
// semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
// Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and
// ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction
// kernel only participates for special multi-CTA modes; single-CTA context kernels still write
// the final stats directly into this buffer.
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
Expand Down
73 changes: 27 additions & 46 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,11 +1364,15 @@ def run(
if rope_theta is None:
rope_theta = 1e4

lse_shape = (q.size(0), q.size(1))
if lse is not None:
check_shape_dtype_device(lse, lse_shape, torch.float32, q.device, "lse")
elif return_lse:
lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device)
if return_lse:
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
)
Comment on lines +1367 to +1375
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -e
sed -n '1218,1525p' flashinfer/decode.py
echo '---'
sed -n '2096,2160p' flashinfer/decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 16188


🏁 Script executed:

# Check for existing guards on lse/return_lse with trtllm-gen
rg -n "trtllm-gen" flashinfer/decode.py | head -20

Repository: flashinfer-ai/flashinfer

Length of output: 1599


🏁 Script executed:

# Check the complete argument assembly and how lse is passed to the backend
sed -n '1350,1450p' flashinfer/decode.py | cat -n

Repository: flashinfer-ai/flashinfer

Length of output: 4692


🏁 Script executed:

# Verify the custom op assert and surrounding context
sed -n '2145,2165p' flashinfer/decode.py

Repository: flashinfer-ai/flashinfer

Length of output: 975


Block lse/return_lse on the TRT-LLM decode wrapper instead of asserting internally.

The public wrapper accepts return_lse=True and explicit lse tensors without checking the backend, but passes them to the custom op's paged_run() which asserts maybe_lse is None. This causes an AssertionError instead of a stable user-facing error, and fails silently under python -O.

♻️ Suggested fix
+        if self._backend == "trtllm-gen" and (return_lse or lse is not None):
+            raise ValueError(
+                "trtllm-gen backend does not support lse/return_lse"
+            )
         if return_lse:
             if lse is None:
                 lse = torch.empty(
                     (q.size(0), q.size(1)), dtype=torch.float32, device=q.device
                 )
             else:
                 check_shape_dtype_device(
                     lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1367 - 1375, The TRT-LLM decode wrapper
must refuse use of return_lse/lse instead of letting the custom op's paged_run
assert; add an explicit check in the wrapper (the block handling return_lse and
lse) that if the backend is TRT-LLM (or when calling paged_run) and (return_lse
is True or lse is not None) raise a clear ValueError with a user-facing message;
reference the existing symbols return_lse, lse, paged_run and maybe_lse so you
locate the code path and replace the silent assertion with this explicit check.


if out is None:
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
Expand Down Expand Up @@ -1959,13 +1963,21 @@ def run(
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
)

lse_shape = (q_nope.size(0), q_nope.size(1))
if lse is not None:
check_shape_dtype_device(
lse, lse_shape, torch.float32, q_nope.device, "lse"
)
elif return_lse:
lse = torch.empty(lse_shape, dtype=torch.float32, device=device)
if return_lse:
if lse is None:
lse = torch.empty(
(q_nope.size(0), q_nope.size(1)),
dtype=torch.float32,
device=device,
)
else:
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
q_nope.dtype,
q_nope.device,
"lse",
)
Comment on lines +1974 to +1980
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check_shape_dtype_device for lse is using q_nope.dtype as the expected data type. The log-sum-exp tensor (lse) should have a high-precision float type, typically torch.float32, regardless of the query's data type. Using q_nope.dtype could lead to incorrect type checks when q_nope is a lower precision type like float16 or bfloat16. This appears to be a reintroduction of a bug that might have been fixed in the reverted changes.

Suggested change
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
q_nope.dtype,
q_nope.device,
"lse",
)
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
torch.float32,
q_nope.device,
"lse",
)

Comment on lines +1966 to +1980
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate caller-supplied MLA lse buffers as float32.

This branch allocates lse as torch.float32, but the explicit-buffer path validates against q_nope.dtype. A correctly preallocated float32 lse tensor will fail here for bf16/fp16 inputs.

♻️ Suggested fix
                 check_shape_dtype_device(
                     lse,
                     (q_nope.size(0), q_nope.size(1)),
-                    q_nope.dtype,
+                    torch.float32,
                     q_nope.device,
                     "lse",
                 )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/decode.py` around lines 1966 - 1980, The code allocates lse as
torch.float32 when return_lse is true but validates a caller-supplied lse
against q_nope.dtype (which may be bf16/fp16), causing valid float32 buffers to
fail; update the validation to expect torch.float32 instead of q_nope.dtype by
calling check_shape_dtype_device(lse, (q_nope.size(0), q_nope.size(1)),
torch.float32, q_nope.device, "lse") (ensure you import/qualify torch.float32 if
needed) while keeping the shape and device checks the same.

self._cached_module.run(
self._float_workspace_buffer,
self._int_workspace_buffer,
Expand Down Expand Up @@ -2024,7 +2036,6 @@ def _paged_run(
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
lse: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(query)
Expand Down Expand Up @@ -2070,7 +2081,6 @@ def _paged_run(
value_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
lse,
)
return out

Expand Down Expand Up @@ -2137,6 +2147,7 @@ def paged_run(
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> None:
assert maybe_lse is None
assert paged_kv_cache is not None
assert num_qo_heads is not None
assert num_kv_heads is not None
Expand Down Expand Up @@ -2165,7 +2176,6 @@ def paged_run(
value_block_scales=value_block_scales,
skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
lse=maybe_lse,
)

@register_fake_op(f"flashinfer::{uri}_paged_run")
Expand Down Expand Up @@ -2249,11 +2259,7 @@ def trtllm_batch_decode_with_kv_cache(
skip_softmax_threshold_scale_factor: Optional[float] = None,
kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
uses_shared_paged_kv_idx: bool = True,
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> Union[
torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
]:
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
----------
Expand Down Expand Up @@ -2373,19 +2379,10 @@ def trtllm_batch_decode_with_kv_cache(
True (default) uses vLLM/FlashInfer layout with a 2D page table.
False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``.

lse: Optional[torch.Tensor] = None
The log-sum-exp of attention logits, if not provided, will be allocated internally.
Only supported by trtllm-gen backend.

return_lse: bool = False
Whether to return the logsumexp of attention scores, defaults to ``False``.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
output torch.Tensor or FP4Tensor.
lse: Optional[torch.Tensor]
The log-sum-exp of attention logits, if not provided, will be allocated internally.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

Expand Down Expand Up @@ -2433,11 +2430,6 @@ def trtllm_batch_decode_with_kv_cache(
backend = (
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
)
wants_lse = return_lse or lse is not None
if wants_lse and backend != "trtllm-gen":
raise ValueError(
"lse and return_lse are only supported by the trtllm-gen backend"
)

if backend == "xqa":
# xqa backend doesn't support nvfp4 output
Expand Down Expand Up @@ -2586,12 +2578,6 @@ def trtllm_batch_decode_with_kv_cache(

_check_block_tables_shape(block_tables, uses_shared_paged_kv_idx)

lse_shape = (query.size(0), query.size(1))
if lse is not None:
check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse")
elif return_lse:
lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device)

run_func(
out,
out_scale_factor,
Expand Down Expand Up @@ -2620,18 +2606,13 @@ def trtllm_batch_decode_with_kv_cache(
v_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
lse,
)

out = (
return (
out
if out_dtype != "nvfp4"
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
)
if return_lse:
return out, lse
else:
return out
else:
raise KeyError(f"Backend {backend} not supported")

Expand Down
Loading
Loading