Skip to content

Commit 2f1bd9c

Browse files
committed
FP8 Context MLA integration.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 6a5806b commit 2f1bd9c

File tree

7 files changed

+61
-63
lines changed

7 files changed

+61
-63
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,8 +2521,7 @@ int AttentionOp::initialize() noexcept
25212521
if (mIsMLAEnabled)
25222522
{
25232523
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "MLA(Deepseek v2) only support fmha");
2524-
TLLM_CHECK_WITH_INFO(
2525-
!mFP8ContextFMHA && !mDenseContextFMHA, "MLA(Deepseek v2) currently not support FP8 and dense fmha");
2524+
TLLM_CHECK_WITH_INFO(!mDenseContextFMHA, "MLA(Deepseek v2) currently not support dense fmha");
25262525
TLLM_CHECK_WITH_INFO(
25272526
mPagedKVCache && mUseKVCache && mRemovePadding, "MLA(Deepseek v2) only support paged kv cache");
25282527
TLLM_CHECK_WITH_INFO(!mCrossAttention, "MLA(Deepseek v2) do not support cross attention right now");
@@ -2684,11 +2683,6 @@ int AttentionOp::initialize() noexcept
26842683
qDataType = DATA_TYPE_E4M3;
26852684
kvDataType = DATA_TYPE_E4M3;
26862685
}
2687-
// When FP8 Context FMHA is enabled, the output data type needs to be E4M3.
2688-
if (mFP8ContextFMHA)
2689-
{
2690-
outputDataType = DATA_TYPE_E4M3;
2691-
}
26922686

26932687
// Instantiate the mTllmGenFMHARunner used for MLA
26942688
mTllmGenFMHARunner.reset(new TllmGenFmhaRunner(qDataType, kvDataType, outputDataType));

cpp/tensorrt_llm/common/attentionOp.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -450,13 +450,13 @@ class AttentionOp
450450
(int8_t) mPositionEmbeddingType, mUseLognScaling, mRemovePadding, (int32_t) mMaskType,
451451
mBlockSparseParams.data(), mPagedKVCache, mTokensPerBlock, mKVCacheQuantMode.value(), mTpSize, mTpRank,
452452
mUnfuseQkvGemm, (int32_t) mType, mMaxContextLength, mQKVBiasEnabled, mCrossAttention, mMaxDistance,
453-
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mDenseContextFMHA, mHasFullAttentionMask,
454-
mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree, mSpecDecodingIsGenerationLengthVariable,
455-
mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA, mUseGenFlashMLA, mMLAParams.data(),
456-
mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads, mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank,
457-
mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast, mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode,
458-
mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant, mNbMultiBlockSemaphores,
459-
mAttentionChunkSize.value_or(-1));
453+
mPosShiftEnabled, mPagedContextFMHA, mFP8ContextFMHA, mFP8ContextMLA, mDenseContextFMHA,
454+
mHasFullAttentionMask, mIsSpecDecodingEnabled, mUseSpecDecoding, mIsSpecDecTree,
455+
mSpecDecodingIsGenerationLengthVariable, mSpecDecodingMaxGenerationLength, mIsMLAEnabled, mIsGenerationMLA,
456+
mUseGenFlashMLA, mMLAParams.data(), mCpSize, mCpRank, mCpGroup, mNumAttnHeads, mNumAttnKVHeads,
457+
mNumKVHeadsOrigin, mAttnTpSize, mAttnTpRank, mAttnCpSize, mAttnCpRank, mUlyssesMQABroadcast,
458+
mEnableContextFMHA, mFMHAForceFP32Acc, mMultiBlockMode, mEnableXQA, mUseKVCache, mSkipAttn, mFuseFp4Quant,
459+
mNbMultiBlockSemaphores, mAttentionChunkSize.value_or(-1));
460460
};
461461

462462
private:

