Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,19 @@ void trtllm_paged_attention_launcher(
void* out, void* out_scale_factor, void* query, void* key_cache, void* value_cache,
void* workspace_buffer, int* block_tables, const void* k_block_scales_ptr,
const void* v_block_scales_ptr, int* seq_lens, int* cum_seq_lens_q, int* cum_seq_lens_kv,
float* attention_sinks, Data_type q_data_type, Data_type kv_data_type, Data_type o_data_type,
TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len, int64_t max_kv_len,
int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk,
int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens, int64_t q_stride_heads,
int64_t kv_stride_keys_values, int64_t kv_stride_heads, int64_t kv_stride_batch,
int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
float* attention_sinks, float* lse, Data_type q_data_type, Data_type kv_data_type,
Data_type o_data_type, TllmPagedAttentionMode mode, int64_t batch_size, int64_t max_q_len,
int64_t max_kv_len, int64_t num_pages_in_mem_pool, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_qk, int64_t head_dim_vo, int64_t page_size, int64_t q_stride_tokens,
int64_t q_stride_heads, int64_t kv_stride_keys_values, int64_t kv_stride_heads,
int64_t kv_stride_batch, int64_t max_num_blocks_per_seq, double bmm1_scale, double bmm2_scale,
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q,
int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax,
bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
int64_t k_sf_stride_heads, int64_t k_sf_stride_batch, int64_t v_sf_stride_heads,
int64_t v_sf_stride_batch, cudaStream_t stream) {
int64_t v_sf_stride_batch, int64_t lse_stride_tokens, int64_t lse_stride_heads,
cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -190,11 +191,26 @@ void trtllm_paged_attention_launcher(
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
size_t num_semaphores =
round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes
// semaphores be at the first 8MB of workspace buffer: counter | scratch
// todo(Yingyi): add softmax buffer later for lse return
// Keep semaphores at the first 8MB of workspace buffer.
// Workspace layout for generation path: counter | softmax | scratch.
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
// scratch takes the rest of the workspace buffer
}

// Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and
// ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction
// kernel only participates for special multi-CTA modes; single-CTA context kernels still write
// the final stats directly into this buffer.
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
Comment on lines +204 to +206
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The softmaxStatsPtr is now allocated for both Context and Generation modes. In Context mode, runner_params.mSumOfSeqLensQ (the total number of query tokens) can be very large, which may lead to excessive memory allocation in the workspace (e.g., for 128 heads and 128k tokens, this would require ~128MB). Since Context kernels typically do not require this workspace buffer for multi-block reduction (as mMultiCtasKvMode is false), this allocation could cause workspace exhaustion for long sequences. Consider moving this allocation inside the else block (Generation path) or making it conditional on mode == TllmPagedAttentionMode::ForGen.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that trtllm_ragged_attention_launcher on the main branch already allocates softmaxStatsPtr for the Context kernel path (alongside lsePtr), so this PR is consistent with the existing pattern. That said, do you know why the Context kernel needs softmaxStatsPtr even with mMultiCtasKvMode = false? Is it used as an intermediate buffer for computing LSE internally before writing to lsePtr? If so, it might be worth adding a brief comment here explaining that dependency β€” it's not obvious from the code alone.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qsang-nv we need softmaxStatsPtr generally for LSE materialization, not just for multi-CTA reduction. Whenever lsePtr is set, we will call ComputeLSEFromMD. Added a comment in the code.

runner_params.lsePtr = lse;
runner_params.lseStrideTokens = lse_stride_tokens;
runner_params.lseStrideHeads = lse_stride_heads;

// The scratch allocation uses size=0 to consume the rest of the workspace, so it must be the
// last allocation from float_allocator.
if (mode == TllmPagedAttentionMode::ForGen) {
runner_params.multiCtasKvScratchPtr =
float_allocator.aligned_alloc<void>(0, 16, "trtllm_gen_scratch_workspace");
}
Expand Down Expand Up @@ -244,7 +260,7 @@ void trtllm_paged_attention_decode(
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q, Optional<TensorView> key_block_scales,
Optional<TensorView> value_block_scales, Optional<float> skip_softmax_threshold_scale_factor,
Optional<bool> uses_shared_paged_kv_idx) {
Optional<bool> uses_shared_paged_kv_idx, Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -352,19 +368,30 @@ void trtllm_paged_attention_decode(
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

float* lse_ptr = nullptr;
int64_t lse_stride_tokens = 0;
int64_t lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(lse.value().ndim() - 1);
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr,
v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool,
num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, q_stride_tokens,
q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
/*cum_seq_lens_kv*/ nullptr, attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type,
o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len, max_kv_len,
num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size,
q_stride_tokens, q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads,
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream);
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads,
stream);
}

