Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ transforms:
build_model:
stage: factory
device: meta
use_strict_forward: true
# nothing to clean up
run_graph_cleanup: false
requires_clean_graph: false
Expand Down Expand Up @@ -144,6 +143,3 @@ transforms:
############################################################################################
compile_model:
stage: compile
forward_with_cached_sequence_interface:
stage: compile
args_only: true
5 changes: 0 additions & 5 deletions tensorrt_llm/_torch/auto_deploy/config/transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ transforms:
############################################################################################
build_and_load_factory_model:
stage: factory
use_strict_forward: false
############################################################################################
# MOVE ARGUMENTS TO DEVICE
############################################################################################
Expand All @@ -24,10 +23,6 @@ transforms:
stage: cache_init
resize_kv_cache:
stage: cache_init
args_only: false # use kwargs instead of args
############################################################################################
# COMPILE MODEL
############################################################################################
forward_with_cached_sequence_interface:
stage: compile
args_only: false # use kwargs instead of args
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
DynamicShapeCallback = Callable[[], DynamicShape]

Constant = Union[int, float, str, None]


@dataclass
class CacheConfig:
Expand Down Expand Up @@ -310,12 +312,28 @@ def args(self) -> Tuple[torch.Tensor, ...]:
return tuple(self.named_args.values())

@property
def const_args_for_prepare_metadata(self) -> Tuple:
def args_for_prepare_metadata(self) -> Tuple[str, ...]:
"""Return a tuple of node/tensor arguments for the prepare_metadata op.

The ``prepare_metadata`` interface expects the following arguments:

1. ``args_for_prepare_metadata`` as nodes, i.e., as input-dependent tensors.
2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args
to the corresponding ``prepare_metadata`` node/op.

This interface handles the tensor/node arguments part and can be used by compiler passes
like ``insert_cached_attention`` to extract the constant arguments and add them to the
``prepare_metadata`` node/op.
"""
return tuple(self.named_standard_args.keys())

@property
def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
"""Return a tuple of extra (const, non-tensor) arguments for the prepare_metadata op.

The ``prepare_metadata`` interface expects the following arguments:

1. ``named_standard_args`` as nodes,i.e., as input-dependent tensors.
1. ``args_for_prepare_metadata`` as nodes, i.e., as input-dependent tensors.
2. ``const_args_for_prepare_metadata`` as constants that can directly by passed in as args
to the corresponding ``prepare_metadata`` node/op.

Expand Down Expand Up @@ -786,9 +804,6 @@ def add_extra_arg(
self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback


Constant = Union[int, float, str, None]


class MHACallable(Protocol):
def __call__(
self,
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _clean_up_assertions(gm: fx.GraphModule):

def torch_export_to_gm(
model: nn.Module,
args: Tuple[Any, ...],
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
clone: bool = False, # clone or don't clone the model state_dict
*,
Expand Down Expand Up @@ -233,7 +233,7 @@ def torch_export_to_gm(
# run export with patches and lifted to meta
with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict:
# clean up args, kwargs and move to correct device
args, kwargs = tree_to((args, kwargs or {}), device="meta")
args, kwargs = tree_to((args or (), kwargs or {}), device="meta")

# NOTE (lucaslie): export is VERY sensitive to the location of the inference_mode
# context manager. Do NOT move it unless absolutely necessary.
Expand Down
29 changes: 0 additions & 29 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,35 +106,6 @@ def _build_model(self, device: str) -> nn.Module:
"""Factory-specific model building logic."""
raise NotImplementedError("Subclasses must implement this method.")

def _set_strict_forward(self, model: nn.Module):
"""Set the strict (args-only) forward method for the model.

For some factories, the regular forward is sufficient. For others, this needs to be set.
The strict forward method should precisely define a fixed args-only, tensor-only signature
compatible with the model's forward method AND the export behavior, which requires fixed
tensor-only positional arguments.

The function should overwrite the `model.forward` method.

The overwritten forward should have `input_ids` and `position_ids` as initial positional
arguments as defined by the sequence interface. Hence the signature should be something like

.. code-block:: python

def _strict_forward(
self, input_ids: torch.Tensor, position_ids: torch.Tensor, *extra_args: torch.Tensor
) -> Sequence[torch.Tensor]: ...

where `extra_args` are the extra arguments that are defined by the factory and should also
be defined in the `get_extra_inputs` + `get_example_inputs` methods. The actual
`_strict_forward` method should not use `*args` or `**kwargs` but instead use the defined
extra arguments in the order they are defined.

This is necessary as graph export is going to flatten arguments into a list of tensors and
by using a strict forward convention we simplify the export behavior and subsequent handling
of the arguments in the graph module.
"""

def get_quant_config(self) -> Dict:
"""Returns the quantization config for this model or None if not quantized."""
return {}
Expand Down
46 changes: 0 additions & 46 deletions tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import re
import types
from abc import abstractmethod
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -73,19 +72,6 @@ class AutoModelFactory(ModelFactory):
def automodel_cls(self) -> Type[_BaseAutoModelClass]:
"""Get the AutoModel class for calling from_pretrained and from_config."""

@staticmethod
@abstractmethod
def _strict_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor):
"""A strict (args-only) forward method for the model that precisely defines the signature.

