Skip to content

Commit 70af66a

Browse files
committed
FP8 Context MLA integration.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 1fbea49 commit 70af66a

File tree

8 files changed

+54
-58
lines changed

8 files changed

+54
-58
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2569,8 +2569,7 @@ int AttentionOp::initialize() noexcept
25692569
if (mIsMLAEnabled)
25702570
{
25712571
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "MLA(Deepseek v2) only support fmha");
2572-
TLLM_CHECK_WITH_INFO(
2573-
!mFP8ContextFMHA && !mDenseContextFMHA, "MLA(Deepseek v2) currently not support FP8 and dense fmha");
2572+
TLLM_CHECK_WITH_INFO(!mDenseContextFMHA, "MLA(Deepseek v2) currently not support dense fmha");
25742573
TLLM_CHECK_WITH_INFO(
25752574
mPagedKVCache && mUseKVCache && mRemovePadding, "MLA(Deepseek v2) only support paged kv cache");
25762575
TLLM_CHECK_WITH_INFO(!mCrossAttention, "MLA(Deepseek v2) do not support cross attention right now");
@@ -2736,11 +2735,6 @@ int AttentionOp::initialize() noexcept
27362735
qDataType = DATA_TYPE_E4M3;
27372736
kvDataType = DATA_TYPE_E4M3;
27382737
}
2739-
// When FP8 Context FMHA is enabled, the output data type needs to be E4M3.
2740-
if (mFP8ContextFMHA)
2741-
{
2742-
outputDataType = DATA_TYPE_E4M3;
2743-
}
27442738

27452739
// Instantiate the mTllmGenFMHARunner used for MLA
27462740
mTllmGenFMHARunner.reset(new TllmGenFmhaRunner(qDataType, kvDataType, outputDataType));

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,13 @@ class AttentionOp
468468
(int8_t) mPositionEmbeddingType, mUseLognScaling, mRemovePadding, (int32_t) mMaskType,
469469
mBlockSparseParams.data(), mPagedKVCache, mTokensPerBlock, mKVCacheQuantMode.value(), mTpSize, mTpRank,
470470
mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance,
471-
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mChunkPrefillBufferBatchSize, mFP8AttenOutput,
472-
mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree,
473-
mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA,
474-
mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads,
475-
mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast,
476-
mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
477-
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
471+
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8AttenOutput, mFP8ContextMLA, mFP8GenerationMLA,
472+
mChunkPrefillBufferBatchSize, mDenseContextFMHA, mHasFullAttentionMask, mIsSpecDecodingEnabled,
473+
mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength,
474+
mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup,
475+
mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank,
476+
mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache,
477+
mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
478478
};
479479

480480
private:

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,9 @@ class TllmGenFmhaKernel
549549
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;
550550

551551
// Debug info.
552-
std::string info = "qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
552+
std::string info = "dtypeQ=" + std::to_string(static_cast<int>(mDtypeQ)) + ", dtypeKv="
553+
+ std::to_string(static_cast<int>(mDtypeKv)) + ", dtypeOut=" + std::to_string(static_cast<int>(mDtypeOut))
554+
+ ", sm=" + std::to_string(mSM) + ", qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
553555
+ ", maskType=" + std::to_string(static_cast<int>(selectKernelParams.mMaskType))
554556
+ ", kernelType=" + std::to_string(static_cast<int>(kernelType))
555557
+ ", tileScheduler=" + std::to_string(static_cast<int>(selectKernelParams.mTileScheduler))

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -530,38 +530,38 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
530530
{
531531
if (is_fp8_out)
532532
{
533-
runner.reset(new Runner<half, __nv_fp8_e4m3>());
533+
runner = std::make_shared<Runner<half, __nv_fp8_e4m3>>();
534534
}
535535
else if (is_fp4_out)
536536
{
537-
runner.reset(new Runner<half, __nv_fp4_e2m1>());
537+
runner = std::make_shared<Runner<half, __nv_fp4_e2m1>>();
538538
}
539539
else
540540
{
541541
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat16);
542-
runner.reset(new Runner<half>());
542+
runner = std::make_shared<Runner<half>>();
543543
}
544544
}
545545
else if (dtype == nvinfer1::DataType::kFLOAT)
546546
{
547547
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat32);
548-
runner.reset(new Runner<float>());
548+
runner = std::make_shared<Runner<float>>();
549549
}
550550
#ifdef ENABLE_BF16
551551
else if (dtype == nvinfer1::DataType::kBF16)
552552
{
553553
if (is_fp8_out)
554554
{
555-
runner.reset(new Runner<__nv_bfloat16, __nv_fp8_e4m3>());
555+
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp8_e4m3>>();
556556
}
557557
else if (is_fp4_out)
558558
{
559-
runner.reset(new Runner<__nv_bfloat16, __nv_fp4_e2m1>());
559+
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp4_e2m1>>();
560560
}
561561
else
562562
{
563563
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kBFloat16);
564-
runner.reset(new Runner<__nv_bfloat16>());
564+
runner = std::make_shared<Runner<__nv_bfloat16>>();
565565
}
566566
}
567567
#endif
@@ -629,7 +629,8 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
629629
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
630630
static_cast<int>(layer_num)};
631631

