Skip to content

Commit e09e409

Browse files
authored
Fix: Enhance ModelConfig for kv cache size calculations (#5868)
Signed-off-by: qixiang-99 <[email protected]>
1 parent fa34cb7 commit e09e409

File tree

6 files changed

+58
-9
lines changed

6 files changed

+58
-9
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ struct KvCacheStats
180180
SizeType32 missedBlocks;
181181
// Measuring the KV Cache reuse rate. cacheHitRate = reusedBlocks / (reusedBlocks + missedBlocks).
182182
float cacheHitRate;
183+
// Number of free blocks for every configured attention-window size.
184+
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
183185
};
184186

185187
// Basic building block of a paged KV cache - a single
@@ -1457,6 +1459,11 @@ class KVCacheManager : public BaseKVCacheManager
14571459
return mBlockManager.getNumMissedBlocks();
14581460
}
14591461

1462+
[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize() const
1463+
{
1464+
return mBlockManager.getNumFreeBlocksPerWindowSize();
1465+
}
1466+
14601467
[[nodiscard]] KvCacheStats getKvCacheStats() const override
14611468
{
14621469
KvCacheStats kvCacheStats;
@@ -1471,6 +1478,7 @@ class KVCacheManager : public BaseKVCacheManager
14711478
kvCacheStats.cacheHitRate = kvCacheStats.reusedBlocks == 0 ? 0
14721479
: static_cast<float>(kvCacheStats.reusedBlocks)
14731480
/ static_cast<float>(kvCacheStats.reusedBlocks + kvCacheStats.missedBlocks);
1481+
kvCacheStats.numFreeBlocksPerWindowSize = getNumFreeBlocksPerWindowSize();
14741482
return kvCacheStats;
14751483
}
14761484

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
298298
.def_readwrite("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks)
299299
.def_readwrite("reused_blocks", &tbk::KvCacheStats::reusedBlocks)
300300
.def_readwrite("missed_blocks", &tbk::KvCacheStats::missedBlocks)
301-
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate);
301+
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate)
302+
.def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize);
302303

303304
py::class_<tbk::TempAttentionWindowInputs>(m, "TempAttentionWindowInputs")
304305
.def(py::init<>())

tensorrt_llm/_torch/model_config.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,21 @@ def get_bindings_model_config(self,
305305
hidden_size=hidden_size,
306306
data_type=torch_dtype_to_binding(
307307
self.pretrained_config.torch_dtype))
308+
309+
# For kv cache size calculation: set tokens_per_block
308310
if tokens_per_block is None:
309311
logger.warning(
310312
f"tokens_per_block is not set, using default value {model_config_cpp.tokens_per_block}"
311313
)
312314
else:
313315
model_config_cpp.tokens_per_block = tokens_per_block
314316

317+
# For kv cache size calculation: set num_kv_heads
318+
num_kv_heads = getattr(
319+
self.pretrained_config, "num_key_value_heads",
320+
num_heads) // (self.mapping.tp_size * self.mapping.cp_size)
321+
model_config_cpp.set_num_kv_heads(num_kv_heads)
322+
315323
mlp_hidden_size = None
316324
if self.pretrained_config.intermediate_size is not None:
317325
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
@@ -333,9 +341,16 @@ def get_bindings_model_config(self,
333341
f"Failed to infer mlp hidden size for model: {self.pretrained_config.model_type}"
334342
)
335343

336-
if "head_size" in self.pretrained_config:
337-
head_size = self.pretrained_config.head_size
344+
# For kv cache size calculation: set size_per_head
345+
head_dim_names = ["head_size", "head_dim"]
346+
for head_dim_name in head_dim_names:
347+
if head_dim_name in self.pretrained_config:
348+
head_size = getattr(self.pretrained_config, head_dim_name)
349+
break
338350
else:
351+
logger.warning(
352+
f"head_size/head_dim is not set, using default value {hidden_size // num_heads}"
353+
)
339354
head_size = hidden_size // num_heads
340355

341356
model_config_cpp.mlp_hidden_size = mlp_hidden_size

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ def __init__(
193193
else 0)
194194

195195
# Determine if this is VSWA (Variable Sliding Window Attention)
196-
is_vswa = len(self.max_attention_window_vec) > 1
196+
self.is_vswa = len(self.max_attention_window_vec) > 1
197197

198198
# Calculate blocks per window using appropriate method
199-
if is_vswa:
199+
if self.is_vswa:
200200
# VSWA case: use C++ implementation for variable window sizes
201201
# model config check
202202
if model_config is None:
@@ -523,14 +523,29 @@ def get_batch_cache_indices(
523523
return result
524524

525525
def get_num_free_blocks(self) -> int:
526-
return self.impl.get_kv_cache_stats().free_num_blocks
526+
if self.is_vswa:
527+
logger.info(
528+
f"For VSWA case, we return the minimum of the number of free blocks for each window size: {self.impl.get_kv_cache_stats().num_free_blocks_per_window_size}"
529+
)
530+
return min(self.impl.get_kv_cache_stats().
531+
num_free_blocks_per_window_size.values())
532+
else:
533+
return self.impl.get_kv_cache_stats().free_num_blocks
527534

528535
def get_num_kv_blocks(self, num_tokens: int) -> int:
529536
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block
530537

531538
def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int:
532-
return (self.get_num_free_blocks() * self.tokens_per_block -
533-
self.num_extra_kv_tokens - max_num_draft_tokens)
539+
if self.max_attention_window_vec and len(
540+
self.max_attention_window_vec) > 1:
541+
# VSWA case, the available tokens should the the minimum of the available tokens for each window size
542+
min_free_blocks = min(self.impl.get_kv_cache_stats().
543+
num_free_blocks_per_window_size.values())
544+
res = min_free_blocks * self.tokens_per_block - self.num_extra_kv_tokens - max_num_draft_tokens
545+
else:
546+
res = (self.get_num_free_blocks() * self.tokens_per_block -
547+
self.num_extra_kv_tokens - max_num_draft_tokens)
548+
return res
534549

535550
def get_buffers(self, layer_idx: int) -> Optional[torch.Tensor]:
536551
layer_offset = self.layer_offsets[layer_idx]

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def test_nvfp4_streaming(self, stream_interval):
7070
task.evaluate(llm, streaming=True)
7171

7272

73-
@skip_post_blackwell # TODO: remove this skip after this nvbug is fixed: https://nvbugspro.nvidia.com/bug/5295470
7473
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
7574
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
7675
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
@@ -537,6 +536,15 @@ def test_auto_dtype_vswa(self):
537536
task = GSM8K(self.MODEL_NAME)
538537
task.evaluate(llm)
539538

539+
@pytest.mark.skip(
540+
reason=
541+
"remove this skip after the kernel support mentioned in this nvbug is fixed: https://nvbugspro.nvidia.com/bug/5338620"
542+
)
543+
def test_auto_dtype_chunked_prefill(self):
544+
# NOTE: Test with VSWA kv cache config.
545+
self.kv_cache_config.max_attention_window = [
546+
512, 512, 512, 512, 512, 32768
547+
] # Gemma3 1B attention window size pattern
540548
# chunked prefill case or more features
541549
extra_llm_config = dict(
542550
enable_chunked_prefill=True,

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ l0_h100:
2424
- unittest/disaggregated/test_router.py
2525
- unittest/disaggregated/test_remoteDictionary.py
2626
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
27+
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa
28+
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_chunked_prefill
2729
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
2830
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]
2931
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)

0 commit comments

Comments
 (0)