Skip to content

Commit 3a96d75

Browse files
authored
[https://nvbugs/5527956][fix] AutoDeploy: fix IMA due to outdated metadata (#8002)
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 2e5850c commit 3a96d75

File tree

6 files changed

+51
-39
lines changed

6 files changed

+51
-39
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torch.fx import Node
2020

2121
from ...._utils import nvtx_range
22+
from ..utils.logger import ad_logger
2223

2324
DynamicShape = Dict[int, Dim] # indicating the dynamic shape in tensor dimension
2425
DynamicShapeCallback = Callable[[], DynamicShape]
@@ -122,22 +123,28 @@ def __init__(
122123
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
123124
max_seq_len_adjusted = self.max_seq_len + 1
124125

125-
if max_num_tokens is None or max_num_tokens < 1:
126-
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
127-
else:
128-
self.max_num_tokens = max_num_tokens
126+
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len_adjusted,
127+
# we use the provided max_num_tokens. If max_num_tokens provided is more, we still use
128+
# max_batch_size * max_seq_len_adjusted since the extra tokens cannot be used.
129+
self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted
130+
if max_num_tokens is not None and max_num_tokens > 0:
131+
self.max_num_tokens = min(self.max_num_tokens, max_num_tokens)
129132

130-
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
131-
# we use the provided max_num_tokens to calculate the number of pages
132-
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
133133
# Num pages can not be less than max_batch_size.
134134
self._num_pages = max(
135135
self.max_batch_size,
136-
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
136+
(self.max_num_tokens) // self.page_size # floored number of pages
137+
+ (self.max_num_tokens % self.page_size > 0) * self.max_batch_size, # +1 per sequence
137138
)
138139
# sanity check
139140
assert self.num_pages >= self.max_batch_size, "num_pages can't be less than max_batch_size"
140141

142+
# log parameters
143+
ad_logger.info(
144+
f"[SequenceInfo:] {self.max_seq_len=}, {self.max_batch_size=}, {self.page_size=}, "
145+
f"{self.max_num_tokens=} (inferred), {max_num_tokens=} (provided), {self.num_pages=}"
146+
)
147+
141148
# indicator if extra args are activated that are needed for cached attention backends
142149
self._is_cached_attn = False
143150

@@ -572,6 +579,12 @@ def _store_arg(
572579
# pin the memory on the host
573580
tnsr_host = torch.tensor(tnsr_like, dtype=tnsr_device.dtype, pin_memory=True)
574581

582+
# check for available space
583+
assert tnsr_device.numel() >= tnsr_host.numel(), (
584+
f"device tensor {name} is too small, available: {tnsr_device.numel()}, "
585+
f"required: {tnsr_host.numel()}"
586+
)
587+
575588
# reset/copy to the device in a non-blocking fashion
576589
if reset:
577590
tnsr_device.zero_()
@@ -632,8 +645,8 @@ def nest_sequences(
632645
cache_loc, pages_per_seq = self._get_cache_locations_and_pages_per_sequence(
633646
page_assignments
634647
)
635-
self._store_arg("cache_loc", cache_loc)
636-
self._store_arg("pages_per_seq", pages_per_seq)
648+
self._store_arg("cache_loc", cache_loc, reset=True)
649+
self._store_arg("pages_per_seq", pages_per_seq, reset=True)
637650

638651
### UPDATE MAIN INPUTS #####################################################################
639652
# set new input_ids and make sure to flatten it

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ def build_from_config(cls, ad_config: AutoDeployConfig):
8989
attn_page_size = ad_config.attn_page_size
9090
max_num_tokens = ad_config.max_num_tokens
9191
max_beam_width = ad_config.max_beam_width
92-
ad_logger.info(
93-
f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}, {max_beam_width=}"
94-
)
9592

9693
# update device to contain the current default device if it's in cuda
9794
device = torch.device(ad_config.device)

tensorrt_llm/_torch/auto_deploy/shim/interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def resize_cache(self, new_num_pages: int):
7373
self.info.num_pages = new_num_pages
7474
for name, cache in self._caches.items():
7575
# We assume cache is a tensor of shape (max_batch_size, page_size, n_heads, head_dim)
76-
if "cache" in name:
76+
# TODO: cache resize should ideally be handled via a callback to the AttentionDescriptor
77+
# to avoid hard-coding any assumptions about the cache shape or its "pagedness"
78+
if "k_cache" in name or "v_cache" in name:
7779
current_shape = cache.shape
7880
new_shape = (new_num_pages, *current_shape[1:])
7981
cache.resize_(new_shape)

tensorrt_llm/_torch/auto_deploy/transform/library/cleanup_input_constraints.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ def _apply(
4343
raise TypeError(f"Unexpected type {type(s)} in symbolic shape.")
4444

4545
# update the max constraint for each vr
46-
max_total = math.prod(vr.upper for vr in vrs)
46+
# NOTE: this is more a heuristic anyway than a strict constraint. We just want to make sure
47+
# that this never gets triggered. So we multiply by 1000 to be safe. Not that it has to
48+
# be a symint (not an int) --> so that's why we use a heuristic based on the existing
49+
# symint values instead of just using e.g. max_num_tokens...
50+
max_total = math.prod(vr.upper for vr in vrs) * 1000
4751
for vr in vrs:
4852
object.__setattr__(vr, "upper", max_total)
4953

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ...models.factory import ModelFactory
1313
from ...shim.interface import CachedSequenceInterface
1414
from ...transformations._graph import add_graph_input
15-
from ...utils.logger import ad_logger
1615
from ...utils.node_utils import get_all_input_output_nodes, is_op
1716
from ..interface import (
1817
BaseTransform,
@@ -280,34 +279,32 @@ def _get_mem_info_in_mb():
280279
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
281280
)
282281

283-
try:
284-
# Let's run a forward pass to get the memory usage
285-
cm.info.set_max_num_tokens_sample()
286-
free_mem_pre, _ = _get_mem_info_in_mb()
287-
self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}")
282+
# TODO: the manual PyTorch workflow respects max_num_tokens if set and does _NOT_ resize
283+
# the cache in this case. Should we do the same here?
288284

289-
self._run_forward(gm, cm)
285+
# Let's run a forward pass to get the memory usage
286+
cm.info.set_max_num_tokens_sample()
287+
free_mem_pre, _ = _get_mem_info_in_mb()
288+
self._log_info(f"Free memory before forward pass (MB): {free_mem_pre}")
290289

291-
free_mem_post, _ = _get_mem_info_in_mb()
292-
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")
290+
self._run_forward(gm, cm)
293291

294-
memory_for_forward_pass = free_mem_pre - free_mem_post
295-
self._log_info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
292+
free_mem_post, _ = _get_mem_info_in_mb()
293+
self._log_info(f"Free memory after forward pass (MB): {free_mem_post}")
296294

297-
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
298-
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
295+
memory_for_forward_pass = free_mem_pre - free_mem_post
296+
self._log_info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
299297

300-
# Need to sync all the GPUs
301-
gathered_num_pages = [None] * get_world_size()
302-
all_gather_object(gathered_num_pages, new_num_pages)
303-
new_num_pages = min(gathered_num_pages)
304-
self._log_info(f"After all_gather - new_num_pages: {new_num_pages}")
298+
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
299+
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
305300

306-
cm.resize_cache(new_num_pages)
307-
except Exception as e:
308-
ad_logger.warning(
309-
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
310-
)
301+
# Need to sync all the GPUs
302+
gathered_num_pages = [None] * get_world_size()
303+
all_gather_object(gathered_num_pages, new_num_pages)
304+
new_num_pages = min(gathered_num_pages)
305+
self._log_info(f"After all_gather - new_num_pages: {new_num_pages}")
306+
307+
cm.resize_cache(new_num_pages)
311308

312309
# Free memory
313310
torch.cuda.empty_cache()

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,6 @@ test_e2e.py::test_trtllm_multimodal_benchmark_serving SKIP (https://nvbugs/55233
327327
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[MMLU-gen_tp=2-ctx_pp=4] SKIP (https://nvbugs/5434320)
328328
examples/test_llama.py::test_llm_llama_1gpu_fp8_kv_cache[llama-v2-7b-hf-bfloat16] SKIP (https://nvbugs/5527940)
329329
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5528070)
330-
accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype SKIP (https://nvbugs/5527956)
331330
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True] SKIP (https://nvbugs/5509024)
332331
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[throughput] SKIP (https://nvbugs/5481198)
333332
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale_chunked_prefill[latency] SKIP (https://nvbugs/5481198)

0 commit comments

Comments
 (0)