Skip to content

Commit efa94ee

Browse files
committed
[TRTLLM-6633][feat] Padding for piecewise cudagraph
Signed-off-by: Jin Li <[email protected]>
1 parent b3e8fa2 commit efa94ee

File tree

10 files changed

+248
-163
lines changed

10 files changed

+248
-163
lines changed

cpp/tensorrt_llm/thop/mlaPreprocessOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
748748
{
749749
m.def(
750750
"merge_chunked_attention_for_mla("
751-
"Tensor merged_attn"
751+
"Tensor(a!) merged_attn"
752752
", Tensor temp_attn"
753753
", Tensor merged_softmax_stats"
754754
", Tensor temp_softmax_stats"

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class AttentionMetadata:
135135
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
136136
_num_tokens: int = field(init=False, default=0, repr=False)
137137

138+
# The number of tokens in the padded sequence.
139+
padded_num_tokens: Optional[int] = None
140+
138141
# This buffer is currently only used for TrtllmAttentionMetadata.
139142
cache_indirection: Optional[torch.Tensor] = None
140143

tensorrt_llm/_torch/compilation/piecewise_optimizer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
from tensorrt_llm.llmapi.utils import enable_llm_debug
1313

14-
from ..utils import (get_model_extra_attrs, get_piecewise_cuda_graph_flag,
15-
make_weak_ref)
14+
from ..utils import (get_model_extra_attrs,
15+
get_per_request_piecewise_cuda_graph_flag,
16+
get_piecewise_cuda_graph_flag, make_weak_ref)
1617
from .multi_stream.auto_multi_stream import multi_stream_schedule
1718
from .utils import (get_enable_piecewise_cuda_graph_capture_flag,
1819
is_call_function)
@@ -155,8 +156,10 @@ def __call__(self, *args):
155156
elif isinstance(self.compile_time_num_tokens, int):
156157
runtime_num_of_token = self.compile_time_num_tokens
157158

158-
if runtime_num_of_token is None or runtime_num_of_token not in self.entries or not get_piecewise_cuda_graph_flag(
159-
):
159+
if (runtime_num_of_token is None
160+
or runtime_num_of_token not in self.entries
161+
or not get_piecewise_cuda_graph_flag()
162+
or not get_per_request_piecewise_cuda_graph_flag()):
160163
return self.default_callable(*args)
161164

162165
entry = self.entries[runtime_num_of_token]

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,9 @@ def forward(
11841184
inputs_embeds=inputs_embeds,
11851185
)
11861186

1187+
if attn_metadata.padded_num_tokens is not None:
1188+
hidden_states = hidden_states[:attn_metadata.num_tokens]
1189+
11871190
if spec_metadata and spec_metadata.spec_dec_mode.is_mtp():
11881191
# get logits
11891192
logits = self.logits_processor.forward(
@@ -1192,10 +1195,16 @@ def forward(
11921195
attn_metadata,
11931196
True,
11941197
)
1198+
mtp_input_ids = input_ids
1199+
mtp_position_ids = position_ids
1200+
if attn_metadata.padded_num_tokens is not None:
1201+
mtp_input_ids = input_ids[:attn_metadata.num_tokens]
1202+
mtp_position_ids = position_ids[:attn_metadata.num_tokens]
1203+
11951204
# get accepted tokens and next draft tokens
11961205
return self.mtp_worker(
1197-
input_ids=input_ids,
1198-
position_ids=position_ids,
1206+
input_ids=mtp_input_ids,
1207+
position_ids=mtp_position_ids,
11991208
hidden_states=hidden_states,
12001209
logits=logits,
12011210
lm_head=self.lm_head,

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ def forward(
381381
**kwargs,
382382
)
383383

384+
if attn_metadata.padded_num_tokens is not None:
385+
hidden_states = hidden_states[:attn_metadata.num_tokens]
386+
384387
if self.draft_model is not None:
385388
# get logits
386389
logits = self.logits_processor.forward(
@@ -389,9 +392,15 @@ def forward(
389392
attn_metadata,
390393
True,
391394
)
395+
mtp_input_ids = input_ids
396+
mtp_position_ids = position_ids
397+
if attn_metadata.padded_num_tokens is not None:
398+
mtp_input_ids = input_ids[:attn_metadata.num_tokens]
399+
mtp_position_ids = position_ids[:attn_metadata.num_tokens]
400+
392401
# get accepted tokens and next draft tokens
393-
return self.spec_worker(input_ids=input_ids,
394-
position_ids=position_ids,
402+
return self.spec_worker(input_ids=mtp_input_ids,
403+
position_ids=mtp_position_ids,
395404
hidden_states=hidden_states,
396405
logits=logits,
397406
attn_metadata=attn_metadata,

tensorrt_llm/_torch/modules/attention.py

Lines changed: 68 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tensorrt_llm._utils import get_sm_version
99
from tensorrt_llm.logger import logger
1010
from tensorrt_llm.mapping import Mapping
11+
from tensorrt_llm.quantization.utils import fp4_utils
1112

1213
from ..attention_backend import (AttentionInputType, AttentionMetadata,
1314
FlashInferAttentionMetadata, TrtllmAttention,
@@ -336,6 +337,21 @@ def _attn_impl(
336337
attention_sinks: Optional[torch.Tensor] = None,
337338
):
338339

340+
padded_num_tokens = attn_metadata.padded_num_tokens
341+
num_tokens = attn_metadata.num_tokens
342+
343+
if padded_num_tokens is not None:
344+
assert q.shape[0] == padded_num_tokens
345+
q = q[:num_tokens, :]
346+
if k is not None:
347+
assert k.shape[0] == padded_num_tokens
348+
k = k[:num_tokens, :]
349+
if v is not None:
350+
assert v.shape[0] == padded_num_tokens
351+
v = v[:num_tokens, :]
352+
assert output is not None
353+
assert output_sf is None
354+
339355
out_scale = None
340356
out_scale_sf = None
341357
has_quant_scale = (self.o_proj.has_fp8_qdq or self.o_proj.has_nvfp4
@@ -366,14 +382,19 @@ def _attn_impl(
366382
attention_window_size=attention_window_size,
367383
attention_mask_data=attention_mask_data,
368384
enable_attn_nvfp4_output=enable_attn_nvfp4_output,
369-
output=output,
385+
output=output[:num_tokens, :] if output is not None else None,
370386
output_sf=output_sf,
371387
attention_sinks=attention_sinks)
372388
if isinstance(attn_output, tuple):
373389
assert len(
374390
attn_output
375391
) == 2, "attn_output should be a tuple of (output, output_sf)"
376392
return attn_output[0], attn_output[1]
393+
if output is not None and output.shape[0] != num_tokens:
394+
output[num_tokens:].fill_(0)
395+
if output_sf is not None and output_sf.shape[0] != fp4_utils.pad_up(
396+
num_tokens, 128):
397+
output_sf[fp4_utils.pad_up(num_tokens, 128):].fill_(0)
377398
return attn_output, None
378399

379400
def forward(
@@ -908,11 +929,10 @@ def create_output(self, hidden_states: torch.Tensor):
908929
return hidden_states.new_empty([num_tokens, hidden_size],
909930
dtype=hidden_states.dtype)
910931

911-
def forward_impl(self,
912-
position_ids: Optional[torch.Tensor],
932+
def forward_impl(self, position_ids: Optional[torch.Tensor],
913933
hidden_states: torch.Tensor,
914934
attn_metadata: AttentionMetadata,
915-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
935+
output: torch.Tensor) -> None:
916936
"""
917937
Forward pass for the MLA module.
918938
@@ -925,6 +945,18 @@ def forward_impl(self,
925945
Returns:
926946
torch.Tensor: The output tensor.
927947
"""
948+
# split q, k, v into context and gen batches
949+
num_contexts = attn_metadata.num_contexts
950+
num_generations = attn_metadata.num_generations
951+
num_ctx_tokens = attn_metadata.num_ctx_tokens
952+
num_tokens = attn_metadata.num_tokens
953+
padded_num_tokens = attn_metadata.padded_num_tokens
954+
955+
if padded_num_tokens is not None:
956+
hidden_states = hidden_states[:num_tokens, ...]
957+
if position_ids is not None:
958+
position_ids = position_ids[:num_tokens, ...]
959+
928960
if self.is_lite:
929961
compressed_kv, k_pe = self.kv_a_proj_with_mqa(hidden_states).split(
930962
[self.kv_lora_rank, self.qk_rope_head_dim], -1)
@@ -952,15 +984,11 @@ def forward_impl(self,
952984
self.aux_stream,
953985
)
954986

955-
# split q, k, v into context and gen batches
956-
num_contexts = attn_metadata.num_contexts
957-
num_generations = attn_metadata.num_generations
958-
num_ctx_tokens = attn_metadata.num_ctx_tokens
959-
num_tokens = attn_metadata.num_tokens
960-
961987
assert q.shape[
962988
0] == num_tokens, f"Expect q.shape[0] to be {num_tokens}, but got {q.shape[0]}"
963989

990+
assert output is not None, "output must be provided"
991+
964992
if num_contexts > 0:
965993
q_ctx = q[:num_ctx_tokens, ...]
966994
compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
@@ -970,17 +998,14 @@ def forward_impl(self,
970998
assert position_ids is not None
971999
k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)
9721000

973-
attn_output_context = self.forward_context(
1001+
self.forward_context(
9741002
q_ctx,
9751003
compressed_kv_ctx,
9761004
k_pe_ctx,
9771005
attn_metadata,
1006+
output[:num_ctx_tokens, :],
9781007
latent_cache_ctx,
979-
output=output if num_generations == 0 else None)
980-
if num_generations == 0:
981-
return attn_output_context
982-
else:
983-
attn_output_context = None
1008+
)
9841009

9851010
if num_generations > 0:
9861011
q_gen = q[num_ctx_tokens:, ...]
@@ -991,39 +1016,17 @@ def forward_impl(self,
9911016
assert position_ids is not None
9921017
k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)
9931018

994-
attn_output_gen = self.forward_generation(
1019+
self.forward_generation(
9951020
q_gen,
9961021
compressed_kv_gen,
9971022
k_pe_gen,
9981023
attn_metadata,
1024+
output[num_ctx_tokens:num_tokens, :],
9991025
latent_cache_gen,
1000-
output=output if num_contexts == 0 else None)
1001-
if num_contexts == 0:
1002-
return attn_output_gen
1003-
else:
1004-
attn_output_gen = None
1026+
)
10051027

1006-
# release pytorch activation memory
1007-
q = None
1008-
compressed_kv = None
1009-
k_pe = None
1010-
1011-
assert attn_output_context is not None and attn_output_gen is not None
1012-
assert (
1013-
len(attn_output_context.shape) == 2
1014-
), f"attn_output_context must be rank 2, not {len(attn_output_context.shape)}"
1015-
assert (
1016-
len(attn_output_gen.shape) == 2
1017-
), f"attn_output_gen must be rank 2, not {len(attn_output_gen.shape)}"
1018-
output = output if output is not None else torch.empty(
1019-
(num_tokens, attn_output_context.shape[1]),
1020-
dtype=attn_output_context.dtype,
1021-
device=attn_output_context.device)
1022-
output[:attn_output_context.shape[0], :] = attn_output_context
1023-
output[attn_output_context.shape[0]:, :] = attn_output_gen
1024-
attn_output_context = None
1025-
attn_output_gen = None
1026-
return output
1028+
if padded_num_tokens is not None:
1029+
output[num_tokens:].fill_(0)
10271030

10281031
def _maybe_concat_qkv(self, q, k, v):
10291032
if k is not None and v is not None and self.support_fused_qkv:
@@ -1032,13 +1035,14 @@ def _maybe_concat_qkv(self, q, k, v):
10321035
return q, k, v
10331036

10341037
def forward_context_default(
1035-
self,
1036-
q: torch.Tensor,
1037-
compressed_kv: torch.Tensor,
1038-
k_pe: torch.Tensor,
1039-
attn_metadata: AttentionMetadata,
1040-
latent_cache: Optional[torch.Tensor] = None,
1041-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
1038+
self,
1039+
q: torch.Tensor,
1040+
compressed_kv: torch.Tensor,
1041+
k_pe: torch.Tensor,
1042+
attn_metadata: AttentionMetadata,
1043+
output: torch.Tensor,
1044+
latent_cache: Optional[torch.Tensor] = None,
1045+
) -> torch.Tensor:
10421046
kv = self.kv_b_proj(compressed_kv)
10431047
k_nope, v = kv.split(
10441048
[
@@ -1080,7 +1084,7 @@ def forward_context_with_cached_kv(
10801084
q: torch.Tensor,
10811085
latent_cache: torch.Tensor,
10821086
attn_metadata: AttentionMetadata,
1083-
output: Optional[torch.Tensor] = None,
1087+
output: torch.Tensor,
10841088
) -> torch.Tensor:
10851089
assert latent_cache is not None
10861090
trtllm_attention = cast(TrtllmAttention, self.mha)
@@ -1166,7 +1170,7 @@ def forward_context_with_chunked_prefill(
11661170
latent_cache: torch.
11671171
Tensor, # compressed_kv + k_pe [context_tokens, 1, lora_size + rope_size]
11681172
attn_metadata: TrtllmAttentionMetadata,
1169-
output: Optional[torch.Tensor] = None,
1173+
output: torch.Tensor,
11701174
) -> torch.Tensor:
11711175
trtllm_attention = cast(TrtllmAttention, self.mha)
11721176
# apply RoPE, append compressed_kv + k_pe to paged kv cache and assign q_pe to q
@@ -1189,11 +1193,8 @@ def forward_context_with_chunked_prefill(
11891193
dtype=torch.float,
11901194
device='cuda',
11911195
)
1192-
if output is None:
1193-
attn_output = q.new_empty(
1194-
(q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype)
1195-
else:
1196-
attn_output = output
1196+
1197+
attn_output = output
11971198
temp_attn_output = q.new_empty(
11981199
(q.size(0), self.num_heads * self.v_head_dim), dtype=q.dtype)
11991200

@@ -1325,8 +1326,8 @@ def forward_context(
13251326
compressed_kv: torch.Tensor,
13261327
k_pe: torch.Tensor,
13271328
attn_metadata: AttentionMetadata,
1329+
output: torch.Tensor,
13281330
latent_cache: Optional[torch.Tensor] = None,
1329-
output: Optional[torch.Tensor] = None,
13301331
) -> torch.Tensor:
13311332
if isinstance(self.mha, TrtllmAttention):
13321333
assert isinstance(attn_metadata, TrtllmAttentionMetadata)
@@ -1339,16 +1340,17 @@ def forward_context(
13391340
return self.forward_context_with_cached_kv(
13401341
q, latent_cache, attn_metadata, output)
13411342
return self.forward_context_default(q, compressed_kv, k_pe,
1342-
attn_metadata, latent_cache, output)
1343+
attn_metadata, output, latent_cache)
13431344

13441345
def forward_generation(
1345-
self,
1346-
q: torch.Tensor,
1347-
compressed_kv: torch.Tensor,
1348-
k_pe: torch.Tensor,
1349-
attn_metadata: AttentionMetadata,
1350-
latent_cache: Optional[torch.Tensor] = None,
1351-
output: Optional[torch.Tensor] = None) -> torch.Tensor:
1346+
self,
1347+
q: torch.Tensor,
1348+
compressed_kv: torch.Tensor,
1349+
k_pe: torch.Tensor,
1350+
attn_metadata: AttentionMetadata,
1351+
output: torch.Tensor,
1352+
latent_cache: Optional[torch.Tensor] = None,
1353+
) -> torch.Tensor:
13521354
num_tokens = q.shape[0]
13531355
q_nope, q_pe = q.view([-1, self.num_heads, self.qk_head_dim]).split(
13541356
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
@@ -1420,12 +1422,6 @@ def forward_generation(
14201422
attn_out_latent = attn_out_latent.view(
14211423
[-1, self.num_heads, self.kv_lora_rank])
14221424

1423-
# [seq, num_heads * v_head_dim]
1424-
output = output if output is not None else torch.empty(
1425-
[num_tokens, self.num_heads * self.v_head_dim],
1426-
dtype=attn_out_latent.dtype,
1427-
device=attn_out_latent.device)
1428-
14291425
attn_output = output.view([num_tokens, self.num_heads, self.v_head_dim])
14301426

14311427
if self.v_b_proj.dtype == torch.bfloat16:

0 commit comments

Comments
 (0)