Skip to content

Commit d6ebcf7

Browse files
authored
[TRTLLM-6994][feat] FP8 Context MLA integration (Cherry-pick #6059 from release/1.1.0rc2) (#7610)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 420f0fb commit d6ebcf7

File tree

8 files changed

+56
-58
lines changed

8 files changed

+56
-58
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2572,8 +2572,7 @@ int AttentionOp::initialize() noexcept
25722572
if (mIsMLAEnabled)
25732573
{
25742574
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "MLA(Deepseek v2) only support fmha");
2575-
TLLM_CHECK_WITH_INFO(
2576-
!mFP8ContextFMHA && !mDenseContextFMHA, "MLA(Deepseek v2) currently not support FP8 and dense fmha");
2575+
TLLM_CHECK_WITH_INFO(!mDenseContextFMHA, "MLA(Deepseek v2) currently not support dense fmha");
25772576
TLLM_CHECK_WITH_INFO(
25782577
mPagedKVCache && mUseKVCache && mRemovePadding, "MLA(Deepseek v2) only support paged kv cache");
25792578
TLLM_CHECK_WITH_INFO(!mCrossAttention, "MLA(Deepseek v2) do not support cross attention right now");
@@ -2739,11 +2738,6 @@ int AttentionOp::initialize() noexcept
27392738
qDataType = DATA_TYPE_E4M3;
27402739
kvDataType = DATA_TYPE_E4M3;
27412740
}
2742-
// When FP8 Context FMHA is enabled, the output data type needs to be E4M3.
2743-
if (mFP8ContextFMHA)
2744-
{
2745-
outputDataType = DATA_TYPE_E4M3;
2746-
}
27472741

27482742
// Instantiate the mTllmGenFMHARunner used for MLA
27492743
mTllmGenFMHARunner.reset(new TllmGenFmhaRunner(qDataType, kvDataType, outputDataType));

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -470,13 +470,13 @@ class AttentionOp
470470
(int8_t) mPositionEmbeddingType, mUseLognScaling, mRemovePadding, (int32_t) mMaskType,
471471
mBlockSparseParams.data(), mPagedKVCache, mTokensPerBlock, mKVCacheQuantMode.value(), mTpSize, mTpRank,
472472
mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance,
473-
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mChunkPrefillBufferBatchSize, mFP8AttenOutput,
474-
mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree,
475-
mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA,
476-
mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads,
477-
mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast,
478-
mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
479-
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
473+
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
474+
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
475+
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
476+
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
477+
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
478+
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
479+
mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
480480
};
481481

482482
private:

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,9 @@ class TllmGenFmhaKernel
549549
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;
550550

551551
// Debug info.
552-
std::string info = "qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
552+
std::string info = "dtypeQ=" + std::to_string(static_cast<int>(mDtypeQ)) + ", dtypeKv="
553+
+ std::to_string(static_cast<int>(mDtypeKv)) + ", dtypeOut=" + std::to_string(static_cast<int>(mDtypeOut))
554+
+ ", sm=" + std::to_string(mSM) + ", qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
553555
+ ", maskType=" + std::to_string(static_cast<int>(selectKernelParams.mMaskType))
554556
+ ", kernelType=" + std::to_string(static_cast<int>(kernelType))
555557
+ ", tileScheduler=" + std::to_string(static_cast<int>(selectKernelParams.mTileScheduler))

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -538,38 +538,38 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
538538
{
539539
if (is_fp8_out)
540540
{
541-
runner.reset(new Runner<half, __nv_fp8_e4m3>());
541+
runner = std::make_shared<Runner<half, __nv_fp8_e4m3>>();
542542
}
543543
else if (is_fp4_out)
544544
{
545-
runner.reset(new Runner<half, __nv_fp4_e2m1>());
545+
runner = std::make_shared<Runner<half, __nv_fp4_e2m1>>();
546546
}
547547
else
548548
{
549549
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat16);
550-
runner.reset(new Runner<half>());
550+
runner = std::make_shared<Runner<half>>();
551551
}
552552
}
553553
else if (dtype == nvinfer1::DataType::kFLOAT)
554554
{
555555
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat32);
556-
runner.reset(new Runner<float>());
556+
runner = std::make_shared<Runner<float>>();
557557
}
558558
#ifdef ENABLE_BF16
559559
else if (dtype == nvinfer1::DataType::kBF16)
560560
{
561561
if (is_fp8_out)
562562
{
563-
runner.reset(new Runner<__nv_bfloat16, __nv_fp8_e4m3>());
563+
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp8_e4m3>>();
564564
}
565565
else if (is_fp4_out)
566566
{
567-
runner.reset(new Runner<__nv_bfloat16, __nv_fp4_e2m1>());
567+
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp4_e2m1>>();
568568
}
569569
else
570570
{
571571
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kBFloat16);
572-
runner.reset(new Runner<__nv_bfloat16>());
572+
runner = std::make_shared<Runner<__nv_bfloat16>>();
573573
}
574574
}
575575
#endif
@@ -637,7 +637,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
637637
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
638638
static_cast<int>(layer_num)};
639639