void trtllm_paged_attention_context(
Expand All @@ -376,7 +403,8 @@ void trtllm_paged_attention_context(
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> key_block_scales, Optional<TensorView> value_block_scales,
Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx) {
Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx,
Optional<TensorView> lse) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand Down Expand Up @@ -474,20 +502,31 @@ void trtllm_paged_attention_context(
skip_softmax_threshold_scale_factor.value_or(0.0f);
bool const skips_softmax = skip_softmax_threshold_scale_factor_value != 0.0f;

float* lse_ptr = nullptr;
int64_t lse_stride_tokens = 0;
int64_t lse_stride_heads = 0;
if (lse.has_value()) {
TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor";
lse_ptr = static_cast<float*>(lse.value().data_ptr());
lse_stride_tokens = lse.value().stride(0);
lse_stride_heads = lse.value().stride(1);
}

trtllm_paged_attention_launcher(
out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(),
workspace_buffer.data_ptr(), static_cast<int*>(block_tables.data_ptr()), k_block_scales_ptr,
v_block_scales_ptr, static_cast<int*>(seq_lens.data_ptr()),
/*cum_seq_lens_q=*/static_cast<int*>(cum_seq_lens_q.data_ptr()),
/*cum_seq_lens_kv=*/static_cast<int*>(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr,
q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
lse_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size,
max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q,
head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values,
kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value,
bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left,
sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, k_sf_stride_heads,
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, stream);
k_sf_stride_batch, v_sf_stride_heads, v_sf_stride_batch, lse_stride_tokens, lse_stride_heads,
stream);
}

void trtllm_ragged_attention_launcher(
Expand Down Expand Up @@ -566,6 +605,10 @@ void trtllm_ragged_attention_launcher(
// semaphores be at the first 8MB of workspace buffer: counter | softmax | scratch
runner_params.multiCtasKvCounterPtr = float_allocator.aligned_alloc<int32_t>(
num_semaphores * sizeof(uint32_t), 16, "trtllm_gen_counter_workspace");
// Needed whenever LSE is requested: the FMHA path emits per-row softmax (m, d) stats here, and
// ComputeLSEFromMD later converts that intermediate buffer into lsePtr. The separate reduction
// kernel only participates for special multi-CTA modes; single-CTA context kernels still write
// the final stats directly into this buffer.
runner_params.softmaxStatsPtr = float_allocator.aligned_alloc<float2>(
sizeof(float2) * num_qo_heads * runner_params.mSumOfSeqLensQ, 16,
"trtllm_gen_softmax_workspace");
Expand Down
73 changes: 46 additions & 27 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,15 +1364,11 @@ def run(
if rope_theta is None:
rope_theta = 1e4

if return_lse:
if lse is None:
lse = torch.empty(
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
)
else:
check_shape_dtype_device(
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
)
lse_shape = (q.size(0), q.size(1))
if lse is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update BatchPrefillWithPagedKVCacheWrapper.run() in prefill.py to have the same checks.

check_shape_dtype_device(lse, lse_shape, torch.float32, q.device, "lse")
elif return_lse:
lse = torch.empty(lse_shape, dtype=torch.float32, device=q.device)

if out is None:
out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype
Expand Down Expand Up @@ -1963,21 +1959,13 @@ def run(
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
)

if return_lse:
if lse is None:
lse = torch.empty(
(q_nope.size(0), q_nope.size(1)),
dtype=torch.float32,
device=device,
)
else:
check_shape_dtype_device(
lse,
(q_nope.size(0), q_nope.size(1)),
q_nope.dtype,
q_nope.device,
"lse",
)
lse_shape = (q_nope.size(0), q_nope.size(1))
if lse is not None:
check_shape_dtype_device(
lse, lse_shape, torch.float32, q_nope.device, "lse"
)
elif return_lse:
lse = torch.empty(lse_shape, dtype=torch.float32, device=device)
self._cached_module.run(
self._float_workspace_buffer,
self._int_workspace_buffer,
Expand Down Expand Up @@ -2036,6 +2024,7 @@ def _paged_run(
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
lse: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(query)
Expand Down Expand Up @@ -2081,6 +2070,7 @@ def _paged_run(
value_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
lse,
)
return out

Expand Down Expand Up @@ -2147,7 +2137,6 @@ def paged_run(
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> None:
assert maybe_lse is None
assert paged_kv_cache is not None
assert num_qo_heads is not None
assert num_kv_heads is not None
Expand Down Expand Up @@ -2176,6 +2165,7 @@ def paged_run(
value_block_scales=value_block_scales,
skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
lse=maybe_lse,
)

@register_fake_op(f"flashinfer::{uri}_paged_run")
Expand Down Expand Up @@ -2259,7 +2249,11 @@ def trtllm_batch_decode_with_kv_cache(
skip_softmax_threshold_scale_factor: Optional[float] = None,
kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
uses_shared_paged_kv_idx: bool = True,
) -> Union[torch.Tensor, FP4Tensor]:
lse: Optional[torch.Tensor] = None,
return_lse: bool = False,
) -> Union[
torch.Tensor, FP4Tensor, Tuple[Union[torch.Tensor, FP4Tensor], torch.Tensor]
]:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""
Parameters
----------
Expand Down Expand Up @@ -2379,10 +2373,19 @@ def trtllm_batch_decode_with_kv_cache(
True (default) uses vLLM/FlashInfer layout with a 2D page table.
False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``.

lse: Optional[torch.Tensor] = None
The log-sum-exp of attention logits, if not provided, will be allocated internally.
Only supported by trtllm-gen backend.

return_lse: bool = False
Whether to return the logsumexp of attention scores, defaults to ``False``.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
output torch.Tensor or FP4Tensor.
lse: Optional[torch.Tensor]
The log-sum-exp of attention logits, if not provided, will be allocated internally.
"""
enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

Expand Down Expand Up @@ -2430,6 +2433,11 @@ def trtllm_batch_decode_with_kv_cache(
backend = (
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
)
wants_lse = return_lse or lse is not None
if wants_lse and backend != "trtllm-gen":
raise ValueError(
"lse and return_lse are only supported by the trtllm-gen backend"
)

if backend == "xqa":
# xqa backend doesn't support nvfp4 output
Expand Down Expand Up @@ -2578,6 +2586,12 @@ def trtllm_batch_decode_with_kv_cache(

_check_block_tables_shape(block_tables, uses_shared_paged_kv_idx)

lse_shape = (query.size(0), query.size(1))
if lse is not None:
check_shape_dtype_device(lse, lse_shape, torch.float32, query.device, "lse")
elif return_lse:
lse = torch.empty(lse_shape, dtype=torch.float32, device=query.device)

run_func(
out,
out_scale_factor,
Expand Down Expand Up @@ -2606,13 +2620,18 @@ def trtllm_batch_decode_with_kv_cache(
v_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
lse,
)

return (
out = (
out
if out_dtype != "nvfp4"
else FP4Tensor(out, out_scale_factor, o_sf_start_index, query.shape)
)
if return_lse:
return out, lse
else:
return out
else:
raise KeyError(f"Backend {backend} not supported")

Expand Down
Loading
Loading