Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@ transforms:
############################################################################################
build_model:
stage: factory
run_per_gm: false
device: meta
# nothing to clean up
run_graph_cleanup: false
requires_clean_graph: false
export_to_gm:
stage: export
clone_state_dict: false
strict: false
# nothing to clean up
run_graph_cleanup: false
run_per_gm: false
requires_clean_graph: false
cleanup_noop_slice:
stage: post_export
Expand All @@ -35,6 +33,7 @@ transforms:
run_shape_prop: true
match_eager_attention:
stage: pattern_matcher
requires_shape_prop: true
match_grouped_attention:
stage: pattern_matcher
match_attention_layout:
Expand Down Expand Up @@ -87,8 +86,10 @@ transforms:
############################################################################################
load_weights:
stage: weight_load
run_per_gm: false
move_inputs_to_device:
stage: weight_load
run_per_gm: false
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
Expand Down Expand Up @@ -138,10 +139,13 @@ transforms:
attn_backend: cuda_causal_conv
initialize_cache:
stage: cache_init
run_per_gm: false
resize_kv_cache:
stage: cache_init
run_per_gm: false
############################################################################################
# COMPILE MODEL
############################################################################################
compile_model:
stage: compile
run_per_gm: false
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/transformers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,29 @@ transforms:
############################################################################################
build_and_load_factory_model:
stage: factory
run_per_gm: false
############################################################################################
# MOVE ARGUMENTS TO DEVICE
############################################################################################
move_inputs_to_device:
stage: weight_load
run_per_gm: false
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
detect_hf_attn_layers:
stage: cache_init
run_per_gm: false
transformers_replace_cached_attn:
stage: cache_init
attn_backend: flashinfer
run_per_gm: false
initialize_cache:
stage: cache_init
run_per_gm: false
resize_kv_cache:
stage: cache_init
run_per_gm: false
############################################################################################
# COMPILE MODEL
############################################################################################
155 changes: 19 additions & 136 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,16 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Protocol,
Sequence,
Set,
Tuple,
Type,
Union,
)
from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union

import torch
from torch._ops import OpOverloadPacket
from torch.export import Dim
from torch.fx import Node
from torch.types import Number

from ...._utils import nvtx_range
from ..utils.logger import ad_logger

DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
DynamicShapeCallback = Callable[[], DynamicShape]

Constant = Union[int, float, str, None]


Expand Down Expand Up @@ -67,12 +51,6 @@ class SequenceInfo:
### EXTRA ARGUMENTS PROVIDED TO THE INTERFACE ##################################################
Those are extra arguments that can be provided to the interface and they are stored as follows:
- _extra_args: dictionary of extra arguments with currently active values.
- _extra_none_inputs: dictionary of none inputs to the extra arguments.
NOTE: we assume that extra arguments are *optional* arguments to the model. However, we
cannot represent them via `None` since fx graphs require a fixed input type. Instead,
we require a special placeholder tensor to represent the `None` input.
- _extra_dynamic_shapes_callbacks: dictionary of callbacks to initialize the dynamic shapes of
the extra arguments.