640-
op->mFP8ContextMLA = (tensorrt_llm::common::getSMVersion() == 120 || tensorrt_llm::common::getSMVersion() == 90)
640+
op->mFP8ContextMLA = (tensorrt_llm::common::getSMVersion() == 90 || tensorrt_llm::common::getSMVersion() == 100
641+
|| tensorrt_llm::common::getSMVersion() == 120)
641642
&& op->mKVCacheQuantMode.hasFp8KvCache();
642643
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
643644
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();

tensorrt_llm/_torch/modules/attention.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ def create_weights(self):
301301
# which could be modified after __init__
302302
self.attn.update_quant_config(self.quant_config)
303303

304+
self.o_proj.create_weights()
305+
self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
306+
or self.o_proj.has_fp8_block_scales
307+
or self.o_proj.has_fp8_rowwise
308+
or self.o_proj.has_w4a8_nvfp4_fp8)
309+
304310
def split_qkv(self, q, k=None, v=None):
305311
if k is None and v is None:
306312
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -320,12 +326,8 @@ def create_output(self, q: torch.Tensor):
320326
out_dtype = q.dtype
321327

322328
if self.attn_backend == "TRTLLM":
323-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
324-
or self.o_proj.has_fp8_block_scales
325-
or self.o_proj.has_fp8_rowwise
326-
or self.o_proj.has_w4a8_nvfp4_fp8)
327-
if has_quant_scale and (self.attn.has_fp8_kv_cache
328-
or self.attn.has_fp4_kv_cache):
329+
if self.has_quant_scale and (self.attn.has_fp8_kv_cache
330+
or self.attn.has_fp4_kv_cache):
329331
out_dtype = torch.float8_e4m3fn
330332
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
331333
return output
@@ -356,11 +358,7 @@ def _attn_impl(
356358

357359
out_scale = None
358360
out_scale_sf = None
359-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
360-
or self.o_proj.has_fp8_block_scales
361-
or self.o_proj.has_fp8_rowwise
362-
or self.o_proj.has_w4a8_nvfp4_fp8)
363-
if has_quant_scale:
361+
if self.has_quant_scale:
364362
out_scale = self.o_proj.inv_input_scale
365363
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
366364
out_scale_sf = self.o_proj.input_scale
@@ -858,6 +856,9 @@ def create_weights(self):
858856
self.mha.update_quant_config(self.quant_config)
859857
self.mqa.update_quant_config(self.quant_config)
860858

