diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index be646731224..db1ea76a7a7 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -274,7 +274,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr; xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens; xqaParams.is_fp8_output = mFP8ContextFMHA; - xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr); + xqaParams.fp8_out_scale + = ((mFP8ContextFMHA || mFP8ContextMLA) ? generationsParams.attention_output_orig_quant : nullptr); // Parameters required for FP4 output. xqaParams.output_sf = generationsParams.context_buf_sf; xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale; @@ -728,10 +729,29 @@ size_t AttentionOp::getWorkspaceSizeForContext(nvinfer1::DataType type, int32_t size_t const qkv_buf_2_size = mEnableContextFMHA ? 0 : size * max_num_tokens * local_hidden_units_qo; size_t const qk_buf_float_size = mEnableContextFMHA ? 0 : sizeof(float) * batch_size * mNumHeads * input_seq_length * kv_seq_length; - size_t const fp8_qkv_buffer_size - = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() + int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_v_per_head = (mMLAParams.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head; + int const total_k_dim_all_heads + = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + + int const num_total_qkv_elements + = max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + + size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; + if (mFP8ContextMLA) + { + fp8_qkv_buffer_size + = mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0; + } + size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; // Each token holds (batch_idx, token_idx_in_seq) int2. @@ -1341,10 +1361,26 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea size_t const qk_buf_float_size = mEnableContextFMHA ? 0 : sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length; - size_t const fp8_qkv_buffer_size - = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() + int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_v_per_head = (mMLAParams.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head; + int const total_k_dim_all_heads + = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + int const num_total_qkv_elements + = params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; + if (mFP8ContextMLA) + { + fp8_qkv_buffer_size + = mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0; + } size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length; size_t const encoder_padding_offset_size @@ -1352,8 +1388,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea // Each token holds (batch_idx, token_idx_in_seq) int2. size_t const tokens_info_size = sizeof(int2) * params.num_tokens; size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0; - size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0; - size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0; + size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0; + size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0; // cp workspace size upper bound size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1); @@ -1600,6 +1636,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea params.mla_param->cache_type = cache_type; params.mla_param->cu_q_seqlens = cu_q_seqlens; params.mla_param->quant_scale_kv = params.kv_scale_orig_quant; + // Set BMM scales for FP8 context computation + params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr; + params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr; + params.mla_param->host_bmm1_scale = decoder_params.fmhaHostBmm1Scale; + params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr; + // Set additional scales for context phase + params.mla_param->quant_scale_o = params.attention_output_orig_quant; + params.mla_param->dequant_scale_q = params.kv_scale_quant_orig; + params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig; if (mPagedContextFMHA && mPagedKVCache) { TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr, @@ -1678,8 +1723,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea // TODO: set it correctly for contiguous kv buffer (cross-attention). fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens; // Device buffer pointers. - fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast(fp8_qkv_buffer) - : reinterpret_cast(attention_input); + fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast(fp8_qkv_buffer) + : reinterpret_cast(attention_input); fmhaParams.qPtr = reinterpret_cast(q_buf_2_); // TODO: add contiguous kv buffer (cross-attention). fmhaParams.kvPtr = nullptr; @@ -2480,7 +2525,7 @@ int AttentionOp::initialize() noexcept } // FP8 FMHA should be used with fp8 workflow together. - if (mFP8ContextFMHA) + if (mFP8ContextFMHA || mFP8ContextMLA) { data_type = DATA_TYPE_E4M3; } @@ -2513,6 +2558,11 @@ int AttentionOp::initialize() noexcept fmhaParams.dataTypeOut = DATA_TYPE_BF16; fmhaParams.dataTypeKv = DATA_TYPE_BF16; } + if (mFP8ContextMLA && mKVCacheQuantMode.hasFp8KvCache()) + { + fmhaParams.dataTypeKv = DATA_TYPE_E4M3; + fmhaParams.dataTypeOut = DATA_TYPE_BF16; + } // TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to // bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime. fmhaParams.forceFp32Acc = false; @@ -2566,7 +2616,7 @@ int AttentionOp::initialize() noexcept // Deepseek-V2 Generation needs a differ fmha with different argumments if (mIsMLAEnabled) { - mEnableXQA = (mSM == kSM_120); + mEnableXQA = (mSM == kSM_120) && mIsGenerationMLA; if (mUseTllmGen) { Data_type qDataType = DATA_TYPE_FP32; @@ -2829,6 +2879,7 @@ std::string AttentionOp::toString() const ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl; ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl; ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl; + ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl; ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl; ss << "mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl; ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl; diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index b738fdaf2fd..01609c225cf 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -386,6 +386,7 @@ class AttentionOp bool mPosShiftEnabled = false; bool mPagedContextFMHA = false; bool mFP8ContextFMHA = false; + bool mFP8ContextMLA = false; bool mFP8GenerationMLA = false; bool mDenseContextFMHA = false; bool mHasFullAttentionMask = false; diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.cu b/cpp/tensorrt_llm/kernels/mlaKernels.cu index 2849eba71d3..cac0e8f0513 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.cu +++ b/cpp/tensorrt_llm/kernels/mlaKernels.cu @@ -923,6 +923,49 @@ void invokeMLARopeContext(MlaParams& params, KVCacheBuffer kv_cache_buffer, c <<>>(params.attention_input_buf, params.latent_cache, kv_cache_buffer, params.cos_sin_cache, params.head_num, head_size, params.meta.kv_lora_rank, params.cu_q_seqlens, params.cache_seq_lens, params.max_input_seq_len, params.cache_type, params.quant_scale_kv); + if (params.attention_input_buf != nullptr && params.quant_attention_input_buf != nullptr + && params.cache_type == KvCacheDataType::FP8) + { + TLLM_LOG_DEBUG("MLA RoPE Context: Quantizing attention_input_buf to FP8"); + + int const dim_q_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim); + int const dim_k_per_head = (params.meta.qk_nope_head_dim + params.meta.qk_rope_head_dim); + int const dim_v_per_head = (params.meta.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = params.head_num * dim_q_per_head; + int const total_k_dim_all_heads + = params.head_num * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = params.head_num * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + + int const num_total_qkv_elements + = params.acc_q_len * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + size_t headDim = params.meta.kv_lora_rank + params.meta.qk_rope_head_dim; + float const* device_qkv_scale_ptr = params.quant_scale_qkv; + + if (num_total_qkv_elements > 0) + { + int const threads_per_block = 256; + int const num_blocks = (num_total_qkv_elements + threads_per_block - 1) / threads_per_block; + + TLLM_LOG_DEBUG( + "Launching QuantizeCopyInputToFp8Kernel with num_blocks: %d, threads_per_block: %d, elements: %d", + num_blocks, threads_per_block, num_total_qkv_elements); + + tensorrt_llm::kernels::QuantizeCopyInputToFp8Kernel<<>>( + static_cast(params.attention_input_buf), // Source + static_cast<__nv_fp8_e4m3*>(params.quant_attention_input_buf), // Destination + num_total_qkv_elements, device_qkv_scale_ptr); + sync_check_cuda_error(stream); + + cudaStreamSynchronize(stream); + } + else + { + TLLM_LOG_WARNING("MLA RoPE Context: num_total_qkv_elements is 0, skipping quantization."); + } + } } template @@ -1037,6 +1080,17 @@ INSTANTIATE_SET_KVCACHE_MLA(float); INSTANTIATE_SET_KVCACHE_MLA(half); INSTANTIATE_SET_KVCACHE_MLA(__nv_bfloat16); +template +__global__ void QuantizeCopyInputToFp8Kernel( + T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr) +{ + uint element_idx = threadIdx.x + blockDim.x * blockIdx.x; + if (element_idx < num_total_elements) + { + float scale_factor = (device_scale_ptr != nullptr) ? *device_scale_ptr : 1.0f; + output_fp8_buffer[element_idx] = __nv_fp8_e4m3(static_cast(input_buffer[element_idx]) * scale_factor); + } +} } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/kernels/mlaKernels.h b/cpp/tensorrt_llm/kernels/mlaKernels.h index 3d5aa4f148d..e3472c13d8c 100644 --- a/cpp/tensorrt_llm/kernels/mlaKernels.h +++ b/cpp/tensorrt_llm/kernels/mlaKernels.h @@ -87,6 +87,8 @@ struct MlaParams void* context_paged_kv_ptr = nullptr; void* context_kv_cache_block_offsets_ptr = nullptr; int32_t context_paged_kv_max_blocks_per_seq = 0; + // for FP8 context qkv quantization + float const* quant_scale_qkv = nullptr; }; template @@ -111,5 +113,9 @@ void invokeMLARopeAppendPagedKVAssignQ(KVBlockArray& kv_cache, T* q_ptr, T* late float2 const* cos_sin_cache, size_t head_num, int nope_size, int rope_size, int lora_size, float const* kv_scale_orig_quant_ptr, cudaStream_t stream); +template +__global__ void QuantizeCopyInputToFp8Kernel( + T_IN const* input_buffer, __nv_fp8_e4m3* output_fp8_buffer, int num_total_elements, float const* device_scale_ptr); + } // namespace kernels } // namespace tensorrt_llm diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 90173696271..b753f2f93ea 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -151,7 +151,10 @@ class Runner : public RunnerBase { rotary_inv_freq_ptr = rotary_inv_freq.value().data_ptr(); } - rotary_cos_sin_ptr = static_cast(rotary_cos_sin.value().data_ptr()); + if (rotary_cos_sin.has_value()) + { + rotary_cos_sin_ptr = static_cast(rotary_cos_sin.value().data_ptr()); + } } void* workspace_ptr = workspace.data_ptr(); @@ -206,22 +209,30 @@ class Runner : public RunnerBase // Commonly, cyclic_attention_window_size, and max_attention_window_size will be the same // unless each layer has different attention window sizes. // the kv_cache capacity. - int const max_attention_window_size - = beam_width == 1 ? attention_window_size : cache_indirection.value().size(2); + int const max_attention_window_size = beam_width == 1 ? attention_window_size + : cache_indirection.has_value() ? cache_indirection.value().size(2) + : attention_window_size; // The cyclic_attention_window_size will determine the cyclic kv cache position of new tokens. // Note that this cyclic_attention_window_size might be smaller than the actual kv cache capactity. int const cyclic_attention_window_size = attention_window_size; bool const can_use_one_more_block = beam_width > 1; - int max_blocks_per_sequence = op.useKVCache() ? kv_cache_block_offsets.value().size(-1) : 0; - int32_t const pool_index - = op.useKVCache() ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 0}).item() : 0; - int32_t const layer_idx_in_cache_pool - = op.useKVCache() ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 1}).item() : 0; - KVBlockArray::DataType* block_offsets = static_cast( - op.useKVCache() ? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr); - KVBlockArray::DataType* host_block_offsets = static_cast( - op.useKVCache() ? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() : nullptr); + int max_blocks_per_sequence + = op.useKVCache() && kv_cache_block_offsets.has_value() ? kv_cache_block_offsets.value().size(-1) : 0; + int32_t const pool_index = op.useKVCache() && host_kv_cache_pool_mapping.has_value() + ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 0}).item() + : 0; + int32_t const layer_idx_in_cache_pool = op.useKVCache() && host_kv_cache_pool_mapping.has_value() + ? host_kv_cache_pool_mapping.value().index({op.mLayerIdx, 1}).item() + : 0; + KVBlockArray::DataType* block_offsets + = static_cast(op.useKVCache() && kv_cache_block_offsets.has_value() + ? kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() + : nullptr); + KVBlockArray::DataType* host_block_offsets + = static_cast(op.useKVCache() && host_kv_cache_block_offsets.has_value() + ? host_kv_cache_block_offsets.value().index({pool_index, seq_offset}).data_ptr() + : nullptr); auto const cache_elem_size = (op.mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T)); auto const block_size = op.mTokensPerBlock * op.mNumKVHeads * op.mHeadSize; @@ -229,12 +240,12 @@ class Runner : public RunnerBase int32_t const kv_factor = op.isMLAEnabled() ? 1 : 2; auto const intra_pool_offset = layer_idx_in_cache_pool * kv_factor * bytes_per_block; - void* host_primary_pool_pointer = op.useKVCache() + void* host_primary_pool_pointer = op.useKVCache() && host_kv_cache_pool_pointers.has_value() ? reinterpret_cast( reinterpret_cast(host_kv_cache_pool_pointers.value().index({pool_index, 0}).item()) + intra_pool_offset) : nullptr; - void* host_secondary_pool_pointer = op.useKVCache() + void* host_secondary_pool_pointer = op.useKVCache() && host_kv_cache_pool_pointers.has_value() ? reinterpret_cast( reinterpret_cast(host_kv_cache_pool_pointers.value().index({pool_index, 1}).item()) + intra_pool_offset) @@ -242,16 +253,19 @@ class Runner : public RunnerBase float const* kv_scale_orig_quant_ptr = nullptr; float const* kv_scale_quant_orig_ptr = nullptr; - if (op.mKVCacheQuantMode.hasKvCacheQuant()) + if (op.mKVCacheQuantMode.hasKvCacheQuant() && kv_scale_orig_quant.has_value() + && kv_scale_quant_orig.has_value()) { kv_scale_orig_quant_ptr = kv_scale_orig_quant.value().data_ptr(); kv_scale_quant_orig_ptr = kv_scale_quant_orig.value().data_ptr(); } // For FP8 output, out_scale represents the output scale. - float const* out_scale_ptr - = (op.mFP8ContextFMHA && !op.mFuseFp4Quant) ? out_scale.value().data_ptr() : nullptr; + float const* out_scale_ptr = (op.mFP8ContextFMHA && !op.mFuseFp4Quant && out_scale.has_value()) + ? out_scale.value().data_ptr() + : nullptr; // For NVFP4 output, out_scale holds the global scale for scaling factors. - float const* out_sf_scale_ptr = op.mFuseFp4Quant ? out_scale.value().data_ptr() : nullptr; + float const* out_sf_scale_ptr + = op.mFuseFp4Quant && out_scale.has_value() ? out_scale.value().data_ptr() : nullptr; AttentionOp::EnqueueParams common_enqueue_params; common_enqueue_params.attention_input = attention_input; @@ -317,7 +331,9 @@ class Runner : public RunnerBase AttentionOp::EnqueueGenerationParams enqueue_params{common_enqueue_params}; enqueue_params.beam_width = beam_width; enqueue_params.num_requests = num_requests; - enqueue_params.cache_indir = beam_width == 1 ? nullptr : cache_indirection.value().data_ptr(); + enqueue_params.cache_indir = beam_width == 1 + ? nullptr + : (cache_indirection.has_value() ? cache_indirection.value().data_ptr() : nullptr); enqueue_params.semaphores = op.multiBlockSemaphores(); enqueue_params.host_past_key_value_lengths = host_past_key_value_lengths.data_ptr(); enqueue_params.start_token_idx_sf = token_offset; @@ -543,6 +559,7 @@ void attention_inplace(torch::Tensor q, torch::optional k, torch: static_cast(v_head_dim.value()), static_cast(predicted_tokens_per_seq), static_cast(layer_num)}; + op->mFP8ContextMLA = tensorrt_llm::common::getSMVersion() == 120 && op->mKVCacheQuantMode.hasFp8KvCache(); op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim; op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache(); // only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64 diff --git a/tests/unittest/_torch/test_attention_mla.py b/tests/unittest/_torch/test_attention_mla.py index a61975ddf8a..182c80e34ef 100644 --- a/tests/unittest/_torch/test_attention_mla.py +++ b/tests/unittest/_torch/test_attention_mla.py @@ -339,7 +339,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: accuracy_dict = { torch.bfloat16: (3e-2, 3e-3), - torch.float8_e4m3fn: (4e-1, 4e-2), + torch.float8_e4m3fn: (4.075e-1, 4.075e-2), }