From 6f30929163bb4d2310e3f422cf093bff34edcf17 Mon Sep 17 00:00:00 2001 From: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> Date: Thu, 17 Jul 2025 06:14:23 +0000 Subject: [PATCH] Refactor KVCacheManager: Simplify token availability calculation and add model config assertion for VSWA Signed-off-by: qixiang-99 <203170375+qixiang-99@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/resource_manager.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index c5a9f264b01..df577bc7e89 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -536,16 +536,8 @@ def get_num_kv_blocks(self, num_tokens: int) -> int: return (num_tokens + self.tokens_per_block - 1) // self.tokens_per_block def get_num_available_tokens(self, max_num_draft_tokens: int = 0) -> int: - if self.max_attention_window_vec and len( - self.max_attention_window_vec) > 1: - # VSWA case, the available tokens should the the minimum of the available tokens for each window size - min_free_blocks = min(self.impl.get_kv_cache_stats(). - num_free_blocks_per_window_size.values()) - res = min_free_blocks * self.tokens_per_block - self.num_extra_kv_tokens - max_num_draft_tokens - else: - res = (self.get_num_free_blocks() * self.tokens_per_block - - self.num_extra_kv_tokens - max_num_draft_tokens) - return res + return (self.get_num_free_blocks() * self.tokens_per_block - + self.num_extra_kv_tokens - max_num_draft_tokens) def get_buffers(self, layer_idx: int) -> Optional[torch.Tensor]: layer_offset = self.layer_offsets[layer_idx] @@ -732,6 +724,8 @@ def calculate_max_num_blocks_from_cpp( # VSWA on Torch backend has not supported the cross attention. is_cross_attention = False + # check model config + assert model_config.layer_types is not None, "layer_types have to be set correctly for VSWA" # Construct WorldConfig from self.mapping world_config_cpp = WorldConfig(