Skip to content

Commit 7961672

Browse files
committed
FP8 Context MLA integration.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 219e955 commit 7961672

File tree

8 files changed

+56
-59
lines changed

8 files changed

+56
-59
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2570,8 +2570,7 @@ int AttentionOp::initialize() noexcept
25702570
if (mIsMLAEnabled)
25712571
{
25722572
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "MLA(Deepseek v2) only support fmha");
2573-
TLLM_CHECK_WITH_INFO(
2574-
!mFP8ContextFMHA && !mDenseContextFMHA, "MLA(Deepseek v2) currently not support FP8 and dense fmha");
2573+
TLLM_CHECK_WITH_INFO(!mDenseContextFMHA, "MLA(Deepseek v2) currently not support dense fmha");
25752574
TLLM_CHECK_WITH_INFO(
25762575
mPagedKVCache && mUseKVCache && mRemovePadding, "MLA(Deepseek v2) only support paged kv cache");
25772576
TLLM_CHECK_WITH_INFO(!mCrossAttention, "MLA(Deepseek v2) do not support cross attention right now");
@@ -2736,11 +2735,6 @@ int AttentionOp::initialize() noexcept
27362735
qDataType = DATA_TYPE_E4M3;
27372736
kvDataType = DATA_TYPE_E4M3;
27382737
}
2739-
// When FP8 Context FMHA is enabled, the output data type needs to be E4M3.
2740-
if (mFP8ContextFMHA)
2741-
{
2742-
outputDataType = DATA_TYPE_E4M3;
2743-
}
27442738

27452739
// Instantiate the mTllmGenFMHARunner used for MLA
27462740
mTllmGenFMHARunner.reset(new TllmGenFmhaRunner(qDataType, kvDataType, outputDataType));

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,13 +466,13 @@ class AttentionOp
466466
(int8_t) mPositionEmbeddingType, mUseLognScaling, mRemovePadding, (int32_t) mMaskType,
467467
mBlockSparseParams.data(), mPagedKVCache, mTokensPerBlock, mKVCacheQuantMode.value(), mTpSize, mTpRank,
468468
mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance,
469-
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
470-
mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable,
471-
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(),
472-
mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank,
473-
mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode,
474-
mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores,
475-
mAttentionChunkSize.value_or(-1));
469+
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8ContextMLA, mDenseContextFMHA,
470+
mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree,
471+
mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA,
472+
mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads,
473+
mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast,
474+
mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
475+
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
476476
};
477477

478478
private:

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,9 @@ class TllmGenFmhaKernel
541541
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;
542542

