Skip to content

Commit 76c92ed

Browse files
committed
[None][fix] Remove overwrite of kv_cache_config.max_tokens
The existing code overwrites kv_cache_config.max_token and this restricts kv_cache_config.max_token from be passed to the kv_cache_manager. This is not correct, this commit fixes it. Additionally, we have `max_gpu_total_bytes` from NVIDIA#5933 to estimate GPU memory now. The next step is to remove the `max_tokens` concept as it is confusing under a VSWA scheme and overlaps with `max_gpu_total_bytes` under full attention scheme. Signed-off-by: eopXD <[email protected]>
1 parent 27677a3 commit 76c92ed

File tree

1 file changed

+2
-40
lines changed

1 file changed

+2
-40
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -188,50 +188,12 @@ def _create_dummy_context_requests(
188188
requests = requests * self._mapping.tp_size
189189
return requests
190190

191-
def _get_token_num_for_estimation(self) -> int:
192-
"""Compute KV cache capacity required for estimate_max_kv_cache_tokens to succeed."""
193-
if 'cp_type' in self._mapping.cp_config:
194-
raise ValueError(
195-
"KV cache size estimation not supported with context parallelism."
196-
)
197-
# estimate_max_kv_cache_tokens submits self._dummy_reqs
198-
num_cache_blocks = 0
199-
num_extra_tokens_per_seq = 1 # account for generated tokens
200-
pytorch_backend_config = self._pytorch_backend_config
201-
spec_cfg = self._speculative_config
202-
if not pytorch_backend_config.disable_overlap_scheduler:
203-
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
204-
if spec_cfg is not None:
205-
num_extra_tokens_per_seq += spec_cfg.max_draft_len
206-
207-
if spec_cfg is not None:
208-
num_extra_tokens_per_seq += spec_cfg.max_draft_len
209-
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
210-
211-
if self._dummy_reqs is None:
212-
self._dummy_reqs = self._create_dummy_context_requests(
213-
max(1, self._net_max_seq_len - 1))
214-
for req in self._dummy_reqs:
215-
num_req_tokens = len(req.input_token_ids) + num_extra_tokens_per_seq
216-
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
217-
num_cache_blocks += (num_req_tokens + self._tokens_per_block -
218-
1) // self._tokens_per_block
219-
# Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation
220-
return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[
221-
0].sampling_config.beam_width
222-
223191
def try_prepare_estimation(self) -> bool:
224192
"""Prepare for possible KV cache capacity estimation.
225193
226-
This updates `kv_cache_config` and returns a boolean indicating whether KV cache
227-
estimation is to be performend.
194+
Returns a boolean indicating whether KV cache estimation is to be performed.
228195
"""
229-
estimating_kv_cache = False
230-
if 'cp_type' not in self._mapping.cp_config:
231-
estimating_kv_cache = True
232-
self._kv_cache_config.max_tokens = self._get_token_num_for_estimation(
233-
)
234-
return estimating_kv_cache
196+
return 'cp_type' not in self._mapping.cp_config
235197

236198
def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
237199
"""Perform KV cache capacity estimation.

0 commit comments

Comments
 (0)