### CACHE ARGUMENTS NEEDED FOR ATTENTION OPERATORS FOR FLATTENED SEQUENCES + CACHES ############
- seq_len: [s_0, s_1, ..., s_{b-1}] such that s_total = sum(s_i)
Expand Down Expand Up @@ -175,12 +153,6 @@ def __init__(
# indicator if extra args are activated that are needed for cached attention backends
self._is_cached_attn = False

# indicator how to handle the "None" input for extra args
self._use_strict_args = True

# container for dynamic shapes
self._dynamic_shapes: Optional[Dict[str, DynamicShape]] = None

# TENSOR FIELDS ############################################################################
self._args_device: Dict[str, torch.Tensor] = {
# TENSOR FIELDS FOR UNCACHED ATTENTION
Expand All @@ -206,9 +178,6 @@ def __init__(

# EXTRA TENSOR FIELDS ######################################################################
self._extra_args: Dict[str, Optional[torch.Tensor]] = {}
self._extra_none_inputs: Dict[str, torch.Tensor] = {}
self._extra_dynamic_shapes: Optional[Dict[str, DynamicShape]] = None
self._extra_dynamic_shapes_callbacks: Dict[str, DynamicShapeCallback] = {}
############################################################################################

# call reset once to set a consistent initial state
Expand All @@ -218,33 +187,6 @@ def __init__(
def device(self) -> torch.device:
return self._args_device["input_ids"].device

@property
def use_strict_args(self) -> bool:
return self._use_strict_args

@use_strict_args.setter
def use_strict_args(self, val: bool) -> None:
"""Configure whether to use strict graph arguments only.

Args:
val: strict graph arguments only or not.

In strict arguments mode,
* only stock arguments (like input_ids, position_ids, etc.) or extra
arguments that are explicitly added via the ``add_extra_arg`` interface are allowed.
Other arguments that are provided in ``nest_sequences`` will be rejected and throw an
error.
* registered extra arguments that are not provided to ``nest_sequences`` will be added to
the argument list automatically using the registered None-like tensor.

In non-strict argument mode,
* all arguments including all **kwargs that are provided to ``nest_sequences`` and will
simply be passed to the model in the order received.
* registered extra arguments that are not provided to ``nest_sequences`` will be added
_not_ be added to the argument list.
"""
self._use_strict_args = val

def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
"""Shape the tensor for the forward pass based on the current attention mode.

Expand Down Expand Up @@ -325,7 +267,11 @@ def args_for_prepare_metadata(self) -> Tuple[str, ...]:
like ``insert_cached_attention`` to extract the constant arguments and add them to the
``prepare_metadata`` node/op.
"""
return tuple(self.named_standard_args.keys())
# NOTE: for now we do _not_ include input_ids since we are not guaranteed that input_ids
# is part of the graph, e.g., in situations where the graph is a submodule of the overall
# model. In such instances, the graph usually sees inputs_embeds. However, we assume for
# now that position_ids is always part of the graph.
return ("position_ids",) + self._cached_arg_names

@property
def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
Expand All @@ -343,36 +289,6 @@ def const_args_for_prepare_metadata(self) -> Tuple[Constant, ...]:
"""
return tuple(getattr(self, k) for k in self._cached_constants)

@property
def named_dynamic_shapes(self) -> Dict[str, DynamicShape]:
"""Return dynamic shapes of sequence info tensors.

NOTE: will be lazily initialized since the Dim object is not picklable for multi-processing.
"""
# lazy initialization of dynamic shapes with Dim objects
if self._dynamic_shapes is None:
# set up shape for uncached args (same for all, i.e., batch_size and seq_len)
bs_seq_len_shape: DynamicShape = {}
if self.max_batch_size > 1:
bs_seq_len_shape[0] = Dim("batch_size", max=self.max_batch_size)
bs_seq_len_shape[1] = Dim("seq_len", max=self.max_seq_len)
# bs_seq_len_shape[1] = Dim.AUTO
self._dynamic_shapes = {k: bs_seq_len_shape for k in self._uncached_arg_names}
# cached args are static
self._dynamic_shapes.update({k: {} for k in self._cached_arg_names})

for k, callback in self._extra_dynamic_shapes_callbacks.items():
if k not in self._dynamic_shapes:
self._dynamic_shapes[k] = callback()

# return dynamic shapes according to currently active named_args with consistent order
return {k: self._dynamic_shapes[k] for k in self.named_args.keys()}

@property
def dynamic_shapes(self) -> Tuple[DynamicShape, ...]:
"""Return dynamic shapes of sequence info tensors."""
return tuple(self.named_dynamic_shapes.values())

@property
def seq_len(self) -> List[int]:
return self._args_host["seq_len"].copy()
Expand Down Expand Up @@ -466,7 +382,9 @@ def _get_cache_locations_and_pages_per_sequence(
return cache_loc_flat, pages_per_seq

@classmethod
def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
def _get_sanitized_seq_len(
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
) -> torch.Tensor:
"""Sanitize sequence lengths.

We want to cover the following scenarios with this function:
Expand Down Expand Up @@ -499,22 +417,24 @@ def _get_sanitized_seq_len(cls, input_ids: torch.Tensor, seq_len: torch.Tensor)
# valid cache location in the batch. This would ensure that the dummy sequences just
# repeats valid computation...
"""
_, s = input_ids.shape[:2]
num_seq = cls._get_sanitized_num_sequences(input_ids, seq_len)
_, s = input_or_position_ids.shape[:2]
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
if s > 1:
return seq_len[:num_seq].detach().clone()
else:
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)

@staticmethod
def _get_sanitized_num_sequences(input_ids: torch.Tensor, seq_len: torch.Tensor) -> int:
def _get_sanitized_num_sequences(
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
) -> int:
"""Get number of sequences.

We makes sure that this function is compatible with both torch graph capture and cudagraph.
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
with max_batch_size or max_batch_size*max_seq_len.
"""
b, s = input_ids.shape[:2]
b, s = input_or_position_ids.shape[:2]
if s > 1:
num_seq = torch.sum(seq_len > 0)
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
Expand Down Expand Up @@ -547,12 +467,11 @@ def _move_dict(d: Dict[str, torch.Tensor]) -> None:

_move_dict(self._args_device)
_move_dict(self._extra_args)
_move_dict(self._extra_none_inputs)

def set_example_sequence(
self,
input_ids: Sequence[Sequence[int]] = None,
position_ids: Optional[torch.Tensor] = None,
input_ids: Optional[Sequence[Sequence[int]]] = None,
position_ids: Optional[Sequence[Sequence[int]]] = None,
**extra_args,
) -> None:
"""Set an example sequence useful for testing and export purposes without cache history."""
Expand Down Expand Up @@ -652,8 +571,6 @@ def _store_extra_arg(
else:
tnsr_like = tnsr_like[0]
self._extra_args[name] = tnsr_like.to(self.device, non_blocking=True)
elif self.use_strict_args:
self._extra_args[name] = self._extra_none_inputs[name]
else:
self._extra_args[name] = None

Expand Down Expand Up @@ -736,15 +653,8 @@ def nest_sequences(

### UPDATE EXTRA INPUTS ####################################################################
self._extra_args = {}
# in strict argument mode, we only accept registered extra arguments
if self.use_strict_args:
for name in self._extra_none_inputs.keys():
self._store_extra_arg(name, extra_args.pop(name, None))
assert not extra_args, f"Extra arguments {extra_args.keys()} not found"
# otherwise, we simply pass in all extra arguments
else:
for key, value in extra_args.items():
self._store_extra_arg(key, value)
for key, value in extra_args.items():
self._store_extra_arg(key, value)

@nvtx_range("ad_rescatter_input_ids")
def rescatter_input_ids(
Expand Down Expand Up @@ -778,31 +688,6 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
return list(torch.split(t_squeezed, self.seq_len))

def add_extra_arg(
self,
name: str,
none_input: torch.Tensor,
dynamic_shape_callback: Optional[DynamicShapeCallback] = None,
) -> None:
"""Add an extra argument to the sequence info object.

Args:
name: The name of the extra argument.
none_input: None input value of the extra argument.
dynamic_shape_callback: The callback to get the dynamic shape of the extra argument.

Note that the extra argument is expected to be a tensor.
"""
assert name not in self._named_args().keys(), f"Extra argument {name} already exists"

self._extra_args[name] = none_input.to(self.device)
self._extra_none_inputs[name] = self._extra_args[name]

if dynamic_shape_callback is None:
self._extra_dynamic_shapes_callbacks[name] = lambda: {}
else:
self._extra_dynamic_shapes_callbacks[name] = dynamic_shape_callback


class MHACallable(Protocol):
def __call__(
Expand All @@ -814,7 +699,6 @@ def __call__(
class PrepareMetadataCallable(Protocol):
def __call__(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand Down Expand Up @@ -901,7 +785,6 @@ def get_prepare_metadata_op(cls) -> Tuple[PrepareMetadataCallable, int]:

```
def prepare_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def _build_conv_state_from_sequence(input_bt_c: torch.Tensor, kernel_size: int)
# ---------------------------------------------------------------
@torch.library.custom_op("auto_deploy::cuda_causal_conv_prepare_metadata", mutates_args=())
def cuda_causal_conv_prepare_metadata(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
seq_len: torch.Tensor,
input_pos: torch.Tensor,
Expand All @@ -67,7 +66,7 @@ def cuda_causal_conv_prepare_metadata(

Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized).
"""
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)

seq_start = torch.zeros_like(seq_len_sanitized)
Expand All @@ -81,9 +80,9 @@ def cuda_causal_conv_prepare_metadata(

@cuda_causal_conv_prepare_metadata.register_fake
def cuda_causal_conv_prepare_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
seq_len_sanitized = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
num_seq = len(seq_len_sanitized)
return (
torch.empty_like(seq_len_sanitized),
Expand Down
Loading