Skip to content

Commit 5036ed3

Browse files
committed
add memory estimation for attention metadata to solve OOM issue when cuda graph is enabled and max window size is large
Signed-off-by: qixiang-99 <[email protected]>
1 parent d2b1162 commit 5036ed3

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,57 @@ def _get_free_gpu_memory_fraction(self) -> float:
9696
fraction = 0.9
9797
return fraction
9898

99+
def _get_num_graphs(self) -> int:
100+
return len(self._model_engine._cuda_graph_batch_sizes)
101+
102+
def _get_extra_memory_for_attention_metadata(
103+
self, kv_cache_manager: KVCacheManager) -> int:
104+
"""
105+
`kv_cache_block_offsets` (see `TrtllmAttentionMetadata`) stores the KV-cache
106+
block offsets for every request. Its layout is
107+
[num_pools, max_num_sequences, 2, max_blocks_per_seq].
108+
109+
• Estimation phase: we run a dry-run with requests of length
110+
`max_num_tokens` (e.g. 8192). Consequently, `max_blocks_per_seq` is small
111+
and the tensor's footprint appears modest.
112+
113+
• Real inference: `max_blocks_per_seq` can increase to
114+
`max_seq_len / tokens_per_block`. For long-context models this is
115+
orders of magnitude larger, so the tensor consumes significantly more
116+
GPU memory.
117+
118+
• CUDA graphs: when graph capture is enabled the full
119+
`kv_cache_block_offsets` tensor must be pre-allocated,
120+
making the extra memory grow linearly with the number of graphs.
121+
"""
122+
# get the max_blocks_per_seq in estimation phase
123+
est_phase_max_blocks_per_seq = kv_cache_manager.max_blocks_per_seq
124+
125+
max_batch_size = self._executor_config.max_batch_size
126+
if max_batch_size is None:
127+
logger.warning(f"max_batch_size is not set, using 1")
128+
max_batch_size = 1
129+
max_window_size = max(
130+
self._executor_config.kv_cache_config.max_attention_window)
131+
tokens_per_block = self._executor_config.tokens_per_block
132+
num_pools = kv_cache_manager.num_pools
133+
134+
# calculate the max_blocks_per_seq in real inference phase
135+
real_phase_max_blocks_per_seq = int(
136+
(max_window_size + tokens_per_block - 1) // tokens_per_block)
137+
138+
# calculate the extra memory from kv_cache_block_offsets for each graph
139+
extra_bytes_per_graph = (
140+
real_phase_max_blocks_per_seq -
141+
est_phase_max_blocks_per_seq) * num_pools * max_batch_size * 2 * 4
142+
# get number of graphs
143+
num_graphs = self._get_num_graphs()
144+
total_extra_bytes = int(extra_bytes_per_graph * num_graphs)
145+
logger.info(
146+
f"extra bytes per graph from kv_cache_block_offsets: {extra_bytes_per_graph / (GB):.2f} GiB, total extra bytes: {total_extra_bytes / (GB):.2f} GiB"
147+
)
148+
return total_extra_bytes
149+
99150
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
100151
alloc_kv_tokens: int) -> int:
101152
"""
@@ -256,6 +307,13 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
256307
logger.info(
257308
f"Memory used outside torch (e.g., NCCL and CUDA graphs) in memory usage profiling: {extra_cost / (GB):.2f} GiB"
258309
)
310+
311+
# get extra memory from attention metadata
312+
extra_memory_for_attention_metadata = self._get_extra_memory_for_attention_metadata(
313+
py_executor.resource_manager.resource_managers.get(
314+
ResourceManagerType.KV_CACHE_MANAGER))
315+
peak_memory += extra_memory_for_attention_metadata
316+
259317
kv_stats = py_executor.resource_manager.resource_managers.get(
260318
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
261319

0 commit comments

Comments
 (0)