Skip to content

Commit 0aed9b8

Browse files
committed
kwargs-first pipeline
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent ca82911 commit 0aed9b8

File tree

17 files changed

+48
-237
lines changed

17 files changed

+48
-237
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ transforms:
77
build_model:
88
stage: factory
99
device: meta
10-
use_strict_forward: true
1110
# nothing to clean up
1211
run_graph_cleanup: false
1312
requires_clean_graph: false
@@ -144,6 +143,3 @@ transforms:
144143
############################################################################################
145144
compile_model:
146145
stage: compile
147-
forward_with_cached_sequence_interface:
148-
stage: compile
149-
args_only: true

tensorrt_llm/_torch/auto_deploy/config/transformers.yaml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ transforms:
66
############################################################################################
77
build_and_load_factory_model:
88
stage: factory
9-
use_strict_forward: false
109
############################################################################################
1110
# MOVE ARGUMENTS TO DEVICE
1211
############################################################################################
@@ -24,10 +23,6 @@ transforms:
2423
stage: cache_init
2524
resize_kv_cache:
2625
stage: cache_init
27-
args_only: false # use kwargs instead of args
2826
############################################################################################
2927
# COMPILE MODEL
3028
############################################################################################
31-
forward_with_cached_sequence_interface:
32-
stage: compile
33-
args_only: false # use kwargs instead of args

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
3838
DynamicShapeCallback = Callable[[], DynamicShape]
3939

40+
Constant = Union[int, float, str, None]
41+
4042

4143
@dataclass
4244
class CacheConfig:
@@ -310,12 +312,28 @@ def args(self) -> Tuple[torch.Tensor, ...]:
310312
return tuple(self.named_args.values())
311313

