Skip to content

Commit 42b9f10

Browse files
HuiGao-NVdominicshanshan
authored andcommitted
[https://nvbugs/5474169][fix]Adjust max seq len for kvcache for memory estimation (NVIDIA#7391)
Signed-off-by: Hui Gao <[email protected]>
1 parent 4d33f07 commit 42b9f10

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def __init__(self, *, executor_config: ExecutorConfig,
5252
self._draft_model_engine = draft_model_engine
5353
self._mapping = mapping
5454
self._max_kv_tokens_in = self._executor_config.kv_cache_config.max_tokens
55-
self._dummy_reqs = self._create_dummy_context_requests(net_max_seq_len -
56-
1)
5755
self._kv_connector_manager = kv_connector_manager
56+
self._dummy_reqs = None
57+
self._max_seq_len = net_max_seq_len
5858

5959
@staticmethod
6060
def _get_cache_size_per_token(model_config: ModelConfig,
@@ -177,6 +177,10 @@ def _get_token_num_for_estimation(self) -> int:
177177
if spec_cfg is not None:
178178
num_extra_tokens_per_seq += spec_cfg.max_draft_len
179179
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
180+
181+
if self._dummy_reqs is None:
182+
self._dummy_reqs = self._create_dummy_context_requests(
183+
max(1, self._max_seq_len - 1))
180184
for req in self._dummy_reqs:
181185
num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq
182186
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
@@ -466,6 +470,10 @@ def _create_kv_cache_manager(
466470
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
467471
executor_config.max_seq_len = kv_cache_manager.max_seq_len
468472

473+
# When SWA is enabled, max_seq_len is updated inside kv_cache_manager.
474+
if kv_cache_manager is not None:
475+
self._max_seq_len = kv_cache_manager.max_seq_len
476+
469477
return kv_cache_manager
470478

471479
def build_managers(self,

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def calculate_max_num_blocks(self,
582582
if kv_cache_config.free_gpu_memory_fraction is not None:
583583
max_tokens = min(kv_cache_config.max_tokens, max_tokens)
584584
logger.warning(
585-
f'Both free_gpu_memory_fraction and max_tokens are set (to {free_mem_fraction} and {kv_cache_config.max_tokens}, respectively). The smaller value will be used.'
585+
f'Both free_gpu_memory_fraction and max_tokens are set (to {free_mem_fraction} and {max_tokens} with free memory {free_mem / (1 << 32)} of total memory {total_mem / (1<<32)}, respectively). The smaller value will be used.'
586586
)
587587
else:
588588
max_tokens = kv_cache_config.max_tokens

0 commit comments

Comments
 (0)