Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def _create_kv_cache_manager(
kv_connector_manager=self._kv_connector_manager
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
is_estimating_kv_cache=estimating_kv_cache,
)
elif is_nemotron_hybrid(config):
if self._max_beam_width > 1:
Expand Down Expand Up @@ -518,6 +519,7 @@ def _create_kv_cache_manager(
mapping=mapping,
dtype=kv_cache_dtype,
spec_config=spec_config,
is_estimating_kv_cache=estimating_kv_cache,
)
elif is_qwen3_next(config):
if self._max_beam_width > 1:
Expand Down Expand Up @@ -568,6 +570,7 @@ def _create_kv_cache_manager(
mapping=mapping,
dtype=kv_cache_dtype,
spec_config=spec_config,
is_estimating_kv_cache=estimating_kv_cache,
)
else:
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
Expand Down Expand Up @@ -595,6 +598,7 @@ def _create_kv_cache_manager(
kv_connector_manager=self._kv_connector_manager
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
is_estimating_kv_cache=estimating_kv_cache,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def __init__(
mapping: Mapping,
dtype: DataType = DataType.HALF,
spec_config: Optional["DecodingBaseConfig"] = None,
is_estimating_kv_cache: bool = False,
) -> None:

# mamba hybrid cache requires block reuse to be disabled in KV cache config
Expand Down Expand Up @@ -231,6 +232,7 @@ def __init__(
dtype=dtype,
spec_config=spec_config,
layer_mask=layer_mask,
is_estimating_kv_cache=is_estimating_kv_cache,
)

def prepare_resources(self, scheduled_batch: ScheduledRequests):
Expand Down
89 changes: 58 additions & 31 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def __init__(
enable_indexer_k_cache: bool = False,
indexer_k_cache_quant_block_size: int = 128,
indexer_k_cache_index_head_dim: int = 0,
is_estimating_kv_cache: bool = False,
**kwargs,
) -> None:
self.mapping = mapping
Expand Down Expand Up @@ -269,37 +270,61 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
# Determine if this is VSWA (Variable Sliding Window Attention)
self.is_vswa = len(set(self.max_attention_window_vec)) > 1

# Calculate blocks per window using appropriate method
if self.is_vswa:
# VSWA case: use C++ implementation for variable window sizes
# model config check
if model_config is None:
raise ValueError(
"model_config is required for VSWA (Variable Sliding Window Attention)"
)
# kv cache config check
assert isinstance(
kv_cache_config, KvCacheConfig
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
kv_cache_config=kv_cache_config,
model_config=model_config,
extra_cost_memory=0,
)
else:
# Standard case: use original Python implementation
self.blocks_in_primary_pool, self.blocks_in_secondary_pool = self.calculate_max_num_blocks(
kv_cache_config=kv_cache_config,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
mapping=mapping,
dtype=dtype,
kv_factor=self.kv_factor,
# Calculate kv cache blocks for each window size
# FIXME: flashinfer.py accesses kv_cache_manager.blocks_in_primary_pool
# This dependency should be adjusted as it only covers the single window
# case and not VSWA scheme.
if is_estimating_kv_cache:
# If this is an estimation dry run, we have already calculated the
# max_tokens under _util.py::try_prepare_estimation
# Since this is a dry run, assigning the same max_tokens capacity
# to all window sizes as they are full attentions is enough.
self.blocks_in_primary_pool = int(kv_cache_config.max_tokens //
tokens_per_block)

host_cache_size = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0
max_tokens_secondary = host_cache_size // self.get_cache_bytes_per_token(
)
self.blocks_in_secondary_pool = int(max_tokens_secondary //
tokens_per_block)

blocks_per_window = {
self.max_attention_window_vec[0]:
window_size:
(self.blocks_in_primary_pool, self.blocks_in_secondary_pool)
for window_size in set(self.max_attention_window_vec)
}
logger.info(
f"[kv cache manager] Primary/secondary blocks for window sizes set to {blocks_per_window} for estimation dry run"
)
else:
if self.is_vswa:
# VSWA case: use C++ implementation for variable window sizes
if model_config is None:
raise ValueError(
"model_config is required for VSWA (Variable Sliding Window Attention)"
)
assert isinstance(
kv_cache_config, KvCacheConfig
), "calculate_max_num_blocks_from_cpp only accepts KvCacheConfig"
blocks_per_window = self.calculate_max_num_blocks_from_cpp(
kv_cache_config=kv_cache_config,
model_config=model_config,
extra_cost_memory=0,
)
else:
# Standard case: use original Python implementation
self.blocks_in_primary_pool, self.blocks_in_secondary_pool = self.calculate_max_num_blocks(
kv_cache_config=kv_cache_config,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
mapping=mapping,
dtype=dtype,
kv_factor=self.kv_factor,
)
blocks_per_window = {
self.max_attention_window_vec[0]:
(self.blocks_in_primary_pool, self.blocks_in_secondary_pool)
}

# Validate and adjust attention windows against their upper bounds if needed
blocks_per_window, self.max_seq_len, self.max_attention_window_vec = self._validate_and_adjust_attention_windows(
Expand Down Expand Up @@ -736,11 +761,13 @@ def calculate_max_num_blocks(self,
max_tokens = mpi_comm().allreduce(max_tokens, op=MPI.MIN)

# get number of blocks
blocks_in_primary_pool = math.ceil(max_tokens / tokens_per_block)
blocks_in_primary_pool = int(max_tokens // tokens_per_block)

host_cache_size = kv_cache_config.host_cache_size if kv_cache_config.host_cache_size else 0
max_tokens_secondary = host_cache_size / cache_size_bytes_per_token
blocks_in_secondary_pool = max(
0, int(max_tokens_secondary / tokens_per_block))
max_tokens_secondary = host_cache_size // self.get_cache_bytes_per_token(
)
blocks_in_secondary_pool = int(max_tokens_secondary // tokens_per_block)

return blocks_in_primary_pool, blocks_in_secondary_pool

def get_max_atten_window_upper_bound(self, blocks_in_primary_pool,
Expand Down