632-
op->mFP8ContextMLA = (tensorrt_llm::common::getSMVersion() == 120 || tensorrt_llm::common::getSMVersion() == 90)
632+
op->mFP8ContextMLA = (tensorrt_llm::common::getSMVersion() == 90 || tensorrt_llm::common::getSMVersion() == 100
633+
|| tensorrt_llm::common::getSMVersion() == 120)
633634
&& op->mKVCacheQuantMode.hasFp8KvCache();
634635
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
635636
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();

tensorrt_llm/_torch/modules/attention.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ def create_weights(self):
301301
# which could be modified after __init__
302302
self.attn.update_quant_config(self.quant_config)
303303

304+
self.o_proj.create_weights()
305+
self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
306+
or self.o_proj.has_fp8_block_scales
307+
or self.o_proj.has_fp8_rowwise
308+
or self.o_proj.has_w4a8_nvfp4_fp8)
309+
304310
def split_qkv(self, q, k=None, v=None):
305311
if k is None and v is None:
306312
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -320,12 +326,8 @@ def create_output(self, q: torch.Tensor):
320326
out_dtype = q.dtype
321327

322328
if self.attn_backend == "TRTLLM":
323-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
324-
or self.o_proj.has_fp8_block_scales
325-
or self.o_proj.has_fp8_rowwise
326-
or self.o_proj.has_w4a8_nvfp4_fp8)
327-
if has_quant_scale and (self.attn.has_fp8_kv_cache
328-
or self.attn.has_fp4_kv_cache):
329+
if self.has_quant_scale and (self.attn.has_fp8_kv_cache
330+
or self.attn.has_fp4_kv_cache):
329331
out_dtype = torch.float8_e4m3fn
330332
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
331333
return output
@@ -356,11 +358,7 @@ def _attn_impl(
356358

357359
out_scale = None
358360
out_scale_sf = None
359-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
360-
or self.o_proj.has_fp8_block_scales
361-
or self.o_proj.has_fp8_rowwise
362-
or self.o_proj.has_w4a8_nvfp4_fp8)
363-
if has_quant_scale:
361+
if self.has_quant_scale:
364362
out_scale = self.o_proj.inv_input_scale
365363
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
366364
out_scale_sf = self.o_proj.input_scale
@@ -858,6 +856,9 @@ def create_weights(self):
858856
self.mha.update_quant_config(self.quant_config)
859857
self.mqa.update_quant_config(self.quant_config)
860858

