@@ -274,7 +274,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
274274 xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr ;
275275 xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens ;
276276 xqaParams.is_fp8_output = mFP8ContextFMHA ;
277- xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr );
277+ xqaParams.fp8_out_scale
278+ = ((mFP8ContextFMHA || mFP8ContextMLA ) ? generationsParams.attention_output_orig_quant : nullptr );
278279 // Parameters required for FP4 output.
279280 xqaParams.output_sf = generationsParams.context_buf_sf ;
280281 xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale ;
@@ -728,10 +729,29 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
728729 size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
729730 size_t const qk_buf_float_size
730731 = mEnableContextFMHA ? 0 : sizeof (float ) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
731- size_t const fp8_qkv_buffer_size
732- = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
732+ int const dim_q_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
733+ int const dim_k_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
734+ int const dim_v_per_head = (mMLAParams .v_head_dim );
735+
736+ // Total dimension per token across all heads for Q, K, and V components respectively
737+ int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
738+ int const total_k_dim_all_heads
739+ = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
740+ int const total_v_dim_all_heads
741+ = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
742+
743+ int const num_total_qkv_elements
744+ = max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
745+
746+ size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
733747 ? max_num_tokens * size_t (local_hidden_units_qo + 2 * local_hidden_units_kv)
734748 : 0 ;
749+ if (mFP8ContextMLA )
750+ {
751+ fp8_qkv_buffer_size
752+ = mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput () ? num_total_qkv_elements : 0 ;
753+ }
754+
735755 size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
736756 size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof (int ) * max_num_tokens;
737757 // Each token holds (batch_idx, token_idx_in_seq) int2.
@@ -1341,19 +1361,35 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
13411361 size_t const qk_buf_float_size = mEnableContextFMHA
13421362 ? 0
13431363 : sizeof (float ) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
1344- size_t const fp8_qkv_buffer_size
1345- = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
1364+ int const dim_q_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
1365+ int const dim_k_per_head = (mMLAParams .qk_rope_head_dim + mMLAParams .qk_nope_head_dim );
1366+ int const dim_v_per_head = (mMLAParams .v_head_dim );
1367+
1368+ // Total dimension per token across all heads for Q, K, and V components respectively
1369+ int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
1370+ int const total_k_dim_all_heads
1371+ = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
1372+ int const total_v_dim_all_heads
1373+ = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
1374+ int const num_total_qkv_elements
1375+ = params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
1376+ size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput ()
13461377 ? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
13471378 : 0 ;
1379+ if (mFP8ContextMLA )
1380+ {
1381+ fp8_qkv_buffer_size
1382+ = mEnableContextFMHA && !mFmhaDispatcher ->isSeparateQAndKvInput () ? num_total_qkv_elements : 0 ;
1383+ }
13481384 size_t const padding_offset_size
13491385 = mEnableContextFMHA ? 0 : sizeof (int ) * params.batch_size * params.input_seq_length ;
13501386 size_t const encoder_padding_offset_size
13511387 = mEnableContextFMHA ? 0 : sizeof (int ) * params.batch_size * params.cross_kv_length ;
13521388 // Each token holds (batch_idx, token_idx_in_seq) int2.
13531389 size_t const tokens_info_size = sizeof (int2) * params.num_tokens ;
13541390 size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof (uint32_t ) : 0 ;
1355- size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof (float ) * 2 : 0 ;
1356- size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof (float ) : 0 ;
1391+ size_t const fmha_bmm1_scale_size = ( mFP8ContextFMHA || mFP8ContextMLA ) ? sizeof (float ) * 2 : 0 ;
1392+ size_t const fmha_bmm2_scale_size = ( mFP8ContextFMHA || mFP8ContextMLA ) ? sizeof (float ) : 0 ;
13571393
13581394 // cp workspace size upper bound
13591395 size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1 );
@@ -1600,6 +1636,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16001636 params.mla_param ->cache_type = cache_type;
16011637 params.mla_param ->cu_q_seqlens = cu_q_seqlens;
16021638 params.mla_param ->quant_scale_kv = params.kv_scale_orig_quant ;
1639+ // Set BMM scales for FP8 context computation
1640+ params.mla_param ->bmm1_scale = fmha_bmm1_scale_ptr;
1641+ params.mla_param ->bmm2_scale = fmha_bmm2_scale_ptr;
1642+ params.mla_param ->host_bmm1_scale = decoder_params.fmhaHostBmm1Scale ;
1643+ params.mla_param ->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr ;
1644+ // Set additional scales for context phase
1645+ params.mla_param ->quant_scale_o = params.attention_output_orig_quant ;
1646+ params.mla_param ->dequant_scale_q = params.kv_scale_quant_orig ;
1647+ params.mla_param ->dequant_scale_kv = params.kv_scale_quant_orig ;
16031648 if (mPagedContextFMHA && mPagedKVCache )
16041649 {
16051650 TLLM_CHECK_WITH_INFO (params.mla_param ->context_paged_kv_ptr != nullptr ,
@@ -1678,8 +1723,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
16781723 // TODO: set it correctly for contiguous kv buffer (cross-attention).
16791724 fmhaParams.totalKvSeqLen = isCrossAttention () ? params.num_encoder_tokens : params.num_tokens ;
16801725 // Device buffer pointers.
1681- fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast <void const *>(fp8_qkv_buffer)
1682- : reinterpret_cast <void const *>(attention_input);
1726+ fmhaParams.qkvPtr = ( mFP8ContextFMHA || mFP8ContextMLA ) ? reinterpret_cast <void const *>(fp8_qkv_buffer)
1727+ : reinterpret_cast <void const *>(attention_input);
16831728 fmhaParams.qPtr = reinterpret_cast <void const *>(q_buf_2_);
16841729 // TODO: add contiguous kv buffer (cross-attention).
16851730 fmhaParams.kvPtr = nullptr ;
@@ -2480,7 +2525,7 @@ int AttentionOp::initialize() noexcept
24802525 }
24812526
24822527 // FP8 FMHA should be used with fp8 workflow together.
2483- if (mFP8ContextFMHA )
2528+ if (mFP8ContextFMHA || mFP8ContextMLA )
24842529 {
24852530 data_type = DATA_TYPE_E4M3;
24862531 }
@@ -2513,6 +2558,11 @@ int AttentionOp::initialize() noexcept
25132558 fmhaParams.dataTypeOut = DATA_TYPE_BF16;
25142559 fmhaParams.dataTypeKv = DATA_TYPE_BF16;
25152560 }
2561+ if (mFP8ContextMLA && mKVCacheQuantMode .hasFp8KvCache ())
2562+ {
2563+ fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
2564+ fmhaParams.dataTypeOut = DATA_TYPE_BF16;
2565+ }
25162566 // TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to
25172567 // bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime.
25182568 fmhaParams.forceFp32Acc = false ;
@@ -2566,7 +2616,7 @@ int AttentionOp::initialize() noexcept
25662616 // Deepseek-V2 Generation needs a differ fmha with different argumments
25672617 if (mIsMLAEnabled )
25682618 {
2569- mEnableXQA = (mSM == kSM_120 );
2619+ mEnableXQA = (mSM == kSM_120 ) && mIsGenerationMLA ;
25702620 if (mUseTllmGen )
25712621 {
25722622 Data_type qDataType = DATA_TYPE_FP32;
@@ -2829,6 +2879,7 @@ std::string AttentionOp::toString() const
28292879 ss << " mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl;
28302880 ss << " mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl;
28312881 ss << " mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl;
2882+ ss << " mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl;
28322883 ss << " mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl;
28332884 ss << " mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl;
28342885 ss << " mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl;
0 commit comments