The function should contain input_ids and position_ids as positional arguments at a
minimum. Other arguments can be added as needed and must follow the correct order.
"""

def _set_strict_forward(self, model: nn.Module):
"""Set the strict (args-only) forward method for the model."""
model.forward = types.MethodType(self._strict_forward, model)


@ModelFactoryRegistry.register("AutoModelForCausalLM")
class AutoModelForCausalLMFactory(AutoModelFactory):
Expand Down Expand Up @@ -132,16 +118,6 @@ def __init__(self, *args, **kwargs):
def automodel_cls(self) -> Type[_BaseAutoModelClass]:
return AutoModelForCausalLM

@staticmethod
def _strict_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: torch.Tensor):
"""A strict (args-only) forward pass for the model to functionalize the args.

This follows the standard function signature as expected by factory.py. We do _not_ use the
model.forward method directly to create the patch. Instead we use the type of the model to
get the forward method to keep the patch composable with other forward patches.
"""
return type(model).forward(model, input_ids=input_ids, position_ids=position_ids)

def _recursive_update_config(
self, config: PretrainedConfig, update_dict: Dict[str, Any]
) -> Tuple[PretrainedConfig, Dict[str, Any]]:
Expand Down Expand Up @@ -542,28 +518,6 @@ def init_processor(self) -> Optional[Any]:
return None
return AutoProcessor.from_pretrained(self.tokenizer, **self.tokenizer_kwargs)

# TODO: in theory the signature could be auto-derived but it would probably require some hefty
# meta-programming to progmatically generate the functions and signature from something like the
# example inputs. And even with that we would still need to figure out how to automatically
# infer the dynamic shapes for the extra inputs.
# Alternatively, we could try to directly use the HF forward again but I am not sure whether
# this will trigger some kind of kwarg-handling inside the graph which I would want to avoid.
@staticmethod
def _strict_forward(
model: nn.Module,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
pixel_values: torch.Tensor,
):
"""A strict (args-only) forward pass for the model to functionalize the args.

It adds pixel_values as a positional argument as expected by most
AutoModelForImageTextToText in addition to the required input_ids and position_ids.
"""
return type(model).forward(
model, input_ids=input_ids, position_ids=position_ids, pixel_values=pixel_values
)

def get_example_inputs(self) -> Dict[str, torch.Tensor]:
"""Return a dictionary of example inputs for the model."""

Expand Down
21 changes: 0 additions & 21 deletions tensorrt_llm/_torch/auto_deploy/models/mistral3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,6 @@ def get_extra_inputs(

return extra_inputs

@staticmethod
def _strict_forward(
model: torch.nn.Module,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
pixel_values: torch.Tensor,
image_sizes: torch.Tensor,
):
"""A strict (args-only) forward pass for the model to functionalize the args.

It adds ``pixel_values`` and ``image_sizes`` as a positional argument as expected by
Mistral3Model in addition to the required ``input_ids`` and ``position_ids``.
"""
return type(model).forward(
model,
input_ids=input_ids,
position_ids=position_ids,
pixel_values=pixel_values,
image_sizes=image_sizes,
)

@property
def _example_image_dims(self) -> Tuple[int, int]:
# The pixtral processor requires a minimum image size, which is larger than the default (16, 16)
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _prepare_inputs(
@nvtx_range("ad_compute_logits")
def _compute_logits(self) -> List[torch.Tensor]:
# run the model
logits: torch.Tensor = self.model(self.cache_seq_interface)[0]
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]

# return a list of tensors
return self.cache_seq_interface.info.unnest_sequences(logits)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ class BuildModelConfig(TransformConfig):
"""Configuration for the build model transform."""

device: str = Field(default="meta", description="The device to build the model on.")
use_strict_forward: bool = Field(
default=True,
description="If True, the forward pass will be patched to use a strict positional-only list"
" of arguments. If False, the default with **kwargs can be used.",
)


@TransformRegistry.register("build_model")
Expand All @@ -51,9 +46,6 @@ def _apply(
# build the model
model = factory.build_model(self.config.device)

assert self.config.use_strict_forward, "Only strict forward is supported."
factory._set_strict_forward(model)

# as wrapper to satisfy the interface we will register the model as a submodule
gm.add_module("factory_model", model)

Expand Down Expand Up @@ -89,8 +81,6 @@ def _apply(
# build and load the model
model = factory.build_and_load_model(self.config.device)

assert not self.config.use_strict_forward, "Only regular forward is supported."

# as wrapper to satisfy the interface we will register the model as a submodule
gm.add_module("factory_model", model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def _apply(
compiler_cls = CompileBackendRegistry.get(self.config.compile_backend)
egm_compiled = compiler_cls(
gm,
args=cm.args,
args=(),
kwargs=cm.named_args,
max_batch_size=cm.info.max_batch_size,
**self.config.model_dump(),
).compile()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def _apply(
# export the model to a graph module
gm = torch_export_to_gm(
model,
args=cm.args,
dynamic_shapes=cm.dynamic_shapes,
args=(),
kwargs=cm.named_args,
dynamic_shapes=cm.named_dynamic_shapes,
clone=self.config.clone_state_dict,
strict=self.config.strict,
patch_list=self.config.patch_list,
Expand Down

This file was deleted.

Loading
Loading