@@ -96,6 +96,57 @@ def _get_free_gpu_memory_fraction(self) -> float:
9696 fraction = 0.9
9797 return fraction
9898
99+ def _get_num_graphs (self ) -> int :
100+ return len (self ._model_engine ._cuda_graph_batch_sizes )
101+
102+ def _get_extra_memory_for_attention_metadata (
103+ self , kv_cache_manager : KVCacheManager ) -> int :
104+ """
105+ `kv_cache_block_offsets` (see `TrtllmAttentionMetadata`) stores the KV-cache
106+ block offsets for every request. Its layout is
107+ [num_pools, max_num_sequences, 2, max_blocks_per_seq].
108+
109+ • Estimation phase: we run a dry-run with requests of length
110+ `max_num_tokens` (e.g. 8192). Consequently, `max_blocks_per_seq` is small
111+ and the tensor's footprint appears modest.
112+
113+ • Real inference: `max_blocks_per_seq` can increase to
114+ `max_seq_len / tokens_per_block`. For long-context models this is
115+ orders of magnitude larger, so the tensor consumes significantly more
116+ GPU memory.
117+
118+ • CUDA graphs: when graph capture is enabled the full
119+ `kv_cache_block_offsets` tensor must be pre-allocated,
120+ making the extra memory grow linearly with the number of graphs.
121+ """
122+ # get the max_blocks_per_seq in estimation phase
123+ est_phase_max_blocks_per_seq = kv_cache_manager .max_blocks_per_seq
124+
125+ max_batch_size = self ._executor_config .max_batch_size
126+ if max_batch_size is None :
127+ logger .warning (f"max_batch_size is not set, using 1" )
128+ max_batch_size = 1
129+ max_window_size = max (
130+ self ._executor_config .kv_cache_config .max_attention_window )
131+ tokens_per_block = self ._executor_config .tokens_per_block
132+ num_pools = kv_cache_manager .num_pools
133+
134+ # calculate the max_blocks_per_seq in real inference phase
135+ real_phase_max_blocks_per_seq = int (
136+ (max_window_size + tokens_per_block - 1 ) // tokens_per_block )
137+
138+ # calculate the extra memory from kv_cache_block_offsets for each graph
139+ extra_bytes_per_graph = (
140+ real_phase_max_blocks_per_seq -
141+ est_phase_max_blocks_per_seq ) * num_pools * max_batch_size * 2 * 4
142+ # get number of graphs
143+ num_graphs = self ._get_num_graphs ()
144+ total_extra_bytes = int (extra_bytes_per_graph * num_graphs )
145+ logger .info (
146+ f"extra bytes per graph from kv_cache_block_offsets: { extra_bytes_per_graph / (GB ):.2f} GiB, total extra bytes: { total_extra_bytes / (GB ):.2f} GiB"
147+ )
148+ return total_extra_bytes
149+
99150 def _cal_max_memory (self , peak_memory , total_gpu_memory , fraction ,
100151 alloc_kv_tokens : int ) -> int :
101152 """
@@ -256,6 +307,13 @@ def configure_kv_cache_capacity(self, py_executor: PyExecutor) -> None:
256307 logger .info (
257308 f"Memory used outside torch (e.g., NCCL and CUDA graphs) in memory usage profiling: { extra_cost / (GB ):.2f} GiB"
258309 )
310+
311+ # get extra memory from attention metadata
312+ extra_memory_for_attention_metadata = self ._get_extra_memory_for_attention_metadata (
313+ py_executor .resource_manager .resource_managers .get (
314+ ResourceManagerType .KV_CACHE_MANAGER ))
315+ peak_memory += extra_memory_for_attention_metadata
316+
259317 kv_stats = py_executor .resource_manager .resource_managers .get (
260318 ResourceManagerType .KV_CACHE_MANAGER ).get_kv_cache_stats ()
261319
0 commit comments