Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/auto_deploy/build_and_run_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PromptConfig(BaseModel):
"apply_chat_template.",
)
sp_kwargs: Dict[str, Any] = Field(
default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0},
default_factory=lambda: {"max_tokens": 100, "top_k": None, "temperature": 1.0},
description="Sampling parameter kwargs passed on the SamplingParams class. "
"Defaults are set to the values used in the original model.",
)
Expand Down
68 changes: 54 additions & 14 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,25 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union
from typing import (
Callable,
Dict,
List,
Literal,
Optional,
Protocol,
Sequence,
Set,
Tuple,
Type,
Union,
)

import torch
from torch._ops import OpOverloadPacket
from torch.export import Dim
from torch.fx import Node
from torch.types import Number

from ...._utils import nvtx_range
from ..utils.logger import ad_logger
Expand Down Expand Up @@ -142,10 +155,19 @@ def __init__(
# sanity check
assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"

# cache_loc requires some special treatment due to block reuse. Note that the constraint for
# cache_loc with block_reuse is as follows:
# 0 <= cache_loc < num_pages
# len(cache_loc) <= max_num_cache_loc_assignments
max_num_cache_loc_assignments = (
max_seq_len_adjusted // self.page_size + 1
) * self.max_batch_size

# log parameters
ad_logger.info(
f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.page_size=}, "
f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}"
f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}, "
f"{max_num_cache_loc_assignments=}"
)

# indicator if extra args are activated that are needed for cached attention backends
Expand All @@ -165,7 +187,7 @@ def __init__(
# TENSOR FIELDS FOR CACHED ATTENTION
"seq_len": torch.empty(self.max_batch_size, dtype=torch.int),
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
"cache_loc": torch.empty(self.num_pages, dtype=torch.int),
"cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int),
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.int),
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
Expand Down Expand Up @@ -369,7 +391,8 @@ def num_pages(self) -> int:
def num_pages(self, value):
self._num_pages = value
# update the cache_loc tensor
self._args_device["cache_loc"].resize_(value)
if self._args_device["cache_loc"].numel() < value:
self._args_device["cache_loc"].resize_(value)

@property
def is_paged(self) -> bool:
Expand Down Expand Up @@ -570,15 +593,15 @@ def _flatten(nested_seqs: Sequence[Sequence[int]]) -> List[int]:
def _store_arg(
self,
name: str,
tnsr_like: List[int | float],
reset: bool = False,
tnsr_like: List[Number],
reset_val: Optional[Number] = None,
) -> None:
"""Store the argument on the host and copy to the device in a non-blocking fashion.

Args:
name: Name of the argument to store.
tnsr_like: List of values to store.
reset: Whether to reset the full tensor on the device to 0 before writing to it.
reset_val: Value to reset/fill the full tensor on the device to before writing to it.
"""
with nvtx_range(f"ad_store_seq_info_arg_{name}"):
tnsr_device = self._args_device[name]
Expand All @@ -596,8 +619,8 @@ def _store_arg(
)

# reset/copy to the device in a non-blocking fashion
if reset:
tnsr_device.zero_()
if reset_val is not None:
tnsr_device.fill_(reset_val)
tnsr_device[: len(tnsr_like)].copy_(tnsr_host, non_blocking=True)