859+
# Although we use FP8 MLA for context/generation phase, the output is still in BF16
860+
self.out_scale = None
861+
861862
# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
862863
# which can be modified after __init__
863864
has_fp8_block_scales = (
@@ -1061,17 +1062,14 @@ def forward_context_default(
10611062
self.qk_rope_head_dim)
10621063
k = k.view(-1, self.num_heads * self.qk_head_dim)
10631064

1064-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1065-
out_scale = None # Currently we use BF16 MHA for context phase
1066-
10671065
attn_output = self.mha.forward(
10681066
q,
10691067
k,
10701068
v,
10711069
attn_metadata,
10721070
attention_input_type=AttentionInputType.context_only,
10731071
latent_cache=latent_cache,
1074-
out_scale=out_scale,
1072+
out_scale=self.out_scale,
10751073
output=output,
10761074
)
10771075

@@ -1126,9 +1124,6 @@ def forward_context_with_cached_kv(
11261124
full_kv = None
11271125
full_k_nope = None
11281126

1129-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1130-
out_scale = None # Currently we use BF16 MHA for context phase
1131-
11321127
# latent_cache must be None to differentiate from normal context phase,
11331128
# so that we can skip applying RoPE and appending KV cache inside attention op
11341129
attn_output = self.mha.forward(
@@ -1138,7 +1133,7 @@ def forward_context_with_cached_kv(
11381133
attn_metadata,
11391134
attention_input_type=AttentionInputType.context_only,
11401135
latent_cache=None,
1141-
out_scale=out_scale,
1136+
out_scale=self.out_scale,
11421137
output=output,
11431138
)
11441139

@@ -1232,7 +1227,6 @@ def forward_context_with_chunked_prefill(
12321227
loop_idx]
12331228
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens
12341229

1235-
out_scale = None
12361230
# do not apply mask for attention within loop
12371231
# latent_cache must be None to differentiate from normal context phase,
12381232
# so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1243,7 +1237,7 @@ def forward_context_with_chunked_prefill(
12431237
attn_metadata,
12441238
attention_input_type=AttentionInputType.context_only,
12451239
latent_cache=None,
1246-
out_scale=out_scale,
1240+
out_scale=self.out_scale,
12471241
attention_mask=PredefinedAttentionMask.FULL,
12481242
softmax_stats_tensor=self.temp_softmax_stats_tensor,
12491243
chunked_prefill_buffer_batch_size=attn_metadata.
@@ -1284,9 +1278,6 @@ def forward_context_with_chunked_prefill(
12841278
num_contexts].sum().item(
12851279
)
12861280

1287-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1288-
out_scale = None # Currently we use BF16 MHA for context phase
1289-
12901281
# latent_cache must be None to differentiate from normal context phase,
12911282
# so that we can skip applying RoPE and appending KV cache inside attention op
12921283
temp_attn_output = self.mha.forward(
@@ -1296,7 +1287,7 @@ def forward_context_with_chunked_prefill(
12961287
attn_metadata,
12971288
attention_input_type=AttentionInputType.context_only,
12981289
latent_cache=None,
1299-
out_scale=out_scale,
1290+
out_scale=self.out_scale,
13001291
softmax_stats_tensor=self.temp_softmax_stats_tensor,
13011292
chunked_prefill_buffer_batch_size=attn_metadata.runtime_features.
13021293
chunked_prefill_buffer_batch_size,
@@ -1394,16 +1385,13 @@ def forward_generation(
13941385
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
13951386
])
13961387

1397-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1398-
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1399-
14001388
attn_out_latent = self.mqa.forward(
14011389
fused_q,
14021390
None,
14031391
None,
14041392
attn_metadata,
14051393
attention_input_type=AttentionInputType.generation_only,
1406-
out_scale=out_scale,
1394+
out_scale=self.out_scale,
14071395
latent_cache=latent_cache, # kvcache and k_pe
14081396
q_pe=q_pe, # used by `invokeMLARopeGeneration`
14091397
)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ def init_meta_tensor(t: torch.Tensor):
10081008

10091009
except Exception:
10101010
logger.info(
1011-
f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n"
1011+
f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n"
10121012
)
10131013
model = AutoModelForCausalLM.from_config(config)
10141014

tensorrt_llm/executor/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,9 @@ def _deduce_max_tokens(request: GenerationRequest,
512512
else:
513513
# use max_tokens if can't deduce default_max_tokens
514514
return max_tokens
515+
assert (
516+
len(prompt_token_ids) <= executor_config.max_seq_len
517+
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
515518
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
516519
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
517520
if default_max_tokens <= 0:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,7 +1214,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
12141214
task = GSM8K(self.MODEL_NAME)
12151215
task.evaluate(llm)
12161216

1217-
@skip_no_hopper
1217+
@skip_pre_hopper
12181218
@parametrize_with_ids("torch_compile", [False, True])
12191219
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
12201220
[(False, False, False, False),
@@ -1238,6 +1238,8 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
12381238
disable_overlap_scheduler=not overlap_scheduler,
12391239
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
12401240
torch_compile_config=torch_compile_config,
1241+
moe_config=MoeConfig(
1242+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
12411243
)
12421244

12431245
if fp8kv:
@@ -1313,7 +1315,7 @@ def test_cute_dsl_fp8_block_scales(
13131315
task = GSM8K(self.MODEL_NAME)
13141316
task.evaluate(llm)
13151317

1316-
@pytest.mark.skip_device_not_contain(["H100"])
1318+
@skip_pre_hopper
13171319
@parametrize_with_ids("mtp_nextn", [0, 2])
13181320
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13191321
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
@@ -1326,6 +1328,8 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13261328
max_batch_size=512,
13271329
enable_padding=True,
13281330
),
1331+
moe_config=MoeConfig(
1332+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13291333
)
13301334
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
13311335
kv_cache_config=kv_cache_config,
@@ -1336,7 +1340,7 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13361340
task.evaluate(llm)
13371341

13381342
@pytest.mark.skip_less_device(4)
1339-
@skip_no_hopper
1343+
@skip_pre_hopper
13401344
@parametrize_with_ids("mtp_nextn", [0, 2])
13411345
@parametrize_with_ids("attention_dp", [False, True])
13421346
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
@@ -1348,6 +1352,8 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
13481352
pytorch_config = dict(
13491353
disable_overlap_scheduler=False,
13501354
cuda_graph_config=CudaGraphConfig(enable_padding=True),
1355+
moe_config=MoeConfig(
1356+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13511357
)
13521358

13531359
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
@@ -1361,7 +1367,7 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
13611367
task.evaluate(llm)
13621368

13631369
@pytest.mark.skip_less_device(4)
1364-
@skip_no_hopper
1370+
@skip_pre_hopper
13651371
@parametrize_with_ids("torch_compile", [False, True])
13661372
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
13671373
[(False, False, False, False),
@@ -1390,6 +1396,8 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
13901396
disable_overlap_scheduler=not overlap_scheduler,
13911397
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
13921398
torch_compile_config=torch_compile_config,
1399+
moe_config=MoeConfig(
1400+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13931401
)
13941402

13951403
if fp8kv:
@@ -1476,7 +1484,7 @@ def test_cute_dsl_fp8_block_scales_4gpus(
14761484
task.evaluate(llm)
14771485

14781486
@pytest.mark.skip_less_device(4)
1479-
@pytest.mark.skip_device_not_contain(["H100", "H200"])
1487+
@skip_pre_hopper
14801488
def test_fp8_block_scales_4gpus_static_eplb(self):
14811489
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
14821490

0 commit comments

Comments
 (0)