@@ -304,17 +304,12 @@ def __post_init__(self):
304304
305305 capture_graph = torch .cuda .is_current_stream_capturing ()
306306
307- def get_empty (tensor_shape : list [int ], dtype : torch .dtype ,
308- cache_name : str ) -> torch .Tensor :
309- if self .cuda_graph_buffers is None :
310- return torch .zeros (tensor_shape , device = 'cuda' , dtype = dtype )
311- return self .cuda_graph_buffers .get_buffer (tensor_shape , dtype ,
312- cache_name , capture_graph )
313-
314- self .indexer_k_cache_block_offsets = get_empty (
307+ self .indexer_k_cache_block_offsets = self .get_empty (
308+ self .cuda_graph_buffers ,
315309 [self .max_num_sequences , self .kv_cache_manager .max_blocks_per_seq ],
316310 cache_name = "indexer_k_cache_block_offsets" ,
317311 dtype = torch .int32 ,
312+ capture_graph = capture_graph ,
318313 )
319314 self .host_indexer_k_cache_block_offsets = torch .zeros_like (
320315 self .indexer_k_cache_block_offsets ,
@@ -324,41 +319,49 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
324319
325320 # For mla_rope_append_paged_kv_assign_q
326321 if not self .enable_context_mla_with_cached_kv :
327- self .ctx_cached_token_indptr = get_empty (
322+ self .ctx_cached_token_indptr = self .get_empty (
323+ self .cuda_graph_buffers ,
328324 (self .max_num_requests + 1 , ),
329325 cache_name = "ctx_cached_token_indptr" ,
330326 dtype = torch .int64 ,
327+ capture_graph = capture_graph ,
331328 )
332329 self .host_ctx_cached_token_indptr = torch .zeros_like (
333330 self .ctx_cached_token_indptr ,
334331 device = 'cpu' ,
335332 pin_memory = True ,
336333 )
337- self .ctx_kv_indptr = get_empty (
334+ self .ctx_kv_indptr = self .get_empty (
335+ self .cuda_graph_buffers ,
338336 (self .max_num_requests + 1 , ),
339337 cache_name = "ctx_kv_indptr" ,
340338 dtype = torch .int64 ,
339+ capture_graph = capture_graph ,
341340 )
342341 self .host_ctx_kv_indptr = torch .zeros_like (
343342 self .ctx_kv_indptr ,
344343 device = 'cpu' ,
345344 pin_memory = True ,
346345 )
347346 # New generation buffers for dsa
348- self .gen_cached_token_indptr = get_empty (
347+ self .gen_cached_token_indptr = self .get_empty (
348+ self .cuda_graph_buffers ,
349349 (self .max_num_requests + 1 , ),
350350 cache_name = "gen_cached_token_indptr" ,
351351 dtype = torch .int64 ,
352+ capture_graph = capture_graph ,
352353 )
353354 self .host_gen_cached_token_indptr = torch .zeros_like (
354355 self .gen_cached_token_indptr ,
355356 device = 'cpu' ,
356357 pin_memory = True ,
357358 )
358- self .gen_kv_indptr = get_empty (
359+ self .gen_kv_indptr = self .get_empty (
360+ self .cuda_graph_buffers ,
359361 (self .max_num_requests + 1 , ),
360362 cache_name = "gen_kv_indptr" ,
361363 dtype = torch .int64 ,
364+ capture_graph = capture_graph ,
362365 )
363366 self .host_gen_kv_indptr = torch .zeros_like (
364367 self .gen_kv_indptr ,
@@ -367,52 +370,66 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
367370 )
368371 # Indexer metadata
369372 # Separate slot mappings for non-interleaved layout (flat byte indices)
370- self .slot_mapping_fp8 = get_empty (
373+ self .slot_mapping_fp8 = self .get_empty (
374+ self .cuda_graph_buffers ,
371375 (self .max_num_tokens , ),
372376 cache_name = "slot_mapping_fp8" ,
373377 dtype = torch .int64 ,
378+ capture_graph = capture_graph ,
374379 )
375380 self .host_slot_mapping_fp8 = torch .zeros_like (
376381 self .slot_mapping_fp8 ,
377382 device = 'cpu' ,
378383 pin_memory = True ,
379384 )
380- self .slot_mapping_scale = get_empty (
385+ self .slot_mapping_scale = self .get_empty (
386+ self .cuda_graph_buffers ,
381387 (self .max_num_tokens , ),
382388 cache_name = "slot_mapping_scale" ,
383389 dtype = torch .int64 ,
390+ capture_graph = capture_graph ,
384391 )
385392 self .host_slot_mapping_scale = torch .zeros_like (
386393 self .slot_mapping_scale ,
387394 device = 'cpu' ,
388395 pin_memory = True ,
389396 )
390397 # Per-token request index buffer for topk_indices conversion
391- self .req_idx_per_token = get_empty (
398+ self .req_idx_per_token = self .get_empty (
399+ self .cuda_graph_buffers ,
392400 (self .max_num_tokens , ),
393401 cache_name = "req_idx_per_token" ,
394402 dtype = torch .int32 ,
403+ capture_graph = capture_graph ,
395404 )
396405 # Block table for topk_indices conversion (shared for context and generation)
397- self .block_table = get_empty (
406+ self .block_table = self .get_empty (
407+ self .cuda_graph_buffers ,
398408 (self .max_num_requests , self .kv_cache_manager .max_blocks_per_seq ),
399409 cache_name = "block_table" ,
400410 dtype = torch .int32 ,
411+ capture_graph = capture_graph ,
401412 )
402- self .scheduler_metadata_buffer = get_empty (
413+ self .scheduler_metadata_buffer = self .get_empty (
414+ self .cuda_graph_buffers ,
403415 (self .num_sms + 1 , 2 ),
404416 cache_name = "scheduler_metadata_buffer" ,
405417 dtype = torch .int32 ,
418+ capture_graph = capture_graph ,
406419 )
407- self .cu_seqlen_ks = get_empty (
420+ self .cu_seqlen_ks = self .get_empty (
421+ self .cuda_graph_buffers ,
408422 (self .max_num_tokens , ),
409423 cache_name = "cu_seqlen_ks" ,
410424 dtype = torch .int32 ,
425+ capture_graph = capture_graph ,
411426 )
412- self .cu_seqlen_ke = get_empty (
427+ self .cu_seqlen_ke = self .get_empty (
428+ self .cuda_graph_buffers ,
413429 (self .max_num_tokens , ),
414430 cache_name = "cu_seqlen_ke" ,
415431 dtype = torch .int32 ,
432+ capture_graph = capture_graph ,
416433 )
417434
418435 def prepare (self ):
0 commit comments