cpp/tensorrt_llm/kernels/trtllmGenKernels/fmha/fmhaKernels.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,9 @@ class TllmGenFmhaKernel
526526
int numTokensPerPage = (!isPagedKv(params.mQkvLayout)) ? 0 : params.mNumTokensPerPage;
527527

528528
// Debug info.
529-
std::string info = "qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
529+
std::string info = "dtypeQ=" + std::to_string(static_cast<int>(mDtypeQ)) + ", dtypeKv="
530+
+ std::to_string(static_cast<int>(mDtypeKv)) + ", dtypeOut=" + std::to_string(static_cast<int>(mDtypeOut))
531+
+ ", sm=" + std::to_string(mSM) + ", qkvLayout=" + std::to_string(static_cast<int>(params.mQkvLayout))
530532
+ ", maskType=" + std::to_string(static_cast<int>(selectKernelParams.mMaskType))
531533
+ ", kernelType=" + std::to_string(static_cast<int>(kernelType))
532534
+ ", tileScheduler=" + std::to_string(static_cast<int>(selectKernelParams.mTileScheduler))

cpp/tensorrt_llm/thop/attentionOp.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -489,38 +489,38 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
489489
{
490490
if (is_fp8_out)
491491
{
492-
runner.reset(new Runner<half, __nv_fp8_e4m3>());
492+
runner = std::make_shared<Runner<half, __nv_fp8_e4m3>>();
493493
}
494494
else if (is_fp4_out)
495495
{
496-
runner.reset(new Runner<half, __nv_fp4_e2m1>());
496+
runner = std::make_shared<Runner<half, __nv_fp4_e2m1>>();
497497
}
498498
else
499499
{
500500
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat16);
501-
runner.reset(new Runner<half>());
501+
runner = std::make_shared<Runner<half>>();
502502
}
503503
}
504504
else if (dtype == nvinfer1::DataType::kFLOAT)
505505
{
506506
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kFloat32);
507-
runner.reset(new Runner<float>());
507+
runner = std::make_shared<Runner<float>>();
508508
}
509509
#ifdef ENABLE_BF16
510510
else if (dtype == nvinfer1::DataType::kBF16)
511511
{
512512
if (is_fp8_out)
513513
{
514-
runner.reset(new Runner<__nv_bfloat16, __nv_fp8_e4m3>());
514+
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp8_e4m3>>();
515515
}
516516
else if (is_fp4_out)
517517
{
518-
runner.reset(new Runner<__nv_bfloat16, __nv_fp4_e2m1>());
518+
runner = std::make_shared<Runner<__nv_bfloat16, __nv_fp4_e2m1>>();
519519
}
520520
else
521521
{
522522
TLLM_CHECK(!out_dtype.has_value() || out_dtype.value() == torch::kBFloat16);
523-
runner.reset(new Runner<__nv_bfloat16>());
523+
runner = std::make_shared<Runner<__nv_bfloat16>>();
524524
}
525525
}
526526
#endif
@@ -538,13 +538,13 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
538538
auto op = std::make_shared<AttentionOp>();
539539
op->mType = dtype;
540540
op->mFMHAForceFP32Acc = dtype == nvinfer1::DataType::kBF16;
541+
op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode));
541542
op->mFP8ContextFMHA = is_fp8_out || is_fp4_out;
542543
op->mLayerIdx = layer_idx;
543544
op->mNumHeads = num_heads;
544545
op->mNumKVHeads = num_kv_heads;
545546
op->mHeadSize = head_size;
546547
op->mMaskType = static_cast<tensorrt_llm::kernels::AttentionMaskType>(int32_t(mask_type));
547-
op->mKVCacheQuantMode = tensorrt_llm::common::QuantMode(uint32_t(quant_mode));
548548
op->mUseKVCache = use_kv_cache;
549549
op->mPagedKVCache = op->mPagedKVCache && use_kv_cache; // update mPagedKVCache based on use_kv_cache
550550
op->mTokensPerBlock = tokens_per_block.value_or(0);
@@ -587,7 +587,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
587587
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
588588
static_cast<int>(layer_num)};
589589

