Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/llm-api/quickstart_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def setup_llm(args, **kwargs):
kv_cache_config = KvCacheConfig(
enable_block_reuse=not args.disable_kv_cache_reuse,
free_gpu_memory_fraction=args.kv_cache_fraction,
dtype=args.kv_cache_dtype,
)

spec_decode_algo = args.spec_decode_algo.upper(
Expand Down Expand Up @@ -195,6 +194,7 @@ def setup_llm(args, **kwargs):
model=args.model_dir,
backend='pytorch',
disable_overlap_scheduler=args.disable_overlap_scheduler,
kv_cache_dtype=args.kv_cache_dtype,
kv_cache_config=kv_cache_config,
attn_backend=args.attention_backend,
cuda_graph_config=cuda_graph_config,
Expand Down
12 changes: 4 additions & 8 deletions tensorrt_llm/bench/benchmark/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,12 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
enable_chunked_prefill = params.get("enable_chunked_prefill", False)

kv_cache_dtype = "auto"
kv_cache_config = {}
if extra_llm_api_options:
with open(extra_llm_api_options, 'r') as f:
llm_args_dict = yaml.safe_load(f)
kv_cache_config = llm_args_dict.get("kv_cache_config", {
"dtype": "auto",
})
kv_cache_dtype = kv_cache_config.get("dtype", "auto")

if "kv_cache_dtype" in llm_args_dict:
kv_cache_dtype = llm_args_dict["kv_cache_dtype"]

enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
enable_chunked_prefill)
Expand Down Expand Up @@ -160,11 +158,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
"max_batch_size": max_batch_size
}

kv_cache_config["dtype"] = kv_cache_dtype

pyt_options = {
"cuda_graph_config": cuda_graph_config,
"kv_cache_config": kv_cache_config,
"kv_cache_dtype": kv_cache_dtype,
}

backend = params.get("backend", "pytorch")
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/bench/dataclasses/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def get_pytorch_perf_config(self) -> PyTorchConfig:
def get_autodeploy_perf_config(self) -> Dict:
AutoDeployPerfConfig = dict
ad_config = AutoDeployPerfConfig()
ad_config["kv_cache_dtype"] = "auto"
ad_config["attn_backend"] = "flashinfer"
return ad_config

Expand Down
14 changes: 2 additions & 12 deletions tensorrt_llm/bench/dataclasses/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from tensorrt_llm.bench.dataclasses.statistics import (BenchmarkStatistics,
PercentileStats,
RequestRecord)
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.logger import Logger
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode

Expand Down Expand Up @@ -276,17 +275,8 @@ def get_statistics_dict(self) -> Dict[str, Any]:
model = self.rt_cfg.model_path or self.rt_cfg.model
model_config = ModelConfig.from_pretrained(model,
trust_remote_code=True)
kv_cache_config = self.kwargs.get("kv_cache_config",
KvCacheConfig())
if isinstance(kv_cache_config, KvCacheConfig):
kv_cache_dtype = kv_cache_config.dtype
elif isinstance(kv_cache_config, dict):
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
else:
raise ValueError(
f"Invalid kv_cache_config type: {type(kv_cache_config)}.")

validate_and_set_kv_cache_quant(model_config, kv_cache_dtype)
validate_and_set_kv_cache_quant(model_config,
self.kwargs["kv_cache_dtype"])
Comment on lines +278 to +279
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Simplification looks good, but verify kwargs consistency.

The direct access to self.kwargs["kv_cache_dtype"] simplifies the logic by removing the intermediate KvCacheConfig extraction. However, ensure that all callers consistently provide kv_cache_dtype in kwargs to avoid KeyError exceptions.


🏁 Script executed:

#!/bin/bash
# Description: Verify that kv_cache_dtype is consistently provided in kwargs across the codebase
# Expected: All ReportUtility instantiations should include kv_cache_dtype in kwargs

# Search for ReportUtility instantiations
echo "=== ReportUtility instantiations ==="
rg -A 10 -B 5 "ReportUtility\(" --type py

# Search for places where kwargs might be passed to ReportUtility
echo -e "\n=== Potential kwargs sources ==="
rg -A 5 -B 5 "kv_cache_dtype.*kwargs\|kwargs.*kv_cache_dtype" --type py

Length of output: 3274


🏁 Script executed:

#!/bin/bash
set -e

echo "=== ReportUtility context in low_latency.py ==="
rg -n -C 5 "report_utility" tensorrt_llm/bench/benchmark/low_latency.py || true

echo -e "\n=== 'kwargs' declarations/usages in low_latency.py ==="
rg -n "kwargs" tensorrt_llm/bench/benchmark/low_latency.py || true