def _store_extra_arg(
Expand All @@ -616,6 +639,18 @@ def _store_extra_arg(
else:
self._extra_args[name] = None

def _get_unique_value(self, occupied: Set[int], max_val: int) -> int:
"""Get un unoccupied value from the range indicated by max_val.

In addition, this function performs a sanity check to ensure that no value in the occupied
set is out of bounds.
"""
full_range = set(range(max_val))
free_values = full_range - occupied
out_of_range = occupied - full_range
assert not out_of_range, f"Out of range values: {out_of_range}"
return free_values.pop() if free_values else 0

@nvtx_range("ad_nest_sequences")
def nest_sequences(
self,
Expand All @@ -636,33 +671,38 @@ def nest_sequences(
slot_idx: List of slot indices for each sequence.
extra_args: Extra arguments to be stored in the interface.

This i/f will ensure that all sequence info args are updated accordingly.
This i/f will ensure that all sequence info args are updated accordingly. Reset values are
chosen as "neutral" values so that for cases like rounding up batch sizes for cudagraph we
only write to unused buffers/caches.
"""
### UPDATE METADATA ########################################################################
# update metadata first since it's useful for other updates to have up-to-date information

# set new sequence lengths --> resetting the remaining entries to zero is important to help
# us discern the actual number of sequences in the batch.
self._store_arg("seq_len", [len(ids) for ids in input_ids], reset=True)
self._store_arg("seq_len", [len(ids) for ids in input_ids], reset_val=0)

# check for updated input_pos (i.e. cache start position)
if input_pos is not None:
self._store_arg(
"input_pos",
[input_pos] * self.num_sequences if isinstance(input_pos, int) else input_pos,
reset_val=0,
)

# check for updated page_assignments
if page_assignments is not None:
cache_loc, pages_per_seq = self._get_cache_locations_and_pages_per_sequence(
page_assignments
)
self._store_arg("cache_loc", cache_loc, reset=True)
self._store_arg("pages_per_seq", pages_per_seq, reset=True)
free_cache_loc = self._get_unique_value(set(cache_loc), self.num_pages)
self._store_arg("cache_loc", cache_loc, reset_val=free_cache_loc)
self._store_arg("pages_per_seq", pages_per_seq, reset_val=1)

# check for updated slot_idx
if slot_idx is not None:
self._store_arg("slot_idx", slot_idx)
free_slot_idx = self._get_unique_value(set(slot_idx), self.max_batch_size)
self._store_arg("slot_idx", slot_idx, reset_val=free_slot_idx)

### UPDATE MAIN INPUTS #####################################################################
# set new input_ids and make sure to flatten it
Expand Down
14 changes: 13 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tensorrt_llm.models.modeling_utils import QuantConfig

from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, _ParallelConfig
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig
from ...llmapi.utils import get_type_repr
from .models import ModelFactory, ModelFactoryRegistry
from .utils._config import DynamicYamlMixInForSettings
Expand Down Expand Up @@ -116,18 +116,30 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):

device: str = Field(default="cuda", description="The device to use for the model.", frozen=True)

# TODO: see if we can just remove this field and use kv_cache_config.dtype instead?
kv_cache_dtype: str = Field(
default="auto",
description="Data type for KV cache. This is a temporary field until kv_cache_dtype is "
"supported in AutoDeploy.",
)

# NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet
# see https://github.com/NVIDIA/TensorRT-LLM/issues/7142
kv_cache_config: KvCacheConfig = Field(
default_factory=lambda **kwargs: KvCacheConfig(copy_on_partial_reuse=False, **kwargs),
description="KV cache config.",
)

max_beam_width: int = Field(
default=1,
description="The maximum beam width. >1 is not supported by AutoDeploy.",
frozen=True,
)

enable_chunked_prefill: bool = Field(
default=False, description="Enable chunked prefill.", frozen=True
)

### INFERENCE OPTIMIZER CONFIG #################################################################
attn_backend: Literal["flashinfer", "triton", "torch"] = Field(
default="flashinfer", description="Attention backend to use."
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):
if not self.skip_loading_weights:
self.prefetch_checkpoint(force=True)
self._load_checkpoint(model, device)
ad_logger.info("Loading and initializing weights. Done.")

@staticmethod
def _to_maybe_random(model: nn.Module, device: DeviceLikeType):
Expand Down
33 changes: 31 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,15 @@ def _prepare_inputs(
# look at context requests first
for request in context_requests:
# store input ids and pos of first token in sequence
input_ids.append(request.get_tokens(0))
input_pos.append(request.context_current_position)
# NOTE: begin_compute > 0 indicates block reuse
# NOTE: end_compute will be used in the future for chunked prefill
all_prompt_tokens = request.get_tokens(0)
begin_compute = request.context_current_position
end_compute = begin_compute + request.context_chunk_size
prompt_tokens = all_prompt_tokens[begin_compute:end_compute]

input_ids.append(prompt_tokens)
input_pos.append(begin_compute)

request.py_batch_idx = request.seq_slot
last_logit_only.append(True)
Expand Down Expand Up @@ -326,6 +333,28 @@ def create_autodeploy_executor(ad_config: LlmArgs):
# initialize model engine
engine = ADEngine.build_from_config(ad_config=ad_config)

# check kvcache config for partial block reuse
# TODO: copy_on_partial_reuse is not supported yet, see
# https://github.com/NVIDIA/TensorRT-LLM/issues/7142 for more details.
enable_block_reuse = ad_config.kv_cache_config.enable_block_reuse
enable_partial_reuse = ad_config.kv_cache_config.enable_partial_reuse
copy_on_partial_reuse = ad_config.kv_cache_config.copy_on_partial_reuse
if enable_block_reuse and enable_partial_reuse and copy_on_partial_reuse:
raise RuntimeError(
f"partial block reuse with {copy_on_partial_reuse=} set to True is NOT supported"
" in AutoDeploy. Please set it to False via the kv_cache_config.copy_on_partial_reuse "
"field in tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs."
)

# TODO: detect whether SSM layer is present in the model and raise an error or disable block
# reuse with a warning --> see https://github.com/NVIDIA/TensorRT-LLM/issues/7142. For now, we
# just emit a general warning.
if enable_block_reuse:
ad_logger.warning(
f"{enable_block_reuse=} is enabled. Note that this is not supported for SSM layers and"
" may lead to incorrect results if the model contains SSM layers."
)

# resource managers
kv_cache_manager = _CacheManagerWithFakePool(
ad_config.kv_cache_config,
Expand Down
9 changes: 0 additions & 9 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,6 @@ def launch_server(host: str,
elif backend == '_autodeploy':
# AutoDeploy does not support build_config
llm_args.pop("build_config", None)
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142):
# AutoDeploy does not support cache reuse yet.
kv_cache_config = llm_args["kv_cache_config"]
# If the LLM API options YAML contained a portion for `kv_cache_config`, then this will be
# a dict. Otherwise, it will be an instance of the `KvCacheConfig` class, hence the below.
if isinstance(kv_cache_config, dict):
llm_args["kv_cache_config"]["enable_block_reuse"] = False
else:
llm_args["kv_cache_config"].enable_block_reuse = False
llm = AutoDeployLLM(**llm_args)
elif backend == 'tensorrt' or backend == 'trt':
llm_args.pop("backend")
Expand Down
22 changes: 16 additions & 6 deletions tests/integration/defs/accuracy/test_llm_api_autodeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest

from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
Expand All @@ -22,19 +24,27 @@
from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness


def _hf_model_dir_or_hub_id(
hf_model_subdir: str,
hf_hub_id: str,
) -> str:
llm_models_path = llm_models_root()
if llm_models_path and os.path.isdir(
(model_fullpath := os.path.join(llm_models_path, hf_model_subdir))):
return str(model_fullpath)
else:
return hf_hub_id


class TestLlama3_1_8B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B"
MODEL_PATH = _hf_model_dir_or_hub_id("llama-3.1-model/Meta-Llama-3.1-8B",
MODEL_NAME)

def get_default_kwargs(self):
return {
'skip_tokenizer_init': False,
'trust_remote_code': True,
# TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142):
# AutoDeploy does not support cache reuse yet.
'kv_cache_config': {
'enable_block_reuse': False,
},
'max_batch_size': 512,
# 131072 is the max seq len for the model
'max_seq_len': 8192,
Expand Down