diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 19d0de71fd1..f4ae2073216 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -778,8 +778,19 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput()) { fp8_q_buf_size = max_num_tokens * static_cast(total_q_dim_all_heads); - fp8_k_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast(total_k_dim_all_heads); - fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast(total_v_dim_all_heads); + + if (useSparseMLA()) + { + // Sparse MLA (absorption mode): K and V are stored directly in KV cache during MLA RoPE kernel. + // No separate FP8 buffers needed for K/V since they're read from paged KV cache (Q_PAGED_KV layout). + fp8_k_buf_size = 0; + fp8_v_buf_size = 0; + } + else + { + fp8_k_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast(total_k_dim_all_heads); + fp8_v_buf_size = mChunkPrefillBufferBatchSize * max_num_tokens * static_cast(total_v_dim_all_heads); + } } size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; @@ -1436,8 +1447,19 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea if (mEnableContextFMHA && mFP8ContextMLA && mFmhaDispatcher->isSeparateQAndKvInput()) { fp8_q_buf_size = params.num_tokens * static_cast(total_q_dim_all_heads); - fp8_k_buf_size = params.total_kv_len * static_cast(total_k_dim_all_heads); - fp8_v_buf_size = params.total_kv_len * static_cast(total_v_dim_all_heads); + + if (useSparseMLA()) + { + // Sparse MLA (absorption mode): K and V are stored directly in KV cache during MLA RoPE kernel. + // No separate FP8 buffers needed for K/V since they're read from paged KV cache (Q_PAGED_KV layout). + fp8_k_buf_size = 0; + fp8_v_buf_size = 0; + } + else + { + fp8_k_buf_size = params.total_kv_len * static_cast(total_k_dim_all_heads); + fp8_v_buf_size = params.total_kv_len * static_cast(total_v_dim_all_heads); + } } size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length; @@ -1805,11 +1827,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea TLLM_CHECK_WITH_INFO( mFmhaDispatcher->isSeparateQAndKvInput(), "Separate QKV input is required for fp8 context MLA"); TLLM_CHECK_WITH_INFO(fp8_q_buf != nullptr, "FP8 q buffer is required for fp8 context MLA"); - TLLM_CHECK_WITH_INFO(fp8_k_buf != nullptr, "FP8 k buffer is required for fp8 context MLA"); - TLLM_CHECK_WITH_INFO(fp8_v_buf != nullptr, "FP8 v buffer is required for fp8 context MLA"); + // In sparse MLA (absorption mode), K and V are stored in KV cache, not as separate FP8 buffers + TLLM_CHECK_WITH_INFO(useSparseMLA() || fp8_k_buf != nullptr, + "FP8 k buffer is required for fp8 context MLA in non-sparse mode"); + TLLM_CHECK_WITH_INFO(useSparseMLA() || fp8_v_buf != nullptr, + "FP8 v buffer is required for fp8 context MLA in non-sparse mode"); + fmhaParams.qPtr = reinterpret_cast(fp8_q_buf); - fmhaParams.kPtr = reinterpret_cast(fp8_k_buf); - fmhaParams.vPtr = reinterpret_cast(fp8_v_buf); + fmhaParams.kPtr = useSparseMLA() ? nullptr : reinterpret_cast(fp8_k_buf); + fmhaParams.vPtr = useSparseMLA() ? nullptr : reinterpret_cast(fp8_v_buf); } else {