Skip to content

Commit 97674c3

Browse files
authored
[TRTLLM-8690][feat] add more tensors to share buffers (#8691)
Signed-off-by: Hui Gao <[email protected]>
1 parent ed297d7 commit 97674c3

File tree

7 files changed

+152
-77
lines changed

7 files changed

+152
-77
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,22 @@ def positions(self) -> torch.Tensor:
124124

125125
def __post_init__(self) -> None:
126126
super().__post_init__()
127+
self._post_init_with_buffers(self.cuda_graph_buffers)
128+
129+
def _post_init_with_buffers(self, buffers) -> None:
130+
capture_graph = torch.cuda.is_current_stream_capturing()
127131

128132
if self.workspace_buffer is None:
129133
# Note: even though flashinfer only recommends 128 MB, we have to push it
130134
# a bit higher to cover all possible CUDA graph cases. If it's too small,
131135
# warmup will crash.
132-
self.workspace_buffer = torch.empty(320 * 1024 * 1024,
133-
dtype=torch.uint8,
134-
device="cuda")
136+
self.workspace_buffer = self.get_empty(
137+
buffers,
138+
(320 * 1024 * 1024, ),
139+
dtype=torch.uint8,
140+
cache_name="workspace_buffer",
141+
capture_graph=capture_graph,
142+
)
135143

136144
self.paged_kv_indptr_decode = torch.empty((self.max_num_requests + 1, ),
137145
device='cuda',
@@ -163,9 +171,13 @@ def __post_init__(self) -> None:
163171

164172
if self.kv_cache_manager is not None:
165173
max_num_pages = self.kv_cache_manager.blocks_in_primary_pool
166-
self._paged_kv_indices = torch.empty((max_num_pages, ),
167-
device='cuda',
168-
dtype=torch.int)
174+
self._paged_kv_indices = self.get_empty(
175+
buffers,
176+
(max_num_pages, ),
177+
dtype=torch.int,
178+
cache_name="_paged_kv_indices",
179+
capture_graph=capture_graph,
180+
)
169181

170182
def create_cuda_graph_metadata(self,
171183
max_batch_size: int,

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tensorrt_llm.mapping import Mapping
1818
from tensorrt_llm.models.modeling_utils import QuantConfig
1919

20+
from ..memory_buffer_utils import Buffers
2021
from ..metadata import KVCacheParams
2122
from ..pyexecutor.resource_manager import KVCacheManager
2223
from ..utils import get_model_extra_attrs
@@ -349,6 +350,49 @@ def update_for_spec_dec(self) -> None:
349350
Hook to be called during forward when using spec-dec one-model mode.
350351
"""
351352

353+
@staticmethod
354+
def get_empty(buffers: Buffers,
355+
tensor_shape: list[int],
356+
dtype: torch.dtype,
357+
cache_name: str,
358+
capture_graph: bool = False) -> torch.Tensor:
359+
"""
360+
Finds a compatible, reusable buffer from a cache or creates a new one.
361+
362+
This function searches for a pre-allocated tensor (buffer) that can be
363+
reused for an operation involving a tensor with the shape of `tensor_shape`.
364+
365+
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
366+
367+
If a compatible buffer is found, it's returned immediately. Otherwise, a new
368+
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
369+
370+
Args:
371+
tensor_shape: The required shape.
372+
dtype: The required dtype.
373+
cache_name: The key for the specific list of buffers to search in.
374+
Returns:
375+
An existing compatible buffer or a newly created one.
376+
"""
377+
if buffers is None:
378+
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
379+
380+
return buffers.get_buffer(tensor_shape, dtype, cache_name,
381+
capture_graph)
382+
383+
@staticmethod
384+
def get_empty_like(buffers,
385+
like_tensor: torch.Tensor,
386+
cache_name: str,
387+
capture_graph: bool = False) -> torch.Tensor:
388+
return AttentionMetadata.get_empty(
389+
buffers,
390+
like_tensor.shape,
391+
dtype=like_tensor.dtype,
392+
cache_name=cache_name,
393+
capture_graph=capture_graph,
394+
)
395+
352396

353397
class PositionalEmbedder(Protocol):
354398
"""

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -304,17 +304,12 @@ def __post_init__(self):
304304

305305
capture_graph = torch.cuda.is_current_stream_capturing()
306306

307-
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
308-
cache_name: str) -> torch.Tensor:
309-
if self.cuda_graph_buffers is None:
310-
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
311-
return self.cuda_graph_buffers.get_buffer(tensor_shape, dtype,
312-
cache_name, capture_graph)
313-
314-
self.indexer_k_cache_block_offsets = get_empty(
307+
self.indexer_k_cache_block_offsets = self.get_empty(
308+
self.cuda_graph_buffers,
315309
[self.max_num_sequences, self.kv_cache_manager.max_blocks_per_seq],
316310
cache_name="indexer_k_cache_block_offsets",
317311
dtype=torch.int32,
312+
capture_graph=capture_graph,
318313
)
319314
self.host_indexer_k_cache_block_offsets = torch.zeros_like(
320315
self.indexer_k_cache_block_offsets,
@@ -324,41 +319,49 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
324319

325320
# For mla_rope_append_paged_kv_assign_q
326321
if not self.enable_context_mla_with_cached_kv:
327-
self.ctx_cached_token_indptr = get_empty(
322+
self.ctx_cached_token_indptr = self.get_empty(
323+
self.cuda_graph_buffers,
328324
(self.max_num_requests + 1, ),
329325
cache_name="ctx_cached_token_indptr",
330326
dtype=torch.int64,
327+
capture_graph=capture_graph,
331328
)
332329
self.host_ctx_cached_token_indptr = torch.zeros_like(
333330
self.ctx_cached_token_indptr,
334331
device='cpu',
335332
pin_memory=True,
336333
)
337-
self.ctx_kv_indptr = get_empty(
334+
self.ctx_kv_indptr = self.get_empty(
335+
self.cuda_graph_buffers,
338336
(self.max_num_requests + 1, ),
339337
cache_name="ctx_kv_indptr",
340338
dtype=torch.int64,
339+
capture_graph=capture_graph,
341340
)
342341
self.host_ctx_kv_indptr = torch.zeros_like(
343342
self.ctx_kv_indptr,
344343
device='cpu',
345344
pin_memory=True,
346345
)
347346
# New generation buffers for dsa
348-
self.gen_cached_token_indptr = get_empty(
347+
self.gen_cached_token_indptr = self.get_empty(
348+
self.cuda_graph_buffers,
349349
(self.max_num_requests + 1, ),
350350
cache_name="gen_cached_token_indptr",
351351
dtype=torch.int64,
352+
capture_graph=capture_graph,
352353
)
353354
self.host_gen_cached_token_indptr = torch.zeros_like(
354355
self.gen_cached_token_indptr,
355356
device='cpu',
356357
pin_memory=True,
357358
)
358-
self.gen_kv_indptr = get_empty(
359+
self.gen_kv_indptr = self.get_empty(
360+
self.cuda_graph_buffers,
359361
(self.max_num_requests + 1, ),
360362
cache_name="gen_kv_indptr",
361363
dtype=torch.int64,
364+
capture_graph=capture_graph,
362365
)
363366
self.host_gen_kv_indptr = torch.zeros_like(
364367
self.gen_kv_indptr,
@@ -367,52 +370,66 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
367370
)
368371
# Indexer metadata
369372
# Separate slot mappings for non-interleaved layout (flat byte indices)
370-
self.slot_mapping_fp8 = get_empty(
373+
self.slot_mapping_fp8 = self.get_empty(
374+
self.cuda_graph_buffers,
371375
(self.max_num_tokens, ),
372376
cache_name="slot_mapping_fp8",
373377
dtype=torch.int64,
378+
capture_graph=capture_graph,
374379
)
375380
self.host_slot_mapping_fp8 = torch.zeros_like(
376381
self.slot_mapping_fp8,
377382
device='cpu',
378383
pin_memory=True,
379384
)
380-
self.slot_mapping_scale = get_empty(
385+
self.slot_mapping_scale = self.get_empty(
386+
self.cuda_graph_buffers,
381387
(self.max_num_tokens, ),
382388
cache_name="slot_mapping_scale",
383389
dtype=torch.int64,
390+
capture_graph=capture_graph,
384391
)
385392
self.host_slot_mapping_scale = torch.zeros_like(
386393
self.slot_mapping_scale,
387394
device='cpu',
388395
pin_memory=True,
389396
)
390397
# Per-token request index buffer for topk_indices conversion
391-
self.req_idx_per_token = get_empty(
398+
self.req_idx_per_token = self.get_empty(
399+
self.cuda_graph_buffers,
392400
(self.max_num_tokens, ),
393401
cache_name="req_idx_per_token",
394402
dtype=torch.int32,
403+
capture_graph=capture_graph,
395404
)
396405
# Block table for topk_indices conversion (shared for context and generation)
397-
self.block_table = get_empty(
406+
self.block_table = self.get_empty(
407+
self.cuda_graph_buffers,
398408
(self.max_num_requests, self.kv_cache_manager.max_blocks_per_seq),
399409
cache_name="block_table",
400410
dtype=torch.int32,
411+
capture_graph=capture_graph,
401412
)
402-
self.scheduler_metadata_buffer = get_empty(
413+
self.scheduler_metadata_buffer = self.get_empty(
414+
self.cuda_graph_buffers,
403415
(self.num_sms + 1, 2),
404416
cache_name="scheduler_metadata_buffer",
405417
dtype=torch.int32,
418+
capture_graph=capture_graph,
406419
)
407-
self.cu_seqlen_ks = get_empty(
420+
self.cu_seqlen_ks = self.get_empty(
421+
self.cuda_graph_buffers,
408422
(self.max_num_tokens, ),
409423
cache_name="cu_seqlen_ks",
410424
dtype=torch.int32,
425+
capture_graph=capture_graph,
411426
)
412-
self.cu_seqlen_ke = get_empty(
427+
self.cu_seqlen_ke = self.get_empty(
428+
self.cuda_graph_buffers,
413429
(self.max_num_tokens, ),
414430
cache_name="cu_seqlen_ke",
415431
dtype=torch.int32,
432+
capture_graph=capture_graph,
416433
)
417434

418435
def prepare(self):

tensorrt_llm/_torch/attention_backend/sparse/rocket.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,19 @@ def __post_init__(self):
3535
if self.sparse_attention_config is None:
3636
raise ValueError("Sparse attention config is not set")
3737
self.prompt_budget = self.sparse_attention_config.prompt_budget
38-
self.kt_cache_block_offsets = torch.empty(
38+
39+
capture_graph = torch.cuda.is_current_stream_capturing()
40+
self.kt_cache_block_offsets = self.get_empty(
41+
self.cuda_graph_buffers,
3942
[
4043
self.max_num_sequences,
4144
self.kv_cache_manager.max_kt_blocks_per_seq
4245
],
4346
dtype=torch.int32,
44-
device='cuda',
47+
cache_name="kt_cache_block_offsets",
48+
capture_graph=capture_graph,
4549
)
50+
4651
self.host_kt_cache_block_offsets = torch.zeros_like(
4752
self.kt_cache_block_offsets,
4853
device='cpu',

0 commit comments

Comments
 (0)