echo -e "\n=== ReportUtility context in throughput.py ==="
rg -n -C 5 "report_utility" tensorrt_llm/bench/benchmark/throughput.py || true

echo -e "\n=== 'kwargs' declarations/usages in throughput.py ==="
rg -n "kwargs" tensorrt_llm/bench/benchmark/throughput.py || true

Length of output: 4469


🏁 Script executed:

#!/bin/bash
set -e

echo "=== Searching for get_llm_args definition and usage ==="
rg -n "get_llm_args" -t py

echo -e "\n=== Inspecting RuntimeConfig class for kv_cache_dtype ==="
rg -n -C 3 "class .*RuntimeConfig" -t py || true

Length of output: 1901


🏁 Script executed:

#!/bin/bash
set -e

echo "=== Extracting RuntimeConfig class and get_llm_args method ==="
rg -n -C 50 "class RuntimeConfig" tensorrt_llm/bench/dataclasses/configuration.py

echo -e "\n=== Extracting get_llm_args implementation ==="
rg -n -C 50 "def get_llm_args" tensorrt_llm/bench/dataclasses/configuration.py

echo -e "\n=== Searching for kv_cache_dtype in configuration ==="
rg -n "kv_cache_dtype" -t py

Length of output: 35393


🏁 Script executed:

#!/bin/bash
set -e

echo "=== validate_and_set_kv_cache_quant definition ==="
rg -n -C 20 "def validate_and_set_kv_cache_quant" tensorrt_llm/bench/benchmark/utils/general.py || true

Length of output: 224


Add kv_cache_dtype to kwargs before calling ReportUtility

The new call to

validate_and_set_kv_cache_quant(model_config, self.kwargs["kv_cache_dtype"])

will KeyError unless every ReportUtility instantiation’s kwargs includes "kv_cache_dtype". Please update both benchmark entry-points:

• In tensorrt_llm/bench/benchmark/low_latency.py (around line 353):

     # merge runtime_config args
     kwargs = kwargs | runtime_config.get_llm_args()
     kwargs["backend"] = backend
+    # ensure kv_cache_dtype is passed through for reporting
+    kwargs["kv_cache_dtype"] = runtime_config.settings_config.get_kvcache_config().kv_cache_dtype
     report_utility = ReportUtility(statistics, metadata, runtime_config,
-                                   logger, kwargs, True)
+                                   logger, kwargs, True)

• In tensorrt_llm/bench/benchmark/throughput.py (around line 436):

     kwargs = kwargs | runtime_config.get_llm_args()
     kwargs["backend"] = backend
+    # ensure kv_cache_dtype is passed through for reporting
+    kwargs["kv_cache_dtype"] = runtime_config.settings_config.get_kvcache_config().kv_cache_dtype
     report_utility = ReportUtility(statistics, metadata, runtime_config,
-                                   logger, kwargs, streaming)
+                                   logger, kwargs, streaming)

With these additions, self.kwargs["kv_cache_dtype"] will always be defined and validate_and_set_kv_cache_quant will receive the expected dtype.

🤖 Prompt for AI Agents
In tensorrt_llm/bench/dataclasses/reporting.py at lines 278-279, the code
accesses self.kwargs["kv_cache_dtype"] which can cause a KeyError if
"kv_cache_dtype" is not present in kwargs. To fix this, update the benchmark
entry points in tensorrt_llm/bench/benchmark/low_latency.py around line 353 and
tensorrt_llm/bench/benchmark/throughput.py around line 436 to include
"kv_cache_dtype" in the kwargs passed to ReportUtility. This ensures
self.kwargs["kv_cache_dtype"] is always defined before calling
validate_and_set_kv_cache_quant.


