Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -778,8 +778,19 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput())
{
fp8_q_buf_size = max_num_tokens * static_cast<size_t>(total_q_dim_all_heads);
fp8_k_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_k_dim_all_heads);
fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_v_dim_all_heads);

if (useSparseMLA())
{
// Sparse MLA (absorption mode): K and V are stored directly in KV cache during MLA RoPE kernel.
// No separate FP8 buffers needed for K/V since they're read from paged KV cache (Q_PAGED_KV layout).
fp8_k_buf_size = 0;
fp8_v_buf_size = 0;
}
else
{
fp8_k_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_k_dim_all_heads);
fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast<size_t>(total_v_dim_all_heads);
}
}

size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
Expand Down Expand Up @@ -1436,8 +1447,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput())
{
fp8_q_buf_size = params.num_tokens * static_cast<size_t>(total_q_dim_all_heads);
fp8_k_buf_size = params.total_kv_len * static_cast<size_t>(total_k_dim_all_heads);
fp8_v_buf_size = params.total_kv_len * static_cast<size_t>(total_v_dim_all_heads);

if (useSparseMLA())
{
// Sparse MLA (absorption mode): K and V are stored directly in KV cache during MLA RoPE kernel.
// No separate FP8 buffers needed for K/V since they're read from paged KV cache (Q_PAGED_KV layout).
fp8_k_buf_size = 0;
fp8_v_buf_size = 0;
}
else
{
fp8_k_buf_size = params.total_kv_len * static_cast<size_t>(total_k_dim_all_heads);
fp8_v_buf_size = params.total_kv_len * static_cast<size_t>(total_v_dim_all_heads);
}
}
size_t const padding_offset_size
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
Expand Down Expand Up @@ -1805,11 +1827,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
TLLM_CHECK_WITH_INFO(
mFmhaDispatcher->isSeparateQAndKvInput(), "Separate QKV input is required for fp8 context MLA");
TLLM_CHECK_WITH_INFO(fp8_q_buf != nullptr, "FP8 q buffer is required for fp8 context MLA");
TLLM_CHECK_WITH_INFO(fp8_k_buf != nullptr, "FP8 k buffer is required for fp8 context MLA");
TLLM_CHECK_WITH_INFO(fp8_v_buf != nullptr, "FP8 v buffer is required for fp8 context MLA");
// In sparse MLA (absorption mode), K and V are stored in KV cache, not as separate FP8 buffers
TLLM_CHECK_WITH_INFO(useSparseMLA() || fp8_k_buf != nullptr,
"FP8 k buffer is required for fp8 context MLA in non-sparse mode");
TLLM_CHECK_WITH_INFO(useSparseMLA() || fp8_v_buf != nullptr,
"FP8 v buffer is required for fp8 context MLA in non-sparse mode");

fmhaParams.qPtr = reinterpret_cast<void const*>(fp8_q_buf);
fmhaParams.kPtr = reinterpret_cast<void const*>(fp8_k_buf);
fmhaParams.vPtr = reinterpret_cast<void const*>(fp8_v_buf);
fmhaParams.kPtr = useSparseMLA() ? nullptr : reinterpret_cast<void const*>(fp8_k_buf);
fmhaParams.vPtr = useSparseMLA() ? nullptr : reinterpret_cast<void const*>(fp8_v_buf);
}
else
{
Expand Down