From b1ada48446ac885f22f38eb50a52342beff7d084 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:17:14 -0700 Subject: [PATCH 1/5] fill seq info data with valid dummy data Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/attention_interface.py | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index fc666fd40a1..27d5f588787 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -11,12 +11,25 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Callable, Dict, List, Literal, Optional, Protocol, Sequence, Tuple, Type, Union +from typing import ( + Callable, + Dict, + List, + Literal, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Type, + Union, +) import torch from torch._ops import OpOverloadPacket from torch.export import Dim from torch.fx import Node +from torch.types import Number from ...._utils import nvtx_range from ..utils.logger import ad_logger @@ -570,15 +583,15 @@ def _flatten(nested_seqs: Sequence[Sequence[int]]) -> List[int]: def _store_arg( self, name: str, - tnsr_like: List[int | float], - reset: bool = False, + tnsr_like: List[Number], + reset_val: Optional[Number] = None, ) -> None: """Store the argument on the host and copy to the device in a non-blocking fashion. Args: name: Name of the argument to store. tnsr_like: List of values to store. - reset: Whether to reset the full tensor on the device to 0 before writing to it. + reset_val: Value to reset/fill the full tensor on the device to before writing to it. """ with nvtx_range(f"ad_store_seq_info_arg_{name}"): tnsr_device = self._args_device[name] @@ -596,8 +609,8 @@ def _store_arg( ) # reset/copy to the device in a non-blocking fashion - if reset: - tnsr_device.zero_() + if reset_val is not None: + tnsr_device.fill_(reset_val) tnsr_device[: len(tnsr_like)].copy_(tnsr_host, non_blocking=True) def _store_extra_arg( @@ -616,6 +629,11 @@ def _store_extra_arg( else: self._extra_args[name] = None + def _get_unique_value(self, occupied: Set[int], max_val: int) -> int: + """Get un unoccupied value from the range indicated by max_val.""" + free_values = set(range(max_val)) - occupied + return free_values.pop() if free_values else 0 + @nvtx_range("ad_nest_sequences") def nest_sequences( self, @@ -636,20 +654,23 @@ def nest_sequences( slot_idx: List of slot indices for each sequence. extra_args: Extra arguments to be stored in the interface. - This i/f will ensure that all sequence info args are updated accordingly. + This i/f will ensure that all sequence info args are updated accordingly. Reset values are + chosen as "neutral" values so that for cases like rounding up batch sizes for cudagraph we + only write to unused buffers/caches. """ ### UPDATE METADATA ######################################################################## # update metadata first since it's useful for other updates to have up-to-date information # set new sequence lengths --> resetting the remaining entries to zero is important to help # us discern the actual number of sequences in the batch. - self._store_arg("seq_len", [len(ids) for ids in input_ids], reset=True) + self._store_arg("seq_len", [len(ids) for ids in input_ids], reset_val=0) # check for updated input_pos (i.e. cache start position) if input_pos is not None: self._store_arg( "input_pos", [input_pos] * self.num_sequences if isinstance(input_pos, int) else input_pos, + reset_val=0, ) # check for updated page_assignments @@ -657,12 +678,14 @@ def nest_sequences( cache_loc, pages_per_seq = self._get_cache_locations_and_pages_per_sequence( page_assignments ) - self._store_arg("cache_loc", cache_loc, reset=True) - self._store_arg("pages_per_seq", pages_per_seq, reset=True) + free_cache_loc = self._get_unique_value(set(cache_loc), self.num_pages) + self._store_arg("cache_loc", cache_loc, reset_val=free_cache_loc) + self._store_arg("pages_per_seq", pages_per_seq, reset_val=1) # check for updated slot_idx if slot_idx is not None: - self._store_arg("slot_idx", slot_idx) + free_slot_idx = self._get_unique_value(set(slot_idx), self.max_batch_size) + self._store_arg("slot_idx", slot_idx, reset_val=free_slot_idx) ### UPDATE MAIN INPUTS ##################################################################### # set new input_ids and make sure to flatten it From ec49c9df63630177115a78fa3d8d8f77d6cd9222 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:35:58 -0700 Subject: [PATCH 2/5] disable kv block reuse Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/llm_args.py | 10 +++++++++- tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py | 9 +++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index e4ac0db0752..ee9861c6746 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -8,7 +8,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig -from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, _ParallelConfig +from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig from ...llmapi.utils import get_type_repr from .models import ModelFactory, ModelFactoryRegistry from .utils._config import DynamicYamlMixInForSettings @@ -304,6 +304,14 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings): _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) + # NOTE: we do not support enable_block_reuse and enable_partial_reuse in AutoDeploy yet! + kv_cache_config: KvCacheConfig = Field( + default_factory=lambda **kwargs: KvCacheConfig( + enable_block_reuse=False, enable_partial_reuse=False, **kwargs + ), + description="KV cache config.", + ) + @property def quant_config(self) -> QuantConfig: if self._quant_config is None: diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 7b7b52e584c..815348d24f2 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -326,6 +326,15 @@ def create_autodeploy_executor(ad_config: LlmArgs): # initialize model engine engine = ADEngine.build_from_config(ad_config=ad_config) + # check kvcache config + enable_block_reuse = ad_config.kv_cache_config.enable_block_reuse + enable_partial_reuse = ad_config.kv_cache_config.enable_partial_reuse + if enable_block_reuse or enable_partial_reuse: + raise RuntimeError( + f"Setting {enable_block_reuse=} and/or {enable_partial_reuse=} to True is NOT supported" + " in AutoDeploy. Please set them to False." + ) + # resource managers kv_cache_manager = _CacheManagerWithFakePool( ad_config.kv_cache_config, From 0d2704ca062aca4e76089aea8a79e49ca8acd711 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:36:23 -0700 Subject: [PATCH 3/5] fix sampling params for build_and_run_ad.py Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- examples/auto_deploy/build_and_run_ad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/auto_deploy/build_and_run_ad.py b/examples/auto_deploy/build_and_run_ad.py index 0768b635887..d1faf8fdc6d 100644 --- a/examples/auto_deploy/build_and_run_ad.py +++ b/examples/auto_deploy/build_and_run_ad.py @@ -62,7 +62,7 @@ class PromptConfig(BaseModel): "apply_chat_template.", ) sp_kwargs: Dict[str, Any] = Field( - default_factory=lambda: {"max_tokens": 100, "top_k": 200, "temperature": 1.0}, + default_factory=lambda: {"max_tokens": 100, "top_k": None, "temperature": 1.0}, description="Sampling parameter kwargs passed on the SamplingParams class. " "Defaults are set to the values used in the original model.", ) From 0fa77dd46319900e2aa54ace23717971f50b5b91 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:56:45 -0700 Subject: [PATCH 4/5] support enable_block_reuse Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/attention_interface.py | 27 +++++++++++++++---- tensorrt_llm/_torch/auto_deploy/llm_args.py | 6 ++--- .../_torch/auto_deploy/models/factory.py | 1 + .../_torch/auto_deploy/shim/ad_executor.py | 26 ++++++++++++++---- tensorrt_llm/commands/serve.py | 9 ------- .../defs/accuracy/test_llm_api_autodeploy.py | 22 ++++++++++----- 6 files changed, 62 insertions(+), 29 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 27d5f588787..38f7c5f0815 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -155,10 +155,19 @@ def __init__( # sanity check assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size" + # cache_loc requires some special treatment due to block reuse. Note that the constraint for + # cache_loc with block_reuse is as follows: + # 0<= max(cache_loc) < num_pages + # len(cache_loc) <= max_num_cache_loc_assignments + max_num_cache_loc_assignments = ( + max_seq_len_adjusted // self.page_size + 1 + ) * self.max_batch_size + # log parameters ad_logger.info( f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.page_size=}, " - f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}" + f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}, " + f"{max_num_cache_loc_assignments=}" ) # indicator if extra args are activated that are needed for cached attention backends @@ -178,7 +187,7 @@ def __init__( # TENSOR FIELDS FOR CACHED ATTENTION "seq_len": torch.empty(self.max_batch_size, dtype=torch.int), "input_pos": torch.empty(self.max_batch_size, dtype=torch.int), - "cache_loc": torch.empty(self.num_pages, dtype=torch.int), + "cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int), "pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int), "slot_idx": torch.empty(self.max_batch_size, dtype=torch.int), # OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER @@ -382,7 +391,8 @@ def num_pages(self) -> int: def num_pages(self, value): self._num_pages = value # update the cache_loc tensor - self._args_device["cache_loc"].resize_(value) + if self._args_device["cache_loc"].numel() < value: + self._args_device["cache_loc"].resize_(value) @property def is_paged(self) -> bool: @@ -630,8 +640,15 @@ def _store_extra_arg( self._extra_args[name] = None def _get_unique_value(self, occupied: Set[int], max_val: int) -> int: - """Get un unoccupied value from the range indicated by max_val.""" - free_values = set(range(max_val)) - occupied + """Get un unoccupied value from the range indicated by max_val. + + In addition, this function performs a sanity check to ensure that no value in the occupied + set is out of bounds. + """ + full_range = set(range(max_val)) + free_values = full_range - occupied + out_of_range = occupied - full_range + assert not out_of_range, f"Out of range values: {out_of_range}" return free_values.pop() if free_values else 0 @nvtx_range("ad_nest_sequences") diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index ee9861c6746..e439435179b 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -304,11 +304,9 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings): _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) - # NOTE: we do not support enable_block_reuse and enable_partial_reuse in AutoDeploy yet! + # NOTE: we do not support enable_partial_reuse in AutoDeploy yet! kv_cache_config: KvCacheConfig = Field( - default_factory=lambda **kwargs: KvCacheConfig( - enable_block_reuse=False, enable_partial_reuse=False, **kwargs - ), + default_factory=lambda **kwargs: KvCacheConfig(enable_partial_reuse=False, **kwargs), description="KV cache config.", ) diff --git a/tensorrt_llm/_torch/auto_deploy/models/factory.py b/tensorrt_llm/_torch/auto_deploy/models/factory.py index 8e0ed29cebb..ad1d119842f 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/factory.py +++ b/tensorrt_llm/_torch/auto_deploy/models/factory.py @@ -235,6 +235,7 @@ def load_or_random_init(self, model: nn.Module, device: DeviceLikeType): if not self.skip_loading_weights: self.prefetch_checkpoint(force=True) self._load_checkpoint(model, device) + ad_logger.info("Loading and initializing weights. Done.") @staticmethod def _to_maybe_random(model: nn.Module, device: DeviceLikeType): diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 815348d24f2..5b8022a2c1c 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -191,7 +191,8 @@ def _prepare_inputs( # look at context requests first for request in context_requests: # store input ids and pos of first token in sequence - input_ids.append(request.get_tokens(0)) + # NOTE: input_pos > 0 may indicate block reuse --> hence we need to slice the input_ids + input_ids.append(request.get_tokens(0)[request.context_current_position :]) input_pos.append(request.context_current_position) request.py_batch_idx = request.seq_slot @@ -326,16 +327,31 @@ def create_autodeploy_executor(ad_config: LlmArgs): # initialize model engine engine = ADEngine.build_from_config(ad_config=ad_config) - # check kvcache config + # check kvcache config for partial block reuse + # NOTE: partial reuse is not supported since the KVCacheManager will copy the partial block to a + # new block internally. Since AD only uses the page assignment functionality of the + # KVCacheManager with an internally-managed cache pool, this copy to a new block has no effect + # in AD. Full block reuse on the other hand is supported since the KVCacheManager will just + # point to the existing block. enable_block_reuse = ad_config.kv_cache_config.enable_block_reuse enable_partial_reuse = ad_config.kv_cache_config.enable_partial_reuse - if enable_block_reuse or enable_partial_reuse: + if enable_block_reuse and enable_partial_reuse: raise RuntimeError( - f"Setting {enable_block_reuse=} and/or {enable_partial_reuse=} to True is NOT supported" - " in AutoDeploy. Please set them to False." + f"enable_block_reuse with {enable_partial_reuse=} set to True is NOT supported" + " in AutoDeploy. Please set it to False." + ) + + # TODO: detect whether SSM layer is present in the model and raise an error or disable block + # reuse with a warning --> see https://github.com/NVIDIA/TensorRT-LLM/issues/7142. For now, we + # just emit a general warning. + if enable_block_reuse: + ad_logger.warning( + f"{enable_block_reuse=} is enabled. Note that this is not supported for SSM layers and" + " may lead to incorrect results if the model contains SSM layers." ) # resource managers + ad_logger.info(f"{ad_config.kv_cache_config=}") kv_cache_manager = _CacheManagerWithFakePool( ad_config.kv_cache_config, num_blocks=engine.cache_seq_interface.info.num_pages, diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 77da0eac5b6..7da0930264b 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -165,15 +165,6 @@ def launch_server(host: str, elif backend == '_autodeploy': # AutoDeploy does not support build_config llm_args.pop("build_config", None) - # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142): - # AutoDeploy does not support cache reuse yet. - kv_cache_config = llm_args["kv_cache_config"] - # If the LLM API options YAML contained a portion for `kv_cache_config`, then this will be - # a dict. Otherwise, it will be an instance of the `KvCacheConfig` class, hence the below. - if isinstance(kv_cache_config, dict): - llm_args["kv_cache_config"]["enable_block_reuse"] = False - else: - llm_args["kv_cache_config"].enable_block_reuse = False llm = AutoDeployLLM(**llm_args) elif backend == 'tensorrt' or backend == 'trt': llm_args.pop("backend") diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 82a9f54a118..58d5ee067d4 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import pytest from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM @@ -22,19 +24,27 @@ from .accuracy_core import MMLU, CnnDailymail, LlmapiAccuracyTestHarness +def _hf_model_dir_or_hub_id( + hf_model_subdir: str, + hf_hub_id: str, +) -> str: + llm_models_path = llm_models_root() + if llm_models_path and os.path.isdir( + (model_fullpath := os.path.join(llm_models_path, hf_model_subdir))): + return str(model_fullpath) + else: + return hf_hub_id + + class TestLlama3_1_8B(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.1-8B" - MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Meta-Llama-3.1-8B" + MODEL_PATH = _hf_model_dir_or_hub_id("llama-3.1-model/Meta-Llama-3.1-8B", + MODEL_NAME) def get_default_kwargs(self): return { 'skip_tokenizer_init': False, 'trust_remote_code': True, - # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142): - # AutoDeploy does not support cache reuse yet. - 'kv_cache_config': { - 'enable_block_reuse': False, - }, 'max_batch_size': 512, # 131072 is the max seq len for the model 'max_seq_len': 8192, From fc23f260089a39cf195f35871dd7e4ef68110b4e Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Thu, 2 Oct 2025 08:43:18 -0700 Subject: [PATCH 5/5] addressing reviewer feedback Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/attention_interface.py | 2 +- tensorrt_llm/_torch/auto_deploy/llm_args.py | 18 ++++++++---- .../_torch/auto_deploy/shim/ad_executor.py | 28 +++++++++++-------- 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 38f7c5f0815..2f12112710e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -157,7 +157,7 @@ def __init__( # cache_loc requires some special treatment due to block reuse. Note that the constraint for # cache_loc with block_reuse is as follows: - # 0<= max(cache_loc) < num_pages + # 0 <= cache_loc < num_pages # len(cache_loc) <= max_num_cache_loc_assignments max_num_cache_loc_assignments = ( max_seq_len_adjusted // self.page_size + 1 diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index e439435179b..7cd235cefd4 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -116,18 +116,30 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): device: str = Field(default="cuda", description="The device to use for the model.", frozen=True) + # TODO: see if we can just remove this field and use kv_cache_config.dtype instead? kv_cache_dtype: str = Field( default="auto", description="Data type for KV cache. This is a temporary field until kv_cache_dtype is " "supported in AutoDeploy.", ) + # NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet + # see https://github.com/NVIDIA/TensorRT-LLM/issues/7142 + kv_cache_config: KvCacheConfig = Field( + default_factory=lambda **kwargs: KvCacheConfig(copy_on_partial_reuse=False, **kwargs), + description="KV cache config.", + ) + max_beam_width: int = Field( default=1, description="The maximum beam width. >1 is not supported by AutoDeploy.", frozen=True, ) + enable_chunked_prefill: bool = Field( + default=False, description="Enable chunked prefill.", frozen=True + ) + ### INFERENCE OPTIMIZER CONFIG ################################################################# attn_backend: Literal["flashinfer", "triton", "torch"] = Field( default="flashinfer", description="Attention backend to use." @@ -304,12 +316,6 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings): _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) - # NOTE: we do not support enable_partial_reuse in AutoDeploy yet! - kv_cache_config: KvCacheConfig = Field( - default_factory=lambda **kwargs: KvCacheConfig(enable_partial_reuse=False, **kwargs), - description="KV cache config.", - ) - @property def quant_config(self) -> QuantConfig: if self._quant_config is None: diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 5b8022a2c1c..e0dd2a4857e 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -191,9 +191,15 @@ def _prepare_inputs( # look at context requests first for request in context_requests: # store input ids and pos of first token in sequence - # NOTE: input_pos > 0 may indicate block reuse --> hence we need to slice the input_ids - input_ids.append(request.get_tokens(0)[request.context_current_position :]) - input_pos.append(request.context_current_position) + # NOTE: begin_compute > 0 indicates block reuse + # NOTE: end_compute will be used in the future for chunked prefill + all_prompt_tokens = request.get_tokens(0) + begin_compute = request.context_current_position + end_compute = begin_compute + request.context_chunk_size + prompt_tokens = all_prompt_tokens[begin_compute:end_compute] + + input_ids.append(prompt_tokens) + input_pos.append(begin_compute) request.py_batch_idx = request.seq_slot last_logit_only.append(True) @@ -328,17 +334,16 @@ def create_autodeploy_executor(ad_config: LlmArgs): engine = ADEngine.build_from_config(ad_config=ad_config) # check kvcache config for partial block reuse - # NOTE: partial reuse is not supported since the KVCacheManager will copy the partial block to a - # new block internally. Since AD only uses the page assignment functionality of the - # KVCacheManager with an internally-managed cache pool, this copy to a new block has no effect - # in AD. Full block reuse on the other hand is supported since the KVCacheManager will just - # point to the existing block. + # TODO: copy_on_partial_reuse is not supported yet, see + # https://github.com/NVIDIA/TensorRT-LLM/issues/7142 for more details. enable_block_reuse = ad_config.kv_cache_config.enable_block_reuse enable_partial_reuse = ad_config.kv_cache_config.enable_partial_reuse - if enable_block_reuse and enable_partial_reuse: + copy_on_partial_reuse = ad_config.kv_cache_config.copy_on_partial_reuse + if enable_block_reuse and enable_partial_reuse and copy_on_partial_reuse: raise RuntimeError( - f"enable_block_reuse with {enable_partial_reuse=} set to True is NOT supported" - " in AutoDeploy. Please set it to False." + f"partial block reuse with {copy_on_partial_reuse=} set to True is NOT supported" + " in AutoDeploy. Please set it to False via the kv_cache_config.copy_on_partial_reuse " + "field in tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs." ) # TODO: detect whether SSM layer is present in the model and raise an error or disable block @@ -351,7 +356,6 @@ def create_autodeploy_executor(ad_config: LlmArgs): ) # resource managers - ad_logger.info(f"{ad_config.kv_cache_config=}") kv_cache_manager = _CacheManagerWithFakePool( ad_config.kv_cache_config, num_blocks=engine.cache_seq_interface.info.num_pages,