590-
op->mFP8ContextMLA = tensorrt_llm::common::getSMVersion() == 120 && op->mKVCacheQuantMode.hasFp8KvCache();
590+
op->mFP8ContextMLA
591+
= (tensorrt_llm::common::getSMVersion() == 100 || tensorrt_llm::common::getSMVersion() == 120)
592+
&& op->mKVCacheQuantMode.hasFp8KvCache();
591593
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
592594
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
593595
// only enable flash mla on sm90 and head_size == 576 and tokens_per_block == 64

tensorrt_llm/_torch/modules/attention.py

Lines changed: 23 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def __init__(
285285

286286
self.support_fused_qkv = self.attn.support_fused_qkv()
287287
self.support_nvfp4_output = self.attn.support_nvfp4_output()
288+
self.enable_attn_nvfp4_output = True
288289

289290
if not config.skip_create_weights_in_init:
290291
self.create_weights()
@@ -294,6 +295,17 @@ def create_weights(self):
294295
# which could be modified after __init__
295296
self.attn.update_quant_config(self.quant_config)
296297

298+
self.out_scale = None
299+
self.out_scale_sf = None
300+
self.o_proj.create_weights()
301+
self.has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
302+
or self.o_proj.has_fp8_block_scales
303+
or self.o_proj.has_fp8_rowwise)
304+
if self.has_quant_scale:
305+
self.out_scale = self.o_proj.inv_input_scale.data
306+
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and self.enable_attn_nvfp4_output:
307+
self.out_scale_sf = self.o_proj.input_scale.data
308+
297309
def split_qkv(self, q, k=None, v=None):
298310
if k is None and v is None:
299311
q, k, v = q.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -313,10 +325,7 @@ def create_output(self, q: torch.Tensor):
313325
out_dtype = q.dtype
314326

