From 43965a2acab7407e6251f64daab1bf93895af29d Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Tue, 30 Sep 2025 08:01:14 -0700 Subject: [PATCH 1/4] subgraph pipeline Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 11 +- .../auto_deploy/config/transformers.yaml | 6 + .../custom_ops/attention_interface.py | 22 +- .../custom_ops/cuda_backend_causal_conv.py | 7 +- .../custom_ops/flashinfer_attention.py | 7 +- .../_torch/auto_deploy/custom_ops/mla.py | 5 +- .../custom_ops/torch_backend_attention.py | 7 +- .../custom_ops/torch_backend_causal_conv.py | 7 +- .../custom_ops/torch_backend_mamba.py | 7 +- .../custom_ops/triton_attention.py | 7 +- .../_torch/auto_deploy/transform/interface.py | 218 ++++++++++-------- .../transform/library/build_model.py | 26 +-- .../transform/library/compile_model.py | 14 +- .../transform/library/export_to_gm.py | 16 +- .../auto_deploy/transform/library/kvcache.py | 21 +- .../transform/library/kvcache_transformers.py | 49 ++-- .../transform/library/load_weights.py | 22 +- .../_torch/auto_deploy/transform/optimizer.py | 28 +-- .../auto_deploy/transformations/_graph.py | 53 ++--- .../test_cuda_causal_conv_cached_op.py | 3 +- .../custom_ops/test_torch_attention_op.py | 4 +- .../test_torch_causal_conv_cached_op.py | 3 +- .../custom_ops/test_torch_mamba_cached_op.py | 3 +- .../singlegpu/test_ad_build_small_single.py | 3 + .../transformations/library/test_kv_cache.py | 2 + 25 files changed, 284 insertions(+), 267 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 1684fe08c9a..74675b725c6 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -6,6 +6,7 @@ transforms: ############################################################################################ build_model: stage: factory + run_per_gm: false device: meta # nothing to clean up run_graph_cleanup: false @@ -14,8 +15,8 @@ transforms: stage: export clone_state_dict: false strict: false - # nothing to clean up - run_graph_cleanup: false + run_per_gm: false + run_graph_cleanup: true requires_clean_graph: false cleanup_noop_slice: stage: post_export @@ -35,6 +36,7 @@ transforms: run_shape_prop: true match_eager_attention: stage: pattern_matcher + requires_shape_prop: true match_grouped_attention: stage: pattern_matcher match_attention_layout: @@ -87,8 +89,10 @@ transforms: ############################################################################################ load_weights: stage: weight_load + run_per_gm: false move_inputs_to_device: stage: weight_load + run_per_gm: false ############################################################################################ # RUN POST-LOAD FUSION AND OPTIMIZATIONS ############################################################################################ @@ -138,10 +142,13 @@ transforms: attn_backend: cuda_causal_conv initialize_cache: stage: cache_init + run_per_gm: false resize_kv_cache: stage: cache_init + run_per_gm: false ############################################################################################ # COMPILE MODEL ############################################################################################ compile_model: stage: compile + run_per_gm: false diff --git a/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml b/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml index 5b32f81672d..529a3bb5879 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml @@ -6,23 +6,29 @@ transforms: ############################################################################################ build_and_load_factory_model: stage: factory + run_per_gm: false ############################################################################################ # MOVE ARGUMENTS TO DEVICE ############################################################################################ move_inputs_to_device: stage: weight_load + run_per_gm: false ############################################################################################ # SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES ############################################################################################ detect_hf_attn_layers: stage: cache_init + run_per_gm: false transformers_replace_cached_attn: stage: cache_init attn_backend: flashinfer + run_per_gm: false initialize_cache: stage: cache_init + run_per_gm: false resize_kv_cache: stage: cache_init + run_per_gm: false ############################################################################################ # COMPILE MODEL ############################################################################################ diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index d61bb8854e8..7c2d28fbce1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -325,7 +325,11 @@ def args_for_prepare_metadata(self) -> Tuple[str, ...]: like ``insert_cached_attention`` to extract the constant arguments and add them to the ``prepare_metadata`` node/op. """ - return tuple(self.named_standard_args.keys()) + # NOTE: for now we do _not_ include input_ids since we are not guaranteed that input_ids + # is part of the graph, e.g., in situations where the graph is a submodule of the overall + # model. In such instances, the graph usually sees inputs_embeds. However, we assume for + # now that position_ids is always part of the graph. + return ("position_ids",) + self._cached_arg_names @property def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]: @@ -466,7 +470,9 @@ def _get_cache_locations_and_pages_per_sequence( return cache_loc_flat, pages_per_seq @classmethod - def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor: + def _get_sanitized_seq_len( + cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor + ) -> torch.Tensor: """Sanitize sequence lengths. We want to cover the following scenarios with this function: @@ -499,22 +505,24 @@ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) # valid cache location in the batch. This would ensure that the dummy sequences just # repeats valid computation... """ - _, s = input_ids.shape[:2] - num_seq = cls._get_sanitized_num_sequences(input_ids, seq_len) + _, s = input_or_position_ids.shape[:2] + num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len) if s > 1: return seq_len[:num_seq].detach().clone() else: return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device) @staticmethod - def _get_sanitized_num_sequences(input_ids: torch.Tensor, seq_len: torch.Tensor) -> int: + def _get_sanitized_num_sequences( + input_or_position_ids: torch.Tensor, seq_len: torch.Tensor + ) -> int: """Get number of sequences. We makes sure that this function is compatible with both torch graph capture and cudagraph. Both can be a bit temparamental when trying to extract the number of sequences from a tensor with max_batch_size or max_batch_size*max_seq_len. """ - b, s = input_ids.shape[:2] + b, s = input_or_position_ids.shape[:2] if s > 1: num_seq = torch.sum(seq_len > 0) assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded" @@ -814,7 +822,6 @@ def __call__( class PrepareMetadataCallable(Protocol): def __call__( self, - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -901,7 +908,6 @@ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]: ``` def prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py index fbfc2fad614..b8e134be19f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py @@ -54,7 +54,6 @@ def _build_conv_state_from_sequence(input_bt_c: torch.Tensor, kernel_size: int) # --------------------------------------------------------------- @torch.library.custom_op("auto_deploy::cuda_causal_conv_prepare_metadata", mutates_args=()) def cuda_causal_conv_prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -67,7 +66,7 @@ def cuda_causal_conv_prepare_metadata( Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). """ - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) seq_start = torch.zeros_like(seq_len_sanitized) @@ -81,9 +80,9 @@ def cuda_causal_conv_prepare_metadata( @cuda_causal_conv_prepare_metadata.register_fake def cuda_causal_conv_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) return ( torch.empty_like(seq_len_sanitized), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 3200a21937d..63a8c7b1547 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -155,7 +155,6 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper): @torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=()) def prepare_flashinfer_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -174,7 +173,7 @@ def prepare_flashinfer_metadata( _GlobalFlashInferPlanner.reset() # retrieve sanitzed metadata - seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len) # prepare flashinfer-style metadata @@ -214,9 +213,9 @@ def prepare_flashinfer_metadata( # As SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_flashinfer_metadata.register_fake def prepare_flashinfer_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device) return ( qo_indptr, # qo_indptr diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py index 20100531973..ea68da9e508 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py @@ -175,7 +175,6 @@ def fused_flattened_mla_with_cache_fake( "auto_deploy::triton_attention_prepare_fused_mla_metadata", mutates_args=() ) def prepare_fused_mla_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -184,7 +183,7 @@ def prepare_fused_mla_metadata( slot_idx: torch.Tensor, page_size: int, ) -> List[torch.Tensor]: - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) seq_start = torch.zeros_like(seq_len[:num_seq]) seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0) return ( @@ -197,7 +196,7 @@ def prepare_fused_mla_metadata( @prepare_fused_mla_metadata.register_fake def prepare_fused_mla_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): return ( torch.empty_like(seq_len), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index 6eadb4b4466..df2d4b24c59 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -356,7 +356,6 @@ def torch_backend_mha_with_cache_fake( @torch.library.custom_op("auto_deploy::torch_cached_attention_prepare_metadata", mutates_args=()) def torch_backend_prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -366,7 +365,7 @@ def torch_backend_prepare_metadata( page_size: int, ) -> List[torch.Tensor]: """Prepare metadata for torch backend attention (similar to triton backend).""" - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) seq_start = torch.zeros_like(seq_len[:num_seq]) seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0) return ( @@ -379,9 +378,9 @@ def torch_backend_prepare_metadata( @torch_backend_prepare_metadata.register_fake def torch_backend_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) return ( torch.empty_like(seq_len[:num_seq]), torch.empty_like(input_pos[:num_seq]), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py index 522779fd183..6aaf5ecb405 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_causal_conv.py @@ -140,7 +140,6 @@ def _torch_causal_conv1d_decode( @torch.library.custom_op("auto_deploy::torch_causal_conv_prepare_metadata", mutates_args=()) def torch_causal_conv_prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -153,7 +152,7 @@ def torch_causal_conv_prepare_metadata( Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). """ - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) seq_start = torch.zeros_like(seq_len_sanitized) @@ -167,9 +166,9 @@ def torch_causal_conv_prepare_metadata( @torch_causal_conv_prepare_metadata.register_fake def torch_causal_conv_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) return ( torch.empty_like(seq_len_sanitized), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py index 4ac148e815e..6bf7eb84d14 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_mamba.py @@ -113,7 +113,6 @@ def _update_ssm_state_cache(ssm_cache: torch.Tensor, ssm_state: torch.Tensor) -> @torch.library.custom_op("auto_deploy::torch_ssm_prepare_metadata", mutates_args=()) def _torch_ssm_prepare_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -127,7 +126,7 @@ def _torch_ssm_prepare_metadata( Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). """ # Determine number of active sequences and compute seq_start boundaries - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) seq_start = torch.zeros_like(seq_len_sanitized) @@ -142,10 +141,10 @@ def _torch_ssm_prepare_metadata( @_torch_ssm_prepare_metadata.register_fake def _torch_ssm_prepare_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): # Use the same sanitization logic to determine sizes in fake mode - seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len) + seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len) num_seq = len(seq_len_sanitized) return ( torch.empty_like(seq_len_sanitized), diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 56aad993a3c..34e0c5a988d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -284,7 +284,6 @@ def flattened_mha_fake( "auto_deploy::triton_attention_prepare_fused_mha_metadata", mutates_args=() ) def prepare_fused_mha_metadata( - input_ids: torch.Tensor, position_ids: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, @@ -294,7 +293,7 @@ def prepare_fused_mha_metadata( page_size: int, ) -> List[torch.Tensor]: # TODO: maybe use slot_idx instead of pages_per_seq?? - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) seq_start = torch.zeros_like(seq_len[:num_seq]) seq_start[1:] = torch.cumsum(seq_len[: num_seq - 1], 0) return ( @@ -309,9 +308,9 @@ def prepare_fused_mha_metadata( # SequenceInfo._get_sanitized_num_sequences could break in fake mode @prepare_fused_mha_metadata.register_fake def prepare_fused_mha_metadata_fake( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size ): - num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len) + num_seq = SequenceInfo._get_sanitized_num_sequences(position_ids, seq_len) return ( torch.empty_like(seq_len[:num_seq]), torch.empty_like(input_pos[:num_seq]), diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index a0895b61d14..6915dac8540 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -4,7 +4,7 @@ """ import time -from abc import ABC, abstractmethod +from abc import ABC from contextlib import nullcontext from enum import Enum from functools import total_ordering, wraps @@ -82,7 +82,7 @@ class TransformConfig(BaseModel): ### OPTIONAL CONFIG ########################################################################### run_per_gm: bool = Field( description="Whether to run the transform per graph (sub)module or on whole module.", - default=False, + default=True, ) enabled: bool = Field( default=True, @@ -126,9 +126,11 @@ class TransformInfo(BaseModel): } skipped: bool = Field( + default=True, description="Whether the transform was skipped.", ) num_matches: int = Field( + default=0, description="Number of matches found.", ) is_clean: bool = Field( @@ -145,6 +147,32 @@ class TransformInfo(BaseModel): "tensors in the graph and it preserves the has_valid_shapes flag of the last transform.", ) + @classmethod + def from_last_info(cls, info: "TransformInfo") -> "TransformInfo": + """Create a new TransformInfo from the last transform info.""" + return cls( + is_clean=info.is_clean, + has_valid_shapes=info.has_valid_shapes, + ) + + def __or__(self, other: "TransformInfo") -> "TransformInfo": + """Merge two TransformInfo objects.""" + return TransformInfo( + skipped=self.skipped and other.skipped, # we only count skipped if both were skipped + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean or other.is_clean, + has_valid_shapes=self.has_valid_shapes or other.has_valid_shapes, + ) + + def __and__(self, other: "TransformInfo") -> "TransformInfo": + """Merge two TransformInfo objects.""" + return TransformInfo( + skipped=self.skipped and other.skipped, # we only count skipped if both were skipped + num_matches=self.num_matches + other.num_matches, + is_clean=self.is_clean and other.is_clean, + has_valid_shapes=self.has_valid_shapes and other.has_valid_shapes, + ) + TransformHistory = Dict[str, TransformInfo] @@ -248,7 +276,7 @@ def from_kwargs(cls, **kwargs) -> "BaseTransform": @final def __call__( self, - gm: nn.Module, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, @@ -256,13 +284,13 @@ def __call__( """Apply the transform to the graph. Args: - gm: The graph module to apply the transform to. + mod: The model to apply the transform to. cm: The cached sequence interface defining the sequence interface. factory: The model factory used to build the model. shared_config: Global info shared between multiple transforms. Returns: - GraphModule: The transformed graph module. + nn.Module: The transformed model. NOTE: The transform can/should modify the graph module in place if possible. Returning the graph is mostly to standardize the interface for transforms that cannot modify the graph @@ -276,13 +304,16 @@ def __call__( t_name = self.get_transform_key() # retrieve autodeploy metadata from the graphmodule - autodeploy_meta = self._get_autodeploy_meta(gm) + autodeploy_meta = self._get_autodeploy_meta(mod) # retrieve transform history and last transform info history: TransformHistory = autodeploy_meta.get(self._history_key, {}) h_keys = list(history.keys()) # preserves order of insertion/transform execution info_last = history[h_keys[-1]] if h_keys else TransformInfo(skipped=False, num_matches=0) + # initialize new info object + info = TransformInfo.from_last_info(info_last) + # show debug info for debug config ad_logger.debug(f"{t_name} config: {self.config}") @@ -294,42 +325,47 @@ def __call__( # run or skip the transform if self.config.enabled: - # run graph pre-cleanup + # run graph pre-cleanup and update info object elapsed_time_pre_cleanup = -time.time() - is_clean_pre, has_valid_shapes_pre = self._run_pre_cleanup(gm, info_last) + info = info | self._run_cleanup( + mod, + self.config.requires_clean_graph, + self.config.requires_shape_prop, + info.is_clean, + info.has_valid_shapes, + ) elapsed_time_pre_cleanup += time.time() # run the transform in a error-handling wrapper if desired elapsed_time_apply = -time.time() if self.config.skip_on_error: try: - gm, info = self._apply_per_gm(gm, cm, factory, shared_config) + mod, info_apply = self._apply_per_gm_or_whole_model( + mod, cm, factory, shared_config + ) except Exception as e: error_msg = f"Transform {t_name} failed" ad_logger.warning(f"{error_msg}: {e}") - info = TransformInfo(skipped=True, num_matches=0) + info_apply = TransformInfo(skipped=True, num_matches=0) else: # handle this here normally to improve debugging and error message - gm, info = self._apply_per_gm(gm, cm, factory, shared_config) + mod, info_apply = self._apply_per_gm_or_whole_model(mod, cm, factory, shared_config) elapsed_time_apply += time.time() # we cannot say it's clean if the previous wasn't clean even if this one is # create new info object with updated cleanup status - info_dict = info.model_dump() - info_dict["is_clean"] &= is_clean_pre - info_dict["has_valid_shapes"] &= has_valid_shapes_pre - info = TransformInfo(**info_dict) + info = info & info_apply # run graph post-cleanup elapsed_time_post_cleanup = -time.time() - info = self._run_post_cleanup(gm, info) + info = info | self._run_cleanup( + mod, + self.config.run_graph_cleanup, + self.config.run_shape_prop, + info.is_clean, + info.has_valid_shapes, + ) elapsed_time_post_cleanup += time.time() - else: - # skip the transform and set info object using the last transform info - info_dict = info_last.model_dump() - info_dict["skipped"] = True - info_dict["num_matches"] = 0 - info = TransformInfo(**info_dict) elapsed_time_total += time.time() @@ -348,36 +384,37 @@ def __call__( f"post_cleanup={elapsed_time_post_cleanup:.3f}s", ] self._log_info(", ".join(log_msgs_timing)) - ad_logger.debug(f"Graph after {t_name}: {gm}") + ad_logger.debug(f"Model after {t_name}: {mod}") # update + store new meta data history[t_name] = info autodeploy_meta[self._history_key] = history - self._set_autodeploy_meta(gm, autodeploy_meta) + self._set_autodeploy_meta(mod, autodeploy_meta) # return the graph module - return gm + return mod @final - def _apply_per_gm( + def _apply_per_gm_or_whole_model( self, - gm: nn.Module, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: if not self.config.run_per_gm: - return self._apply(gm, cm, factory, shared_config) + return self._apply_to_full_model(mod, cm, factory, shared_config) # just run it on first graph module we are encountering for now... - for k, graph_sub in named_graphmodules(gm): - graph_sub, info = self._apply(graph_sub, cm, factory, shared_config) + info = TransformInfo() + for k, graph_sub in named_graphmodules(mod): + graph_sub, info_apply = self._apply(graph_sub, cm, factory, shared_config) if k == "": - gm = graph_sub + mod = graph_sub else: - gm.set_submodule(k, graph_sub) - break - return gm, info + mod.set_submodule(k, graph_sub) + info = info & info_apply + return mod, info @final def _log_info(self, *args: any): @@ -385,88 +422,56 @@ def _log_info(self, *args: any): ad_logger.info(*args) @final - def _get_autodeploy_meta(self, gm: GraphModule) -> AutodeployMeta: + def _get_autodeploy_meta(self, mod: nn.Module) -> AutodeployMeta: """Get the autodeploy metadata from the graphmodule.""" - if not hasattr(gm, "meta"): - gm.meta = {} - return gm.meta.get(self._autodeploy_meta_key, {}) + if not hasattr(mod, "meta"): + mod.meta = {} + return mod.meta.get(self._autodeploy_meta_key, {}) @final - def _set_autodeploy_meta(self, gm: GraphModule, autodeploy_meta: AutodeployMeta) -> None: + def _set_autodeploy_meta(self, mod: nn.Module, autodeploy_meta: AutodeployMeta) -> None: """Set the autodeploy metadata in the graphmodule.""" - if not hasattr(gm, "meta"): - gm.meta = {} - gm.meta[self._autodeploy_meta_key] = autodeploy_meta + if not hasattr(mod, "meta"): + mod.meta = {} + mod.meta[self._autodeploy_meta_key] = autodeploy_meta @final - def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> Tuple[bool, bool]: + def _run_cleanup( + self, + mod: nn.Module, + clean_graph: bool, + clean_shape: bool, + is_clean: bool, + has_valid_shapes: bool, + ) -> TransformInfo: """Run graph cleanup before the transform. Args: - gm: The graph module to run cleanup on. - info: The last transform info. + mod: The model to run cleanup on. + clean_graph: Whether we want a clean graph after the transform. + clean_shape: Whether we want clean shapes after the transform. + is_clean: The current cleanup status. + has_valid_shapes: The current shape propagation status. Returns: - A tuple of (is_clean, has_valid_shapes) indicating the cleanup status after the - pre-cleanup. - - This is used to ensure the transform is applied to a clean graph as needed by the transform. + An info object indicating the cleanup status after this function is called. """ - if not self.config.requires_clean_graph: - return info.is_clean, info.has_valid_shapes - - is_clean = info.is_clean - has_valid_shapes = is_clean and info.has_valid_shapes - - use_meta = isinstance(gm, GraphModule) and placeholders_on_meta(gm) - # check if run cleanup depending on the config and info - if self.config.requires_shape_prop and not has_valid_shapes: - self._log_info("running pre-cleanup with shape_prop") - canonicalize_graph(gm) - with lift_to_meta(gm) if use_meta else nullcontext(): - run_shape_prop(gm) + if clean_shape and not (is_clean and has_valid_shapes): + self._log_info("running graph cleanup (with shape_prop)") + canonicalize_graph(mod) + with lift_to_meta(mod) if placeholders_on_meta(mod) else nullcontext(): + run_shape_prop(mod) is_clean = True has_valid_shapes = True - elif self.config.requires_clean_graph and not is_clean: - self._log_info("running pre-cleanup (no shape_prop)") - canonicalize_graph(gm) + elif clean_graph and not is_clean: + self._log_info("running graph cleanup (no shape_prop)") + canonicalize_graph(mod) is_clean = True + has_valid_shapes = False - return is_clean, has_valid_shapes - - @final - def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformInfo: - """Run graph cleanup after the transform. + return TransformInfo(is_clean=is_clean, has_valid_shapes=has_valid_shapes) - Cleanup is done as requested in the config and we will update the graph module and info - accordingly. - - Returns: - Updated TransformInfo with cleanup status. - """ - if not self.config.run_graph_cleanup: - return info - - use_meta = isinstance(gm, GraphModule) and placeholders_on_meta(gm) - - # check if run cleanup depending on the config and info - if self.config.run_shape_prop and not (info.is_clean and info.has_valid_shapes): - self._log_info("running post-cleanup with shape_prop") - canonicalize_graph(gm) - with lift_to_meta(gm) if use_meta else nullcontext(): - run_shape_prop(gm) - elif self.config.run_graph_cleanup and not info.is_clean: - self._log_info("running post-cleanup (no shape_prop)") - canonicalize_graph(gm) - - # create new info object with updated cleanup status - info_dict = info.model_dump() - info_dict["is_clean"] |= self.config.run_graph_cleanup - info_dict["has_valid_shapes"] |= self.config.run_shape_prop - return TransformInfo(**info_dict) - - @abstractmethod def _apply( self, gm: GraphModule, @@ -478,6 +483,21 @@ def _apply( This is the core method that should be implemented by subclasses. """ + raise NotImplementedError( + f"Transform {self.get_transform_key()} only supports `run_per_gm=False`." + ) + + def _apply_to_full_model( + self, + model: nn.Module, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[nn.Module, TransformInfo]: + """Apply the transform to the full model.""" + raise NotImplementedError( + f"Transform {self.get_transform_key()} only supports `run_per_gm=True`." + ) class TransformRegistry: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py index 96a81dbfec7..b166c2acd6e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py @@ -2,8 +2,8 @@ from typing import Tuple, Type +import torch.nn as nn from pydantic import Field -from torch.fx import GraphModule from ...models import ModelFactory, hf from ...shim.interface import CachedSequenceInterface @@ -36,23 +36,20 @@ class BuildModel(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return BuildModelConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: # build the model model = factory.build_model(self.config.device) - # as wrapper to satisfy the interface we will register the model as a submodule - gm.add_module("factory_model", model) - - # by convention, we say this fake graph module is always clean + # by convention, we say the model is always clean info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) - return gm, info + return model, info @TransformRegistry.register("build_and_load_factory_model") @@ -68,22 +65,19 @@ class BuildAndLoadFactoryModel(BuildModel): config: BuildModelConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: # load model with auto sharding assert isinstance(factory, hf.AutoModelFactory), "Only HF models are supported." # build and load the model model = factory.build_and_load_model(self.config.device) - # as wrapper to satisfy the interface we will register the model as a submodule - gm.add_module("factory_model", model) - # this ensures that extra_args are passed in as they are received instead of enforcing the # registered extra_args cm.info.use_strict_args = False @@ -95,4 +89,4 @@ def _apply( # by convention, we say this fake graph module is always clean info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) - return gm, info + return model, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py index a77fbd3ac85..d9dff807e18 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py @@ -1,7 +1,7 @@ from typing import List, Literal, Optional, Tuple, Type +import torch.nn as nn from pydantic import Field -from torch.fx import GraphModule from ...compile import CompileBackendRegistry from ...models.factory import ModelFactory @@ -39,18 +39,18 @@ class CompileModel(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return CompileModelConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: cm.info.set_generate_only_batch() compiler_cls = CompileBackendRegistry.get(self.config.compile_backend) - egm_compiled = compiler_cls( - gm, + mod_compiled = compiler_cls( + mod, args=(), kwargs=cm.named_args, max_batch_size=cm.info.max_batch_size, @@ -62,4 +62,4 @@ def _apply( # store info object about the transform info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) - return egm_compiled, info + return mod_compiled, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index e92d842958f..fffb8a25a0e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -2,8 +2,8 @@ from typing import List, Optional, Tuple, Type +import torch.nn as nn from pydantic import Field -from torch.fx import GraphModule from ...export import torch_export_to_gm from ...models.factory import ModelFactory @@ -49,25 +49,19 @@ class ExportToGM(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return ExportToGMConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: - # at this point we assume the gm is just a dummy graph module - assert len(gm.graph.nodes) == 0, "Expected empty graph module." - - # retrieve the actual model from the dummy graph module - model = gm.get_submodule("factory_model") - + ) -> Tuple[nn.Module, TransformInfo]: # set the example sequence cm.info.set_example_sequence(**factory.get_example_inputs()) # export the model to a graph module gm = torch_export_to_gm( - model, + mod, args=(), kwargs=cm.named_args, dynamic_shapes=cm.named_dynamic_shapes, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index be8bf4c7b3e..945b375139d 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Tuple, Type import torch +import torch.nn as nn from pydantic import Field from torch.fx import GraphModule, Node @@ -238,13 +239,13 @@ class ResizeKVCache(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return ResizeKVCacheConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: free_mem_ratio = self.config.free_mem_ratio def _get_mem_info_in_mb(): @@ -262,7 +263,7 @@ def _get_mem_info_in_mb(): if free_mem_ratio == 0.0: self._log_info(f"Skipping cache resize for {free_mem_ratio=}") - return gm, TransformInfo( + return mod, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) @@ -274,7 +275,7 @@ def _get_mem_info_in_mb(): free_mem_pre, _ = _get_mem_info_in_mb() self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}") - gm(**cm.named_args) + mod(**cm.named_args) free_mem_post, _ = _get_mem_info_in_mb() self._log_info(f"Free memory after forward pass (MB): {free_mem_post}") @@ -310,18 +311,18 @@ def _get_mem_info_in_mb(): has_valid_shapes=True, ) - return gm, info + return mod, info @TransformRegistry.register("initialize_cache") class InitializeCache(BaseTransform): - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: num_caches = cm.initialize_caches() self._log_info(f"Initialized {num_caches} caches for cached attention") @@ -329,4 +330,4 @@ def _apply( skipped=False, num_matches=num_caches, is_clean=True, has_valid_shapes=True ) - return gm, info + return mod, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index ed218061ef4..bcc3db30256 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -6,7 +6,8 @@ import torch import torch.fx as fx -from torch.fx import GraphModule, Node +import torch.nn as nn +from torch.fx import Graph, GraphModule, Node from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -106,23 +107,24 @@ class DetectHFAttnLayers(BaseTransform): This is achieved by running a single forward pass to profile the model. """ - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - model = gm.factory_model - # Register profiler attn operator ALL_ATTENTION_FUNCTIONS.register("ad_profile_mha", fake_profiler_mha) + # let's start a fake graph module for making tracing/profiling easier + mod._gm = GraphModule(nn.Module(), Graph()) + # run the forward pass with the profiling function - with switch_attn_implementation(model.config, "ad_profile_mha"): + with switch_attn_implementation(mod.config, "ad_profile_mha"): # update the graph module with the fake attn nodes during the profiling run - profiling_metadata = {"gm": gm, "num_matches": 0} - model.forward(**cm.named_args, profiling_metadata=profiling_metadata) + profiling_metadata = {"gm": mod._gm, "num_matches": 0} + mod.forward(**cm.named_args, profiling_metadata=profiling_metadata) info = TransformInfo( skipped=False, @@ -131,7 +133,7 @@ def _apply( has_valid_shapes=True, ) - return gm, info + return mod, info def get_cached_attn( @@ -188,9 +190,9 @@ def cached_attn( return cached_attn -def forward_with_prepare_metadata(gm: GraphModule, **cm_kwargs): +def forward_with_prepare_metadata(mod: nn.Module, **cm_kwargs): """Run prepare_metadata as pre-processing step, add to kwargs, and then run regular forward.""" - + gm = mod._gm if hasattr(gm, "_prepare_metadata_info"): # collect args+constant args args = [cm_kwargs[k] for k in gm._prepare_metadata_info["arg_names"]] @@ -201,7 +203,7 @@ def forward_with_prepare_metadata(gm: GraphModule, **cm_kwargs): return_names = gm._prepare_metadata_info["return_names"] cm_kwargs.update({k: v for k, v in zip(return_names, metadata)}) - return gm.factory_model.forward(**cm_kwargs) + return mod._original_forward(**cm_kwargs) # TODO: how running different kv cache transforms look like? This one below wouldn't work if we @@ -242,28 +244,29 @@ def _insert_cached_attn_node( attn_node.meta["metadata_cache_buffer_keys"] = (*meta_nodes, *cache_nodes, *buffer_nodes) attn_node.meta["constants"] = constants - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: # switch to cached attn inputs from now cm.info.switch_to_cached_attn_inputs() - # run actual insert cached attn transform - gm, info = super()._apply(gm, cm, factory, shared_config) + # run actual insert cached attn transform with fake graph module + mod._gm, info = super()._apply(mod._gm, cm, factory, shared_config) # register cached attn operator and switch to cached forward function ALL_ATTENTION_FUNCTIONS.register("ad_cached_mha", get_cached_attn(self.attn_descriptor)) - gm.forward = types.MethodType(forward_with_prepare_metadata, gm) + mod._original_forward = mod.forward + mod.forward = types.MethodType(forward_with_prepare_metadata, mod) - # switch to cached attn implementation but _only_ for modules/configs that have a cached + # switch to cached attn implementation but _only_ for submodules/configs that have a cached # attn node (we don't want to switch to cached attn implementation for all modules) - for mod in gm.factory_model.modules(): - if hasattr(mod, "_node_ref"): - mod.config._attn_implementation = "ad_cached_mha" + for submod in mod.modules(): + if hasattr(submod, "_node_ref"): + submod.config._attn_implementation = "ad_cached_mha" # we assume graph is clean again by definition info_dict = info.model_dump() @@ -271,4 +274,4 @@ def _apply( info_dict["has_valid_shapes"] = True info = TransformInfo(**info_dict) - return gm, info + return mod, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py index 4967e6638e9..fb229bdd56c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/load_weights.py @@ -2,8 +2,8 @@ from typing import Optional, Tuple, Type +import torch.nn as nn from pydantic import Field -from torch.fx import GraphModule from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface @@ -36,22 +36,22 @@ class LoadWeightsToDevice(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return MoveDeviceConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: factory.load_or_random_init( - gm, + mod, device=self.config.adconfig_checkpoint_device or self.config.device, ) - move_to_device(gm, self.config.device) + move_to_device(mod, self.config.device) info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) - return gm, info + return mod, info @TransformRegistry.register("move_inputs_to_device") @@ -64,15 +64,15 @@ class LoadFactoryModelWeights(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return MoveDeviceConfig - def _apply( + def _apply_to_full_model( self, - gm: GraphModule, + mod: nn.Module, cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: cm.to(self.config.device) info = TransformInfo(skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True) - return gm, info + return mod, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py index 53659bf8140..31087dbd436 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/optimizer.py @@ -4,7 +4,6 @@ import torch.distributed as dist import torch.nn as nn -from torch.fx import Graph, GraphModule from ..distributed import common as dist_ad from ..models.factory import ModelFactory @@ -44,41 +43,32 @@ def _clean_config(self, config: InferenceOptimizerConfig) -> StrictInferenceOpti # return strict config return strict_config - @staticmethod - def _init_gm() -> GraphModule: - """Initialize a fake graph module. - - This is a dummy graph module that will be used to kick off the transforms. - """ - return GraphModule(nn.Module(), Graph()) - - def __call__( - self, cm: CachedSequenceInterface, gm: Optional[GraphModule] = None - ) -> GraphModule: + def __call__(self, cm: CachedSequenceInterface, mod: Optional[nn.Module] = None) -> nn.Module: """Transform a model into an optimized inference model. Args: cm: The cached sequence interface defining the sequence interface. + mod: The model to transform. Returns: - A GraphModule representing the optimized inference model. + A nn.Module representing the optimized inference model. """ ############################################################################################ # RUN THROUGH CONFIGURED TRANSFORMATIONS ############################################################################################ - # start with an empty fake graph module if not provided - if gm is None: - gm = self._init_gm() + # start with an empty model if not provided + if mod is None: + mod = nn.Module() # iterate over all transforms sorted by stage in the config for t_name, t_config in self.config.items(): # instantiate transform transform = TransformRegistry.get(t_name)(t_config) # run transform - gm = transform(gm, cm, self.factory, self.shared_config) + mod = transform(mod, cm, self.factory, self.shared_config) ############################################################################################ - # RETURN OPTIMIZED GRAPH + # RETURN OPTIMIZED MODEL ############################################################################################ - return gm + return mod diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py index 273e509714c..cea1d80219a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py @@ -132,12 +132,12 @@ def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None: gm.recompile() -def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> None: +def move_to_device(mod: nn.Module, device: DeviceLikeType) -> None: """Move the entire graph module and all sub-GraphModules to the specified device.""" # get device device = torch.device(device) - for _, subgm in reversed(list(named_graphmodules(gm))): + for _, subgm in reversed(list(named_graphmodules(mod))): # recompile graph to update self generated codes in subgraph _move_single_gm_to_device(subgm, device) @@ -171,20 +171,20 @@ def _canonicalize_single_gm(gm: GraphModule) -> None: gm.graph.lint() -def canonicalize_graph(gm: GraphModule) -> None: +def canonicalize_graph(mod: nn.Module) -> None: """Canonicalize the graph of the given GraphModule. Args: - gm: The GraphModule to canonicalize. + mod: The model containing GraphModules to canonicalize. Returns: - The canonicalized (cleaned-up) GraphModule. + The canonicalized (cleaned-up) model. """ - ad_logger.debug(f"Before canonicalizing: {gm}") + ad_logger.debug(f"Before canonicalizing: {mod}") - for _, subgm in reversed(list(named_graphmodules(gm))): + for _, subgm in reversed(list(named_graphmodules(mod))): _canonicalize_single_gm(subgm) - ad_logger.debug(f"After canonicalizing: {gm}") + ad_logger.debug(f"After canonicalizing: {mod}") def _run_shape_prop_single_gm( @@ -216,7 +216,7 @@ def _run_shape_prop_single_gm( def run_shape_prop( - gm: GraphModule, + mod: nn.Module, args_static: Optional[Tuple[Any, ...]] = None, ) -> None: """Run FakeTensor-based shape propagation on the given GraphModule and its submodules. @@ -228,19 +228,19 @@ def run_shape_prop( are synthesized from the static arguments. Args: - gm: The top-level GraphModule on which to run shape propagation. All nested - GraphModules are processed in reverse topological order. + mod: The top-level model containing GraphModules on which to run shape propagation. All + nested GraphModules are processed in reverse topological order. args_static: Optional tuple of concrete tensors used to create FakeTensors when placeholder metadata is missing. Only applied to the top-level GraphModule; submodules reuse their existing placeholder metadata. """ - ad_logger.debug(f"Before running shape propagation: {gm}") + ad_logger.debug(f"Before running shape propagation: {mod}") - for _, subgm in reversed(list(named_graphmodules(gm))): - _run_shape_prop_single_gm(subgm, args_static=args_static if subgm is gm else None) + for _, subgm in reversed(list(named_graphmodules(mod))): + _run_shape_prop_single_gm(subgm, args_static=args_static if subgm is mod else None) - ad_logger.debug(f"After running shape propagation: {gm}") + ad_logger.debug(f"After running shape propagation: {mod}") def add_graph_input( @@ -309,7 +309,7 @@ def call_post_init(spec): return in_node -def placeholders_on_meta(gm: GraphModule) -> bool: +def placeholders_on_meta(mod: nn.Module) -> bool: """ Return True if every placeholder node in the graph is on the meta device. """ @@ -324,17 +324,18 @@ def _is_meta_tensor(t) -> bool: # Fallback for objects with .is_meta attribute return bool(getattr(t, "is_meta", False)) - for n in gm.graph.nodes: - if n.op != "placeholder": - continue - val = n.meta.get("val", None) + for _, subgm in reversed(list(named_graphmodules(mod))): + for n in subgm.graph.nodes: + if n.op != "placeholder": + continue + val = n.meta.get("val", None) - # If placeholder packs multiple values, find the first tensor-like leaf - t = val - if isinstance(val, (list, tuple)): - t = next((x for x in val if hasattr(x, "device") or hasattr(x, "is_meta")), None) + # If placeholder packs multiple values, find the first tensor-like leaf + t = val + if isinstance(val, (list, tuple)): + t = next((x for x in val if hasattr(x, "device") or hasattr(x, "is_meta")), None) - if not _is_meta_tensor(t): - return False + if not _is_meta_tensor(t): + return False return True diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py index 05ac3c70d22..7ffb1709cb6 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py @@ -176,7 +176,7 @@ def test_prepare_metadata_cuda(conv_env): device = conv_env["device"] b, s = 4, 6 - input_ids = torch.randint(0, 1000, (b, s), device=device) + # input_ids = torch.randint(0, 1000, (b, s), device=device) position_ids = torch.arange(s, device=device).expand(b, -1) seq_len = torch.tensor([2, 1, 0, 0], device=device, dtype=torch.int32) input_pos = torch.tensor([0, 3, 0, 0], device=device, dtype=torch.int32) @@ -186,7 +186,6 @@ def test_prepare_metadata_cuda(conv_env): page_size = 128 out = torch.ops.auto_deploy.cuda_causal_conv_prepare_metadata( - input_ids, position_ids, seq_len, input_pos, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py index 58811d6cd70..e68dfa4f24b 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py @@ -469,7 +469,7 @@ def test_metadata_preparation(self): batch_size, seq_len_val = 4, 8 device = self.device - input_ids = torch.randint(0, 1000, (batch_size, seq_len_val), device=device) + # input_ids = torch.randint(0, 1000, (batch_size, seq_len_val), device=device) position_ids = torch.arange(seq_len_val, device=device).expand(batch_size, -1) seq_len = torch.full((batch_size,), seq_len_val, device=device, dtype=torch.int32) input_pos = torch.zeros(batch_size, device=device, dtype=torch.int32) @@ -479,7 +479,7 @@ def test_metadata_preparation(self): # Test metadata preparation result = torch.ops.auto_deploy.torch_cached_attention_prepare_metadata( - input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, 128 + position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, 128 ) # Verify result structure diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py index 502eb634dc3..4090821e252 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py @@ -168,7 +168,7 @@ def test_prepare_metadata(conv_env): device = conv_env["device"] b, s = 4, 6 - input_ids = torch.randint(0, 1000, (b, s), device=device) + # input_ids = torch.randint(0, 1000, (b, s), device=device) position_ids = torch.arange(s, device=device).expand(b, -1) seq_len = torch.tensor([2, 1, 0, 0], device=device, dtype=torch.int32) input_pos = torch.tensor([0, 3, 0, 0], device=device, dtype=torch.int32) @@ -178,7 +178,6 @@ def test_prepare_metadata(conv_env): page_size = 128 out = torch.ops.auto_deploy.torch_causal_conv_prepare_metadata( - input_ids, position_ids, seq_len, input_pos, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py index 601c61e5f45..3000880d435 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py @@ -183,7 +183,7 @@ def test_prepare_metadata(mamba_env): device = mamba_env["device"] b, s = 4, 6 - input_ids = torch.randint(0, 1000, (b, s), device=device) + # input_ids = torch.randint(0, 1000, (b, s), device=device) position_ids = torch.arange(s, device=device).expand(b, -1) seq_len = torch.tensor([2, 1, 0, 0], device=device, dtype=torch.int32) input_pos = torch.tensor([0, 3, 0, 0], device=device, dtype=torch.int32) @@ -193,7 +193,6 @@ def test_prepare_metadata(mamba_env): page_size = 128 out = torch.ops.auto_deploy.torch_ssm_prepare_metadata( - input_ids, position_ids, seq_len, input_pos, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 65faaa0263a..07cb79270c8 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -92,6 +92,9 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): ), get_small_model_config_pytest_param( "mistralai/Mistral-Small-3.1-24B-Instruct-2503", + pytest_param_kwargs={ + "marks": pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5505835") + }, attn_backend="flashinfer", compile_backend="torch-simple", ), diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index fef458e1f7a..f154f00e408 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -163,6 +163,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): { "build_model": { "stage": "factory", + "run_per_gm": False, "device": "cuda", "run_graph_cleanup": False, "requires_clean_graph": False, @@ -170,6 +171,7 @@ def test_sdpa_with_kv_cache(dtype, attn_backend, gqa_config): "export_to_gm": { "stage": "export", "strict": False, + "run_per_gm": False, "clone_state_dict": True, "run_graph_cleanup": False, "requires_clean_graph": False, From c3465d0be2fb4ed02e01526399f5364f881b71c4 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Wed, 8 Oct 2025 11:30:01 -0700 Subject: [PATCH 2/4] factory<>export interface for subgraphs Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 3 - .../custom_ops/attention_interface.py | 133 +--------------- .../_torch/auto_deploy/export/export.py | 78 ++++++++-- .../_torch/auto_deploy/export/interface.py | 20 +++ .../_torch/auto_deploy/models/__init__.py | 2 +- .../_torch/auto_deploy/models/factory.py | 98 ++++++++++-- tensorrt_llm/_torch/auto_deploy/models/hf.py | 124 +++++++++++---- .../_torch/auto_deploy/models/mistral3.py | 36 ----- .../auto_deploy/models/patches/llama4.py | 14 +- .../auto_deploy/models/patches/mistral3.py | 11 +- .../auto_deploy/models/patches/pixtral.py | 34 +++- .../auto_deploy/models/patches/starcoder.py | 2 +- .../_torch/auto_deploy/shim/ad_executor.py | 4 - .../_torch/auto_deploy/shim/interface.py | 14 +- .../transform/library/build_model.py | 4 - .../transform/library/export_to_gm.py | 147 ++++++++++++++++-- .../auto_deploy/transformations/_graph.py | 8 +- .../_utils_test/_graph_test_helpers.py | 9 +- .../_utils_test/_model_test_utils.py | 2 +- .../singlegpu/models/test_llama4_vlm_patch.py | 20 +-- .../unit/singlegpu/models/test_mistral3.py | 15 -- .../singlegpu/models/test_mistral3_patches.py | 27 ++-- .../singlegpu/test_ad_build_small_single.py | 11 +- .../transformations/library/test_kv_cache.py | 12 +- .../library/test_quantization.py | 11 +- 25 files changed, 514 insertions(+), 325 deletions(-) delete mode 100644 tensorrt_llm/_torch/auto_deploy/models/mistral3.py delete mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral3.py diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 74675b725c6..ead86ade2f6 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -8,15 +8,12 @@ transforms: stage: factory run_per_gm: false device: meta - # nothing to clean up - run_graph_cleanup: false requires_clean_graph: false export_to_gm: stage: export clone_state_dict: false strict: false run_per_gm: false - run_graph_cleanup: true requires_clean_graph: false cleanup_noop_slice: stage: post_export diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 7c2d28fbce1..7cd496ce2a8 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -11,32 +11,16 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import ( - Callable, - Dict, - List, - Literal, - Optional, - Protocol, - Sequence, - Set, - Tuple, - Type, - Union, -) +from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union import torch from torch._ops import OpOverloadPacket -from torch.export import Dim from torch.fx import Node from torch.types import Number from ...._utils import nvtx_range from ..utils.logger import ad_logger -DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension -DynamicShapeCallback = Callable[[], DynamicShape] - Constant = Union[int, float, str, None] @@ -67,12 +51,6 @@ class SequenceInfo: ### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ################################################## Those are extra arguments that can be provided to the interface and they are stored as follows: - _extra_args: dictionary of extra arguments with currently active values. - - _extra_none_inputs: dictionary of none inputs to the extra arguments. - NOTE: we assume that extra arguments are *optional* arguments to the model. However, we - cannot represent them via `None` since fx graphs require a fixed input type. Instead, - we require a special placeholder tensor to represent the `None` input. - - _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of - the extra arguments. ### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############ - seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i) @@ -175,12 +153,6 @@ def __init__( # indicator if extra args are activated that are needed for cached attention backends self._is_cached_attn = False - # indicator how to handle the "None" input for extra args - self._use_strict_args = True - - # container for dynamic shapes - self._dynamic_shapes: Optional[Dict[str, DynamicShape]] = None - # TENSOR FIELDS ############################################################################ self._args_device: Dict[str, torch.Tensor] = { # TENSOR FIELDS FOR UNCACHED ATTENTION @@ -206,9 +178,6 @@ def __init__( # EXTRA TENSOR FIELDS ###################################################################### self._extra_args: Dict[str, Optional[torch.Tensor]] = {} - self._extra_none_inputs: Dict[str, torch.Tensor] = {} - self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None - self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {} ############################################################################################ # call reset once to set a consistent initial state @@ -218,33 +187,6 @@ def __init__( def device(self) -> torch.device: return self._args_device["input_ids"].device - @property - def use_strict_args(self) -> bool: - return self._use_strict_args - - @use_strict_args.setter - def use_strict_args(self, val: bool) -> None: - """Configure whether to use strict graph arguments only. - - Args: - val: strict graph arguments only or not. - - In strict arguments mode, - * only stock arguments (like input_ids, position_ids, etc.) or extra - arguments that are explicitly added via the ``add_extra_arg`` interface are allowed. - Other arguments that are provided in ``nest_sequences`` will be rejected and throw an - error. - * registered extra arguments that are not provided to ``nest_sequences`` will be added to - the argument list automatically using the registered None-like tensor. - - In non-strict argument mode, - * all arguments including all **kwargs that are provided to ``nest_sequences`` and will - simply be passed to the model in the order received. - * registered extra arguments that are not provided to ``nest_sequences`` will be added - _not_ be added to the argument list. - """ - self._use_strict_args = val - def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor: """Shape the tensor for the forward pass based on the current attention mode. @@ -347,36 +289,6 @@ def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]: """ return tuple(getattr(self, k) for k in self._cached_constants) - @property - def named_dynamic_shapes(self) -> Dict[str, DynamicShape]: - """Return dynamic shapes of sequence info tensors. - - NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing. - """ - # lazy initialization of dynamic shapes with Dim objects - if self._dynamic_shapes is None: - # set up shape for uncached args (same for all, i.e., batch_size and seq_len) - bs_seq_len_shape: DynamicShape = {} - if self.max_batch_size > 1: - bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size) - bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len) - # bs_seq_len_shape[1] = Dim.AUTO - self._dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names} - # cached args are static - self._dynamic_shapes.update({k: {} for k in self._cached_arg_names}) - - for k, callback in self._extra_dynamic_shapes_callbacks.items(): - if k not in self._dynamic_shapes: - self._dynamic_shapes[k] = callback() - - # return dynamic shapes according to currently active named_args with consistent order - return {k: self._dynamic_shapes[k] for k in self.named_args.keys()} - - @property - def dynamic_shapes(self) -> Tuple[DynamicShape, ...]: - """Return dynamic shapes of sequence info tensors.""" - return tuple(self.named_dynamic_shapes.values()) - @property def seq_len(self) -> List[int]: return self._args_host["seq_len"].copy() @@ -555,12 +467,11 @@ def _move_dict(d: Dict[str, torch.Tensor]) -> None: _move_dict(self._args_device) _move_dict(self._extra_args) - _move_dict(self._extra_none_inputs) def set_example_sequence( self, - input_ids: Sequence[Sequence[int]] = None, - position_ids: Optional[torch.Tensor] = None, + input_ids: Optional[Sequence[Sequence[int]]] = None, + position_ids: Optional[Sequence[Sequence[int]]] = None, **extra_args, ) -> None: """Set an example sequence useful for testing and export purposes without cache history.""" @@ -660,8 +571,6 @@ def _store_extra_arg( else: tnsr_like = tnsr_like[0] self._extra_args[name] = tnsr_like.to(self.device, non_blocking=True) - elif self.use_strict_args: - self._extra_args[name] = self._extra_none_inputs[name] else: self._extra_args[name] = None @@ -744,15 +653,8 @@ def nest_sequences( ### UPDATE EXTRA INPUTS #################################################################### self._extra_args = {} - # in strict argument mode, we only accept registered extra arguments - if self.use_strict_args: - for name in self._extra_none_inputs.keys(): - self._store_extra_arg(name, extra_args.pop(name, None)) - assert not extra_args, f"Extra arguments {extra_args.keys()} not found" - # otherwise, we simply pass in all extra arguments - else: - for key, value in extra_args.items(): - self._store_extra_arg(key, value) + for key, value in extra_args.items(): + self._store_extra_arg(key, value) @nvtx_range("ad_rescatter_input_ids") def rescatter_input_ids( @@ -786,31 +688,6 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) return list(torch.split(t_squeezed, self.seq_len)) - def add_extra_arg( - self, - name: str, - none_input: torch.Tensor, - dynamic_shape_callback: Optional[DynamicShapeCallback] = None, - ) -> None: - """Add an extra argument to the sequence info object. - - Args: - name: The name of the extra argument. - none_input: None input value of the extra argument. - dynamic_shape_callback: The callback to get the dynamic shape of the extra argument. - - Note that the extra argument is expected to be a tensor. - """ - assert name not in self._named_args().keys(), f"Extra argument {name} already exists" - - self._extra_args[name] = none_input.to(self.device) - self._extra_none_inputs[name] = self._extra_args[name] - - if dynamic_shape_callback is None: - self._extra_dynamic_shapes_callbacks[name] = lambda: {} - else: - self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback - class MHACallable(Protocol): def __call__( diff --git a/tensorrt_llm/_torch/auto_deploy/export/export.py b/tensorrt_llm/_torch/auto_deploy/export/export.py index c239c388a27..514ed1e2250 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/export.py +++ b/tensorrt_llm/_torch/auto_deploy/export/export.py @@ -3,7 +3,7 @@ from collections import defaultdict from contextlib import nullcontext from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.export as te @@ -195,6 +195,61 @@ def _clean_up_assertions(gm: fx.GraphModule): canonicalize_graph(gm) +def run_forward_for_capture( + model: nn.Module, + capture_fn: Optional[Callable[..., nn.Module]] = None, + args: Optional[Tuple[Any, ...]] = None, + kwargs: Optional[Dict[str, Any]] = None, + clone: bool = False, # clone or don't clone the model state_dict + *, + patch_configs: Optional[Dict[str, Union[dict, Any]]] = None, + patch_list: Optional[List[str]] = None, +) -> nn.Module: + """A wrapper to run the provided closure over the model on the meta device with patches. + + This utility automates the following steps for running a closure (``capture_fn``): + + 1. Provide patches for certain corner cases that might not be supported by the closure. + 2. Standardize the execution of the closure to properly run on the meta device. + + Args: + model: The model to run the closure on + capture_fn: The closure to run the model with. If not provided, run a forward pass. + args: Arguments for the model + kwargs: Keyword arguments for the model + clone: Whether to clone the model state_dict when the closure returns a different module + patch_configs: Optional patch configurations. If None, all registered patches + will be applied with default settings. + patch_list: Optional list of patch names to apply with default settings. + """ + # run capture with patches and lifted to meta + with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict: + # clean up args, kwargs and move to correct device + args, kwargs = tree_to((args or (), kwargs or {}), device="meta") + + # NOTE (lucaslie): export is VERY sensitive to the location of the inference_mode + # context manager. Do NOT move it unless absolutely necessary. + with torch.inference_mode(): + if capture_fn is None: + model(*args, **kwargs) + mod_after_capture = model + else: + mod_after_capture = capture_fn(model, args, kwargs) + + # load state_dict into egm + # NOTE: export might have removed unused params/buffers (hence we allow unexpected keys) + if mod_after_capture is not model: + load_buffers_and_params( + mod_after_capture, + state_dict, + strict_missing=True, + strict_unexpected=False, + clone=clone, + ) + + return mod_after_capture + + def torch_export_to_gm( model: nn.Module, args: Optional[Tuple[Any, ...]] = None, @@ -230,23 +285,16 @@ def torch_export_to_gm( Cannot be used together with patch_configs. """ - # run export with patches and lifted to meta - with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict: - # clean up args, kwargs and move to correct device - args, kwargs = tree_to((args or (), kwargs or {}), device="meta") - - # NOTE (lucaslie): export is VERY sensitive to the location of the inference_mode - # context manager. Do NOT move it unless absolutely necessary. - with torch.inference_mode(): - ep = te.export(model, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict) + def _capture_fn(model, args, kwargs): + ep = te.export(model, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict) egm = ep.module() assert isinstance(egm, fx.GraphModule) + return egm - # load state_dict into egm - # NOTE: export might have removed unused params/buffers (hence we allow unexpected keys) - load_buffers_and_params( - egm, state_dict, strict_missing=True, strict_unexpected=False, clone=clone - ) + # run capture with export + egm = run_forward_for_capture( + model, _capture_fn, args, kwargs, clone, patch_list=patch_list, patch_configs=patch_configs + ) # Export strips away all methods not traced during forward. The model could have # load hooks that contain logic for correct state_dict loading. We need to add those diff --git a/tensorrt_llm/_torch/auto_deploy/export/interface.py b/tensorrt_llm/_torch/auto_deploy/export/interface.py index db0cbbd94d6..2986ae2b1e8 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/export/interface.py @@ -35,6 +35,15 @@ class ExportPatchConfig(BaseModel): ) +class DisabledExportPatchConfig(ExportPatchConfig): + """Standard configuration for an export patch that is disabled by default.""" + + enabled: bool = Field( + default=False, + description="Whether to enable this patch.", + ) + + class BaseExportPatch(ABC): """Base class for all export patches. @@ -130,6 +139,17 @@ def _revert_patch(self): pass +class DisabledBaseExportPatch(BaseExportPatch): + """A base class for export patches that are disabled by default.""" + + config: DisabledExportPatchConfig + + @classmethod + def get_config_class(cls) -> Type[ExportPatchConfig]: + """Get the configuration class for the patch.""" + return DisabledExportPatchConfig + + class ContextManagerPatch(BaseExportPatch): """A patch that wraps an existing context manager. diff --git a/tensorrt_llm/_torch/auto_deploy/models/__init__.py b/tensorrt_llm/_torch/auto_deploy/models/__init__.py index a0e08950963..a004f7a8b13 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/models/__init__.py @@ -1,2 +1,2 @@ -from . import hf, mistral3, patches +from . import hf, patches from .factory import * diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index f220a49260a..84f4e307228 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -3,15 +3,19 @@ import copy from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Type, final import torch import torch.nn as nn from torch._prims_common import DeviceLikeType +from torch.export import Dim +from torch.fx import GraphModule -from ..custom_ops.attention_interface import CacheConfig, DynamicShapeCallback +from ..custom_ops.attention_interface import CacheConfig from ..utils.logger import ad_logger +DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension + class ShardingConfigSource(Enum): """Enum for factory source.""" @@ -20,6 +24,73 @@ class ShardingConfigSource(Enum): UNKNOWN = "unknown" +class SubModuleExportInfo: + """Information+configuration for exporting a submodule from a factory model.""" + + def __init__(self, submodule_name: str): + self.submodule_name = submodule_name + self._dynamic_shape_lookup: Optional[Dict[str, DynamicShape]] = None + + @property + @final + def dynamic_shape_lookup(self) -> Dict[str, DynamicShape]: + """Return the lookup for the dynamic shapes of keyword arguments. + + This property is lazy-initialized and will be computed on the first access. This is useful + for two reasons: + 1. dynamic shape object are not picklable, so we need to compute them on the first + access in the respective subprocess. + 2. On the other hand, we do not want to initialize the symbolic integers defining the + dynamic shapes more than once to ensure that the **same** symbolic integers is used + for semantically identical shape information (e.g. there is only one dynamic batch + size). This property is used during export process to compare the dynamic shapes of + multiple inputs. + """ + if self._dynamic_shape_lookup is None: + self._dynamic_shape_lookup = self._init_dynamic_shape_lookup() + return self._dynamic_shape_lookup + + @abstractmethod + def _init_dynamic_shape_lookup(self) -> Dict[str, DynamicShape]: + """Initialize the lookup for the dynamic shapes of keyword arguments.""" + raise NotImplementedError("Subclasses must implement this method.") + + @abstractmethod + def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): + """Post-process the subgraph module. + + Args: + sub_mod: The submodule from which the graph was captured+exported. + sub_gm: The graph module that was exported. + + This method can be useful if there are certain properties of the the original submodule + that we may want to preserve in the exported graph module. + """ + raise NotImplementedError("Subclasses must implement this method.") + + +class FullModelExportInfo(SubModuleExportInfo): + """An export configuration for the full model. + + This is the most "trivial" export configuration as it is simply the instructions to export the + full model with the standard forward function signature of ``input_ids, position_ids``. + """ + + def __init__(self): + super().__init__("") + + def _init_dynamic_shape_lookup(self) -> Dict[str, DynamicShape]: + batch_size_dyn = Dim.DYNAMIC + seq_len_dyn = Dim.DYNAMIC + return { + "input_ids": {0: batch_size_dyn, 1: seq_len_dyn}, + "position_ids": {0: batch_size_dyn, 1: seq_len_dyn}, + } + + def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): + pass + + class ModelFactory(ABC): """An interface to return and correctly initialize a model from a desired source. @@ -248,22 +319,19 @@ def get_example_inputs(self) -> Dict[str, torch.Tensor]: """ return {} - def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]: - """Return a dictionary of extra model inputs that behave like optional forward arguments. + @abstractmethod + def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: + """Specify and return a dictionary of export configurations for the model. + + Args: + model: The model to get the export configurations for. Returns: - A dictionary of extra inputs for the model where the key corresponds to the argument - name and the value corresponds to a tuple of (none_input, dynamic_shape_callback): - - `none_input`: The none input value of the extra input indicating the tensor - value corresponding to the equivalent of the None input. `None` is not supported - as we require the input to be a tensor. Hence, this none_input acts as a - placeholder for the None input. We assume that the "optional" behavior of these - arguments can be represented via a placeholder tensor and and an appropriate - check within the forward function using ``torch.cond``. - - `dynamic_shape_callback`: A function that returns the dynamic shape of the extra - input. Simply set to `None` if the extra input is not dynamic. + A list of export configurations for the model. + + Each export config describes which submodules of the factory model should be exported and + contains relevant configuration+information for the export process. """ - return {} class ModelFactoryRegistry: diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index b2177fe9209..acfb117d473 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -2,6 +2,7 @@ import os import re +import types from abc import abstractmethod from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -14,6 +15,8 @@ from huggingface_hub.utils import HFValidationError, filter_repo_objects, validate_repo_id from PIL import Image from torch._prims_common import DeviceLikeType +from torch.export import Dim +from torch.fx import GraphModule from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -30,10 +33,17 @@ WEIGHTS_NAME, ) -from ..custom_ops.attention_interface import CacheConfig, Dim, DynamicShapeCallback +from ..custom_ops.attention_interface import CacheConfig from ..utils._config import deep_merge_dicts from ..utils.logger import ad_logger -from .factory import ModelFactory, ModelFactoryRegistry, ShardingConfigSource +from .factory import ( + DynamicShape, + FullModelExportInfo, + ModelFactory, + ModelFactoryRegistry, + ShardingConfigSource, + SubModuleExportInfo, +) from .quant_config_reader import QuantConfigReader, autodetect_quant_config_reader @@ -440,6 +450,9 @@ def _remap_param_names_load_hook(self, model, state_dict, *args, **kwargs) -> No if new_key != key: state_dict[new_key] = state_dict.pop(key) + def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: + return [FullModelExportInfo()] + class _StateDictParamNameConverter: """Helper class for applying param name conversions to a state dict. @@ -480,6 +493,78 @@ def __call__(self, module, state_dict, *args, **kwargs) -> None: state_dict[new_key] = state_dict.pop(key) +class TextModelExportInfo(SubModuleExportInfo): + """An export configuration for the text model portion of a VLM.""" + + def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): + """Post-process the subgraph module and make sure the embedding remains available.""" + # make sure get_input_embeddings function is available in the graph module + embed_tokens = sub_mod.get_input_embeddings() + sub_gm.get_input_embeddings = types.MethodType( + sub_mod.get_input_embeddings.__func__, sub_gm + ) + + # retrieve+replicate expected submodule hierarchy for where the embedding module is located + for embed_name, subsubmod in sub_mod.named_modules(): + if subsubmod is embed_tokens: + break + else: + raise RuntimeError( + "Could not find embedding module in model. Expected embedding module to be a " + "submodule of the text submodule." + ) + sub_gm.set_submodule(embed_name, embed_tokens) + + # add a dummy node to the graph for making the embedding module impure --> impure nodes + # won't be deleted from the graph during cleanup and this way we ensure that the embedding + # module is not deleted from the GraphModule either. + # TODO (lucaslie): is there a better way to make the embedding module "sticky"? + n_embed_tokens = sub_gm.graph.get_attr(f"{embed_name}.weight") + sub_gm.graph.call_function( + torch._assert, args=(n_embed_tokens, "Avoid embedding getting deleted from graph.") + ) + + def _init_dynamic_shape_lookup(self) -> Dict[str, DynamicShape]: + batch_size_dyn = Dim.DYNAMIC + seq_len_dyn = Dim.DYNAMIC + return { + "input_ids": {0: batch_size_dyn, 1: seq_len_dyn}, + "inputs_embeds": {0: batch_size_dyn, 1: seq_len_dyn}, + "position_ids": {0: batch_size_dyn, 1: seq_len_dyn}, + } + + @classmethod + def from_autoinferred(cls, model: nn.Module) -> "TextModelExportInfo": + """Create an export configuration from the model by auto-inferring the text submodule. + + model: + The full model (AutoModelForImageTextToText) + + Returns: + An export configuration for the text submodule with the right submodule key. + + The text submodule is being auto-discovered by looking at the first submodule that contains + the ``text_config`` instead of the full config object. + """ + # retrieve expected text_config class + text_config_cls = type(model.config.text_config) + + # heuristic to identify the text submodule + submodule_key = None + for name, submodule in model.named_modules(): + if isinstance(getattr(submodule, "config", None), text_config_cls): + submodule_key = name + break + + if submodule_key is None: + raise ValueError( + "Could not find text submodule in model. Expected text submodule to have a config " + f"object of type {text_config_cls}." + ) + + return cls(submodule_key) + + @ModelFactoryRegistry.register("AutoModelForImageTextToText") class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory): _model_defaults = { @@ -518,8 +603,11 @@ def init_processor(self) -> Optional[Any]: return None return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) - def get_example_inputs(self) -> Dict[str, torch.Tensor]: - """Return a dictionary of example inputs for the model.""" + # NOTE: for now we only export text_model - hence using the default example_inputs is + # sufficient. Leaving the logic below for future reference as a special function called + # `get_example_inputs_with_images`. It's also used in unit tests at the moment. + def get_example_inputs_with_images(self) -> Dict[str, torch.Tensor]: + """Return a dictionary of example inputs for the model with images.""" def _prep_seq(text, img1, img2): return [ @@ -549,7 +637,7 @@ def _prep_seq(text, img1, img2): ), ] - processor = AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) + processor = self.init_processor() inputs = processor.apply_chat_template( batch_messages, add_generation_prompt=True, @@ -570,30 +658,12 @@ def _prep_seq(text, img1, img2): # values still need to be returned by `get_example_inputs`. return {**inputs} - def get_extra_inputs(self) -> Dict[str, Tuple[torch.Tensor, Optional[DynamicShapeCallback]]]: - """Return a dictionary of extra inputs for the model. - - Returns: - A dictionary of extra inputs for the model where the key corresponds to the argument - name and the value corresponds to a tuple of (example_input, dynamic_shape_callback). - The dynamic shape callback is a function that returns the dynamic shape of the extra - input. Simply set to `None` if the extra input is not dynamic. - """ - - def _get_dynamic_shape(): - return { - # TODO (lucaslie): how to set default values for dynamic shapes? - 0: Dim("img_batch_size", max=10), - 2: Dim("img_height", min=32, max=2048), - 3: Dim("img_width", min=32, max=2048), - } - - none_pixel_values = torch.zeros(0, 3, 336, 336) - return {"pixel_values": (none_pixel_values, _get_dynamic_shape)} - @property def _example_image_dims(self) -> Tuple[int, int]: # Some specializations (children) of this class may override this if their models have # assumptions on the image dimensions. For example, they may have a lower bound due to # the patch size they use. - return (16, 16) + return (64, 64) + + def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: + return [TextModelExportInfo.from_autoinferred(model)] diff --git a/tensorrt_llm/_torch/auto_deploy/models/mistral3.py b/tensorrt_llm/_torch/auto_deploy/models/mistral3.py deleted file mode 100644 index defe2cae52f..00000000000 --- a/tensorrt_llm/_torch/auto_deploy/models/mistral3.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Auto-deploy model factory for Mistral3 models.""" - -from typing import Dict, Tuple - -import torch - -from tensorrt_llm._torch.auto_deploy.custom_ops import attention_interface -from tensorrt_llm._torch.auto_deploy.models import factory, hf - - -@factory.ModelFactoryRegistry.register("Mistral3VLM") -class Mistral3VLM(hf.AutoModelForImageTextToTextFactory): - def get_extra_inputs( - self, - ) -> Dict[str, Tuple[torch.Tensor, attention_interface.DynamicShapeCallback]]: - """Return a dictionary of extra inputs for the model. - - Returns: - A dictionary of extra inputs for the model where the key corresponds to the argument - name and the value corresponds to a tuple of (example_input, dynamic_shape_callback). - The dynamic shape callback is a function that returns the dynamic shape of the extra - input. - """ - extra_inputs = super().get_extra_inputs() - # Reuse the same dynamic batch dimension for `image_sizes`. - batch_dim = extra_inputs["pixel_values"][1]()[0] - extra_inputs["image_sizes"] = (torch.zeros(0, 2, dtype=torch.long), lambda: {0: batch_dim}) - - return extra_inputs - - @property - def _example_image_dims(self) -> Tuple[int, int]: - # The pixtral processor requires a minimum image size, which is larger than the default (16, 16) - # in the parent class. - # TODO: figure this out on the model config somehow (patch size value, etc.). - return (64, 64) diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py index 2d2785bc296..f6cfb5a6365 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/llama4.py @@ -1,4 +1,8 @@ -"""A patch to handle vision branch in Llama4ForConditionalGeneration.""" +"""A patch to handle vision branch in Llama4ForConditionalGeneration. + +NOTE: most patches are not used at the moment since only text submodule is exported. Keeping it here +for future reference in case we decide to also export the image model. +""" from typing import List, Optional, Tuple, Union @@ -7,7 +11,7 @@ from transformers import Llama4ForConditionalGeneration from transformers.models.llama4.modeling_llama4 import Llama4CausalLMOutputWithPast, Llama4TextMoe -from ...export.interface import BaseExportPatch, ExportPatchRegistry +from ...export.interface import BaseExportPatch, DisabledBaseExportPatch, ExportPatchRegistry # Copy from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama4.py#L1651 @@ -168,8 +172,9 @@ def _no_vision_branch(inputs_embeds, pixel_values, input_ids): ) +# NOTE: registered as patch that is disabled by default since it is not used at the moment @ExportPatchRegistry.register("hf_llama4_vision") -class Llama4VisionPatch(BaseExportPatch): +class Llama4VisionPatch(DisabledBaseExportPatch): """Patch for Llama4ForConditionalGeneration to make it compatible with torch.export. This patch replaces the forward method of Llama4ForConditionalGeneration with @@ -214,7 +219,8 @@ def _moe_forward_with_transpose(self, hidden_states): # TODO: remove this patch once https://github.com/huggingface/transformers/pull/40609 is merged, -# gets released, and TRT-LLM updates to the relevant transformers version +# gets released, and TRT-LLM updates to the relevant transformers version --> this is part of +# 4.56.1 onwards. @ExportPatchRegistry.register("hf_llama4_moe") class Llama4MoEPatch(BaseExportPatch): """Patch for Llama4 MoE routing to fix its current accuracy issue.""" diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/mistral3.py b/tensorrt_llm/_torch/auto_deploy/models/patches/mistral3.py index 5b61b5a4173..092079e284e 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/mistral3.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/mistral3.py @@ -1,4 +1,8 @@ -"""A patch for the Mistral3Model to make it compatible with torch.export.""" +"""A patch for the Mistral3Model to make it compatible with torch.export. + +NOTE: most patches are not used at the moment since only text submodule is exported. Keeping it here +for future reference in case we decide to also export the image model. +""" from typing import List, Optional, Union @@ -8,7 +12,7 @@ Mistral3ModelOutputWithPast, ) -from ...export.interface import BaseExportPatch, ExportPatchRegistry +from ...export.interface import DisabledBaseExportPatch, ExportPatchRegistry def _get_image_features_flat( @@ -160,8 +164,9 @@ def _vision_branch( ) +# NOTE: registered as patch that is disabled by default since it is not used at the moment @ExportPatchRegistry.register("hf_mistral3") -class Mistral3ModelPatch(BaseExportPatch): +class Mistral3ModelPatch(DisabledBaseExportPatch): """Patch for `Mistral3Model`.""" def _apply_patch(self): diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py b/tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py index b2eec1f4c2f..3b4c16c2610 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/pixtral.py @@ -4,6 +4,9 @@ forward pass that are not compatible with `torch.export`. Note that the `register_fake` portion of the ops needs to return the shape (and dtype) of the output tensor(s) without accessing the values in the input tensors, which is where things get tricky, and why so many custom ops / patches are needed. + +NOTE: most patches are not used at the moment since only text submodule is exported. Keeping it here +for future reference in case we decide to also export the image model. """ import torch @@ -14,7 +17,7 @@ position_ids_in_meshgrid, ) -from ...export.interface import BaseExportPatch, ExportPatchRegistry +from ...export.interface import DisabledBaseExportPatch, ExportPatchRegistry # NOTES: # 1. Everything decorated by a `custom_op` must be type annotated. @@ -53,7 +56,7 @@ def _process_patch_embeds_meta( image_sizes: torch.Tensor, patch_size: int, hidden_size: int, - max_widht: int, + max_width: int, ): B = (image_sizes // patch_size).prod(dim=1).sum() device = patch_embeds.device @@ -210,8 +213,9 @@ def _pixtral_rms_norm_forward(self, hidden_states): return self.weight * hidden_states.to(input_dtype) +# NOTE: registered as patch that is disabled by default since it is not used at the moment @ExportPatchRegistry.register("hf_pixtral_vit") -class PixtralVisionModelPatch(BaseExportPatch): +class PixtralVisionModelPatch(DisabledBaseExportPatch): """Patch for `PixtralVisionModel`.""" def _apply_patch(self): @@ -229,3 +233,27 @@ def _revert_patch(self): PixtralVisionModel.forward = self.original_values["PixtralVisionModel.forward"] Mistral3PatchMerger.forward = self.original_values["Mistral3PatchMerger.forward"] PixtralRMSNorm.forward = self.original_values["PixtralRMSNorm.forward"] + + +# NOTE: registered as patch that is disabled by default since it is applied globally... +@ExportPatchRegistry.register("hf_pixtral_dtype") +class PixtralDtypePatch(DisabledBaseExportPatch): + """Patch for `PixtralVisionModel`.""" + + def _apply_patch(self): + """Fix the dtype of pixel_values to align with pixtral weights dtype.""" + + def _forward(mod: PixtralVisionModel, pixel_values: torch.Tensor, *args, **kwargs): + pixel_values = pixel_values.to(mod.patch_conv.weight.dtype) + return self.original_values["forward"](mod, pixel_values, *args, **kwargs) + + self.original_values["forward"] = PixtralVisionModel.forward + PixtralVisionModel.forward = _forward + + def _revert_patch(self): + """Revert the PixtralVisionModel patch.""" + PixtralVisionModel.forward = self.original_values["forward"] + + +# TODO: figure out how to properly register and apply patches like this that are global +ExportPatchRegistry.create_patch("hf_pixtral_dtype", {"enabled": True}).__enter__() diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py b/tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py index 4d28bec3d17..e888dbd0ad3 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/starcoder.py @@ -1,4 +1,4 @@ from transformers.models.starcoder2.configuration_starcoder2 import Starcoder2Config -# Remove this patch after TRT-LLM upgrades to the HF transformers version >= 4.57 +# TODO: Remove this patch after TRT-LLM upgrades to the HF transformers version >= 4.57 Starcoder2Config.base_model_tp_plan["layers.*.mlp.c_proj"] = "rowwise" diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 3ab319f3aa1..dfbacf3b0e1 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -106,10 +106,6 @@ def build_from_config(cls, ad_config: AutoDeployConfig): factory = ad_config.create_factory() - # pass in extra arguments defined by the model factory - for name, (none_input, dynamic_shape_callback) in factory.get_extra_inputs().items(): - seq_info.add_extra_arg(name, none_input, dynamic_shape_callback) - # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__, # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm. diff --git a/tensorrt_llm/_torch/auto_deploy/shim/interface.py b/tensorrt_llm/_torch/auto_deploy/shim/interface.py index 566fca66257..16355a9970a 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/interface.py @@ -4,7 +4,7 @@ import torch.nn as nn from torch._prims_common import DeviceLikeType -from ..custom_ops.attention_interface import DynamicShape, GetCacheCallable, SequenceInfo +from ..custom_ops.attention_interface import GetCacheCallable, SequenceInfo @final @@ -34,18 +34,6 @@ def all_future_arg_names(self) -> List[str]: """Return all the argument names owned by this interface including uninitialized caches.""" return list(self.info.named_args.keys()) + list(self._cache_initializers.keys()) - @property - def dynamic_shapes(self) -> Tuple[DynamicShape, ...]: - """Return the dynamic shapes of all graph arguments owned by this interface (all static).""" - return tuple(self.named_dynamic_shapes.values()) - - @property - def named_dynamic_shapes(self) -> Dict[str, DynamicShape]: - """Return the dynamic shapes of all graph arguments owned by this interface (all static).""" - named_dynamic_shapes = self.info.named_dynamic_shapes - named_dynamic_shapes.update({k: {} for k in self._caches}) - return named_dynamic_shapes - def to(self, *args, **kwargs) -> None: self.info.to(*args, **kwargs) if self._caches: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py index b166c2acd6e..d9c3df18182 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py @@ -78,10 +78,6 @@ def _apply_to_full_model( # build and load the model model = factory.build_and_load_model(self.config.device) - # this ensures that extra_args are passed in as they are received instead of enforcing the - # registered extra_args - cm.info.use_strict_args = False - # we set the standard example sequence WITHOUT extra_args to set them to None so that # only the text portion of the model gets called. cm.info.set_example_sequence() diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index fffb8a25a0e..8f7532bc8d6 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -1,11 +1,15 @@ """A simple wrapper transform to export a model to a graph module.""" -from typing import List, Optional, Tuple, Type +import inspect +from contextlib import contextmanager +from inspect import Parameter, Signature +from typing import Any, Dict, List, Optional, Tuple, Type +import torch import torch.nn as nn from pydantic import Field -from ...export import torch_export_to_gm +from ...export import run_forward_for_capture, torch_export_to_gm from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ..interface import ( @@ -39,6 +43,75 @@ class ExportToGMConfig(TransformConfig): ) +@contextmanager +def capture_forward_kwargs(mod: nn.Module): + """Context manager to capture the keyword arguments of the forward pass of a module.""" + captured_kwargs = {} + + def _capture_kwargs(mod: nn.Module, args, kwargs) -> None: + assert not args, "positional arguments are not supported for capture" + captured_kwargs.clear() + captured_kwargs.update(kwargs) + return None + + try: + # make sure to prepend the hook so that it is called before any other hooks so we can + # capture the original inputs before other potential hooks have a chance to modify them. + # NOTE: this simulates the behavior during torch.export. + handle = mod.register_forward_pre_hook(_capture_kwargs, prepend=True, with_kwargs=True) + yield captured_kwargs + finally: + handle.remove() + + +@contextmanager +def set_exact_signature(mod: nn.Module, kwargs: Dict[str, Any]): + """Temporarily set a signature for the forward function corresponding to provided kwargs. + + Args: + mod: The module to set the signature for. + kwargs: The keyword arguments to set the signature for. + + + Within this context, will have a signature corresponding to only taking the provided kwargs as + keyword-only parameters (+self if it is a method). + """ + is_method = inspect.ismethod(mod.forward) + if is_method: + forward_func = mod.forward.__func__ + elif inspect.isfunction(mod.forward): + forward_func = mod.forward + else: + raise ValueError(f"Unsupported forward function type: {type(mod.forward)}") + + signature_inspected = inspect.signature(forward_func) + + reset_signature = False + if hasattr(forward_func, "__signature__"): + signature_attribute = mod.forward.__signature__ + reset_signature = True + + # construct signature object from kwargs + params_list = [] + if is_method: + # heuristic to identify the self parameter + param_keys = list(signature_inspected.parameters.keys()) + self_key = "self" if "self" in param_keys else param_keys[0] + params_list.append(signature_inspected.parameters[self_key].replace()) + # the rest of the parameters as keyword only + params_list.extend( + [Parameter(k, kind=Parameter.KEYWORD_ONLY, annotation=type(v)) for k, v in kwargs.items()] + ) + forward_func.__signature__ = Signature(parameters=params_list) + try: + yield + finally: + if reset_signature: + forward_func.__signature__ = signature_attribute + else: + del forward_func.__signature__ + + @TransformRegistry.register("export_to_gm") class ExportToGM(BaseTransform): """A simple wrapper transform to export a model to a graph module.""" @@ -59,18 +132,66 @@ def _apply_to_full_model( # set the example sequence cm.info.set_example_sequence(**factory.get_example_inputs()) - # export the model to a graph module - gm = torch_export_to_gm( - mod, - args=(), - kwargs=cm.named_args, - dynamic_shapes=cm.named_dynamic_shapes, - clone=self.config.clone_state_dict, - strict=self.config.strict, - patch_list=self.config.patch_list, + export_infos = factory.get_export_infos(mod) + + # check if any submodules to be exported are children of other submodules that need to be + # exported. We don't allow for this since this may imply that the submodules are not + # independent, which would conflict with graph capture logic, i.e., you cannot graph-capture + # "model" and "model.text_model" for example. However, you can export "model.text_model" and + # "model.vision_model" separately. + sub_keys = [info.submodule_name for info in export_infos] + assert all(not k1.startswith(k2) for k1 in sub_keys for k2 in sub_keys if k1 != k2), ( + f"Cannot export submodules of already exported submodules, {sub_keys=}" ) + for e_info in export_infos: + sub_mod = mod.get_submodule(e_info.submodule_name) + + # start by capturing the kwargs that are passed to the submodule for export + with capture_forward_kwargs(sub_mod) as captured_kwargs: + run_forward_for_capture( + mod, + args=(), + kwargs=cm.named_args, + clone=self.config.clone_state_dict, + patch_list=self.config.patch_list, + ) + + # construct dynamic shapes based on the captured kwargs and the dynamic shape lookup + dynamic_shapes = { + k: e_info.dynamic_shape_lookup[k] if isinstance(v, torch.Tensor) else None + for k, v in captured_kwargs.items() + } + + # export the model to a graph module. We temporarily overwrite the signature of the + # forward function to exactly match the kwargs we pass in. This is to ensure that + # torch.export's graph capture can correctly handle all inputs. Specifically, + # torch.export can get confused by keyword arguments that are not explicitly defined in + # the signature but are captured through generic **kwargs. By overwriting the signature, + # we ensure each argument is explicitly defined in the signature. + with set_exact_signature(sub_mod, captured_kwargs): + sub_gm = torch_export_to_gm( + sub_mod, + args=(), + kwargs=captured_kwargs, + dynamic_shapes=dynamic_shapes, + clone=self.config.clone_state_dict, + strict=self.config.strict, + patch_list=self.config.patch_list, + ) + + # post process the sub graph module + e_info.post_process(sub_mod, sub_gm) + + # set the sub graph module + if e_info.submodule_name == "": + mod = sub_gm + else: + mod.set_submodule(e_info.submodule_name, sub_gm) + # this is a clean graph by definition since it was just exported - info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) + info = TransformInfo( + skipped=False, num_matches=len(sub_keys), is_clean=True, has_valid_shapes=True + ) - return gm, info + return mod, info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py index cea1d80219a..8571ed031c8 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py @@ -311,7 +311,7 @@ def call_post_init(spec): def placeholders_on_meta(mod: nn.Module) -> bool: """ - Return True if every placeholder node in the graph is on the meta device. + Return True if any placeholder node in the graph is on the meta device. """ def _is_meta_tensor(t) -> bool: @@ -335,7 +335,7 @@ def _is_meta_tensor(t) -> bool: if isinstance(val, (list, tuple)): t = next((x for x in val if hasattr(x, "device") or hasattr(x, "is_meta")), None) - if not _is_meta_tensor(t): - return False + if _is_meta_tensor(t): + return True - return True + return False diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index b00ee2bb97a..13e8d4d0040 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -10,7 +10,11 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory +from tensorrt_llm._torch.auto_deploy.models.factory import ( + FullModelExportInfo, + ModelFactory, + SubModuleExportInfo, +) from tensorrt_llm._torch.auto_deploy.transform.library.sharding import ShardingTransformInfo @@ -37,6 +41,9 @@ def get_cache_config(self): def get_quant_config(self): return self.quant_config + def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: + return [FullModelExportInfo()] + class SequenceEmbeddingInfo(SequenceInfo): """A sequence info object for testing that replaces the input_ids with an embedding tensor. diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index 0c1fcc21827..95e6d1161ba 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -439,7 +439,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): }, "mistralai/Mistral-Small-3.1-24B-Instruct-2503": { "llm_models_subdir": "Mistral-Small-3.1-24B-Instruct-2503", - "model_factory": "Mistral3VLM", + "model_factory": "AutoModelForImageTextToText", "compile_backend": "torch-simple", "model_kwargs": { "text_config": {"num_hidden_layers": 2}, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py index b37e0af7f0d..c8888252a56 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py @@ -57,7 +57,7 @@ def test_build_run_llama4_vlm(): pixel_values = inputs["pixel_values"] def _run_with_and_without_image(model, use_patch=True): - with apply_export_patches(patch_list=["hf_llama4_vision"] if use_patch else []): + with apply_export_patches(patch_configs={"hf_llama4_vision": {"enabled": use_patch}}): with torch.inference_mode(): out_no_images = model( input_ids=input_ids, @@ -86,15 +86,15 @@ def _run_with_and_without_image(model, use_patch=True): "position_ids": position_ids, "pixel_values": pixel_values, }, - patch_list=[ - "transformers_sdpa_mask", - "autocast_noop", - "torch_where", - "tensor_meta_device", - "sdpa_kernel_noop", - "sdpa", - "hf_llama4_vision", - ], + patch_configs={ + "transformers_sdpa_mask": {}, + "autocast_noop": {}, + "torch_where": {}, + "tensor_meta_device": {}, + "sdpa_kernel_noop": {}, + "sdpa": {}, + "hf_llama4_vision": {"enabled": True}, + }, ) move_to_device(gm, model.device) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral3.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral3.py deleted file mode 100644 index 5e97ee075b0..00000000000 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral3.py +++ /dev/null @@ -1,15 +0,0 @@ -from tensorrt_llm._torch.auto_deploy.models import mistral3 - - -def test_get_extra_inputs_includes_image_sizes(): - factory = mistral3.Mistral3VLM(model="test-model") - extra_inputs = factory.get_extra_inputs() - - pixel_values = extra_inputs["pixel_values"] - image_sizes = extra_inputs["image_sizes"] - - pixel_values_dynamic_shape = pixel_values[1]() - image_sizes_dynamic_shape = image_sizes[1]() - - # Unfortunately, direct object comparisons do not work. - assert pixel_values_dynamic_shape[0].__dict__ == image_sizes_dynamic_shape[0].__dict__ diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral3_patches.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral3_patches.py index 2b6df250f93..b4f8731733d 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral3_patches.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_mistral3_patches.py @@ -15,7 +15,7 @@ def test_build_run_mistral3_vlm(): factory = llm_args.create_factory() model = factory.build_model("cuda") - inputs = factory.get_example_inputs() + inputs = factory.get_example_inputs_with_images() for key, value in inputs.items(): if isinstance(value, torch.Tensor): dtype = torch.bfloat16 if isinstance(value, torch.FloatTensor) else None @@ -31,7 +31,10 @@ def test_build_run_mistral3_vlm(): def _run_with_and_without_image(model, use_patch=True): with apply_export_patches( - patch_list=["hf_mistral3", "hf_pixtral_vit"] if use_patch else [] + patch_configs={ + "hf_mistral3": {"enabled": use_patch}, + "hf_pixtral_vit": {"enabled": use_patch}, + } ): with torch.inference_mode(): out_no_images = model( @@ -65,16 +68,16 @@ def _run_with_and_without_image(model, use_patch=True): "pixel_values": pixel_values, "image_sizes": image_sizes, }, - patch_list=[ - "transformers_sdpa_mask", - "autocast_noop", - "torch_where", - "tensor_meta_device", - "sdpa_kernel_noop", - "sdpa", - "hf_mistral3", - "hf_pixtral_vit", - ], + patch_configs={ + "transformers_sdpa_mask": {}, + "autocast_noop": {}, + "torch_where": {}, + "tensor_meta_device": {}, + "sdpa_kernel_noop": {}, + "sdpa": {}, + "hf_mistral3": {"enabled": True}, + "hf_pixtral_vit": {"enabled": True}, + }, ) move_to_device(gm, model.device) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py index 07cb79270c8..0af9490f613 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_build_small_single.py @@ -71,14 +71,10 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): attn_backend="torch", compile_backend="torch-simple", ), - # disabled due to https://nvbugspro.nvidia.com/bug/5505835 get_small_model_config_pytest_param( "meta-llama/Llama-4-Scout-17B-16E-Instruct", - pytest_param_kwargs={ - "marks": pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5505835") - }, attn_backend="flashinfer", - compile_backend="torch-simple", + compile_backend="torch-opt", ), get_small_model_config_pytest_param( "deepseek-ai/DeepSeek-V3", @@ -92,11 +88,8 @@ def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): ), get_small_model_config_pytest_param( "mistralai/Mistral-Small-3.1-24B-Instruct-2503", - pytest_param_kwargs={ - "marks": pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5505835") - }, attn_backend="flashinfer", - compile_backend="torch-simple", + compile_backend="torch-cudagraph", ), get_small_model_config_pytest_param( "nvidia/NVIDIA-Nemotron-Nano-12B-v2", diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index f154f00e408..f76bcf3c2ca 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -1,14 +1,19 @@ -from typing import Optional +from typing import List, Optional import pytest import torch +import torch.nn as nn from _graph_test_helpers import SequenceEmbeddingInfo from _model_test_utils import GQA from _torch_test_utils import all_close from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import CacheConfig from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory +from tensorrt_llm._torch.auto_deploy.models.factory import ( + FullModelExportInfo, + ModelFactory, + SubModuleExportInfo, +) from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer @@ -32,6 +37,9 @@ def _load_checkpoint(self, model, device): def get_cache_config(self): return self.cache_config + def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: + return [FullModelExportInfo()] + # Class that uses SDPA directly instead of the regular attention mechanism class GQAWithSdpa(GQA): diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py index 42a7fa5079b..38caf1c5dc9 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py @@ -2,6 +2,8 @@ Tests for basic graph sharding. """ +from typing import List + import pytest import torch import torch.nn as nn @@ -10,7 +12,11 @@ from _torch_test_utils import fp4_compatible, fp8_compatible from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory +from tensorrt_llm._torch.auto_deploy.models.factory import ( + FullModelExportInfo, + ModelFactory, + SubModuleExportInfo, +) from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp8_scale, pack_int4_in_uint8 @@ -31,6 +37,9 @@ def _load_checkpoint(self, model, device): def get_quant_config(self): return self.quant_config + def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: + return [FullModelExportInfo()] + @pytest.mark.parametrize( "quant_config,atol,rtol,num_p_og", From 64fca5d77d5ac6fbfd523e03217a243ae059698e Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Thu, 9 Oct 2025 12:51:25 -0700 Subject: [PATCH 3/4] reviewer feedback Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/models/hf.py | 10 +++++----- .../auto_deploy/transform/library/export_to_gm.py | 9 +++++++-- .../transform/library/kvcache_transformers.py | 2 +- .../_torch/auto_deploy/transformations/_graph.py | 3 +++ 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index acfb117d473..e88328b9796 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -525,12 +525,12 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule): ) def _init_dynamic_shape_lookup(self) -> Dict[str, DynamicShape]: - batch_size_dyn = Dim.DYNAMIC - seq_len_dyn = Dim.DYNAMIC + batch_size_dynamic = Dim.DYNAMIC + seq_len_dynamic = Dim.DYNAMIC return { - "input_ids": {0: batch_size_dyn, 1: seq_len_dyn}, - "inputs_embeds": {0: batch_size_dyn, 1: seq_len_dyn}, - "position_ids": {0: batch_size_dyn, 1: seq_len_dyn}, + "input_ids": {0: batch_size_dynamic, 1: seq_len_dynamic}, + "inputs_embeds": {0: batch_size_dynamic, 1: seq_len_dynamic}, + "position_ids": {0: batch_size_dynamic, 1: seq_len_dynamic}, } @classmethod diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index 8f7532bc8d6..5cb152f3da6 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -88,7 +88,7 @@ def set_exact_signature(mod: nn.Module, kwargs: Dict[str, Any]): reset_signature = False if hasattr(forward_func, "__signature__"): - signature_attribute = mod.forward.__signature__ + signature_attribute = forward_func.__signature__ reset_signature = True # construct signature object from kwargs @@ -139,8 +139,13 @@ def _apply_to_full_model( # independent, which would conflict with graph capture logic, i.e., you cannot graph-capture # "model" and "model.text_model" for example. However, you can export "model.text_model" and # "model.vision_model" separately. + def _is_child(child: str, parent: str) -> bool: + """Check if ``child`` is a child of ``parent``.""" + # covers "a.b.c" is a parent of "a.b" or parent being "", i.e., root (a parent of all!) + return parent == "" or child.startswith(f"{parent}.") + sub_keys = [info.submodule_name for info in export_infos] - assert all(not k1.startswith(k2) for k1 in sub_keys for k2 in sub_keys if k1 != k2), ( + assert all(not _is_child(k1, k2) for k1 in sub_keys for k2 in sub_keys if k1 != k2), ( f"Cannot export submodules of already exported submodules, {sub_keys=}" ) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py index bcc3db30256..233eb3f0d40 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py @@ -113,7 +113,7 @@ def _apply_to_full_model( cm: CachedSequenceInterface, factory: ModelFactory, shared_config: SharedConfig, - ) -> Tuple[GraphModule, TransformInfo]: + ) -> Tuple[nn.Module, TransformInfo]: # Register profiler attn operator ALL_ATTENTION_FUNCTIONS.register("ad_profile_mha", fake_profiler_mha) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py index 8571ed031c8..d7c9090d0af 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py @@ -137,6 +137,9 @@ def move_to_device(mod: nn.Module, device: DeviceLikeType) -> None: # get device device = torch.device(device) + # move the model to the device + mod.to(device) + for _, subgm in reversed(list(named_graphmodules(mod))): # recompile graph to update self generated codes in subgraph _move_single_gm_to_device(subgm, device) From ea6bee7918eba57bd3c672abdb91df279e997380 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Thu, 9 Oct 2025 13:43:24 -0700 Subject: [PATCH 4/4] ensuring small memory footprint for parallel test execution on A30 Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../_utils_test/_model_test_utils.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index 95e6d1161ba..aa1a144cc54 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -435,6 +435,10 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): "llm_models_subdir": "Qwen2.5-3B-Instruct", "model_kwargs": { "num_hidden_layers": 2, + "hidden_size": 64, + "intermediate_size": 128, + "num_attention_heads": 4, + "num_key_value_heads": 2, }, }, "mistralai/Mistral-Small-3.1-24B-Instruct-2503": { @@ -442,8 +446,22 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): "model_factory": "AutoModelForImageTextToText", "compile_backend": "torch-simple", "model_kwargs": { - "text_config": {"num_hidden_layers": 2}, - "vision_config": {"num_hidden_layers": 2}, + "text_config": { + "num_hidden_layers": 2, + "head_dim": 64, + "hidden_size": 64, + "intermediate_size": 128, + "num_attention_heads": 4, + "num_key_value_heads": 2, + }, + "vision_config": { + "num_hidden_layers": 1, + "hidden_size": 64, + "head_dim": 32, + "image_size": 128, + "intermediate_size": 128, + "num_attention_heads": 2, + }, }, }, "ibm-ai-platform/Bamba-9B-v2": {