From 0aed9b8480dae9556d37f14ca4083f7cbe6280b0 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Fri, 26 Sep 2025 09:09:49 -0700 Subject: [PATCH] kwargs-first pipeline Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../_torch/auto_deploy/config/default.yaml | 4 - .../auto_deploy/config/transformers.yaml | 5 -- .../custom_ops/attention_interface.py | 25 ++++-- .../_torch/auto_deploy/export/export.py | 4 +- .../_torch/auto_deploy/models/factory.py | 29 ------- tensorrt_llm/_torch/auto_deploy/models/hf.py | 46 ----------- .../_torch/auto_deploy/models/mistral3.py | 21 ----- .../_torch/auto_deploy/shim/ad_executor.py | 2 +- .../transform/library/build_model.py | 10 --- .../transform/library/compile_model.py | 3 +- .../transform/library/export_to_gm.py | 5 +- .../forward_with_cached_sequence_interface.py | 76 ------------------- .../auto_deploy/transform/library/kvcache.py | 21 +---- .../auto_deploy/transformations/_graph.py | 2 +- .../singlegpu/models/test_llama4_vlm_patch.py | 19 +++-- .../unit/singlegpu/shim/test_engine.py | 9 +-- .../transformations/library/test_kv_cache.py | 4 +- 17 files changed, 48 insertions(+), 237 deletions(-) delete mode 100644 tensorrt_llm/_torch/auto_deploy/transform/library/forward_with_cached_sequence_interface.py diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 764e6724753..37f314c0094 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -7,7 +7,6 @@ transforms: build_model: stage: factory device: meta - use_strict_forward: true # nothing to clean up run_graph_cleanup: false requires_clean_graph: false @@ -144,6 +143,3 @@ transforms: ############################################################################################ compile_model: stage: compile - forward_with_cached_sequence_interface: - stage: compile - args_only: true diff --git a/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml b/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml index d7b532fc091..5b32f81672d 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/transformers.yaml @@ -6,7 +6,6 @@ transforms: ############################################################################################ build_and_load_factory_model: stage: factory - use_strict_forward: false ############################################################################################ # MOVE ARGUMENTS TO DEVICE ############################################################################################ @@ -24,10 +23,6 @@ transforms: stage: cache_init resize_kv_cache: stage: cache_init - args_only: false # use kwargs instead of args ############################################################################################ # COMPILE MODEL ############################################################################################ - forward_with_cached_sequence_interface: - stage: compile - args_only: false # use kwargs instead of args diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index d4f5b6dd1b3..d61bb8854e8 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -37,6 +37,8 @@ DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension DynamicShapeCallback = Callable[[], DynamicShape] +Constant = Union[int, float, str, None] + @dataclass class CacheConfig: @@ -310,12 +312,28 @@ def args(self) -> Tuple[torch.Tensor, ...]: return tuple(self.named_args.values()) @property - def const_args_for_prepare_metadata(self) -> Tuple: + def args_for_prepare_metadata(self) -> Tuple[str, ...]: + """Return a tuple of node/tensor arguments for the prepare_metadata op. + + The ``prepare_metadata`` interface expects the following arguments: + + 1. ``args_for_prepare_metadata`` as nodes, i.e., as input-dependent tensors. + 2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args + to the corresponding ``prepare_metadata`` node/op. + + This interface handles the tensor/node arguments part and can be used by compiler passes + like ``insert_cached_attention`` to extract the constant arguments and add them to the + ``prepare_metadata`` node/op. + """ + return tuple(self.named_standard_args.keys()) + + @property + def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]: """Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op. The ``prepare_metadata`` interface expects the following arguments: - 1. ``named_standard_args`` as nodes,i.e., as input-dependent tensors. + 1. ``args_for_prepare_metadata`` as nodes, i.e., as input-dependent tensors. 2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args to the corresponding ``prepare_metadata`` node/op. @@ -786,9 +804,6 @@ def add_extra_arg( self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback -Constant = Union[int, float, str, None] - - class MHACallable(Protocol): def __call__( self, diff --git a/tensorrt_llm/_torch/auto_deploy/export/export.py b/tensorrt_llm/_torch/auto_deploy/export/export.py index 9d9af3cf9e8..c239c388a27 100644 --- a/tensorrt_llm/_torch/auto_deploy/export/export.py +++ b/tensorrt_llm/_torch/auto_deploy/export/export.py @@ -197,7 +197,7 @@ def _clean_up_assertions(gm: fx.GraphModule): def torch_export_to_gm( model: nn.Module, - args: Tuple[Any, ...], + args: Optional[Tuple[Any, ...]] = None, kwargs: Optional[Dict[str, Any]] = None, clone: bool = False, # clone or don't clone the model state_dict *, @@ -233,7 +233,7 @@ def torch_export_to_gm( # run export with patches and lifted to meta with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict: # clean up args, kwargs and move to correct device - args, kwargs = tree_to((args, kwargs or {}), device="meta") + args, kwargs = tree_to((args or (), kwargs or {}), device="meta") # NOTE (lucaslie): export is VERY sensitive to the location of the inference_mode # context manager. Do NOT move it unless absolutely necessary. diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index ad1d119842f..f220a49260a 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -106,35 +106,6 @@ def _build_model(self, device: str) -> nn.Module: """Factory-specific model building logic.""" raise NotImplementedError("Subclasses must implement this method.") - def _set_strict_forward(self, model: nn.Module): - """Set the strict (args-only) forward method for the model. - - For some factories, the regular forward is sufficient. For others, this needs to be set. - The strict forward method should precisely define a fixed args-only, tensor-only signature - compatible with the model's forward method AND the export behavior, which requires fixed - tensor-only positional arguments. - - The function should overwrite the `model.forward` method. - - The overwritten forward should have `input_ids` and `position_ids` as initial positional - arguments as defined by the sequence interface. Hence the signature should be something like - - .. code-block:: python - - def _strict_forward( - self, input_ids: torch.Tensor, position_ids: torch.Tensor, *extra_args: torch.Tensor - ) -> Sequence[torch.Tensor]: ... - - where `extra_args` are the extra arguments that are defined by the factory and should also - be defined in the `get_extra_inputs` + `get_example_inputs` methods. The actual - `_strict_forward` method should not use `*args` or `**kwargs` but instead use the defined - extra arguments in the order they are defined. - - This is necessary as graph export is going to flatten arguments into a list of tensors and - by using a strict forward convention we simplify the export behavior and subsequent handling - of the arguments in the graph module. - """ - def get_quant_config(self) -> Dict: """Returns the quantization config for this model or None if not quantized.""" return {} diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index c0a17dd81ca..b2177fe9209 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -2,7 +2,6 @@ import os import re -import types from abc import abstractmethod from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -73,19 +72,6 @@ class AutoModelFactory(ModelFactory): def automodel_cls(self) -> Type[_BaseAutoModelClass]: """Get the AutoModel class for calling from_pretrained and from_config.""" - @staticmethod - @abstractmethod - def _strict_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor): - """A strict (args-only) forward method for the model that precisely defines the signature. - - The function should contain input_ids and position_ids as positional arguments at a - minimum. Other arguments can be added as needed and must follow the correct order. - """ - - def _set_strict_forward(self, model: nn.Module): - """Set the strict (args-only) forward method for the model.""" - model.forward = types.MethodType(self._strict_forward, model) - @ModelFactoryRegistry.register("AutoModelForCausalLM") class AutoModelForCausalLMFactory(AutoModelFactory): @@ -132,16 +118,6 @@ def __init__(self, *args, **kwargs): def automodel_cls(self) -> Type[_BaseAutoModelClass]: return AutoModelForCausalLM - @staticmethod - def _strict_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor): - """A strict (args-only) forward pass for the model to functionalize the args. - - This follows the standard function signature as expected by factory.py. We do _not_ use the - model.forward method directly to create the patch. Instead we use the type of the model to - get the forward method to keep the patch composable with other forward patches. - """ - return type(model).forward(model, input_ids=input_ids, position_ids=position_ids) - def _recursive_update_config( self, config: PretrainedConfig, update_dict: Dict[str, Any] ) -> Tuple[PretrainedConfig, Dict[str, Any]]: @@ -542,28 +518,6 @@ def init_processor(self) -> Optional[Any]: return None return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs) - # TODO: in theory the signature could be auto-derived but it would probably require some hefty - # meta-programming to progmatically generate the functions and signature from something like the - # example inputs. And even with that we would still need to figure out how to automatically - # infer the dynamic shapes for the extra inputs. - # Alternatively, we could try to directly use the HF forward again but I am not sure whether - # this will trigger some kind of kwarg-handling inside the graph which I would want to avoid. - @staticmethod - def _strict_forward( - model: nn.Module, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - pixel_values: torch.Tensor, - ): - """A strict (args-only) forward pass for the model to functionalize the args. - - It adds pixel_values as a positional argument as expected by most - AutoModelForImageTextToText in addition to the required input_ids and position_ids. - """ - return type(model).forward( - model, input_ids=input_ids, position_ids=position_ids, pixel_values=pixel_values - ) - def get_example_inputs(self) -> Dict[str, torch.Tensor]: """Return a dictionary of example inputs for the model.""" diff --git a/tensorrt_llm/_torch/auto_deploy/models/mistral3.py b/tensorrt_llm/_torch/auto_deploy/models/mistral3.py index 36dd2a1408a..defe2cae52f 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/mistral3.py +++ b/tensorrt_llm/_torch/auto_deploy/models/mistral3.py @@ -28,27 +28,6 @@ def get_extra_inputs( return extra_inputs - @staticmethod - def _strict_forward( - model: torch.nn.Module, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - pixel_values: torch.Tensor, - image_sizes: torch.Tensor, - ): - """A strict (args-only) forward pass for the model to functionalize the args. - - It adds ``pixel_values`` and ``image_sizes`` as a positional argument as expected by - Mistral3Model in addition to the required ``input_ids`` and ``position_ids``. - """ - return type(model).forward( - model, - input_ids=input_ids, - position_ids=position_ids, - pixel_values=pixel_values, - image_sizes=image_sizes, - ) - @property def _example_image_dims(self) -> Tuple[int, int]: # The pixtral processor requires a minimum image size, which is larger than the default (16, 16) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index e0dd2a4857e..3ab319f3aa1 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -262,7 +262,7 @@ def _prepare_inputs( @nvtx_range("ad_compute_logits") def _compute_logits(self) -> List[torch.Tensor]: # run the model - logits: torch.Tensor = self.model(self.cache_seq_interface)[0] + logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0] # return a list of tensors return self.cache_seq_interface.info.unnest_sequences(logits) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py index 928e4bb7fb2..96a81dbfec7 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py @@ -20,11 +20,6 @@ class BuildModelConfig(TransformConfig): """Configuration for the build model transform.""" device: str = Field(default="meta", description="The device to build the model on.") - use_strict_forward: bool = Field( - default=True, - description="If True, the forward pass will be patched to use a strict positional-only list" - " of arguments. If False, the default with **kwargs can be used.", - ) @TransformRegistry.register("build_model") @@ -51,9 +46,6 @@ def _apply( # build the model model = factory.build_model(self.config.device) - assert self.config.use_strict_forward, "Only strict forward is supported." - factory._set_strict_forward(model) - # as wrapper to satisfy the interface we will register the model as a submodule gm.add_module("factory_model", model) @@ -89,8 +81,6 @@ def _apply( # build and load the model model = factory.build_and_load_model(self.config.device) - assert not self.config.use_strict_forward, "Only regular forward is supported." - # as wrapper to satisfy the interface we will register the model as a submodule gm.add_module("factory_model", model) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py index a5f41433195..a77fbd3ac85 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py @@ -51,7 +51,8 @@ def _apply( compiler_cls = CompileBackendRegistry.get(self.config.compile_backend) egm_compiled = compiler_cls( gm, - args=cm.args, + args=(), + kwargs=cm.named_args, max_batch_size=cm.info.max_batch_size, **self.config.model_dump(), ).compile() diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py index 3d2d587adb4..e92d842958f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py @@ -68,8 +68,9 @@ def _apply( # export the model to a graph module gm = torch_export_to_gm( model, - args=cm.args, - dynamic_shapes=cm.dynamic_shapes, + args=(), + kwargs=cm.named_args, + dynamic_shapes=cm.named_dynamic_shapes, clone=self.config.clone_state_dict, strict=self.config.strict, patch_list=self.config.patch_list, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/forward_with_cached_sequence_interface.py b/tensorrt_llm/_torch/auto_deploy/transform/library/forward_with_cached_sequence_interface.py deleted file mode 100644 index 3daa6492a18..00000000000 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/forward_with_cached_sequence_interface.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Transform to wrap model forward to accept a CachedSequenceInterface as single argument. - -This creates a small ``nn.Module`` wrapper so callers can run ``model(cache_seq_interface)``. -Under the hood, it invokes the original model with ``*cache_seq_interface.args`` (same as -``ad_executor.py`` does today). This enables future callers to simply pass the interface -and let the wrapper handle the argument unpacking. -""" - -from typing import Tuple, Type - -import torch.nn as nn -from pydantic import Field -from torch.fx import GraphModule - -from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory -from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface -from tensorrt_llm._torch.auto_deploy.transform.interface import ( - BaseTransform, - SharedConfig, - TransformConfig, - TransformInfo, - TransformRegistry, -) - - -class ForwardWithCSIConfig(TransformConfig): - """Configuration for the forward-with-CSI wrapper transform.""" - - args_only: bool = Field( - default=True, - description=( - "If True, the wrapper will call the underlying model with *cm.args. " - "If False, it will pass the cm object directly as the first argument." - ), - ) - - -class _ForwardWithCSIWrapper(nn.Module): - """A lightweight wrapper to forward with a CachedSequenceInterface argument.""" - - def __init__(self, model: nn.Module, args_only: bool = True) -> None: - super().__init__() - self.model = model - self.args_only = args_only - - def forward(self, cm: CachedSequenceInterface): # type: ignore[override] - if self.args_only: - return self.model(*cm.args) - # Fallback path with kwargs - return self.model(**cm.named_args) - - -@TransformRegistry.register("forward_with_cached_sequence_interface") -class ForwardWithCachedSequenceInterface(BaseTransform): - """Wrap the model so forward accepts a single ``CachedSequenceInterface`` argument.""" - - config: ForwardWithCSIConfig - - @classmethod - def get_config_class(cls) -> Type[TransformConfig]: - return ForwardWithCSIConfig - - def _apply( - self, - gm: GraphModule, - cm: CachedSequenceInterface, - factory: ModelFactory, - shared_config: SharedConfig, - ) -> Tuple[nn.Module, TransformInfo]: - # ``gm`` is an nn.Module (GraphModule or compiled module). Wrap it so callers can do - # ``wrapped(cm)`` and internally we expand to ``gm(*cm.args)``. - wrapped = _ForwardWithCSIWrapper(gm, args_only=self.config.args_only) - - # No graph mutation; simply return wrapped module. Mark as clean with valid shapes preserved. - info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True) - return wrapped, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 83fb2ff572b..be8bf4c7b3e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -156,12 +156,10 @@ def _apply( if cm.info.is_paged: assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op." - # filtered and sorted for SequenceInfo arguments + constants (input_ids, position_ids, etc.) - m_arg_keys = list(cm.info.named_standard_args.keys()) - m_const_args = cm.info.const_args_for_prepare_metadata - # insert metadata computation and extract each argument as a node - metadata_nodes = self._process_get_metadata(gm, m_arg_keys, m_const_args) + metadata_nodes = self._process_get_metadata( + gm, cm.info.args_for_prepare_metadata, cm.info.const_args_for_prepare_metadata + ) buffer_in_lookup: Dict[str, Node] = {} @@ -225,10 +223,6 @@ class ResizeKVCacheConfig(TransformConfig): free_mem_ratio: float = Field( description="The fraction of available memory to occupy.", default=0.8 ) - args_only: bool = Field( - description="Use ``*cm.args`` (default) or use ``**cm.named_args`` for the forward pass.", - default=True, - ) @TransformRegistry.register("resize_kv_cache") @@ -244,13 +238,6 @@ class ResizeKVCache(BaseTransform): def get_config_class(cls) -> Type[TransformConfig]: return ResizeKVCacheConfig - def _run_forward(self, gm: GraphModule, cm: CachedSequenceInterface): - """Run a forward pass to get the memory usage.""" - if self.config.args_only: - gm(*cm.args) - else: - gm(**cm.named_args) - def _apply( self, gm: GraphModule, @@ -287,7 +274,7 @@ def _get_mem_info_in_mb(): free_mem_pre, _ = _get_mem_info_in_mb() self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}") - self._run_forward(gm, cm) + gm(**cm.named_args) free_mem_post, _ = _get_mem_info_in_mb() self._log_info(f"Free memory after forward pass (MB): {free_mem_post}") diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py index f4f87fdb69e..bca2b41f3b7 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py @@ -246,7 +246,7 @@ def run_shape_prop( def add_graph_input( gm: GraphModule, name: str, - add_kwargs: bool = False, + add_kwargs: bool = True, val: Union[Optional[torch.Tensor], _NoValType] = _NO_VAL, dynamic_shape=None, ) -> Node: diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py index 28b9be5caa6..b37e0af7f0d 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_llama4_vlm_patch.py @@ -19,7 +19,6 @@ def test_build_run_llama4_vlm(): factory = llm_args.create_factory() model = factory.build_model("cuda") - factory._set_strict_forward(model) processor = factory.init_processor() img1 = Image.new("RGB", (16, 16), color=(128, 128, 128)) @@ -61,14 +60,14 @@ def _run_with_and_without_image(model, use_patch=True): with apply_export_patches(patch_list=["hf_llama4_vision"] if use_patch else []): with torch.inference_mode(): out_no_images = model( - input_ids, - position_ids, - torch.zeros_like(pixel_values) if use_patch else None, + input_ids=input_ids, + position_ids=position_ids, + pixel_values=torch.zeros_like(pixel_values) if use_patch else None, ) out_with_images = model( - input_ids, - position_ids, - pixel_values, + input_ids=input_ids, + position_ids=position_ids, + pixel_values=pixel_values, ) return {"no_images": out_no_images.logits, "with_images": out_with_images.logits} @@ -82,7 +81,11 @@ def _run_with_and_without_image(model, use_patch=True): # Export to GM gm = torch_export_to_gm( model, - args=(input_ids, position_ids, pixel_values), + kwargs={ + "input_ids": input_ids, + "position_ids": position_ids, + "pixel_values": pixel_values, + }, patch_list=[ "transformers_sdpa_mask", "autocast_noop", diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py index 6e38c1d4618..a97c60d7337 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/shim/test_engine.py @@ -1,4 +1,4 @@ -from typing import Type, Union +from typing import Optional, Type import pytest import torch @@ -8,7 +8,6 @@ from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine from tensorrt_llm._torch.auto_deploy.shim.demollm import DemoEngine -from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface class TransformerLikeModelwithFakeCachePool(nn.Module): @@ -23,11 +22,7 @@ def __init__(self, vocab_size, embed_dim, hidden_dim): ) self.output_projection = nn.Linear(embed_dim, vocab_size) - def forward(self, cm_or_input_ids: Union[CachedSequenceInterface, torch.Tensor]): - if isinstance(cm_or_input_ids, CachedSequenceInterface): - input_ids = cm_or_input_ids.args[0] - else: - input_ids = cm_or_input_ids + def forward(self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None): embeddings = self.embedding(input_ids) hidden_states = self.mlp(embeddings) logits = self.output_projection(hidden_states) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py index b4159bd4828..d7e715398bd 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_kv_cache.py @@ -188,7 +188,7 @@ def _call_and_unnest(x, input_pos): cm.info.nest_sequences(x, input_pos=input_pos) # Use the cm.args as is - it already contains the correct position_ids - y = gm(*cm.args) + y = gm(**cm.named_args) # Unnest the output sequences return torch.stack(cm.info.unnest_sequences(y)) @@ -217,5 +217,5 @@ def _call_and_unnest(x, input_pos): assert all_close(y_model, y_with_cache, atol=atol, rtol=rtol) # Test 4: Exportability of the transformed model - exported_gm = torch_export_to_gm(gm, args=cm.args) + exported_gm = torch_export_to_gm(gm, args=(), kwargs=cm.named_args) assert exported_gm is not None