diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 866a0d71c8f..ffd80b5bf7e 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -2521,8 +2521,7 @@ int AttentionOp::initialize() noexcept if (mIsMLAEnabled) { TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "MLA(Deepseek v2) only support fmha"); - TLLM_CHECK_WITH_INFO( - !mFP8ContextFMHA && !mDenseContextFMHA, "MLA(Deepseek v2) currently not support FP8 and dense fmha"); + TLLM_CHECK_WITH_INFO(!mDenseContextFMHA, "MLA(Deepseek v2) currently not support dense fmha"); TLLM_CHECK_WITH_INFO( mPagedKVCache && mUseKVCache && mRemovePadding, "MLA(Deepseek v2) only support paged kv cache"); TLLM_CHECK_WITH_INFO(!mCrossAttention, "MLA(Deepseek v2) do not support cross attention right now"); @@ -2684,11 +2683,6 @@ int AttentionOp::initialize() noexcept qDataType = DATA_TYPE_E4M3; kvDataType = DATA_TYPE_E4M3; } - // When FP8 Context FMHA is enabled, the output data type needs to be E4M3. - if (mFP8ContextFMHA) - { - outputDataType = DATA_TYPE_E4M3; - } // Instantiate the mTllmGenFMHARunner used for MLA mTllmGenFMHARunner.reset(new TllmGenFmhaRunner(qDataType, kvDataType, outputDataType)); diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index 39745a3957e..77d00252e96 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -450,13 +450,13 @@ class AttentionOp (int8_t) mPositionEmbeddingType, mUseLognScaling, mRemovePadding, (int32_t) mMaskType, mBlockSparseParams.data(), mPagedKVCache, mTokensPerBlock, mKVCacheQuantMode.value(), mTpSize, mTpRank, mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance, - mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask, - mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, - mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), - mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, - mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, - mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, - mAttentionChunkSize.value_or(-1)); + mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8ContextMLA, mDenseContextFMHA, + mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree, + mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, + mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, + mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, + mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, + mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1)); }; private: diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h index 2e9964b8b77..967dfe5834d 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h @@ -526,7 +526,9 @@ class TllmGenFmhaKernel int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage; // Debug info. - std::string info = "qkvLayout=" + std::to_string(static_cast(params.mQkvLayout)) + std::string info = "dtypeQ=" + std::to_string(static_cast(mDtypeQ)) + ", dtypeKv=" + + std::to_string(static_cast(mDtypeKv)) + ", dtypeOut=" + std::to_string(static_cast(mDtypeOut)) + + ", sm=" + std::to_string(mSM) + ", qkvLayout=" + std::to_string(static_cast(params.mQkvLayout)) + ", maskType=" + std::to_string(static_cast(selectKernelParams.mMaskType)) + ", kernelType=" + std::to_string(static_cast(kernelType)) + ", tileScheduler=" + std::to_string(static_cast(selectKernelParams.mTileScheduler)) diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 8dcc9c1443a..a5a857432fd 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -489,38 +489,38 @@ void attention(torch::Tensor q, std::optional k, std::optional()); + runner = std::make_shared>(); } else if (is_fp4_out) { - runner.reset(new Runner()); + runner = std::make_shared>(); } else { TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat16); - runner.reset(new Runner()); + runner = std::make_shared>(); } } else if (dtype == nvinfer1::DataType::kFLOAT) { TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat32); - runner.reset(new Runner()); + runner = std::make_shared>(); } #ifdef ENABLE_BF16 else if (dtype == nvinfer1::DataType::kBF16) { if (is_fp8_out) { - runner.reset(new Runner<__nv_bfloat16, __nv_fp8_e4m3>()); + runner = std::make_shared>(); } else if (is_fp4_out) { - runner.reset(new Runner<__nv_bfloat16, __nv_fp4_e2m1>()); + runner = std::make_shared>(); } else { TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kBFloat16); - runner.reset(new Runner<__nv_bfloat16>()); + runner = std::make_shared>(); } } #endif @@ -538,13 +538,13 @@ void attention(torch::Tensor q, std::optional k, std::optional(); op->mType = dtype; op->mFMHAForceFP32Acc = dtype == nvinfer1::DataType::kBF16; + op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode)); op->mFP8ContextFMHA = is_fp8_out || is_fp4_out; op->mLayerIdx = layer_idx; op->mNumHeads = num_heads; op->mNumKVHeads = num_kv_heads; op->mHeadSize = head_size; op->mMaskType = static_cast(int32_t(mask_type)); - op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode)); op->mUseKVCache = use_kv_cache; op->mPagedKVCache = op->mPagedKVCache && use_kv_cache; // update mPagedKVCache based on use_kv_cache op->mTokensPerBlock = tokens_per_block.value_or(0); @@ -587,7 +587,9 @@ void attention(torch::Tensor q, std::optional k, std::optional(v_head_dim.value()), static_cast(predicted_tokens_per_seq), static_cast(layer_num)}; - op->mFP8ContextMLA = tensorrt_llm::common::getSMVersion() == 120 && op->mKVCacheQuantMode.hasFp8KvCache(); + op->mFP8ContextMLA + = (tensorrt_llm::common::getSMVersion() == 100 || tensorrt_llm::common::getSMVersion() == 120) + && op->mKVCacheQuantMode.hasFp8KvCache(); op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim; op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache(); // only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64 diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 9b89ffc1934..9fa0261daf9 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -294,6 +294,11 @@ def create_weights(self): # which could be modified after __init__ self.attn.update_quant_config(self.quant_config) + self.o_proj.create_weights() + self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 + or self.o_proj.has_fp8_block_scales + or self.o_proj.has_fp8_rowwise) + def split_qkv(self, q, k=None, v=None): if k is None and v is None: q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -313,10 +318,7 @@ def create_output(self, q: torch.Tensor): out_dtype = q.dtype if self.attn_backend == "TRTLLM": - has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 - or self.o_proj.has_fp8_block_scales - or self.o_proj.has_fp8_rowwise) - if has_quant_scale and self.attn.has_fp8_kv_cache: + if self.has_quant_scale and self.attn.has_fp8_kv_cache: out_dtype = torch.float8_e4m3fn output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype) return output @@ -353,10 +355,7 @@ def _attn_impl( out_scale = None out_scale_sf = None - has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4 - or self.o_proj.has_fp8_block_scales - or self.o_proj.has_fp8_rowwise) - if has_quant_scale: + if self.has_quant_scale: out_scale = self.o_proj.inv_input_scale if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output: out_scale_sf = self.o_proj.input_scale @@ -840,6 +839,9 @@ def create_weights(self): self.mha.update_quant_config(self.quant_config) self.mqa.update_quant_config(self.quant_config) + # Although we use FP8 MLA for context/generation phase, the output is still in BF16 + self.out_scale = None + # k_b_proj_trans's dtype must be consistent with self.kv_b_proj, # which can be modified after __init__ has_fp8_block_scales = ( @@ -1045,9 +1047,6 @@ def forward_context_default( self.qk_rope_head_dim) k = k.view(-1, self.num_heads * self.qk_head_dim) - # out_scale = getattr(self.o_proj, "inv_input_scale", None) - out_scale = None # Currently we use BF16 MHA for context phase - attn_output = self.mha.forward( q, k, @@ -1055,7 +1054,7 @@ def forward_context_default( attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=latent_cache, - out_scale=out_scale, + out_scale=self.out_scale, output=output, ) @@ -1110,9 +1109,6 @@ def forward_context_with_cached_kv( full_kv = None full_k_nope = None - # out_scale = getattr(self.o_proj, "inv_input_scale", None) - out_scale = None # Currently we use BF16 MHA for context phase - # latent_cache must be None to differentiate from normal context phase, # so that we can skip applying RoPE and appending KV cache inside attention op attn_output = self.mha.forward( @@ -1122,7 +1118,7 @@ def forward_context_with_cached_kv( attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=None, - out_scale=out_scale, + out_scale=self.out_scale, output=output, ) @@ -1212,7 +1208,6 @@ def forward_context_with_chunked_prefill( loop_idx] attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens - out_scale = None # do not apply mask for attention within loop # latent_cache must be None to differentiate from normal context phase, # so that we can skip applying RoPE and appending KV cache inside attention op @@ -1223,7 +1218,7 @@ def forward_context_with_chunked_prefill( attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=None, - out_scale=out_scale, + out_scale=self.out_scale, attention_mask=PredefinedAttentionMask.FULL, softmax_stats_tensor=self.temp_softmax_stats_tensor, output=temp_attn_output, @@ -1262,9 +1257,6 @@ def forward_context_with_chunked_prefill( num_contexts].sum().item( ) - # out_scale = getattr(self.o_proj, "inv_input_scale", None) - out_scale = None # Currently we use BF16 MHA for context phase - # latent_cache must be None to differentiate from normal context phase, # so that we can skip applying RoPE and appending KV cache inside attention op temp_attn_output = self.mha.forward( @@ -1274,7 +1266,7 @@ def forward_context_with_chunked_prefill( attn_metadata, attention_input_type=AttentionInputType.context_only, latent_cache=None, - out_scale=out_scale, + out_scale=self.out_scale, softmax_stats_tensor=self.temp_softmax_stats_tensor, output=temp_attn_output, ) @@ -1370,16 +1362,13 @@ def forward_generation( self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim) ]) - # out_scale = getattr(self.o_proj, "inv_input_scale", None) - out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16 - attn_out_latent = self.mqa.forward( fused_q, None, None, attn_metadata, attention_input_type=AttentionInputType.generation_only, - out_scale=out_scale, + out_scale=self.out_scale, latent_cache=latent_cache, # kvcache and k_pe q_pe=q_pe, # used by `invokeMLARopeGeneration` ) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 2726e1d652c..f22fb0a3b8f 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -959,7 +959,7 @@ def init_meta_tensor(t: torch.Tensor): except Exception: logger.info( - f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n" + f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n" ) model = AutoModelForCausalLM.from_config(config) diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 1520648ae3a..05c7544a793 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -485,6 +485,9 @@ def _deduce_max_tokens(request: GenerationRequest, raise ValueError( "`max_tokens` must be set when `default_max_tokens` cannot be deduced" ) + assert ( + len(prompt_token_ids) <= executor_config.max_seq_len + ), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})" splited_prompt_len = int(len(prompt_token_ids) / cp_size) default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len if default_max_tokens <= 0: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 23a470320ef..c2a4e6af295 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1165,7 +1165,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @skip_no_hopper + @skip_pre_hopper @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler", [(False, False, False, False), @@ -1189,6 +1189,8 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph, disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, + moe_config=MoeConfig( + backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"), ) if fp8kv: @@ -1264,7 +1266,7 @@ def test_cute_dsl_fp8_block_scales( task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - @pytest.mark.skip_device_not_contain(["H100"]) + @skip_pre_hopper @parametrize_with_ids("mtp_nextn", [0, 2]) def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) @@ -1277,6 +1279,8 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn): max_batch_size=512, enable_padding=True, ), + moe_config=MoeConfig( + backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"), ) with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8", kv_cache_config=kv_cache_config, @@ -1287,7 +1291,7 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn): task.evaluate(llm) @pytest.mark.skip_less_device(4) - @skip_no_hopper + @skip_pre_hopper @parametrize_with_ids("mtp_nextn", [0, 2]) @parametrize_with_ids("attention_dp", [False, True]) def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn, @@ -1299,6 +1303,8 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn, pytorch_config = dict( disable_overlap_scheduler=False, cuda_graph_config=CudaGraphConfig(enable_padding=True), + moe_config=MoeConfig( + backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"), ) with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8", @@ -1312,7 +1318,7 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn, task.evaluate(llm) @pytest.mark.skip_less_device(4) - @skip_no_hopper + @skip_pre_hopper @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler", [(False, False, False, False), @@ -1341,6 +1347,8 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, disable_overlap_scheduler=not overlap_scheduler, cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, + moe_config=MoeConfig( + backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"), ) if fp8kv: @@ -1427,7 +1435,7 @@ def test_cute_dsl_fp8_block_scales_4gpus( task.evaluate(llm) @pytest.mark.skip_less_device(4) - @pytest.mark.skip_device_not_contain(["H100", "H200"]) + @skip_pre_hopper def test_fp8_block_scales_4gpus_static_eplb(self): kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)