Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams&
xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr;
xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens;
xqaParams.is_fp8_output = mFP8ContextFMHA;
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
xqaParams.fp8_out_scale
= ((mFP8ContextFMHA || mFP8ContextMLA) ? generationsParams.attention_output_orig_quant : nullptr);
// Parameters required for FP4 output.
xqaParams.output_sf = generationsParams.context_buf_sf;
xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale;
Expand Down Expand Up @@ -728,10 +729,29 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t
size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo;
size_t const qk_buf_float_size
= mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_length * kv_seq_length;
size_t const fp8_qkv_buffer_size
= mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
int const dim_v_per_head = (mMLAParams.v_head_dim);

// Total dimension per token across all heads for Q, K, and V components respectively
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
int const total_k_dim_all_heads
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout

int const num_total_qkv_elements
= max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);

size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv)
: 0;
if (mFP8ContextMLA)
{
fp8_qkv_buffer_size
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
}

size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens;
// Each token holds (batch_idx, token_idx_in_seq) int2.
Expand Down Expand Up @@ -1341,19 +1361,35 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
size_t const qk_buf_float_size = mEnableContextFMHA
? 0
: sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length;
size_t const fp8_qkv_buffer_size
= mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim);
int const dim_v_per_head = (mMLAParams.v_head_dim);

// Total dimension per token across all heads for Q, K, and V components respectively
int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head;
int const total_k_dim_all_heads
= mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout
int const num_total_qkv_elements
= params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput()
? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv)
: 0;
if (mFP8ContextMLA)
{
fp8_qkv_buffer_size
= mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0;
}
size_t const padding_offset_size
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length;
size_t const encoder_padding_offset_size
= mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.cross_kv_length;
// Each token holds (batch_idx, token_idx_in_seq) int2.
size_t const tokens_info_size = sizeof(int2) * params.num_tokens;
size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0;
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;
size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0;
size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0;

// cp workspace size upper bound
size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1);
Expand Down Expand Up @@ -1600,6 +1636,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
params.mla_param->cache_type = cache_type;
params.mla_param->cu_q_seqlens = cu_q_seqlens;
params.mla_param->quant_scale_kv = params.kv_scale_orig_quant;
// Set BMM scales for FP8 context computation
params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr;
params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr;
params.mla_param->host_bmm1_scale = decoder_params.fmhaHostBmm1Scale;
params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr;
// Set additional scales for context phase
params.mla_param->quant_scale_o = params.attention_output_orig_quant;
params.mla_param->dequant_scale_q = params.kv_scale_quant_orig;
params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig;
if (mPagedContextFMHA && mPagedKVCache)
{
TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr,
Expand Down Expand Up @@ -1678,8 +1723,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStrea
// TODO: set it correctly for contiguous kv buffer (cross-attention).
fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens;
// Device buffer pointers.
fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
: reinterpret_cast<void const*>(attention_input);
fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast<void const*>(fp8_qkv_buffer)
: reinterpret_cast<void const*>(attention_input);
fmhaParams.qPtr = reinterpret_cast<void const*>(q_buf_2_);
// TODO: add contiguous kv buffer (cross-attention).
fmhaParams.kvPtr = nullptr;
Expand Down Expand Up @@ -2480,7 +2525,7 @@ int AttentionOp::initialize() noexcept
}

