Skip to content

Commit 9fecb35

Browse files
committed
fix:
- expose `num_free_blocks_per_window_size` via kv_cache_stats. - with `num_free_blocks_per_window_size` , update `get_num_free_blocks()` and `get_num_available_tokens()` Signed-off-by: qixiang-99 <[email protected]>
1 parent c16f048 commit 9fecb35

File tree

4 files changed

+32
-6
lines changed

4 files changed

+32
-6
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ struct KvCacheStats
180180
SizeType32 missedBlocks;
181181
// Measuring the KV Cache reuse rate. cacheHitRate = reusedBlocks / (reusedBlocks + missedBlocks).
182182
float cacheHitRate;
183+
// Number of free blocks for every configured attention-window size.
184+
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
183185
};
184186

185187
// Basic building block of a paged KV cache - a single
@@ -1454,6 +1456,11 @@ class KVCacheManager : public BaseKVCacheManager
14541456
return mBlockManager.getNumMissedBlocks();
14551457
}
14561458

1459+
[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize() const
1460+
{
1461+
return mBlockManager.getNumFreeBlocksPerWindowSize();
1462+
}
1463+
14571464
[[nodiscard]] KvCacheStats getKvCacheStats() const override
14581465
{
14591466
KvCacheStats kvCacheStats;
@@ -1468,6 +1475,7 @@ class KVCacheManager : public BaseKVCacheManager
14681475
kvCacheStats.cacheHitRate = kvCacheStats.reusedBlocks == 0 ? 0
14691476
: static_cast<float>(kvCacheStats.reusedBlocks)
14701477
/ static_cast<float>(kvCacheStats.reusedBlocks + kvCacheStats.missedBlocks);
1478+
kvCacheStats.numFreeBlocksPerWindowSize = getNumFreeBlocksPerWindowSize();
14711479
return kvCacheStats;
14721480
}
14731481

cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
298298
.def_readwrite("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks)
299299
.def_readwrite("reused_blocks", &tbk::KvCacheStats::reusedBlocks)
300300
.def_readwrite("missed_blocks", &tbk::KvCacheStats::missedBlocks)
301-
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate);
301+
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate)
302+
.def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize);
302303

303304
py::class_<tbk::TempAttentionWindowInputs>(m, "TempAttentionWindowInputs")
304305
.def(py::init<>())

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ def __init__(
193193
else 0)
194194

195195
# Determine if this is VSWA (Variable Sliding Window Attention)
196-
is_vswa = len(self.max_attention_window_vec) > 1
196+
self.is_vswa = len(self.max_attention_window_vec) > 1
197197

198198
# Calculate blocks per window using appropriate method
199-
if is_vswa:
199+
if self.is_vswa:
200200
# VSWA case: use C++ implementation for variable window sizes
201201
# model config check
202202
if model_config is None:
@@ -523,14 +523,29 @@ def get_batch_cache_indices(
523523
return result
524524

525525
def get_num_free_blocks(self) -> int:
526-
return self.impl.get_kv_cache_stats().free_num_blocks
526+
if self.is_vswa:
527+
logger.info(
528+
f"For VSWA case, we return the minimum of the number of free blocks for each window size: {self.impl.get_kv_cache_stats().num_free_blocks_per_window_size}"
529+
)
530+
return min(self.impl.get_kv_cache_stats().
531+
num_free_blocks_per_window_size.values())
532+
else:
533+
return self.impl.get_kv_cache_stats().free_num_blocks
527534

528535
def get_num_kv_blocks(self, num_tokens: int) -> int:
529536
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block
530537

531538
def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int:
532-
return (self.get_num_free_blocks() * self.tokens_per_block -
533-
self.num_extra_kv_tokens - max_num_draft_tokens)
539+
if self.max_attention_window_vec and len(
540+
self.max_attention_window_vec) > 1:
541+
# VSWA case, the available tokens should the the minimum of the available tokens for each window size
542+
min_free_blocks = min(self.impl.get_kv_cache_stats().
543+
num_free_blocks_per_window_size.values())
544+
res = min_free_blocks * self.tokens_per_block - self.num_extra_kv_tokens - max_num_draft_tokens
545+
else:
546+
res = (self.get_num_free_blocks() * self.tokens_per_block -
547+
self.num_extra_kv_tokens - max_num_draft_tokens)
548+
return res
534549

535550
def get_buffers(self, layer_idx: int) -> Optional[torch.Tensor]:
536551
layer_offset = self.layer_offsets[layer_idx]

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ l0_h100:
2424
- unittest/disaggregated/test_router.py
2525
- unittest/disaggregated/test_remoteDictionary.py
2626
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
27+
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa
28+
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_chunked_prefill
2729
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
2830
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]
2931
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)

0 commit comments

Comments
 (0)