543543
// Debug info.
544-
std::string info = "qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
544+
std::string info = "dtypeQ=" + std::to_string(static_cast<int>(mDtypeQ)) + ", dtypeKv="
545+
+ std::to_string(static_cast<int>(mDtypeKv)) + ", dtypeOut=" + std::to_string(static_cast<int>(mDtypeOut))
546+
+ ", sm=" + std::to_string(mSM) + ", qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
545547
+ ", maskType=" + std::to_string(static_cast<int>(selectKernelParams.mMaskType))
546548
+ ", kernelType=" + std::to_string(static_cast<int>(kernelType))
547549
+ ", tileScheduler=" + std::to_string(static_cast<int>(selectKernelParams.mTileScheduler))

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -529,38 +529,38 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
529529
{
530530
if (is_fp8_out)
531531
{
532-
runner.reset(new Runner<half, __nv_fp8_e4m3>());
532+
runner = std::make_shared<Runner<half, __nv_fp8_e4m3>>();
533533
}
534534
else if (is_fp4_out)
535535
{
536-
runner.reset(new Runner<half, __nv_fp4_e2m1>());
536+
runner = std::make_shared<Runner<half, __nv_fp4_e2m1>>();
537537
}
538538
else
539539
{
540540
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat16);
541-
runner.reset(new Runner<half>());
541+
runner = std::make_shared<Runner<half>>();
542542
}
543543
}
544544
else if (dtype == nvinfer1::DataType::kFLOAT)
545545
{
546546
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat32);
547-
runner.reset(new Runner<float>());
547+
runner = std::make_shared<Runner<float>>();
548548
}
549549
#ifdef ENABLE_BF16
550550
else if (dtype == nvinfer1::DataType::kBF16)
551551
{
552552
if (is_fp8_out)
553553
{
554-
runner.reset(new Runner<__nv_bfloat16, __nv_fp8_e4m3>());
554+
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp8_e4m3>>();
555555
}
556556
else if (is_fp4_out)
557557
{
558-
runner.reset(new Runner<__nv_bfloat16, __nv_fp4_e2m1>());
558+
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp4_e2m1>>();
559559
}
560560
else
561561
{
562562
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kBFloat16);
563-
runner.reset(new Runner<__nv_bfloat16>());
563+
runner = std::make_shared<Runner<__nv_bfloat16>>();
564564
}
565565
}
566566
#endif
@@ -578,13 +578,13 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
578578
auto op = std::make_shared<AttentionOp>();
579579
op->mType = dtype;
580580
op->mFMHAForceFP32Acc = dtype == nvinfer1::DataType::kBF16;
581+
op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode));
581582
op->mFP8ContextFMHA = is_fp8_out || is_fp4_out;
582583
op->mLayerIdx = layer_idx;
583584
op->mNumHeads = num_heads;
584585
op->mNumKVHeads = num_kv_heads;
585586
op->mHeadSize = head_size;
586587
op->mMaskType = static_cast<tensorrt_llm::kernels::AttentionMaskType>(int32_t(mask_type));
587-
op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode));
588588
op->mUseKVCache = use_kv_cache;
589589
op->mPagedKVCache = op->mPagedKVCache && use_kv_cache; // update mPagedKVCache based on use_kv_cache
590590
op->mTokensPerBlock = tokens_per_block.value_or(0);
@@ -627,7 +627,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
627627
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
628628
static_cast<int>(layer_num)};
629629

630-
op->mFP8ContextMLA = tensorrt_llm::common::getSMVersion() == 120 && op->mKVCacheQuantMode.hasFp8KvCache();
630+
op->mFP8ContextMLA
631+
= (tensorrt_llm::common::getSMVersion() == 100 || tensorrt_llm::common::getSMVersion() == 120)
632+
&& op->mKVCacheQuantMode.hasFp8KvCache();
631633
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
632634
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
633635
// only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64

tensorrt_llm/_torch/modules/attention.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,12 @@ def create_weights(self):
295295
# which could be modified after __init__
296296
self.attn.update_quant_config(self.quant_config)
297297

298+
self.o_proj.create_weights()
299+
self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
300+
or self.o_proj.has_fp8_block_scales
301+
or self.o_proj.has_fp8_rowwise
302+
or self.o_proj.has_w4a8_nvfp4_fp8)
303+
298304
def split_qkv(self, q, k=None, v=None):
299305
if k is None and v is None:
300306
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -314,12 +320,8 @@ def create_output(self, q: torch.Tensor):
314320
out_dtype = q.dtype
315321

316322
if self.attn_backend == "TRTLLM":
317-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
318-
or self.o_proj.has_fp8_block_scales
319-
or self.o_proj.has_fp8_rowwise
320-
or self.o_proj.has_w4a8_nvfp4_fp8)
321-
if has_quant_scale and (self.attn.has_fp8_kv_cache
322-
or self.attn.has_fp4_kv_cache):
323+
if self.has_quant_scale and (self.attn.has_fp8_kv_cache
324+
or self.attn.has_fp4_kv_cache):
323325
out_dtype = torch.float8_e4m3fn
324326
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
325327
return output
@@ -350,11 +352,7 @@ def _attn_impl(
350352

351353
out_scale = None
352354
out_scale_sf = None
353-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
354-
or self.o_proj.has_fp8_block_scales
355-
or self.o_proj.has_fp8_rowwise
356-
or self.o_proj.has_w4a8_nvfp4_fp8)
357-
if has_quant_scale:
355+
if self.has_quant_scale:
358356
out_scale = self.o_proj.inv_input_scale
359357
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
360358
out_scale_sf = self.o_proj.input_scale
@@ -847,6 +845,9 @@ def create_weights(self):
847845
self.mha.update_quant_config(self.quant_config)
848846
self.mqa.update_quant_config(self.quant_config)
849847