// FP8 FMHA should be used with fp8 workflow together.
if (mFP8ContextFMHA)
if (mFP8ContextFMHA || mFP8ContextMLA)
{
data_type = DATA_TYPE_E4M3;
}
Expand Down Expand Up @@ -2513,6 +2558,11 @@ int AttentionOp::initialize() noexcept
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
fmhaParams.dataTypeKv = DATA_TYPE_BF16;
}
if (mFP8ContextMLA && mKVCacheQuantMode.hasFp8KvCache())
{
fmhaParams.dataTypeKv = DATA_TYPE_E4M3;
fmhaParams.dataTypeOut = DATA_TYPE_BF16;
}
// TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to
// bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime.
fmhaParams.forceFp32Acc = false;
Expand Down Expand Up @@ -2566,7 +2616,7 @@ int AttentionOp::initialize() noexcept
// Deepseek-V2 Generation needs a differ fmha with different argumments
if (mIsMLAEnabled)
{
mEnableXQA = (mSM == kSM_120);
mEnableXQA = (mSM == kSM_120) && mIsGenerationMLA;
if (mUseTllmGen)
{
Data_type qDataType = DATA_TYPE_FP32;
Expand Down Expand Up @@ -2829,6 +2879,7 @@ std::string AttentionOp::toString() const
ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl;
ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl;
ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl;
ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl;
ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl;
ss << "mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl;
ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl;
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ class AttentionOp
bool mPosShiftEnabled = false;
bool mPagedContextFMHA = false;
bool mFP8ContextFMHA = false;
bool mFP8ContextMLA = false;
bool mFP8GenerationMLA = false;
bool mDenseContextFMHA = false;
bool mHasFullAttentionMask = false;
Expand Down
54 changes: 54 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,49 @@ void invokeMLARopeContext(MlaParams<T>& params, KVCacheBuffer kv_cache_buffer, c
<<<grid, 256, 0, stream>>>(params.attention_input_buf, params.latent_cache, kv_cache_buffer,
params.cos_sin_cache, params.head_num, head_size, params.meta.kv_lora_rank, params.cu_q_seqlens,
params.cache_seq_lens, params.max_input_seq_len, params.cache_type, params.quant_scale_kv);
if (params.attention_input_buf != nullptr && params.quant_attention_input_buf != nullptr
&& params.cache_type == KvCacheDataType::FP8)
{
TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing attention_input_buf to FP8");

int const dim_q_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
int const dim_k_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim);
int const dim_v_per_head = (params.meta.v_head_dim);

// Total dimension per token across all heads for Q, K, and V components respectively
int const total_q_dim_all_heads = params.head_num * dim_q_per_head;
int const total_k_dim_all_heads
= params.head_num * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout
int const total_v_dim_all_heads
= params.head_num * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout

int const num_total_qkv_elements
= params.acc_q_len * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads);
size_t headDim = params.meta.kv_lora_rank + params.meta.qk_rope_head_dim;
float const* device_qkv_scale_ptr = params.quant_scale_qkv;

if (num_total_qkv_elements > 0)
{
int const threads_per_block = 256;
int const num_blocks = (num_total_qkv_elements + threads_per_block - 1) / threads_per_block;

TLLM_LOG_DEBUG(
"Launching QuantizeCopyInputToFp8Kernel with num_blocks: %d, threads_per_block: %d, elements: %d",
num_blocks, threads_per_block, num_total_qkv_elements);

tensorrt_llm::kernels::QuantizeCopyInputToFp8Kernel<T><<<num_blocks, threads_per_block, 0, stream>>>(
static_cast<T const*>(params.attention_input_buf), // Source
static_cast<__nv_fp8_e4m3*>(params.quant_attention_input_buf), // Destination
num_total_qkv_elements, device_qkv_scale_ptr);
sync_check_cuda_error(stream);

cudaStreamSynchronize(stream);
}
else
{
TLLM_LOG_WARNING("MLA RoPE Context: num_total_qkv_elements is 0, skipping quantization.");
}
}
}

template <typename T, typename KVCacheBuffer>
Expand Down Expand Up @@ -1037,6 +1080,17 @@ INSTANTIATE_SET_KVCACHE_MLA(float);
INSTANTIATE_SET_KVCACHE_MLA(half);
INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16);

template <typename T_IN>
__global__ void QuantizeCopyInputToFp8Kernel(
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr)
{
uint element_idx = threadIdx.x + blockDim.x * blockIdx.x;
if (element_idx < num_total_elements)
{
float scale_factor = (device_scale_ptr != nullptr) ? *device_scale_ptr : 1.0f;
output_fp8_buffer[element_idx] = __nv_fp8_e4m3(static_cast<float>(input_buffer[element_idx]) * scale_factor);
}
}
} // namespace kernels

} // namespace tensorrt_llm
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/kernels/mlaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ struct MlaParams
void* context_paged_kv_ptr = nullptr;
void* context_kv_cache_block_offsets_ptr = nullptr;
int32_t context_paged_kv_max_blocks_per_seq = 0;
// for FP8 context qkv quantization
float const* quant_scale_qkv = nullptr;
};

template <typename T, typename KVCacheBuffer>
Expand All @@ -111,5 +113,9 @@ void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* late
float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size,
float const* kv_scale_orig_quant_ptr, cudaStream_t stream);

template <typename T_IN>
__global__ void QuantizeCopyInputToFp8Kernel(
T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr);

} // namespace kernels
} // namespace tensorrt_llm
Loading