Skip to content

Commit 5c6ff1d

Browse files
committed
feat : Add FP8 context MLA support for SM120
Signed-off-by: peaceh <[email protected]>
1 parent 8c82ee2 commit 5c6ff1d

File tree

5 files changed

+159
-30
lines changed

5 files changed

+159
-30
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
274274
xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr;
275275
xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens;
276276
xqaParams.is_fp8_output = mFP8ContextFMHA;
277-
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
277+
xqaParams.fp8_out_scale
278+
= ((mFP8ContextFMHA || mFP8ContextMLA) ? generationsParams.attention_output_orig_quant : nullptr);
278279
// Parameters required for FP4 output.
279280
xqaParams.output_sf = generationsParams.context_buf_sf;
280281
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
@@ -728,10 +729,29 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
728729
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
729730
size_t const qk_buf_float_size
730731
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
731-
size_t const fp8_qkv_buffer_size
732-
= mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
732+
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
733+
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
734+
int const dim_v_per_head = (mMLAParams.v_head_dim);
735+
736+
// Total dimension per token across all heads for Q, K, and V components respectively
737+
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
738+
int const total_k_dim_all_heads
739+
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
740+
int const total_v_dim_all_heads
741+
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
742+
743+
int const num_total_qkv_elements
744+
= max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
745+
746+
size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
733747
? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv)
734748
: 0;
749+
if (mFP8ContextMLA)
750+
{
751+
fp8_qkv_buffer_size
752+
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
753+
}
754+
735755
size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
736756
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
737757
// Each token holds (batch_idx, token_idx_in_seq) int2.
@@ -1341,19 +1361,35 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
13411361
size_t const qk_buf_float_size = mEnableContextFMHA
13421362
? 0
13431363
: sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
1344-
size_t const fp8_qkv_buffer_size
1345-
= mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
1364+
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
1365+
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
1366+
int const dim_v_per_head = (mMLAParams.v_head_dim);
1367+
1368+
// Total dimension per token across all heads for Q, K, and V components respectively
1369+
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
1370+
int const total_k_dim_all_heads
1371+
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
1372+
int const total_v_dim_all_heads
1373+
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
1374+
int const num_total_qkv_elements
1375+
= params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
1376+
size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
13461377
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
13471378
: 0;
1379+
if (mFP8ContextMLA)
1380+
{
1381+
fp8_qkv_buffer_size
1382+
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
1383+
}
13481384
size_t const padding_offset_size
13491385
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
13501386
size_t const encoder_padding_offset_size
13511387
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.cross_kv_length;
13521388
// Each token holds (batch_idx, token_idx_in_seq) int2.
13531389
size_t const tokens_info_size = sizeof(int2) * params.num_tokens;
13541390
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
1355-
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
1356-
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
1391+
size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0;
1392+
size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0;
13571393

13581394
// cp workspace size upper bound
13591395
size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1);
@@ -1600,6 +1636,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16001636
params.mla_param->cache_type = cache_type;
16011637
params.mla_param->cu_q_seqlens = cu_q_seqlens;
16021638
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
1639+
// Set BMM scales for FP8 context computation
1640+
params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr;
1641+
params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr;
1642+
params.mla_param->host_bmm1_scale = decoder_params.fmhaHostBmm1Scale;
1643+
params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr;
1644+
// Set additional scales for context phase
1645+
params.mla_param->quant_scale_o = params.attention_output_orig_quant;
1646+
params.mla_param->dequant_scale_q = params.kv_scale_quant_orig;
1647+
params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig;
16031648
if (mPagedContextFMHA && mPagedKVCache)
16041649
{
16051650
TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr,
@@ -1678,8 +1723,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16781723
// TODO: set it correctly for contiguous kv buffer (cross-attention).
16791724
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens;
16801725
// Device buffer pointers.
1681-
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
1682-
: reinterpret_cast<void const*>(attention_input);
1726+
fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast<void const*>(fp8_qkv_buffer)
1727+
: reinterpret_cast<void const*>(attention_input);
16831728
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
16841729
// TODO: add contiguous kv buffer (cross-attention).
16851730
fmhaParams.kvPtr = nullptr;
@@ -2480,7 +2525,7 @@ int AttentionOp::initialize() noexcept
24802525
}
24812526