stats_dict["engine"] |= {
"backend":
Expand Down
65 changes: 15 additions & 50 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,10 +822,6 @@ class KvCacheConfig(BaseModel, PybindMirror):
use_uvm: bool = Field(default=False,
description="Whether to use UVM for the KV cache.")

# This is a pure python field, not a pybind field. It is only for the Pytorch backend.
dtype: str = Field(default="auto",
description="The data type to use for the KV cache.")

def _to_pybind(self):
return _KvCacheConfig(
enable_block_reuse=self.enable_block_reuse,
Expand Down Expand Up @@ -1037,6 +1033,10 @@ class BaseLlmArgs(BaseModel):
lora_config: Optional[LoraConfig] = Field(
default=None, description="LoRA configuration for the model.")

# Quantization and calibration configurations
quant_config: Optional[QuantConfig] = Field(
default=None, description="Quantization config.", validate_default=True)

# Several options from ExecutorConfig, expanded here for less hierarchy
kv_cache_config: KvCacheConfig = Field(default_factory=KvCacheConfig,
description="KV cache config.")
Expand Down Expand Up @@ -1217,6 +1217,13 @@ def validate_dtype(cls, v, info):
raise RuntimeError("Pre SM 80 GPUs do not support bfloat16")
return v

@field_validator("quant_config", mode='before')
@classmethod
def validate_quant_config(cls, v, info):
if v is None:
v = QuantConfig()
return v

@field_validator("gpus_per_node", mode='before')
@classmethod
def validate_gpus_per_node(cls, v, info):
Expand Down Expand Up @@ -1668,10 +1675,6 @@ class TrtLlmArgs(BaseLlmArgs):
calib_config: Optional[CalibConfig] = Field(
default=None, description="Calibration config.", validate_default=True)

# Quantization and calibration configurations
quant_config: Optional[QuantConfig] = Field(
default=None, description="Quantization config.", validate_default=True)

embedding_parallel_mode: str = Field(
default='SHARDING_ALONG_VOCAB',
description="The embedding parallel mode.")
Expand Down Expand Up @@ -1709,13 +1712,6 @@ def init_calib_config(cls, v):
return CalibConfig()
return v

@field_validator("quant_config", mode='before')
@classmethod
def validate_quant_config(cls, v, info):
if v is None:
v = QuantConfig()
return v

@model_validator(mode="after")
def setup_embedding_parallel_mode(self):
if self.embedding_parallel_mode == 'NONE':
Expand Down Expand Up @@ -1760,11 +1756,6 @@ def validate_enable_build_cache(self):
f"Invalid build_cache_config: {self.enable_build_cache}")
return self

@model_validator(mode="after")
def validate_kv_cache_dtype(self):
assert self.kv_cache_config.dtype == "auto", "KvCacheConfig.dtype is not supported by the TensorRT backend."
return self


class LoadFormat(Enum):
AUTO = 0
Expand Down Expand Up @@ -1838,6 +1829,9 @@ class TorchLlmArgs(BaseLlmArgs):
"If true, will use the TRTLLM sampler instead of the PyTorch sampler. The TRTLLM sampler has a wide coverage of sampling strategies."
)

kv_cache_dtype: str = Field(default="auto",
description="Data type for KV cache.")

enable_iter_perf_stats: bool = Field(
default=False, description="Enable iteration performance statistics.")

Expand Down Expand Up @@ -1903,19 +1897,6 @@ class TorchLlmArgs(BaseLlmArgs):
description="The format of the provided checkpoint.",
)

# PrivateVars
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)

@property
def quant_config(self) -> QuantConfig:
if self._quant_config is None:
self._quant_config = QuantConfig()
return self._quant_config

@quant_config.setter
def quant_config(self, value: QuantConfig):
self._quant_config = value

# TODO: remove backend later
@field_validator('backend', mode='before')
def init_backend(cls, v):
Expand Down Expand Up @@ -2059,22 +2040,6 @@ def validate_cuda_graph_config(self) -> 'TorchLlmArgs':

return self

@model_validator(mode='after')
def sync_quant_config_with_kv_cache_config_dtype(self) -> 'TorchLlmArgs':
if self.kv_cache_config is None:
return self

assert self.quant_config is not None
if self.kv_cache_config.dtype == "auto":
return self
elif self.kv_cache_config.dtype == 'fp8':
self.quant_config.kv_cache_quant_algo = QuantAlgo.FP8
else:
logger.warning(
f"Cannot sync quant_config.kv_cache_quant_algo with kv_cache_config.dtype of {self.kv_cache_config.dtype}, "
"please update the validator")
return self

# TODO: Remove this after the PyTorch backend is fully migrated to TorchLlmArgs from ExecutorConfig
def get_pytorch_backend_config(self) -> "PyTorchConfig":
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
Expand All @@ -2098,7 +2063,7 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig":
moe_backend=self.moe_config.backend,
enable_mixed_sampler=self.enable_mixed_sampler,
enable_trtllm_sampler=self.enable_trtllm_sampler,
kv_cache_dtype=self.kv_cache_config.dtype,
kv_cache_dtype=self.kv_cache_dtype,
enable_iter_perf_stats=self.enable_iter_perf_stats,
enable_iter_req_stats=self.enable_iter_req_stats,
print_iter_log=self.print_iter_log,
Expand Down
3 changes: 0 additions & 3 deletions tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,6 @@ def _update_from_hf_quant_config(self) -> bool:
logger.info(f"Setting {key}={value} from HF quant config.")
setattr(quant_config, key, value)

# Update the quant_config in llm_args for pytorch
self.llm_args.quant_config = quant_config

return True

hf_config_path = f"{self._model_dir}/config.json"
Expand Down
Loading