diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 51bccb7e6e7..13288eabe55 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -30,6 +30,7 @@ #include "tensorrt_llm/runtime/utils/mpiUtils.h" #include #include +#include #include using namespace tensorrt_llm::kernels; @@ -1831,8 +1832,28 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea fmhaParams.chunkedAttentionSize = *mAttentionChunkSize; } - // Run the fmha kernel. - mFmhaDispatcher->run(fmhaParams); + if (mFP8FmhaForEagle3 && !mFmhaDispatcher->useTllmGen() && !mFP8AttenOutput) + { + auto origin_attn_output_dtype = std::is_same_v ? torch::kFloat16 + : std::is_same_v ? torch::kBFloat16 + : torch::kFloat32; + torch::Tensor fp8_attn_output = torch::empty( + {params.output_tensor_numel}, torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(torch::kCUDA)); + auto* origin_attn_output_ptr = fmhaParams.outputPtr; + torch::Tensor origin_attn_tensor + = torch::from_blob(origin_attn_output_ptr, {params.output_tensor_numel}, origin_attn_output_dtype); + fmhaParams.outputPtr = fp8_attn_output.data_ptr(); + // Run the fmha kernel. + mFmhaDispatcher->run(fmhaParams); + // Convert the fp8 output to the original dtype. + auto temp_tensor = fp8_attn_output.to(origin_attn_output_dtype); + origin_attn_tensor.copy_(temp_tensor); + } + else + { + // Run the fmha kernel. + mFmhaDispatcher->run(fmhaParams); + } sync_check_cuda_error(stream); if (mCpSize > 1 && mAttnTpSize > 1 && mAttnCpSize == 1) @@ -2702,6 +2723,16 @@ int AttentionOp::initialize() noexcept fmhaParams.attnLogitSoftcappingScale = mAttnLogitSoftcappingScale; fmhaParams.hasAlibi = isALiBi(); fmhaParams.scaleAlibi = isAliBiWithScale(); + if (mFP8FmhaForEagle3) + { + // use FP8 FMHA for Eagle3 with FP8 target model and BF16/FP16 draft model + FmhaDispatcher tempFmhaDispatcher(fmhaParams); + // use FP8 output for non-TllmGen, because FP8 TllmGen supports BF16/FP16 output + if (!tempFmhaDispatcher.useTllmGen()) + { + fmhaParams.dataTypeOut = DATA_TYPE_E4M3; + } + } // Load kernels from the pre-compiled cubins. mFmhaDispatcher.reset(new FmhaDispatcher(fmhaParams)); diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index f33194c02fa..680cba3502d 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -137,6 +137,9 @@ class AttentionOp T const* k_ptr = nullptr; T const* v_ptr = nullptr; + // optional for mFP8FmhaForEagle3 + int64_t output_tensor_numel = 0; + std::string enqueueContextParamsToString() const { // variables from the params coming from the runtime @@ -190,6 +193,7 @@ class AttentionOp ss << "softmaxStatsPtr: " << this->softmax_stats << std::endl; ss << "k_ptr: " << this->k_ptr << std::endl; ss << "v_ptr: " << this->v_ptr << std::endl; + ss << "output_tensor_numel: " << this->output_tensor_numel << std::endl; return ss.str(); } }; @@ -422,6 +426,7 @@ class AttentionOp bool mIsSpecDecodingEnabled = false; bool mUseSpecDecoding = false; bool mIsSpecDecTree = true; + bool mFP8FmhaForEagle3 = false; bool mSpecDecodingIsGenerationLengthVariable = false; int32_t mSpecDecodingMaxGenerationLength = 1; bool mIsMLAEnabled = false; diff --git a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp index 79e9694b95f..c2eb6257d67 100644 --- a/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp +++ b/cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp @@ -124,6 +124,14 @@ bool DecoderXQAImplJIT::shouldUse(XQAParams const& umbrellaXQAParams, bool forCo bool hasPerfGain = mayHavePerfGain(xqaParams); if (!hasPerfGain) { + if (!xqaParams.is_fp8_output && xqaParams.kv_cache_data_type == DATA_TYPE_E4M3 + && (xqaParams.data_type == DATA_TYPE_BF16 || xqaParams.data_type == DATA_TYPE_FP16)) + { + TLLM_LOG_DEBUG( + "JIT XQA is selected in the generation phase for fp16/bf16 input and e4m3 kv cache because MMHA " + "does not support this combination."); + return true; + } TLLM_LOG_DEBUG("JIT XQA is not used: maybe no performance gain"); return false; } diff --git a/cpp/tensorrt_llm/kernels/fmhaDispatcher.h b/cpp/tensorrt_llm/kernels/fmhaDispatcher.h index f79c55d3805..64f6e78e249 100644 --- a/cpp/tensorrt_llm/kernels/fmhaDispatcher.h +++ b/cpp/tensorrt_llm/kernels/fmhaDispatcher.h @@ -40,6 +40,12 @@ class FmhaDispatcher // Check if any fmha kernel meets the requirements. bool isSupported(); + // Whether to use trtllm-gen kernels. + bool useTllmGen() const + { + return mUseTllmGen; + } + // Does FMHA need a separate Q and Kv input ? bool isSeparateQAndKvInput() const { diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index d6a64d733b8..9a05b297574 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -387,6 +387,7 @@ class Runner : public RunnerBase enqueue_params.batch_size = num_seqs; enqueue_params.k_ptr = k_ptr; enqueue_params.v_ptr = v_ptr; + enqueue_params.output_tensor_numel = output.numel(); if (op.isMLAEnabled()) { @@ -621,17 +622,20 @@ void attention(torch::Tensor q, std::optional k, std::optionalmRotaryEmbeddingLongMscale = rotary_embedding_long_m_scale; op->mRotaryEmbeddingMaxPositions = rotary_embedding_max_positions; op->mRotaryEmbeddingOriginalMaxPositions = rotary_embedding_original_max_positions; - op->mFP8ContextFMHA = is_fp8_out || is_fp4_out || (op->mKVCacheQuantMode.hasFp8KvCache() && use_paged_context_fmha); + op->mFP8ContextFMHA = is_fp8_out || is_fp4_out || (op->mKVCacheQuantMode.hasFp8KvCache() && use_paged_context_fmha) + || op->mFP8FmhaForEagle3; op->mFP8AttenOutput = is_fp8_out; op->mPagedContextFMHA = use_paged_context_fmha; op->mAttentionChunkSize = attention_chunk_size; - TORCH_CHECK(spec_decoding_bool_params.size() == 3, - "Expecting 3 bools for spec-dec mode, is_spec_decoding_enabled, use_spec_decoding, and is_spec_dec_tree."); + TORCH_CHECK(spec_decoding_bool_params.size() == 4, + "Expecting 4 bools for spec-dec mode, is_spec_decoding_enabled, use_spec_decoding, is_spec_dec_tree, and " + "fp8_fmha_for_eagle3."); op->mIsSpecDecodingEnabled = spec_decoding_bool_params[0]; // is_spec_decoding_enabled op->mUseSpecDecoding = spec_decoding_bool_params[1]; // use_spec_decoding op->mIsSpecDecTree = spec_decoding_bool_params[2]; // is_spec_dec_tree + op->mFP8FmhaForEagle3 = spec_decoding_bool_params[3]; // fp8_fmha_for_eagle3 if (is_mla_enable) { diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 90bc6df7848..b607206c521 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -190,6 +190,7 @@ def plan( spec_decoding_generation_lengths: Optional[torch.Tensor] = None, attention_sinks: Optional[torch.Tensor] = None, chunked_prefill_buffer_batch_size: int = 1, + fp8_fmha_for_eagle3: bool = False, **kwargs, ): """ @@ -229,6 +230,7 @@ def plan( helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU. attention_sinks (torch.Tensor): The attention sinks (additional value in the denominator of the softmax) with shape of (num_heads_q) on GPU. chunked_prefill_buffer_batch_size (int): used for malloc buffer for k and v in fp8 context mla. the max input kv length is not max_num_tokens in this case. It is chunked_prefill_buffer_batch_size * max_num_tokens. + fp8_fmha_for_eagle3 (bool): Whether to use FP8 FMHA for Eagle3 + FP8 target model + BF16/FP16 draft model. """ self.layer_idx = layer_idx self.tokens_per_block = tokens_per_block @@ -278,6 +280,7 @@ def plan( self.spec_decoding_packed_mask = spec_decoding_packed_mask self.spec_decoding_generation_lengths = spec_decoding_generation_lengths self.chunked_prefill_buffer_batch_size = chunked_prefill_buffer_batch_size + self.fp8_fmha_for_eagle3 = fp8_fmha_for_eagle3 self.kwargs.update(kwargs) def create_output(self, q: torch.Tensor, out_dtype: torch.dtype): @@ -417,7 +420,7 @@ def run( ] spec_decoding_bool_params = [ self.is_spec_decoding_enabled, self.use_spec_decoding, - self.is_spec_dec_tree + self.is_spec_dec_tree, self.fp8_fmha_for_eagle3 ] spec_decoding_tensor_params = [ self.spec_decoding_generation_lengths, @@ -1211,6 +1214,7 @@ def forward( output_sf: Optional[torch.Tensor] = None, attention_sinks: Optional[torch.Tensor] = None, chunked_prefill_buffer_batch_size: int = 1, + fp8_fmha_for_eagle3: bool = False, **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: assert isinstance( @@ -1287,6 +1291,7 @@ def forward( spec_decoding_generation_lengths, attention_sinks=attention_sinks, chunked_prefill_buffer_batch_size=chunked_prefill_buffer_batch_size, + fp8_fmha_for_eagle3=fp8_fmha_for_eagle3, ) out_dtype = None if out_scale is not None: diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 31d52791f6b..50665cdb5ae 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -65,6 +65,7 @@ def __init__( skip_create_weights_in_init=model_config. skip_create_weights_in_init, ) + self.is_eagle3 = True class Eagle3DecoderLayer(DecoderLayer): diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 8e2f0423233..13eb7241f54 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -318,6 +318,7 @@ def __init__( self.support_fused_qkv = self.attn.support_fused_qkv() self.support_nvfp4_output = self.attn.support_nvfp4_output() + self.is_eagle3 = False if not config.skip_create_weights_in_init: self.create_weights() @@ -404,6 +405,10 @@ def _attn_impl( if mrope_position_deltas is not None: mrope_config["mrope_position_deltas"] = mrope_position_deltas + # Be forced to use FP8 FMHA for BF16/FP16 model with FP8 KV cache (e.g. eagle3 + FP8 target model + BF16/FP16 draft model) + fp8_fmha_for_eagle3 = self.is_eagle3 and not self.has_quant_scale and self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_kv_cache( + ) and attn_metadata.num_contexts != 0 + attn_output = self.attn.forward( q, k, @@ -420,7 +425,8 @@ def _attn_impl( enable_attn_nvfp4_output=enable_attn_nvfp4_output, output=output[:num_tokens, :] if output is not None else None, output_sf=output_sf, - attention_sinks=attention_sinks) + attention_sinks=attention_sinks, + fp8_fmha_for_eagle3=fp8_fmha_for_eagle3) if isinstance(attn_output, tuple): assert len( attn_output diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index ad4aa5a8aa9..923a520c625 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -23,40 +23,68 @@ def enforce_single_worker(monkeypatch): @pytest.mark.parametrize( - "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp", + "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,fp8_target", [ - [True, "TRTLLM", True, False, False, False, True, False, False], - [True, "TRTLLM", True, False, False, False, False, False, False], - [False, "TRTLLM", True, False, False, False, True, False, False], - [False, "TRTLLM", True, False, False, False, False, False, False], - [True, "FLASHINFER", True, False, False, False, True, False, False], - [False, "FLASHINFER", True, False, False, False, True, False, False], - [False, "TRTLLM", False, True, True, False, True, False, False], - [True, "TRTLLM", False, True, True, False, True, False, False], - [True, "TRTLLM", True, False, True, True, True, False, False], - [True, "TRTLLM", True, False, True, False, True, False, False], + [True, "TRTLLM", True, False, False, False, True, False, False, False], + [True, "TRTLLM", True, False, False, False, False, False, False, False], + [False, "TRTLLM", True, False, False, False, True, False, False, False], + [ + False, "TRTLLM", True, False, False, False, False, False, False, + False + ], + [ + True, "FLASHINFER", True, False, False, False, True, False, False, + False + ], + [ + False, "FLASHINFER", True, False, False, False, True, False, False, + False + ], + [False, "TRTLLM", False, True, True, False, True, False, False, False], + [True, "TRTLLM", False, True, True, False, True, False, False, False], + [True, "TRTLLM", True, False, True, True, True, False, False, False], + [True, "TRTLLM", True, False, True, False, True, False, False, False], # TODO: nvbugs/5461761 - # [True, "TRTLLM", True, False, False, True, True, False], - [True, "TRTLLM", False, False, False, False, True, False, False], - [False, "TRTLLM", False, False, False, False, True, False, False], - [True, "TRTLLM", False, False, False, False, False, True, False], - [True, "TRTLLM", False, False, False, False, False, True, True], - [False, "TRTLLM", False, False, False, False, False, True, False], - [True, "TRTLLM", False, False, False, False, True, True, False], - [False, "TRTLLM", False, False, False, False, True, True, False], - [True, "TRTLLM", False, False, False, False, False, False, False], - [False, "TRTLLM", False, False, False, False, False, False, False], - [True, "TRTLLM", False, False, False, True, True, False, False], - [True, "TRTLLM", False, False, False, True, False, False, False], - [True, "FLASHINFER", False, False, False, False, True, False, False], - [False, "FLASHINFER", False, False, False, False, True, False, False], + # [True, "TRTLLM", True, False, False, True, True, False, False, False], + [True, "TRTLLM", False, False, False, False, True, False, False, False], + [ + False, "TRTLLM", False, False, False, False, True, False, False, + False + ], + [True, "TRTLLM", False, False, False, False, False, True, False, False], + [True, "TRTLLM", False, False, False, False, False, True, True, False], + [ + False, "TRTLLM", False, False, False, False, False, True, False, + False + ], + [True, "TRTLLM", False, False, False, False, True, True, False, False], + [False, "TRTLLM", False, False, False, False, True, True, False, False], + [ + True, "TRTLLM", False, False, False, False, False, False, False, + False + ], + [ + False, "TRTLLM", False, False, False, False, False, False, False, + False + ], + [True, "TRTLLM", False, False, False, True, True, False, False, False], + [True, "TRTLLM", False, False, False, True, False, False, False, False], + [ + True, "FLASHINFER", False, False, False, False, True, False, False, + False + ], + [ + False, "FLASHINFER", False, False, False, False, True, False, False, + False + ], + [True, "TRTLLM", False, True, True, True, True, True, True, True], ]) @pytest.mark.high_cuda_memory def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, disable_overlap_scheduler: bool, enable_block_reuse: bool, use_one_model: bool, enable_chunked_prefill: bool, use_chain_drafter: bool, multi_batch: bool, - attention_dp: bool, request): + attention_dp: bool, fp8_target: bool, request): # Eagle3 one model works with overlap scheduler and block reuse. total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: @@ -65,6 +93,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, models_path = llm_models_root() eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" + if fp8_target: + target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8" # bs > 1 gives non-deterministic when doing IFB. There are slight chances # that ref and spec does not match 100% @@ -72,6 +102,8 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, max_draft_len = 4 kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, max_tokens=8192) + if fp8_target: + kv_cache_config.dtype = 'fp8' cuda_graph_config = CudaGraphConfig( batch_sizes=[i for i in range(1, max_batch_size + 1)]) if use_cuda_graph else None @@ -151,9 +183,10 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, generated_text_ref = [result.outputs[0].text for result in results_ref] llm_ref.shutdown() - for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): - # The spec decode algorithm currently guarantees identical results - assert text_spec == text_ref + if not fp8_target: + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): + # The spec decode algorithm currently guarantees identical results + assert text_spec == text_ref def test_deepseek_eagle3():