24822527
// FP8 FMHA should be used with fp8 workflow together.
2483-
if (mFP8ContextFMHA)
2528+
if (mFP8ContextFMHA || mFP8ContextMLA)
24842529
{
24852530
data_type = DATA_TYPE_E4M3;
24862531
}
@@ -2513,6 +2558,11 @@ int AttentionOp::initialize() noexcept
25132558
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
25142559
fmhaParams.dataTypeKv = DATA_TYPE_BF16;
25152560
}
2561+
if (mFP8ContextMLA && mKVCacheQuantMode.hasFp8KvCache())
2562+
{
2563+
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
2564+
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
2565+
}
25162566
// TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to
25172567
// bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime.
25182568
fmhaParams.forceFp32Acc = false;
@@ -2566,7 +2616,7 @@ int AttentionOp::initialize() noexcept
25662616
// Deepseek-V2 Generation needs a differ fmha with different argumments
25672617
if (mIsMLAEnabled)
25682618
{
2569-
mEnableXQA = (mSM == kSM_120);
2619+
mEnableXQA = (mSM == kSM_120) && mIsGenerationMLA;
25702620
if (mUseTllmGen)
25712621
{
25722622
Data_type qDataType = DATA_TYPE_FP32;
@@ -2829,6 +2879,7 @@ std::string AttentionOp::toString() const
28292879
ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl;
28302880
ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl;
28312881
ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl;
2882+
ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl;
28322883
ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl;
28332884
ss << "mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl;
28342885
ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl;

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ class AttentionOp
386386
bool mPosShiftEnabled = false;
387387
bool mPagedContextFMHA = false;
388388
bool mFP8ContextFMHA = false;
389+
bool mFP8ContextMLA = false;
389390
bool mFP8GenerationMLA = false;
390391
bool mDenseContextFMHA = false;
391392
bool mHasFullAttentionMask = false;

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,49 @@ void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, c
923923
<<<grid, 256, 0, stream>>>(params.attention_input_buf, params.latent_cache, kv_cache_buffer,
924924
params.cos_sin_cache, params.head_num, head_size, params.meta.kv_lora_rank, params.cu_q_seqlens,
925925
params.cache_seq_lens, params.max_input_seq_len, params.cache_type, params.quant_scale_kv);
926+
if (params.attention_input_buf != nullptr && params.quant_attention_input_buf != nullptr
927+
&& params.cache_type == KvCacheDataType::FP8)
928+
{
929+
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing attention_input_buf to FP8");
930+
931+
int const dim_q_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
932+
int const dim_k_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
933+
int const dim_v_per_head = (params.meta.v_head_dim);
934+
935+
// Total dimension per token across all heads for Q, K, and V components respectively
936+
int const total_q_dim_all_heads = params.head_num * dim_q_per_head;
937+
int const total_k_dim_all_heads
938+
= params.head_num * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
939+
int const total_v_dim_all_heads
940+
= params.head_num * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
941+
942+
int const num_total_qkv_elements
943+
= params.acc_q_len * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
944+
size_t headDim = params.meta.kv_lora_rank + params.meta.qk_rope_head_dim;
945+
float const* device_qkv_scale_ptr = params.quant_scale_qkv;
946+
947+
if (num_total_qkv_elements > 0)
948+
{
949+
int const threads_per_block = 256;
950+
int const num_blocks = (num_total_qkv_elements + threads_per_block - 1) / threads_per_block;
951+
952+
TLLM_LOG_DEBUG(
953+
"Launching QuantizeCopyInputToFp8Kernel with num_blocks: %d, threads_per_block: %d, elements: %d",
954+
num_blocks, threads_per_block, num_total_qkv_elements);
955+
956+
tensorrt_llm::kernels::QuantizeCopyInputToFp8Kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
957+
static_cast<T const*>(params.attention_input_buf), // Source
958+
static_cast<__nv_fp8_e4m3*>(params.quant_attention_input_buf), // Destination
959+
num_total_qkv_elements, device_qkv_scale_ptr);
960+
sync_check_cuda_error(stream);
961+
962+
cudaStreamSynchronize(stream);
963+
}
964+
else
965+
{
966+
TLLM_LOG_WARNING("MLA RoPE Context: num_total_qkv_elements is 0, skipping quantization.");
967+
}
968+
}
926969
}
927970

928971
template <typename T, typename KVCacheBuffer>
@@ -1037,6 +1080,17 @@ INSTANTIATE_SET_KVCACHE_MLA(float);
10371080
INSTANTIATE_SET_KVCACHE_MLA(half);
10381081
INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16);
10391082

1083+
template <typename T_IN>
1084+
__global__ void QuantizeCopyInputToFp8Kernel(
1085+
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr)
1086+
{
1087+
uint element_idx = threadIdx.x + blockDim.x * blockIdx.x;
1088+
if (element_idx < num_total_elements)
1089+
{
1090+
float scale_factor = (device_scale_ptr != nullptr) ? *device_scale_ptr : 1.0f;
1091+
output_fp8_buffer[element_idx] = __nv_fp8_e4m3(static_cast<float>(input_buffer[element_idx]) * scale_factor);
1092+
}
1093+
}
10401094
} // namespace kernels
10411095

10421096
} // namespace tensorrt_llm

cpp/tensorrt_llm/kernels/mlaKernels.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ struct MlaParams
8787
void* context_paged_kv_ptr = nullptr;
8888
void* context_kv_cache_block_offsets_ptr = nullptr;
8989
int32_t context_paged_kv_max_blocks_per_seq = 0;
90+
// for FP8 context qkv quantization
91+
float const* quant_scale_qkv = nullptr;
9092
};
9193

9294
template <typename T, typename KVCacheBuffer>
@@ -111,5 +113,9 @@ void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* late
111113
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
112114
float const* kv_scale_orig_quant_ptr, cudaStream_t stream);
113115

116+
template <typename T_IN>
117+
__global__ void QuantizeCopyInputToFp8Kernel(
118+
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr);
119+
114120
} // namespace kernels
115121
} // namespace tensorrt_llm

0 commit comments

Comments
 (0)