Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ struct KvCacheStats
SizeType32 missedBlocks;
// Measuring the KV Cache reuse rate. cacheHitRate = reusedBlocks / (reusedBlocks + missedBlocks).
float cacheHitRate;
// Number of free blocks for every configured attention-window size.
std::map<SizeType32, SizeType32> numFreeBlocksPerWindowSize;
};

// Basic building block of a paged KV cache - a single
Expand Down Expand Up @@ -1457,6 +1459,11 @@ class KVCacheManager : public BaseKVCacheManager
return mBlockManager.getNumMissedBlocks();
}

[[nodiscard]] std::map<SizeType32, SizeType32> getNumFreeBlocksPerWindowSize() const
{
return mBlockManager.getNumFreeBlocksPerWindowSize();
}

[[nodiscard]] KvCacheStats getKvCacheStats() const override
{
KvCacheStats kvCacheStats;
Expand All @@ -1471,6 +1478,7 @@ class KVCacheManager : public BaseKVCacheManager
kvCacheStats.cacheHitRate = kvCacheStats.reusedBlocks == 0 ? 0
: static_cast<float>(kvCacheStats.reusedBlocks)
/ static_cast<float>(kvCacheStats.reusedBlocks + kvCacheStats.missedBlocks);
kvCacheStats.numFreeBlocksPerWindowSize = getNumFreeBlocksPerWindowSize();
return kvCacheStats;
}

Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
.def_readwrite("alloc_new_blocks", &tbk::KvCacheStats::allocNewBlocks)
.def_readwrite("reused_blocks", &tbk::KvCacheStats::reusedBlocks)
.def_readwrite("missed_blocks", &tbk::KvCacheStats::missedBlocks)
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate);
.def_readwrite("cache_hit_rate", &tbk::KvCacheStats::cacheHitRate)
.def_readwrite("num_free_blocks_per_window_size", &tbk::KvCacheStats::numFreeBlocksPerWindowSize);

py::class_<tbk::TempAttentionWindowInputs>(m, "TempAttentionWindowInputs")
.def(py::init<>())
Expand Down
19 changes: 17 additions & 2 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,21 @@ def get_bindings_model_config(self,
hidden_size=hidden_size,
data_type=torch_dtype_to_binding(
self.pretrained_config.torch_dtype))

# For kv cache size calculation: set tokens_per_block
if tokens_per_block is None:
logger.warning(
f"tokens_per_block is not set, using default value {model_config_cpp.tokens_per_block}"
)
else:
model_config_cpp.tokens_per_block = tokens_per_block

# For kv cache size calculation: set num_kv_heads
num_kv_heads = getattr(
self.pretrained_config, "num_key_value_heads",
num_heads) // (self.mapping.tp_size * self.mapping.cp_size)
model_config_cpp.set_num_kv_heads(num_kv_heads)

mlp_hidden_size = None
if self.pretrained_config.intermediate_size is not None:
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
Expand All @@ -333,9 +341,16 @@ def get_bindings_model_config(self,
f"Failed to infer mlp hidden size for model: {self.pretrained_config.model_type}"
)

if "head_size" in self.pretrained_config:
head_size = self.pretrained_config.head_size
# For kv cache size calculation: set size_per_head
head_dim_names = ["head_size", "head_dim"]
for head_dim_name in head_dim_names:
if head_dim_name in self.pretrained_config:
head_size = getattr(self.pretrained_config, head_dim_name)
break
else:
logger.warning(
f"head_size/head_dim is not set, using default value {hidden_size // num_heads}"
)
head_size = hidden_size // num_heads

model_config_cpp.mlp_hidden_size = mlp_hidden_size
Expand Down
25 changes: 20 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ def __init__(
else 0)

# Determine if this is VSWA (Variable Sliding Window Attention)
is_vswa = len(self.max_attention_window_vec) > 1
self.is_vswa = len(self.max_attention_window_vec) > 1

# Calculate blocks per window using appropriate method
if is_vswa:
if self.is_vswa:
# VSWA case: use C++ implementation for variable window sizes
# model config check
if model_config is None:
Expand Down Expand Up @@ -523,14 +523,29 @@ def get_batch_cache_indices(
return result

def get_num_free_blocks(self) -> int:
return self.impl.get_kv_cache_stats().free_num_blocks
if self.is_vswa:
logger.info(
f"For VSWA case, we return the minimum of the number of free blocks for each window size: {self.impl.get_kv_cache_stats().num_free_blocks_per_window_size}"
)
return min(self.impl.get_kv_cache_stats().
num_free_blocks_per_window_size.values())
else:
return self.impl.get_kv_cache_stats().free_num_blocks

def get_num_kv_blocks(self, num_tokens: int) -> int:
return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block

def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int:
return (self.get_num_free_blocks() * self.tokens_per_block -
self.num_extra_kv_tokens - max_num_draft_tokens)
if self.max_attention_window_vec and len(
self.max_attention_window_vec) > 1:
# VSWA case, the available tokens should the the minimum of the available tokens for each window size
min_free_blocks = min(self.impl.get_kv_cache_stats().
num_free_blocks_per_window_size.values())
res = min_free_blocks * self.tokens_per_block - self.num_extra_kv_tokens - max_num_draft_tokens
else:
res = (self.get_num_free_blocks() * self.tokens_per_block -
self.num_extra_kv_tokens - max_num_draft_tokens)
return res

def get_buffers(self, layer_idx: int) -> Optional[torch.Tensor]:
layer_offset = self.layer_offsets[layer_idx]
Expand Down
10 changes: 9 additions & 1 deletion tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_nvfp4_streaming(self, stream_interval):
task.evaluate(llm, streaming=True)


@skip_post_blackwell # TODO: remove this skip after this nvbug is fixed: https://nvbugspro.nvidia.com/bug/5295470
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
MODEL_PATH = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct"
Expand Down Expand Up @@ -537,6 +536,15 @@ def test_auto_dtype_vswa(self):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip(
reason=
"remove this skip after the kernel support mentioned in this nvbug is fixed: https://nvbugspro.nvidia.com/bug/5338620"
)
def test_auto_dtype_chunked_prefill(self):
# NOTE: Test with VSWA kv cache config.
self.kv_cache_config.max_attention_window = [
512, 512, 512, 512, 512, 32768
] # Gemma3 1B attention window size pattern
# chunked prefill case or more features
extra_llm_config = dict(
enable_chunked_prefill=True,
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ l0_h100:
- unittest/disaggregated/test_router.py
- unittest/disaggregated/test_remoteDictionary.py
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_vswa
- accuracy/test_llm_api_pytorch.py::TestGemma3_1BInstruct::test_auto_dtype_chunked_prefill
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_chunked_prefill[attn_backend=TRTLLM] TIMEOUT (90)
Expand Down