Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ transforms:
insert_cached_mla_attention:
stage: cache_init
attn_backend: MultiHeadLatentAttention
insert_cached_ssm_attention:
stage: cache_init
attn_backend: torch_ssm
insert_cached_causal_conv:
stage: cache_init
attn_backend: torch_causal_conv
initialize_cache:
stage: cache_init
resize_kv_cache:
Expand Down
24 changes: 8 additions & 16 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
"""Custom ops and make sure they are all registered."""

from ._triton_attention_internal import *
from .dist import *
from .flashinfer_attention import *
from .flashinfer_rope import *
from .linear import *
from .mla import *
from .quant import *
from .rms_norm import *
from .torch_attention import *
from .torch_backend_attention import *
from .torch_moe import *
from .torch_quant import *
from .torch_rope import *
from .triton_attention import *
from .triton_rope import *
from .trtllm_moe import *
import importlib
import pkgutil

__all__ = []

for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
__all__.append(module_name)
importlib.import_module(f"{__name__}.{module_name}")
22 changes: 20 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class SequenceInfo:
- pages_per_seq: [ps_0, ps_1, ..., ps_{b-1}] where ps_i is the number of pages allocated for
sequence i. Note that, for example, cache_loc[p_0:p_1] will correspond to the pages associated
with sequence 1 in the batch.
- slot_idx: [s_0, s_1, ..., s_{b-1}]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a slot?

Sequence slot from the request object. UUID in the range [0, max_batch_size) assigned by the runtime.

Paged attention doesn't care about the sequence mapping; it only cares about which pages hold cache for a particular sequence.

For SSM, there's no notion of a page; you need the whole state.

Corresponds to the slot index of each sequence in the batch.

################################################################################################

Expand Down Expand Up @@ -157,6 +159,7 @@ def __init__(
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
"cache_loc": torch.empty(self.num_pages, dtype=torch.int),
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.int),
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
}
Expand All @@ -165,7 +168,8 @@ def __init__(
}
# NOTE: order of keys is relevant here!
self._uncached_arg_names = ("input_ids", "position_ids")
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq")
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx")
self._cached_constants = ("page_size",)
############################################################################################

# EXTRA TENSOR FIELDS ######################################################################
Expand Down Expand Up @@ -289,7 +293,7 @@ def const_args_for_prepare_metadata(self) -> Tuple:
``insert_cached_attention`` to extract the constant arguments and add them to the
``prepare_metadata`` node/op.
"""
return (self.page_size,)
return tuple(getattr(self, k) for k in self._cached_constants)

@property
def named_dynamic_shapes(self) -> Dict[str, Dict[str, Dim]]:
Expand Down Expand Up @@ -516,11 +520,15 @@ def set_example_sequence(
cache_loc = list(range(sum(pages_per_seq)))
page_assignments = self._get_page_assignments(cache_loc, pages_per_seq)

# vanilla slot indices
slot_idx = list(range(len(input_ids)))

self.nest_sequences(
input_ids,
position_ids, # will be auto-inferred if None
input_pos=0, # no cache history
page_assignments=page_assignments, # vanilla page assignments
slot_idx=slot_idx, # vanilla slot indices
**extra_args,
)

Expand Down Expand Up @@ -601,6 +609,7 @@ def nest_sequences(
position_ids: Optional[Sequence[Sequence[int]]] = None,
input_pos: Optional[Union[Sequence[int], int]] = None,
page_assignments: Optional[Sequence[Sequence[int]]] = None,
slot_idx: Optional[Sequence[int]] = None,
**extra_args: Dict[str, Union[torch.Tensor, Sequence[torch.Tensor]]],
) -> None:
"""Create and store sequence information for the next forward pass.
Expand All @@ -610,6 +619,7 @@ def nest_sequences(
position_ids: List of sequences of position_ids for each token.
input_pos: Absolute starting position in the cache for each sequence.
page_assignments: List of sequences of page assignments for each sequence.
slot_idx: List of slot indices for each sequence.
extra_args: Extra arguments to be stored in the interface.

This i/f will ensure that all sequence info args are updated accordingly.
Expand All @@ -636,6 +646,10 @@ def nest_sequences(
self._store_arg("cache_loc", cache_loc)
self._store_arg("pages_per_seq", pages_per_seq)

# check for updated slot_idx
if slot_idx is not None:
self._store_arg("slot_idx", slot_idx)

### UPDATE MAIN INPUTS #####################################################################
# set new input_ids and make sure to flatten it
self._store_arg("input_ids", self._flatten(input_ids))
Expand Down Expand Up @@ -737,6 +751,7 @@ def __call__(
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
) -> List[torch.Tensor]: ...

Expand Down Expand Up @@ -822,6 +837,9 @@ def prepare_metadata(
seq_len: torch.Tensor,
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
) -> List[torch.Tensor]: ...
```
The metadata should contain all necessary global information required for the underlying
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def prepare_flashinfer_metadata(
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for flashinfer attention.
Expand Down Expand Up @@ -213,7 +214,7 @@ def prepare_flashinfer_metadata(
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
@prepare_flashinfer_metadata.register_fake
def prepare_flashinfer_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
seq_len = SequenceInfo._get_sanitized_seq_len(input_ids, seq_len)
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def prepare_fused_mla_metadata(
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
) -> List[torch.Tensor]:
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def torch_backend_prepare_metadata(
input_pos: torch.Tensor,
cache_loc: torch.Tensor,
pages_per_seq: torch.Tensor,
slot_idx: torch.Tensor,
page_size: int,
) -> List[torch.Tensor]:
"""Prepare metadata for torch backend attention (similar to triton backend)."""
Expand All @@ -378,7 +379,7 @@ def torch_backend_prepare_metadata(

@torch_backend_prepare_metadata.register_fake
def torch_backend_prepare_metadata_fake(
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, page_size
input_ids, position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
):
num_seq = SequenceInfo._get_sanitized_num_sequences(input_ids, seq_len)
return (
Expand Down
Loading
Loading