Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions cpp/tensorrt_llm/common/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2521,8 +2521,7 @@ int AttentionOp::initialize() noexcept
if (mIsMLAEnabled)
{
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "MLA(Deepseek v2) only support fmha");
TLLM_CHECK_WITH_INFO(
!mFP8ContextFMHA && !mDenseContextFMHA, "MLA(Deepseek v2) currently not support FP8 and dense fmha");
TLLM_CHECK_WITH_INFO(!mDenseContextFMHA, "MLA(Deepseek v2) currently not support dense fmha");
TLLM_CHECK_WITH_INFO(
mPagedKVCache && mUseKVCache && mRemovePadding, "MLA(Deepseek v2) only support paged kv cache");
TLLM_CHECK_WITH_INFO(!mCrossAttention, "MLA(Deepseek v2) do not support cross attention right now");
Expand Down Expand Up @@ -2684,11 +2683,6 @@ int AttentionOp::initialize() noexcept
qDataType = DATA_TYPE_E4M3;
kvDataType = DATA_TYPE_E4M3;
}
// When FP8 Context FMHA is enabled, the output data type needs to be E4M3.
if (mFP8ContextFMHA)
{
outputDataType = DATA_TYPE_E4M3;
}

// Instantiate the mTllmGenFMHARunner used for MLA
mTllmGenFMHARunner.reset(new TllmGenFmhaRunner(qDataType, kvDataType, outputDataType));
Expand Down
14 changes: 7 additions & 7 deletions cpp/tensorrt_llm/common/attentionOp.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,13 +450,13 @@ class AttentionOp
(int8_t) mPositionEmbeddingType, mUseLognScaling, mRemovePadding, (int32_t) mMaskType,
mBlockSparseParams.data(), mPagedKVCache, mTokensPerBlock, mKVCacheQuantMode.value(), mTpSize, mTpRank,
mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance,
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable,
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(),
mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank,
mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode,
mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores,
mAttentionChunkSize.value_or(-1));
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8ContextMLA, mDenseContextFMHA,
mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree,
mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA,
mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads,
mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast,
mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
};

private:
Expand Down
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,9 @@ class TllmGenFmhaKernel
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;

// Debug info.
std::string info = "qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
std::string info = "dtypeQ=" + std::to_string(static_cast<int>(mDtypeQ)) + ", dtypeKv="
+ std::to_string(static_cast<int>(mDtypeKv)) + ", dtypeOut=" + std::to_string(static_cast<int>(mDtypeOut))
+ ", sm=" + std::to_string(mSM) + ", qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
+ ", maskType=" + std::to_string(static_cast<int>(selectKernelParams.mMaskType))
+ ", kernelType=" + std::to_string(static_cast<int>(kernelType))
+ ", tileScheduler=" + std::to_string(static_cast<int>(selectKernelParams.mTileScheduler))
Expand Down
20 changes: 11 additions & 9 deletions cpp/tensorrt_llm/thop/attentionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,38 +489,38 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
{
if (is_fp8_out)
{
runner.reset(new Runner<half, __nv_fp8_e4m3>());
runner = std::make_shared<Runner<half, __nv_fp8_e4m3>>();
}
else if (is_fp4_out)
{
runner.reset(new Runner<half, __nv_fp4_e2m1>());
runner = std::make_shared<Runner<half, __nv_fp4_e2m1>>();
}
else
{
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat16);
runner.reset(new Runner<half>());
runner = std::make_shared<Runner<half>>();
}
}
else if (dtype == nvinfer1::DataType::kFLOAT)
{
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat32);
runner.reset(new Runner<float>());
runner = std::make_shared<Runner<float>>();
}
#ifdef ENABLE_BF16
else if (dtype == nvinfer1::DataType::kBF16)
{
if (is_fp8_out)
{
runner.reset(new Runner<__nv_bfloat16, __nv_fp8_e4m3>());
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp8_e4m3>>();
}
else if (is_fp4_out)
{
runner.reset(new Runner<__nv_bfloat16, __nv_fp4_e2m1>());
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp4_e2m1>>();
}
else
{
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kBFloat16);
runner.reset(new Runner<__nv_bfloat16>());
runner = std::make_shared<Runner<__nv_bfloat16>>();
}
}
#endif
Expand All @@ -538,13 +538,13 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
auto op = std::make_shared<AttentionOp>();
op->mType = dtype;
op->mFMHAForceFP32Acc = dtype == nvinfer1::DataType::kBF16;
op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode));
op->mFP8ContextFMHA = is_fp8_out || is_fp4_out;
op->mLayerIdx = layer_idx;
op->mNumHeads = num_heads;
op->mNumKVHeads = num_kv_heads;
op->mHeadSize = head_size;
op->mMaskType = static_cast<tensorrt_llm::kernels::AttentionMaskType>(int32_t(mask_type));
op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode));
op->mUseKVCache = use_kv_cache;
op->mPagedKVCache = op->mPagedKVCache && use_kv_cache; // update mPagedKVCache based on use_kv_cache
op->mTokensPerBlock = tokens_per_block.value_or(0);
Expand Down Expand Up @@ -587,7 +587,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
static_cast<int>(layer_num)};

