diff --git a/python/sglang/srt/layers/utils/cp_utils.py b/python/sglang/srt/layers/utils/cp_utils.py index b6981931fd61..f3ee07809d66 100644 --- a/python/sglang/srt/layers/utils/cp_utils.py +++ b/python/sglang/srt/layers/utils/cp_utils.py @@ -179,12 +179,15 @@ def cp_all_gather_reorganized_into_tensor_kv_cache( input_tensor = F.pad(input_tensor, padding, mode="constant", value=0) # Create output tensor with proper shape for all dimensions - input_tensor_full = torch.empty( - max_len * cp_size, - *input_tensor.shape[1:], - device=input_tensor.device, - dtype=input_tensor.dtype, - ) + with use_symmetric_memory( + get_attention_cp_group(), disabled=not is_allocation_symmetric() + ): + input_tensor_full = torch.empty( + max_len * cp_size, + *input_tensor.shape[1:], + device=input_tensor.device, + dtype=input_tensor.dtype, + ) get_attention_cp_group().cp_all_gather_into_tensor_async( input_tensor_full, input_tensor, stream