859+
# Although we use FP8 MLA for context/generation phase, the output is still in BF16
860+
self.out_scale = None
861+
861862
# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
862863
# which can be modified after __init__
863864
has_fp8_block_scales = (
@@ -1061,17 +1062,14 @@ def forward_context_default(
10611062
self.qk_rope_head_dim)
10621063
k = k.view(-1, self.num_heads * self.qk_head_dim)
10631064

1064-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1065-
out_scale = None # Currently we use BF16 MHA for context phase
1066-
10671065
attn_output = self.mha.forward(
10681066
q,
10691067
k,
10701068
v,
10711069
attn_metadata,
10721070
attention_input_type=AttentionInputType.context_only,
10731071
latent_cache=latent_cache,
1074-
out_scale=out_scale,
1072+
out_scale=self.out_scale,
10751073
output=output,
10761074
)
10771075

@@ -1126,9 +1124,6 @@ def forward_context_with_cached_kv(
11261124
full_kv = None
11271125
full_k_nope = None
11281126

1129-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1130-
out_scale = None # Currently we use BF16 MHA for context phase
1131-
11321127
# latent_cache must be None to differentiate from normal context phase,
11331128
# so that we can skip applying RoPE and appending KV cache inside attention op
11341129
attn_output = self.mha.forward(
@@ -1138,7 +1133,7 @@ def forward_context_with_cached_kv(
11381133
attn_metadata,
11391134
attention_input_type=AttentionInputType.context_only,
11401135
latent_cache=None,
1141-
out_scale=out_scale,
1136+
out_scale=self.out_scale,
11421137
output=output,
11431138
)
11441139

@@ -1232,7 +1227,6 @@ def forward_context_with_chunked_prefill(
12321227
loop_idx]
12331228
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens
12341229

1235-
out_scale = None
12361230
# do not apply mask for attention within loop
12371231
# latent_cache must be None to differentiate from normal context phase,
12381232
# so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1243,7 +1237,7 @@ def forward_context_with_chunked_prefill(
12431237
attn_metadata,
12441238
attention_input_type=AttentionInputType.context_only,
12451239
latent_cache=None,
1246-
out_scale=out_scale,
1240+
out_scale=self.out_scale,
12471241
attention_mask=PredefinedAttentionMask.FULL,
12481242
softmax_stats_tensor=self.temp_softmax_stats_tensor,
12491243
chunked_prefill_buffer_batch_size=attn_metadata.
@@ -1284,9 +1278,6 @@ def forward_context_with_chunked_prefill(
12841278
num_contexts].sum().item(
12851279
)
12861280

1287-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1288-
out_scale = None # Currently we use BF16 MHA for context phase
1289-
12901281
# latent_cache must be None to differentiate from normal context phase,
12911282
# so that we can skip applying RoPE and appending KV cache inside attention op
12921283
temp_attn_output = self.mha.forward(
@@ -1296,7 +1287,7 @@ def forward_context_with_chunked_prefill(
12961287
attn_metadata,
12971288
attention_input_type=AttentionInputType.context_only,
12981289
latent_cache=None,
1299-
out_scale=out_scale,
1290+
out_scale=self.out_scale,
13001291
softmax_stats_tensor=self.temp_softmax_stats_tensor,
13011292
chunked_prefill_buffer_batch_size=attn_metadata.runtime_features.
13021293
chunked_prefill_buffer_batch_size,
@@ -1394,16 +1385,13 @@ def forward_generation(
13941385
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
13951386
])
13961387

1397-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1398-
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1399-
14001388
attn_out_latent = self.mqa.forward(
14011389
fused_q,
14021390
None,
14031391
None,
14041392
attn_metadata,
14051393
attention_input_type=AttentionInputType.generation_only,
1406-
out_scale=out_scale,
1394+
out_scale=self.out_scale,
14071395
latent_cache=latent_cache, # kvcache and k_pe
14081396
q_pe=q_pe, # used by `invokeMLARopeGeneration`
14091397
)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ def init_meta_tensor(t: torch.Tensor):
10021002

10031003
except Exception:
10041004
logger.info(
1005-
f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n"
1005+
f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n"
10061006
)
10071007
model = AutoModelForCausalLM.from_config(config)
10081008

tensorrt_llm/executor/worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,10 @@ def _deduce_max_tokens(request: GenerationRequest,
512512
else:
513513
# use max_tokens if can't deduce default_max_tokens
514514
return max_tokens
515+
if executor_config is not None:
516+
assert (
517+
len(prompt_token_ids) <= executor_config.max_seq_len
518+
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
515519
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
516520
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
517521
if default_max_tokens <= 0:
@@ -892,6 +896,7 @@ def notify_proxy_threads_to_quit():
892896
worker.submit(req)
893897
except RequestError as e:
894898
logger.error(f"submit request failed: {e}")
899+
logger.error(traceback.format_exc())
895900
worker._await_response_helper.temp_error_responses.put(
896901
ErrorResponse(req.id, e, req.id))
897902
else:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,7 +1212,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
12121212
task = GSM8K(self.MODEL_NAME)
12131213
task.evaluate(llm)
12141214

1215-
@skip_no_hopper
1215+
@skip_pre_hopper
12161216
@parametrize_with_ids("torch_compile", [False, True])
12171217
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
12181218
[(False, False, False, False),
@@ -1236,6 +1236,8 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
12361236
disable_overlap_scheduler=not overlap_scheduler,
12371237
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
12381238
torch_compile_config=torch_compile_config,
1239+
moe_config=MoeConfig(
1240+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
12391241
)
12401242

12411243
if fp8kv:
@@ -1311,7 +1313,7 @@ def test_cute_dsl_fp8_block_scales(
13111313
task = GSM8K(self.MODEL_NAME)
13121314
task.evaluate(llm)
13131315

1314-
@pytest.mark.skip_device_not_contain(["H100"])
1316+
@skip_pre_hopper
13151317
@parametrize_with_ids("mtp_nextn", [0, 2])
13161318
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13171319
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
@@ -1324,6 +1326,8 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13241326
max_batch_size=512,
13251327
enable_padding=True,
13261328
),
1329+
moe_config=MoeConfig(
1330+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13271331
)
13281332
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
13291333
kv_cache_config=kv_cache_config,
@@ -1334,7 +1338,7 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13341338
task.evaluate(llm)
13351339

13361340
@pytest.mark.skip_less_device(4)
1337-
@skip_no_hopper
1341+
@skip_pre_hopper
13381342
@parametrize_with_ids("mtp_nextn", [0, 2])
13391343
@parametrize_with_ids("attention_dp", [False, True])
13401344
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
@@ -1346,6 +1350,8 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
13461350
pytorch_config = dict(
13471351
disable_overlap_scheduler=False,
13481352
cuda_graph_config=CudaGraphConfig(enable_padding=True),
1353+
moe_config=MoeConfig(
1354+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13491355
)
13501356

13511357
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
@@ -1359,7 +1365,7 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
13591365
task.evaluate(llm)
13601366

13611367
@pytest.mark.skip_less_device(4)
1362-
@skip_no_hopper
1368+
@skip_pre_hopper
13631369
@parametrize_with_ids("torch_compile", [False, True])
13641370
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
13651371
[(False, False, False, False),
@@ -1388,6 +1394,8 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
13881394
disable_overlap_scheduler=not overlap_scheduler,
13891395
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
13901396
torch_compile_config=torch_compile_config,
1397+
moe_config=MoeConfig(
1398+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13911399
)
13921400

13931401
if fp8kv:
@@ -1474,7 +1482,7 @@ def test_cute_dsl_fp8_block_scales_4gpus(
14741482
task.evaluate(llm)
14751483

14761484
@pytest.mark.skip_less_device(4)
1477-
@pytest.mark.skip_device_not_contain(["H100", "H200"])
1485+
@skip_pre_hopper
14781486
def test_fp8_block_scales_4gpus_static_eplb(self):
14791487
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
14801488

0 commit comments

Comments
 (0)