315327
if self.attn_backend == "TRTLLM":
316-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
317-
or self.o_proj.has_fp8_block_scales
318-
or self.o_proj.has_fp8_rowwise)
319-
if has_quant_scale and self.attn.has_fp8_kv_cache:
328+
if self.has_quant_scale and self.attn.has_fp8_kv_cache:
320329
out_dtype = torch.float8_e4m3fn
321330
output = q.new_empty([num_tokens, hidden_size], dtype=out_dtype)
322331
return output
@@ -351,16 +360,6 @@ def _attn_impl(
351360
assert v.shape[0] == padded_num_tokens
352361
v = v[:num_tokens, :]
353362

354-
out_scale = None
355-
out_scale_sf = None
356-
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
357-
or self.o_proj.has_fp8_block_scales
358-
or self.o_proj.has_fp8_rowwise)
359-
if has_quant_scale:
360-
out_scale = self.o_proj.inv_input_scale
361-
if self.o_proj.has_nvfp4 and self.support_nvfp4_output and enable_attn_nvfp4_output:
362-
out_scale_sf = self.o_proj.input_scale
363-
364363
mrope_config = None
365364
if mrope_rotary_cos_sin is not None or mrope_position_deltas is not None:
366365
mrope_config = dict()
@@ -374,8 +373,8 @@ def _attn_impl(
374373
k,
375374
v,
376375
attn_metadata,
377-
out_scale=out_scale,
378-
out_scale_sf=out_scale_sf,
376+
out_scale=self.out_scale,
377+
out_scale_sf=self.out_scale_sf,
379378
attention_mask=attention_mask,
380379
mrope_config=mrope_config,
381380
attention_window_size=attention_window_size,
@@ -840,6 +839,9 @@ def create_weights(self):
840839
self.mha.update_quant_config(self.quant_config)
841840
self.mqa.update_quant_config(self.quant_config)
842841

842+
# Although we use FP8 MLA for context/generation phase, the output is still in BF16
843+
self.out_scale = None
844+
843845
# k_b_proj_trans's dtype must be consistent with self.kv_b_proj,
844846
# which can be modified after __init__
845847
has_fp8_block_scales = (
@@ -1045,17 +1047,14 @@ def forward_context_default(
10451047
self.qk_rope_head_dim)
10461048
k = k.view(-1, self.num_heads * self.qk_head_dim)
10471049

1048-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1049-
out_scale = None # Currently we use BF16 MHA for context phase
1050-
10511050
attn_output = self.mha.forward(
10521051
q,
10531052
k,
10541053
v,
10551054
attn_metadata,
10561055
attention_input_type=AttentionInputType.context_only,
10571056
latent_cache=latent_cache,
1058-
out_scale=out_scale,
1057+
out_scale=self.out_scale,
10591058
output=output,
10601059
)
10611060

@@ -1110,9 +1109,6 @@ def forward_context_with_cached_kv(
11101109
full_kv = None
11111110
full_k_nope = None
11121111

1113-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1114-
out_scale = None # Currently we use BF16 MHA for context phase
1115-
11161112
# latent_cache must be None to differentiate from normal context phase,
11171113
# so that we can skip applying RoPE and appending KV cache inside attention op
11181114
attn_output = self.mha.forward(
@@ -1122,7 +1118,7 @@ def forward_context_with_cached_kv(
11221118
attn_metadata,
11231119
attention_input_type=AttentionInputType.context_only,
11241120
latent_cache=None,
1125-
out_scale=out_scale,
1121+
out_scale=self.out_scale,
11261122
output=output,
11271123
)
11281124

@@ -1212,7 +1208,6 @@ def forward_context_with_chunked_prefill(
12121208
loop_idx]
12131209
attn_metadata.host_total_kv_lens[0] = total_ctx_chunked_tokens
12141210

1215-
out_scale = None
12161211
# do not apply mask for attention within loop
12171212
# latent_cache must be None to differentiate from normal context phase,
12181213
# so that we can skip applying RoPE and appending KV cache inside attention op
@@ -1223,7 +1218,7 @@ def forward_context_with_chunked_prefill(
12231218
attn_metadata,
12241219
attention_input_type=AttentionInputType.context_only,
12251220
latent_cache=None,
1226-
out_scale=out_scale,
1221+
out_scale=self.out_scale,
12271222
attention_mask=PredefinedAttentionMask.FULL,
12281223
softmax_stats_tensor=self.temp_softmax_stats_tensor,
12291224
output=temp_attn_output,
@@ -1262,9 +1257,6 @@ def forward_context_with_chunked_prefill(
12621257
num_contexts].sum().item(
12631258
)
12641259

1265-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1266-
out_scale = None # Currently we use BF16 MHA for context phase
1267-
12681260
# latent_cache must be None to differentiate from normal context phase,
12691261
# so that we can skip applying RoPE and appending KV cache inside attention op
12701262
temp_attn_output = self.mha.forward(
@@ -1274,7 +1266,7 @@ def forward_context_with_chunked_prefill(
12741266
attn_metadata,
12751267
attention_input_type=AttentionInputType.context_only,
12761268
latent_cache=None,
1277-
out_scale=out_scale,
1269+
out_scale=self.out_scale,
12781270
softmax_stats_tensor=self.temp_softmax_stats_tensor,
12791271
output=temp_attn_output,
12801272
)
@@ -1370,16 +1362,13 @@ def forward_generation(
13701362
self.num_heads * (self.kv_lora_rank + self.qk_rope_head_dim)
13711363
])
13721364

1373-
# out_scale = getattr(self.o_proj, "inv_input_scale", None)
1374-
out_scale = None # Although we use FP8 MLA for generation phase, the output is still in BF16
1375-
13761365
attn_out_latent = self.mqa.forward(
13771366
fused_q,
13781367
None,
13791368
None,
13801369
attn_metadata,
13811370
attention_input_type=AttentionInputType.generation_only,
1382-
out_scale=out_scale,
1371+
out_scale=self.out_scale,
13831372
latent_cache=latent_cache, # kvcache and k_pe
13841373
q_pe=q_pe, # used by `invokeMLARopeGeneration`
13851374
)

tensorrt_llm/executor/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,9 @@ def _deduce_max_tokens(request: GenerationRequest,
485485
raise ValueError(
486486
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
487487
)
488+
assert (
489+
len(prompt_token_ids) <= executor_config.max_seq_len
490+
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
488491
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
489492
default_max_tokens = executor_config.max_seq_len - splited_prompt_len - query_token_len
490493
if default_max_tokens <= 0:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,7 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
11651165
task = GSM8K(self.MODEL_NAME)
11661166
task.evaluate(llm)
11671167

1168-
@skip_no_hopper
1168+
@skip_pre_hopper
11691169
@parametrize_with_ids("torch_compile", [False, True])
11701170
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
11711171
[(False, False, False, False),
@@ -1189,6 +1189,8 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
11891189
disable_overlap_scheduler=not overlap_scheduler,
11901190
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
11911191
torch_compile_config=torch_compile_config,
1192+
moe_config=MoeConfig(
1193+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
11921194
)
11931195

11941196
if fp8kv:
@@ -1264,7 +1266,7 @@ def test_cute_dsl_fp8_block_scales(
12641266
task = GSM8K(self.MODEL_NAME)
12651267
task.evaluate(llm)
12661268

1267-
@pytest.mark.skip_device_not_contain(["H100"])
1269+
@skip_pre_hopper
12681270
@parametrize_with_ids("mtp_nextn", [0, 2])
12691271
def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
12701272
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
@@ -1277,6 +1279,8 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
12771279
max_batch_size=512,
12781280
enable_padding=True,
12791281
),
1282+
moe_config=MoeConfig(
1283+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
12801284
)
12811285
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
12821286
kv_cache_config=kv_cache_config,
@@ -1287,7 +1291,7 @@ def test_fp8_block_scales_cuda_graph_padding(self, mtp_nextn):
12871291
task.evaluate(llm)
12881292

12891293
@pytest.mark.skip_less_device(4)
1290-
@skip_no_hopper
1294+
@skip_pre_hopper
12911295
@parametrize_with_ids("mtp_nextn", [0, 2])
12921296
@parametrize_with_ids("attention_dp", [False, True])
12931297
def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
@@ -1299,6 +1303,8 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
12991303
pytorch_config = dict(
13001304
disable_overlap_scheduler=False,
13011305
cuda_graph_config=CudaGraphConfig(enable_padding=True),
1306+
moe_config=MoeConfig(
1307+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13021308
)
13031309

13041310
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
@@ -1312,7 +1318,7 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
13121318
task.evaluate(llm)
13131319

13141320
@pytest.mark.skip_less_device(4)
1315-
@skip_no_hopper
1321+
@skip_pre_hopper
13161322
@parametrize_with_ids("torch_compile", [False, True])
13171323
@parametrize_with_ids("fp8kv,attention_dp,cuda_graph,overlap_scheduler",
13181324
[(False, False, False, False),
@@ -1341,6 +1347,8 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
13411347
disable_overlap_scheduler=not overlap_scheduler,
13421348
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
13431349
torch_compile_config=torch_compile_config,
1350+
moe_config=MoeConfig(
1351+
backend="DEEPGEMM" if get_sm_version() >= 100 else "CUTLASS"),
13441352
)
13451353

13461354
if fp8kv:
@@ -1427,7 +1435,7 @@ def test_cute_dsl_fp8_block_scales_4gpus(
14271435
task.evaluate(llm)
14281436

14291437
@pytest.mark.skip_less_device(4)
1430-
@pytest.mark.skip_device_not_contain(["H100", "H200"])
1438+
@skip_pre_hopper
14311439
def test_fp8_block_scales_4gpus_static_eplb(self):
14321440
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
14331441

0 commit comments

Comments
 (0)