From 089e11e42af91983750096a7a1e5c1e98399676c Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Wed, 9 Jul 2025 01:40:16 +0000 Subject: [PATCH 1/5] Fix: Enhance ModelConfig for kv cache size calculations - Added support for setting num_kv_heads in ModelConfig. - Implemented logic to determine head_size from head_size or head_dim in pretrained_config. - Added warnings for default values when configurations are not set. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 830cd5bda6a..671564baadc 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -305,6 +305,8 @@ def get_bindings_model_config(self, hidden_size=hidden_size, data_type=torch_dtype_to_binding( self.pretrained_config.torch_dtype)) + + # For kv cache size calculation: set tokens_per_block if tokens_per_block is None: logger.warning( f"tokens_per_block is not set, using default value {model_config_cpp.tokens_per_block}" @@ -312,6 +314,12 @@ def get_bindings_model_config(self, else: model_config_cpp.tokens_per_block = tokens_per_block + # For kv cache size calculation: set num_kv_heads + num_kv_heads = getattr( + self.pretrained_config, "num_key_value_heads", + num_heads) // (self.mapping.tp_size * self.mapping.cp_size) + model_config_cpp.set_num_kv_heads(num_kv_heads) + mlp_hidden_size = None if self.pretrained_config.intermediate_size is not None: mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size @@ -333,9 +341,16 @@ def get_bindings_model_config(self, f"Failed to infer mlp hidden size for model: {self.pretrained_config.model_type}" ) - if "head_size" in self.pretrained_config: - head_size = self.pretrained_config.head_size + # For kv cache size calculation: set size_per_head + head_dim_names = ["head_size", "head_dim"] + for head_dim_name in head_dim_names: + if head_dim_name in self.pretrained_config: + head_size = getattr(self.pretrained_config, head_dim_name) + break else: + logger.warning( + f"head_size/head_dim is not set, using default value {hidden_size // num_heads}" + ) head_size = hidden_size // num_heads model_config_cpp.mlp_hidden_size = mlp_hidden_size From 2186dd4617a3a935c244f3955aafcd2964823255 Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Wed, 9 Jul 2025 05:55:47 +0000 Subject: [PATCH 2/5] Remove chunk prefill test case due to kernel incompatibility. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- .../integration/defs/accuracy/test_llm_api_pytorch.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 701461b19d1..5d4926f5fc8 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -537,17 +537,6 @@ def test_auto_dtype_vswa(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) - # chunked prefill case or more features - extra_llm_config = dict( - enable_chunked_prefill=True, - max_num_tokens=1024, - ) - with LLM(self.MODEL_PATH, - kv_cache_config=self.kv_cache_config, - **extra_llm_config) as llm: - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) - class TestMixtral8x7B(LlmapiAccuracyTestHarness): MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1" From 93b81bbde2d6039039c813d9a1f1cfc01f1253ef Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Wed, 9 Jul 2025 05:57:12 +0000 Subject: [PATCH 3/5] Remove accidentally added skip. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- tests/integration/defs/accuracy/test_llm_api_pytorch.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 5d4926f5fc8..41e99812c03 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -70,7 +70,6 @@ def test_nvfp4_streaming(self, stream_interval): task.evaluate(llm, streaming=True) -@skip_post_blackwell # TODO: remove this skip after this nvbug is fixed: https://nvbugspro.nvidia.com/bug/5295470 class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct" From 14d95b207650de0bd938d55d1a7a92057ec9bf3b Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Wed, 9 Jul 2025 20:07:05 +0000 Subject: [PATCH 4/5] Skip chunk prefill test case due to kernel incompatibility. Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- .../defs/accuracy/test_llm_api_pytorch.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 41e99812c03..5088f901c50 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -536,6 +536,26 @@ def test_auto_dtype_vswa(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.skip( + reason= + "remove this skip after the kernel support mentioned in this nvbug is fixed: https://nvbugspro.nvidia.com/bug/5338620" + ) + def test_auto_dtype_chunked_prefill(self): + # NOTE: Test with VSWA kv cache config. + self.kv_cache_config.max_attention_window = [ + 512, 512, 512, 512, 512, 32768 + ] # Gemma3 1B attention window size pattern + # chunked prefill case or more features + extra_llm_config = dict( + enable_chunked_prefill=True, + max_num_tokens=1024, + ) + with LLM(self.MODEL_PATH, + kv_cache_config=self.kv_cache_config, + **extra_llm_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + class TestMixtral8x7B(LlmapiAccuracyTestHarness): MODEL_NAME = "mistralai/Mixtral-8x7B-v0.1" From b9855b9e0751326378b076ec6263e66eba3d477a Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Mon, 14 Jul 2025 23:51:43 +0000 Subject: [PATCH 5/5] fix: - expose `num_free_blocks_per_window_size` via kv_cache_stats. - with `num_free_blocks_per_window_size` , update `get_num_free_blocks()` and `get_num_available_tokens()` Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- .../batch_manager/kvCacheManager.h | 8 ++++++ .../pybind/batch_manager/kvCacheManager.cpp | 3 ++- .../_torch/pyexecutor/resource_manager.py | 25 +++++++++++++++---- .../test_lists/test-db/l0_h100.yml | 2 ++ 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index caac72744f3..d0daf9e4350 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -180,6 +180,8 @@ struct KvCacheStats SizeType32 missedBlocks; // Measuring the KV Cache reuse rate. cacheHitRate = reusedBlocks / (reusedBlocks + missedBlocks). float cacheHitRate; + // Number of free blocks for every configured attention-window size. + std::map numFreeBlocksPerWindowSize; }; // Basic building block of a paged KV cache - a single @@ -1457,6 +1459,11 @@ class KVCacheManager : public BaseKVCacheManager return mBlockManager.getNumMissedBlocks(); } + [[nodiscard]] std::map getNumFreeBlocksPerWindowSize() const + { + return mBlockManager.getNumFreeBlocksPerWindowSize(); + } + [[nodiscard]] KvCacheStats getKvCacheStats() const override { KvCacheStats kvCacheStats; @@ -1471,6 +1478,7 @@ class KVCacheManager : public BaseKVCacheManager kvCacheStats.cacheHitRate = kvCacheStats.reusedBlocks == 0 ? 0 : static_cast(kvCacheStats.reusedBlocks) / static_cast(kvCacheStats.reusedBlocks + kvCacheStats.missedBlocks); + kvCacheStats.numFreeBlocksPerWindowSize = getNumFreeBlocksPerWindowSize(); return kvCacheStats; } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index a75db66eaee..e31269d1fd9 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -298,7 +298,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def_readwrite("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks) .def_readwrite("reused_blocks", &tbk::KvCacheStats::reusedBlocks) .def_readwrite("missed_blocks", &tbk::KvCacheStats::missedBlocks) - .def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate); + .def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate) + .def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize); py::class_(m, "TempAttentionWindowInputs") .def(py::init<>()) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index ffa8ce4bdae..c5a9f264b01 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -193,10 +193,10 @@ def __init__( else 0) # Determine if this is VSWA (Variable Sliding Window Attention) - is_vswa = len(self.max_attention_window_vec) > 1 + self.is_vswa = len(self.max_attention_window_vec) > 1 # Calculate blocks per window using appropriate method - if is_vswa: + if self.is_vswa: # VSWA case: use C++ implementation for variable window sizes # model config check if model_config is None: @@ -523,14 +523,29 @@ def get_batch_cache_indices( return result def get_num_free_blocks(self) -> int: - return self.impl.get_kv_cache_stats().free_num_blocks + if self.is_vswa: + logger.info( + f"For VSWA case, we return the minimum of the number of free blocks for each window size: {self.impl.get_kv_cache_stats().num_free_blocks_per_window_size}" + ) + return min(self.impl.get_kv_cache_stats(). + num_free_blocks_per_window_size.values()) + else: + return self.impl.get_kv_cache_stats().free_num_blocks def get_num_kv_blocks(self, num_tokens: int) -> int: return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int: - return (self.get_num_free_blocks() * self.tokens_per_block - - self.num_extra_kv_tokens - max_num_draft_tokens) + if self.max_attention_window_vec and len( + self.max_attention_window_vec) > 1: + # VSWA case, the available tokens should the the minimum of the available tokens for each window size + min_free_blocks = min(self.impl.get_kv_cache_stats(). + num_free_blocks_per_window_size.values()) + res = min_free_blocks * self.tokens_per_block - self.num_extra_kv_tokens - max_num_draft_tokens + else: + res = (self.get_num_free_blocks() * self.tokens_per_block - + self.num_extra_kv_tokens - max_num_draft_tokens) + return res def get_buffers(self, layer_idx: int) -> Optional[torch.Tensor]: layer_offset = self.layer_offsets[layer_idx] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 3e9f0d3995b..ca678f13ef5 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -24,6 +24,8 @@ l0_h100: - unittest/disaggregated/test_router.py - unittest/disaggregated/test_remoteDictionary.py - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype + - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa + - accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_chunked_prefill - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)