312314
@property
313-
def const_args_for_prepare_metadata(self) -> Tuple:
315+
def args_for_prepare_metadata(self) -> Tuple[str, ...]:
316+
"""Return a tuple of node/tensor arguments for the prepare_metadata op.
317+
318+
The ``prepare_metadata`` interface expects the following arguments:
319+
320+
1. ``args_for_prepare_metadata`` as nodes, i.e., as input-dependent tensors.
321+
2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args
322+
to the corresponding ``prepare_metadata`` node/op.
323+
324+
This interface handles the tensor/node arguments part and can be used by compiler passes
325+
like ``insert_cached_attention`` to extract the constant arguments and add them to the
326+
``prepare_metadata`` node/op.
327+
"""
328+
return tuple(self.named_standard_args.keys())
329+
330+
@property
331+
def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
314332
"""Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op.
315333
316334
The ``prepare_metadata`` interface expects the following arguments:
317335
318-
1. ``named_standard_args`` as nodes,i.e., as input-dependent tensors.
336+
1. ``args_for_prepare_metadata`` as nodes, i.e., as input-dependent tensors.
319337
2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args
320338
to the corresponding ``prepare_metadata`` node/op.
321339
@@ -786,9 +804,6 @@ def add_extra_arg(
786804
self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback
787805

788806

789-
Constant = Union[int, float, str, None]
790-
791-
792807
class MHACallable(Protocol):
793808
def __call__(
794809
self,

tensorrt_llm/_torch/auto_deploy/export/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _clean_up_assertions(gm: fx.GraphModule):
197197

198198
def torch_export_to_gm(
199199
model: nn.Module,
200-
args: Tuple[Any, ...],
200+
args: Optional[Tuple[Any, ...]] = None,
201201
kwargs: Optional[Dict[str, Any]] = None,
202202
clone: bool = False, # clone or don't clone the model state_dict
203203
*,
@@ -233,7 +233,7 @@ def torch_export_to_gm(
233233
# run export with patches and lifted to meta
234234
with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict:
235235
# clean up args, kwargs and move to correct device
236-
args, kwargs = tree_to((args, kwargs or {}), device="meta")
236+
args, kwargs = tree_to((args or (), kwargs or {}), device="meta")
237237

238238
# NOTE (lucaslie): export is VERY sensitive to the location of the inference_mode
239239
# context manager. Do NOT move it unless absolutely necessary.

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -106,35 +106,6 @@ def _build_model(self, device: str) -> nn.Module:
106106
"""Factory-specific model building logic."""
107107
raise NotImplementedError("Subclasses must implement this method.")
108108

109-
def _set_strict_forward(self, model: nn.Module):
110-
"""Set the strict (args-only) forward method for the model.
111-
112-
For some factories, the regular forward is sufficient. For others, this needs to be set.
113-
The strict forward method should precisely define a fixed args-only, tensor-only signature
114-
compatible with the model's forward method AND the export behavior, which requires fixed
115-
tensor-only positional arguments.
116-
117-
The function should overwrite the `model.forward` method.
118-
119-
The overwritten forward should have `input_ids` and `position_ids` as initial positional
120-
arguments as defined by the sequence interface. Hence the signature should be something like
121-
122-
.. code-block:: python
123-
124-
def _strict_forward(
125-
self, input_ids: torch.Tensor, position_ids: torch.Tensor, *extra_args: torch.Tensor
126-
) -> Sequence[torch.Tensor]: ...
127-
128-
where `extra_args` are the extra arguments that are defined by the factory and should also
129-
be defined in the `get_extra_inputs` + `get_example_inputs` methods. The actual
130-
`_strict_forward` method should not use `*args` or `**kwargs` but instead use the defined
131-
extra arguments in the order they are defined.
132-
133-
This is necessary as graph export is going to flatten arguments into a list of tensors and
134-
by using a strict forward convention we simplify the export behavior and subsequent handling
135-
of the arguments in the graph module.
136-
"""
137-
138109
def get_quant_config(self) -> Dict:
139110
"""Returns the quantization config for this model or None if not quantized."""
140111
return {}

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import os
44
import re
5-
import types
65
from abc import abstractmethod
76
from contextlib import contextmanager, nullcontext
87
from typing import Any, Dict, List, Optional, Tuple, Type, Union
@@ -73,19 +72,6 @@ class AutoModelFactory(ModelFactory):
7372
def automodel_cls(self) -> Type[_BaseAutoModelClass]:
7473
"""Get the AutoModel class for calling from_pretrained and from_config."""
7574

76-
@staticmethod
77-
@abstractmethod
78-
def _strict_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor):
79-
"""A strict (args-only) forward method for the model that precisely defines the signature.
80-
81-
The function should contain input_ids and position_ids as positional arguments at a
82-
minimum. Other arguments can be added as needed and must follow the correct order.
83-
"""
84-
85-
def _set_strict_forward(self, model: nn.Module):
86-
"""Set the strict (args-only) forward method for the model."""
87-
model.forward = types.MethodType(self._strict_forward, model)
88-
8975

9076
@ModelFactoryRegistry.register("AutoModelForCausalLM")
9177
class AutoModelForCausalLMFactory(AutoModelFactory):
@@ -132,16 +118,6 @@ def __init__(self, *args, **kwargs):
132118
def automodel_cls(self) -> Type[_BaseAutoModelClass]:
133119
return AutoModelForCausalLM
134120

135-
@staticmethod
136-
def _strict_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor):
137-
"""A strict (args-only) forward pass for the model to functionalize the args.
138-
139-
This follows the standard function signature as expected by factory.py. We do _not_ use the
140-
model.forward method directly to create the patch. Instead we use the type of the model to
141-
get the forward method to keep the patch composable with other forward patches.
142-
"""
143-
return type(model).forward(model, input_ids=input_ids, position_ids=position_ids)
144-
145121
def _recursive_update_config(
146122
self, config: PretrainedConfig, update_dict: Dict[str, Any]
147123
) -> Tuple[PretrainedConfig, Dict[str, Any]]:
@@ -542,28 +518,6 @@ def init_processor(self) -> Optional[Any]:
542518
return None
543519
return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)
544520

545-
# TODO: in theory the signature could be auto-derived but it would probably require some hefty
546-
# meta-programming to progmatically generate the functions and signature from something like the
547-
# example inputs. And even with that we would still need to figure out how to automatically
548-
# infer the dynamic shapes for the extra inputs.
549-
# Alternatively, we could try to directly use the HF forward again but I am not sure whether
550-
# this will trigger some kind of kwarg-handling inside the graph which I would want to avoid.
551-
@staticmethod
552-
def _strict_forward(
553-
model: nn.Module,
554-
input_ids: torch.Tensor,
555-
position_ids: torch.Tensor,
556-
pixel_values: torch.Tensor,
557-
):
558-
"""A strict (args-only) forward pass for the model to functionalize the args.
559-
560-
It adds pixel_values as a positional argument as expected by most
561-
AutoModelForImageTextToText in addition to the required input_ids and position_ids.
562-
"""
563-
return type(model).forward(
564-
model, input_ids=input_ids, position_ids=position_ids, pixel_values=pixel_values
565-
)
566-
567521
def get_example_inputs(self) -> Dict[str, torch.Tensor]:
568522
"""Return a dictionary of example inputs for the model."""
569523

tensorrt_llm/_torch/auto_deploy/models/mistral3.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,27 +28,6 @@ def get_extra_inputs(
2828

2929
return extra_inputs
3030

31-
@staticmethod
32-
def _strict_forward(
33-
model: torch.nn.Module,
34-
input_ids: torch.Tensor,
35-
position_ids: torch.Tensor,
36-
pixel_values: torch.Tensor,
37-
image_sizes: torch.Tensor,
38-
):
39-
"""A strict (args-only) forward pass for the model to functionalize the args.
40-
41-
It adds ``pixel_values`` and ``image_sizes`` as a positional argument as expected by
42-
Mistral3Model in addition to the required ``input_ids`` and ``position_ids``.
43-
"""
44-
return type(model).forward(
45-
model,
46-
input_ids=input_ids,
47-
position_ids=position_ids,
48-
pixel_values=pixel_values,
49-
image_sizes=image_sizes,
50-
)
51-
5231
@property
5332
def _example_image_dims(self) -> Tuple[int, int]:
5433
# The pixtral processor requires a minimum image size, which is larger than the default (16, 16)

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _prepare_inputs(
262262
@nvtx_range("ad_compute_logits")
263263
def _compute_logits(self) -> List[torch.Tensor]:
264264
# run the model
265-
logits: torch.Tensor = self.model(self.cache_seq_interface)[0]
265+
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
266266

267267
# return a list of tensors
268268
return self.cache_seq_interface.info.unnest_sequences(logits)

tensorrt_llm/_torch/auto_deploy/transform/library/build_model.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@ class BuildModelConfig(TransformConfig):
2020
"""Configuration for the build model transform."""
2121

2222
device: str = Field(default="meta", description="The device to build the model on.")
23-
use_strict_forward: bool = Field(
24-
default=True,
25-
description="If True, the forward pass will be patched to use a strict positional-only list"
26-
" of arguments. If False, the default with **kwargs can be used.",
27-
)
2823

2924

3025
@TransformRegistry.register("build_model")
@@ -51,9 +46,6 @@ def _apply(
5146
# build the model
5247
model = factory.build_model(self.config.device)
5348

54-
assert self.config.use_strict_forward, "Only strict forward is supported."
55-
factory._set_strict_forward(model)
56-
5749
# as wrapper to satisfy the interface we will register the model as a submodule
5850
gm.add_module("factory_model", model)
5951

@@ -89,8 +81,6 @@ def _apply(
8981
# build and load the model
9082
model = factory.build_and_load_model(self.config.device)
9183

92-
assert not self.config.use_strict_forward, "Only regular forward is supported."
93-
9484
# as wrapper to satisfy the interface we will register the model as a submodule
9585
gm.add_module("factory_model", model)
9686

tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def _apply(
5151
compiler_cls = CompileBackendRegistry.get(self.config.compile_backend)
5252
egm_compiled = compiler_cls(
5353
gm,
54-
args=cm.args,
54+
args=(),
55+
kwargs=cm.named_args,
5556
max_batch_size=cm.info.max_batch_size,
5657
**self.config.model_dump(),
5758
).compile()

0 commit comments

Comments
 (0)