Skip to content

Commit 4bd6884

Browse files
committed
subgraph pipeline
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 98b3af4 commit 4bd6884

24 files changed

+281
-267
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ transforms:
66
############################################################################################
77
build_model:
88
stage: factory
9+
run_per_gm: false
910
device: meta
1011
# nothing to clean up
1112
run_graph_cleanup: false
@@ -14,8 +15,8 @@ transforms:
1415
stage: export
1516
clone_state_dict: false
1617
strict: false
17-
# nothing to clean up
18-
run_graph_cleanup: false
18+
run_per_gm: false
19+
run_graph_cleanup: true
1920
requires_clean_graph: false
2021
cleanup_noop_slice:
2122
stage: post_export
@@ -35,6 +36,7 @@ transforms:
3536
run_shape_prop: true
3637
match_eager_attention:
3738
stage: pattern_matcher
39+
requires_shape_prop: true
3840
match_grouped_attention:
3941
stage: pattern_matcher
4042
match_attention_layout:
@@ -87,8 +89,10 @@ transforms:
8789
############################################################################################
8890
load_weights:
8991
stage: weight_load
92+
run_per_gm: false
9093
move_inputs_to_device:
9194
stage: weight_load
95+
run_per_gm: false
9296
############################################################################################
9397
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
9498
############################################################################################
@@ -138,10 +142,13 @@ transforms:
138142
attn_backend: cuda_causal_conv
139143
initialize_cache:
140144
stage: cache_init
145+
run_per_gm: false
141146
resize_kv_cache:
142147
stage: cache_init
148+
run_per_gm: false
143149
############################################################################################
144150
# COMPILE MODEL
145151
############################################################################################
146152
compile_model:
147153
stage: compile
154+
run_per_gm: false

tensorrt_llm/_torch/auto_deploy/config/transformers.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,29 @@ transforms:
66
############################################################################################
77
build_and_load_factory_model:
88
stage: factory
9+
run_per_gm: false
910
############################################################################################
1011
# MOVE ARGUMENTS TO DEVICE
1112
############################################################################################
1213
move_inputs_to_device:
1314
stage: weight_load
15+
run_per_gm: false
1416
############################################################################################
1517
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
1618
############################################################################################
1719
detect_hf_attn_layers:
1820
stage: cache_init
21+
run_per_gm: false
1922
transformers_replace_cached_attn:
2023
stage: cache_init
2124
attn_backend: flashinfer
25+
run_per_gm: false
2226
initialize_cache:
2327
stage: cache_init
28+
run_per_gm: false
2429
resize_kv_cache:
2530
stage: cache_init
31+
run_per_gm: false
2632
############################################################################################
2733
# COMPILE MODEL
2834
############################################################################################

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,11 @@ def args_for_prepare_metadata(self) -> Tuple[str, ...]:
325325
like ``insert_cached_attention`` to extract the constant arguments and add them to the
326326
``prepare_metadata`` node/op.
327327
"""
328-
return tuple(self.named_standard_args.keys())
328+
# NOTE: for now we do _not_ include input_ids since we are not guaranteed that input_ids
329+
# is part of the graph, e.g., in situations where the graph is a submodule of the overall
330+
# model. In such instances, the graph usually sees inputs_embeds. However, we assume for
331+
# now that position_ids is always part of the graph.
332+
return ("position_ids",) + self._cached_arg_names
329333

330334
@property
331335
def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
@@ -466,7 +470,9 @@ def _get_cache_locations_and_pages_per_sequence(
466470
return cache_loc_flat, pages_per_seq
467471

468472
@classmethod
469-
def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
473+
def _get_sanitized_seq_len(
474+
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
475+
) -> torch.Tensor:
470476
"""Sanitize sequence lengths.
471477
472478
We want to cover the following scenarios with this function:
@@ -499,22 +505,24 @@ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor)
499505
# valid cache location in the batch. This would ensure that the dummy sequences just
500506
# repeats valid computation...
501507
"""
502-
_, s = input_ids.shape[:2]
503-
num_seq = cls._get_sanitized_num_sequences(input_ids, seq_len)
508+
_, s = input_or_position_ids.shape[:2]
509+
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
504510
if s > 1:
505511
return seq_len[:num_seq].detach().clone()
506512
else:
507513
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
508514

509515
@staticmethod
510-
def _get_sanitized_num_sequences(input_ids: torch.Tensor, seq_len: torch.Tensor) -> int:
516+
def _get_sanitized_num_sequences(
517+
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
518+
) -> int:
511519
"""Get number of sequences.
512520
513521
We makes sure that this function is compatible with both torch graph capture and cudagraph.
514522
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
515523
with max_batch_size or max_batch_size*max_seq_len.
516524
"""
517-
b, s = input_ids.shape[:2]
525+
b, s = input_or_position_ids.shape[:2]
518526
if s > 1:
519527
num_seq = torch.sum(seq_len > 0)
520528
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
@@ -814,7 +822,6 @@ def __call__(
814822
class PrepareMetadataCallable(Protocol):
815823
def __call__(
816824
self,
817-
input_ids: torch.Tensor,
818825
position_ids: torch.Tensor,
819826
seq_len: torch.Tensor,
820827
input_pos: torch.Tensor,
@@ -901,7 +908,6 @@ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:
901908
902909
```
903910
def prepare_metadata(
904-
input_ids: torch.Tensor,
905911
position_ids: torch.Tensor,
906912
seq_len: torch.Tensor,
907913
input_pos: torch.Tensor,

tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def _build_conv_state_from_sequence(input_bt_c: torch.Tensor, kernel_size: int)
5454
# ---------------------------------------------------------------
5555
@torch.library.custom_op("auto_deploy::cuda_causal_conv_prepare_metadata", mutates_args=())
5656
def cuda_causal_conv_prepare_metadata(
57-
input_ids: torch.Tensor,
5857
position_ids: torch.Tensor,
5958
seq_len: torch.Tensor,
6059
input_pos: torch.Tensor,
@@ -67,7 +66,7 @@ def cuda_causal_conv_prepare_metadata(
6766
6867
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
6968
"""
70-
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
69+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
7170
num_seq = len(seq_len_sanitized)
7271