op->mFP8ContextMLA = tensorrt_llm::common::getSMVersion() == 120 && op->mKVCacheQuantMode.hasFp8KvCache();
op->mFP8ContextMLA
= (tensorrt_llm::common::getSMVersion() == 100 || tensorrt_llm::common::getSMVersion() == 120)
&& op->mKVCacheQuantMode.hasFp8KvCache();
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
// only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64
Expand Down
41 changes: 15 additions & 26 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ def create_weights(self):
# which could be modified after __init__
self.attn.update_quant_config(self.quant_config)

self.o_proj.create_weights()
self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
or self.o_proj.has_fp8_block_scales
or self.o_proj.has_fp8_rowwise)

def split_qkv(self, q, k=None, v=None):
if k is None and v is None:
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
Expand All @@ -313,10 +318,7 @@ def create_output(self, q: torch.Tensor):
out_dtype = q.dtype

if self.attn_backend == "TRTLLM":
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
or self.o_proj.has_fp8_block_scales
or self.o_proj.has_fp8_rowwise)
if has_quant_scale and self.attn.has_fp8_kv_cache:
if self.has_quant_scale and self.attn.has_fp8_kv_cache:
out_dtype = torch.float8_e4m3fn
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
return output
Expand Down Expand Up @@ -353,10 +355,7 @@ def _attn_impl(

out_scale = None
out_scale_sf = None
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
or self.o_proj.has_fp8_block_scales
or self.o_proj.has_fp8_rowwise)
if has_quant_scale:
if self.has_quant_scale:
out_scale = self.o_proj.inv_input_scale
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
out_scale_sf = self.o_proj.input_scale
Expand Down Expand Up @@ -840,6 +839,9 @@ def create_weights(self):
self.mha.update_quant_config(self.quant_config)
self.mqa.update_quant_config(self.quant_config)

# Although we use FP8 MLA for context/generation phase, the output is still in BF16
self.out_scale = None

# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
# which can be modified after __init__
has_fp8_block_scales = (
Expand Down Expand Up @@ -1045,17 +1047,14 @@ def forward_context_default(
self.qk_rope_head_dim)
k = k.view(-1, self.num_heads * self.qk_head_dim)

# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Currently we use BF16 MHA for context phase

attn_output = self.mha.forward(
q,
k,
v,
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=latent_cache,
out_scale=out_scale,
out_scale=self.out_scale,
output=output,
)

Expand Down Expand Up @@ -1110,9 +1109,6 @@ def forward_context_with_cached_kv(
full_kv = None
full_k_nope = None

# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Currently we use BF16 MHA for context phase

# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
attn_output = self.mha.forward(
Expand All @@ -1122,7 +1118,7 @@ def forward_context_with_cached_kv(
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=out_scale,
out_scale=self.out_scale,
output=output,
)

Expand Down Expand Up @@ -1212,7 +1208,6 @@ def forward_context_with_chunked_prefill(
loop_idx]
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens

out_scale = None
# do not apply mask for attention within loop
# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
Expand All @@ -1223,7 +1218,7 @@ def forward_context_with_chunked_prefill(
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=out_scale,
out_scale=self.out_scale,
attention_mask=PredefinedAttentionMask.FULL,
softmax_stats_tensor=self.temp_softmax_stats_tensor,
output=temp_attn_output,
Expand Down Expand Up @@ -1262,9 +1257,6 @@ def forward_context_with_chunked_prefill(
num_contexts].sum().item(
)

# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Currently we use BF16 MHA for context phase

# latent_cache must be None to differentiate from normal context phase,
# so that we can skip applying RoPE and appending KV cache inside attention op
temp_attn_output = self.mha.forward(
Expand All @@ -1274,7 +1266,7 @@ def forward_context_with_chunked_prefill(
attn_metadata,
attention_input_type=AttentionInputType.context_only,
latent_cache=None,
out_scale=out_scale,
out_scale=self.out_scale,
softmax_stats_tensor=self.temp_softmax_stats_tensor,
output=temp_attn_output,
)
Expand Down Expand Up @@ -1370,16 +1362,13 @@ def forward_generation(
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
])

# out_scale = getattr(self.o_proj, "inv_input_scale", None)
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16

attn_out_latent = self.mqa.forward(
fused_q,
None,
None,
attn_metadata,
attention_input_type=AttentionInputType.generation_only,
out_scale=out_scale,
out_scale=self.out_scale,
latent_cache=latent_cache, # kvcache and k_pe
q_pe=q_pe, # used by `invokeMLARopeGeneration`
)
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ def init_meta_tensor(t: torch.Tensor):

except Exception:
logger.info(
f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n"
f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n"
)
model = AutoModelForCausalLM.from_config(config)

Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,9 @@ def _deduce_max_tokens(request: GenerationRequest,
raise ValueError(
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
)
assert (
len(prompt_token_ids) <= executor_config.max_seq_len
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
if default_max_tokens <= 0:
Expand Down
18 changes: 13 additions & 5 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_no_hopper
@skip_pre_hopper
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
[(False, False, False, False),
Expand All @@ -1189,6 +1189,8 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
)

if fp8kv:
Expand Down Expand Up @@ -1264,7 +1266,7 @@ def test_cute_dsl_fp8_block_scales(
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_device_not_contain(["H100"])
@skip_pre_hopper
@parametrize_with_ids("mtp_nextn", [0, 2])
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
Expand All @@ -1277,6 +1279,8 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
max_batch_size=512,
enable_padding=True,
),
moe_config=MoeConfig(
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
)
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
kv_cache_config=kv_cache_config,
Expand All @@ -1287,7 +1291,7 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@skip_no_hopper
@skip_pre_hopper
@parametrize_with_ids("mtp_nextn", [0, 2])
@parametrize_with_ids("attention_dp", [False, True])
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
Expand All @@ -1299,6 +1303,8 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
pytorch_config = dict(
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(enable_padding=True),
moe_config=MoeConfig(
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
)

with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
Expand All @@ -1312,7 +1318,7 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@skip_no_hopper
@skip_pre_hopper
@parametrize_with_ids("torch_compile", [False, True])
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
[(False, False, False, False),
Expand Down Expand Up @@ -1341,6 +1347,8 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
)

if fp8kv:
Expand Down Expand Up @@ -1427,7 +1435,7 @@ def test_cute_dsl_fp8_block_scales_4gpus(
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["H100", "H200"])
@skip_pre_hopper
def test_fp8_block_scales_4gpus_static_eplb(self):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)

Expand Down