848+
# Although we use FP8 MLA for context/generation phase, the output is still in BF16
849+
self.out_scale = None
850+
850851
# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
851852
# which can be modified after __init__
852853
has_fp8_block_scales = (
@@ -1050,17 +1051,14 @@ def forward_context_default(
10501051
self.qk_rope_head_dim)
10511052
k = k.view(-1, self.num_heads * self.qk_head_dim)
10521053

1053-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1054-
out_scale = None # Currently we use BF16 MHA for context phase
1055-
10561054
attn_output = self.mha.forward(
10571055
q,
10581056
k,
10591057
v,
10601058
attn_metadata,
10611059
attention_input_type=AttentionInputType.context_only,
10621060
latent_cache=latent_cache,
1063-
out_scale=out_scale,
1061+
out_scale=self.out_scale,
10641062
output=output,
10651063
)
10661064

@@ -1115,9 +1113,6 @@ def forward_context_with_cached_kv(
11151113
full_kv = None
11161114
full_k_nope = None
11171115

1118-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1119-
out_scale = None # Currently we use BF16 MHA for context phase
1120-
11211116
# latent_cache must be None to differentiate from normal context phase,
11221117
# so that we can skip applying RoPE and appending KV cache inside attention op
11231118
attn_output = self.mha.forward(
@@ -1127,7 +1122,7 @@ def forward_context_with_cached_kv(
11271122
attn_metadata,
11281123
attention_input_type=AttentionInputType.context_only,
11291124
latent_cache=None,
1130-
out_scale=out_scale,
1125+
out_scale=self.out_scale,
11311126
output=output,
11321127
)
11331128

@@ -1217,7 +1212,6 @@ def forward_context_with_chunked_prefill(
12171212
loop_idx]
12181213
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens
12191214

1220-
out_scale = None
12211215
# do not apply mask for attention within loop
12221216
# latent_cache must be None to differentiate from normal context phase,
12231217
# so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1228,7 +1222,7 @@ def forward_context_with_chunked_prefill(
12281222
attn_metadata,
12291223
attention_input_type=AttentionInputType.context_only,
12301224
latent_cache=None,
1231-
out_scale=out_scale,
1225+
out_scale=self.out_scale,
12321226
attention_mask=PredefinedAttentionMask.FULL,
12331227
softmax_stats_tensor=self.temp_softmax_stats_tensor,
12341228
output=temp_attn_output,
@@ -1267,9 +1261,6 @@ def forward_context_with_chunked_prefill(
12671261
num_contexts].sum().item(
12681262
)
12691263

1270-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1271-
out_scale = None # Currently we use BF16 MHA for context phase
1272-
12731264
# latent_cache must be None to differentiate from normal context phase,
12741265
# so that we can skip applying RoPE and appending KV cache inside attention op
12751266
temp_attn_output = self.mha.forward(
@@ -1279,7 +1270,7 @@ def forward_context_with_chunked_prefill(
12791270
attn_metadata,
12801271
attention_input_type=AttentionInputType.context_only,
12811272
latent_cache=None,
1282-
out_scale=out_scale,
1273+
out_scale=self.out_scale,
12831274
softmax_stats_tensor=self.temp_softmax_stats_tensor,
12841275
output=temp_attn_output,
12851276
)
@@ -1375,16 +1366,13 @@ def forward_generation(
13751366
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
13761367
])
13771368

1378-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1379-
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1380-
13811369
attn_out_latent = self.mqa.forward(
13821370
fused_q,
13831371
None,
13841372
None,
13851373
attn_metadata,
13861374
attention_input_type=AttentionInputType.generation_only,
1387-
out_scale=out_scale,
1375+
out_scale=self.out_scale,
13881376
latent_cache=latent_cache, # kvcache and k_pe
13891377
q_pe=q_pe, # used by `invokeMLARopeGeneration`
13901378
)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1008,7 +1008,7 @@ def init_meta_tensor(t: torch.Tensor):
10081008

10091009
except Exception:
10101010
logger.info(
1011-
f"Fallback to regular model init: {traceback.format_exc(limit=1)}\n"
1011+
f"Fallback to regular model init: {traceback.format_exc(limit=10)}\n"
10121012
)
10131013
model = AutoModelForCausalLM.from_config(config)
10141014

tensorrt_llm/executor/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,9 @@ def _deduce_max_tokens(request: GenerationRequest,
512512
else:
513513
# use max_tokens if can't deduce default_max_tokens
514514
return max_tokens
515+
assert (
516+
len(prompt_token_ids) <= executor_config.max_seq_len
517+
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
515518
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
516519
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
517520
if default_max_tokens <= 0:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,7 +1212,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
12121212
task = GSM8K(self.MODEL_NAME)
12131213
task.evaluate(llm)
12141214

1215-
@skip_no_hopper
1215+
@skip_pre_hopper
12161216
@parametrize_with_ids("torch_compile", [False, True])
12171217
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
12181218
[(False, False, False, False),
@@ -1236,6 +1236,8 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
12361236
disable_overlap_scheduler=not overlap_scheduler,
12371237
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
12381238
torch_compile_config=torch_compile_config,
1239+
moe_config=MoeConfig(
1240+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
12391241
)
12401242

12411243
if fp8kv:
@@ -1311,7 +1313,7 @@ def test_cute_dsl_fp8_block_scales(
13111313
task = GSM8K(self.MODEL_NAME)
13121314
task.evaluate(llm)
13131315

1314-
@pytest.mark.skip_device_not_contain(["H100"])
1316+
@skip_pre_hopper
13151317
@parametrize_with_ids("mtp_nextn", [0, 2])
13161318
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13171319
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
@@ -1324,6 +1326,8 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13241326
max_batch_size=512,
13251327
enable_padding=True,
13261328
),
1329+
moe_config=MoeConfig(
1330+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13271331
)
13281332
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
13291333
kv_cache_config=kv_cache_config,
@@ -1334,7 +1338,7 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
13341338
task.evaluate(llm)
13351339

13361340
@pytest.mark.skip_less_device(4)
1337-
@skip_no_hopper
1341+
@skip_pre_hopper
13381342
@parametrize_with_ids("mtp_nextn", [0, 2])
13391343
@parametrize_with_ids("attention_dp", [False, True])
13401344
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
@@ -1346,6 +1350,8 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
13461350
pytorch_config = dict(
13471351
disable_overlap_scheduler=False,
13481352
cuda_graph_config=CudaGraphConfig(enable_padding=True),
1353+
moe_config=MoeConfig(
1354+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13491355
)
13501356

13511357
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
@@ -1359,7 +1365,7 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
13591365
task.evaluate(llm)
13601366

13611367
@pytest.mark.skip_less_device(4)
1362-
@skip_no_hopper
1368+
@skip_pre_hopper
13631369
@parametrize_with_ids("torch_compile", [False, True])
13641370
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
13651371
[(False, False, False, False),
@@ -1388,6 +1394,8 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
13881394
disable_overlap_scheduler=not overlap_scheduler,
13891395
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
13901396
torch_compile_config=torch_compile_config,
1397+
moe_config=MoeConfig(
1398+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13911399
)
13921400

13931401
if fp8kv:
@@ -1474,7 +1482,7 @@ def test_cute_dsl_fp8_block_scales_4gpus(
14741482
task.evaluate(llm)
14751483

14761484
@pytest.mark.skip_less_device(4)
1477-
@pytest.mark.skip_device_not_contain(["H100", "H200"])
1485+
@skip_pre_hopper
14781486
def test_fp8_block_scales_4gpus_static_eplb(self):
14791487
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
14801488

0 commit comments

Comments
 (0)