7372
seq_start = torch.zeros_like(seq_len_sanitized)
@@ -81,9 +80,9 @@ def cuda_causal_conv_prepare_metadata(
8180

8281
@cuda_causal_conv_prepare_metadata.register_fake
8382
def cuda_causal_conv_prepare_metadata_fake(
84-
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
83+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
8584
):
86-
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
85+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
8786
num_seq = len(seq_len_sanitized)
8887
return (
8988
torch.empty_like(seq_len_sanitized),

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
155155

156156
@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
157157
def prepare_flashinfer_metadata(
158-
input_ids: torch.Tensor,
159158
position_ids: torch.Tensor,
160159
seq_len: torch.Tensor,
161160
input_pos: torch.Tensor,
@@ -174,7 +173,7 @@ def prepare_flashinfer_metadata(
174173
_GlobalFlashInferPlanner.reset()
175174

176175
# retrieve sanitzed metadata
177-
seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
176+
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
178177
num_seq = len(seq_len)
179178

180179
# prepare flashinfer-style metadata
@@ -214,9 +213,9 @@ def prepare_flashinfer_metadata(
214213
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
215214
@prepare_flashinfer_metadata.register_fake
216215
def prepare_flashinfer_metadata_fake(
217-
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
216+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
218217
):
219-
seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
218+
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
220219
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
221220
return (
222221
qo_indptr, # qo_indptr

tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def fused_flattened_mla_with_cache_fake(
175175
"auto_deploy::triton_attention_prepare_fused_mla_metadata", mutates_args=()
176176
)
177177
def prepare_fused_mla_metadata(
178-
input_ids: torch.Tensor,
179178
position_ids: torch.Tensor,
180179
seq_len: torch.Tensor,
181180
input_pos: torch.Tensor,
@@ -184,7 +183,7 @@ def prepare_fused_mla_metadata(
184183
slot_idx: torch.Tensor,
185184
page_size: int,
186185
) -> List[torch.Tensor]:
187-
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
186+
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
188187
seq_start = torch.zeros_like(seq_len[:num_seq])
189188
seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
190189
return (
@@ -197,7 +196,7 @@ def prepare_fused_mla_metadata(
197196

198197
@prepare_fused_mla_metadata.register_fake
199198
def prepare_fused_mla_metadata_fake(
200-
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
199+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
201200
):
202201
return (
203202
torch.empty_like(seq_len),

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,6 @@ def torch_backend_mha_with_cache_fake(
356356

357357
@torch.library.custom_op("auto_deploy::torch_cached_attention_prepare_metadata", mutates_args=())
358358
def torch_backend_prepare_metadata(
359-
input_ids: torch.Tensor,
360359
position_ids: torch.Tensor,
361360
seq_len: torch.Tensor,
362361
input_pos: torch.Tensor,
@@ -366,7 +365,7 @@ def torch_backend_prepare_metadata(
366365
page_size: int,
367366
) -> List[torch.Tensor]:
368367
"""Prepare metadata for torch backend attention (similar to triton backend)."""
369-
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
368+
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
370369
seq_start = torch.zeros_like(seq_len[:num_seq])
371370
seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
372371
return (
@@ -379,9 +378,9 @@ def torch_backend_prepare_metadata(
379378

380379
@torch_backend_prepare_metadata.register_fake
381380
def torch_backend_prepare_metadata_fake(
382-
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
381+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
383382
):
384-
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
383+
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
385384
return (
386385
torch.empty_like(seq_len[:num_seq]),
387386
torch.empty_like(input_pos[:num_seq]),

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def _torch_causal_conv1d_decode(
140140

141141
@torch.library.custom_op("auto_deploy::torch_causal_conv_prepare_metadata", mutates_args=())
142142
def torch_causal_conv_prepare_metadata(
143-
input_ids: torch.Tensor,
144143
position_ids: torch.Tensor,
145144
seq_len: torch.Tensor,
146145
input_pos: torch.Tensor,
@@ -153,7 +152,7 @@ def torch_causal_conv_prepare_metadata(
153152
154153
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
155154
"""
156-
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
155+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
157156
num_seq = len(seq_len_sanitized)
158157

159158
seq_start = torch.zeros_like(seq_len_sanitized)
@@ -167,9 +166,9 @@ def torch_causal_conv_prepare_metadata(
167166

168167
@torch_causal_conv_prepare_metadata.register_fake
169168
def torch_causal_conv_prepare_metadata_fake(
170-
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
169+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
171170
):
172-
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
171+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
173172
num_seq = len(seq_len_sanitized)
174173
return (
175174
torch.empty_like(seq_len_sanitized),

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def _update_ssm_state_cache(ssm_cache: torch.Tensor, ssm_state: torch.Tensor) ->
113113

114114
@torch.library.custom_op("auto_deploy::torch_ssm_prepare_metadata", mutates_args=())
115115
def _torch_ssm_prepare_metadata(
116-
input_ids: torch.Tensor,
117116
position_ids: torch.Tensor,
118117
seq_len: torch.Tensor,
119118
input_pos: torch.Tensor,
@@ -127,7 +126,7 @@ def _torch_ssm_prepare_metadata(
127126
Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
128127
"""
129128
# Determine number of active sequences and compute seq_start boundaries
130-
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
129+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
131130
num_seq = len(seq_len_sanitized)
132131

133132
seq_start = torch.zeros_like(seq_len_sanitized)
@@ -142,10 +141,10 @@ def _torch_ssm_prepare_metadata(
142141

143142
@_torch_ssm_prepare_metadata.register_fake
144143
def _torch_ssm_prepare_metadata_fake(
145-
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
144+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
146145
):
147146
# Use the same sanitization logic to determine sizes in fake mode
148-
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
147+
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
149148
num_seq = len(seq_len_sanitized)
150149
return (
151150
torch.empty_like(seq_len_sanitized),

tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ def flattened_mha_fake(
284284
"auto_deploy::triton_attention_prepare_fused_mha_metadata", mutates_args=()
285285
)
286286
def prepare_fused_mha_metadata(
287-
input_ids: torch.Tensor,
288287
position_ids: torch.Tensor,
289288
seq_len: torch.Tensor,
290289
input_pos: torch.Tensor,
@@ -294,7 +293,7 @@ def prepare_fused_mha_metadata(
294293
page_size: int,
295294
) -> List[torch.Tensor]:
296295
# TODO: maybe use slot_idx instead of pages_per_seq??
297-
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
296+
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
298297
seq_start = torch.zeros_like(seq_len[:num_seq])
299298
seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0)
300299
return (
@@ -309,9 +308,9 @@ def prepare_fused_mha_metadata(
309308
# SequenceInfo._get_sanitized_num_sequences could break in fake mode
310309
@prepare_fused_mha_metadata.register_fake
311310
def prepare_fused_mha_metadata_fake(
312-
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
311+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
313312
):
314-
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
313+
num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len)
315314
return (
316315
torch.empty_like(seq_len[:num_seq]),
317316
torch.empty_like(input_pos[:num_seq]),

0 commit comments

Comments
 (0)