From e290d93289fc7448aa51f77ba2c66352c6b2db3e Mon Sep 17 00:00:00 2001 From: Saatwik Nagpal Date: Wed, 25 Mar 2026 03:45:04 +0000 Subject: [PATCH] Optimize CUDA IPC multimodal transfer with pool handle cache --- python/sglang/srt/environ.py | 1 + .../multimodal/processors/base_processor.py | 19 +- .../srt/utils/cuda_ipc_transport_utils.py | 225 ++++++++++++++---- 3 files changed, 195 insertions(+), 50 deletions(-) diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 4dcf0613bd91..2ff209fc77bf 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -450,6 +450,7 @@ class Envs: # VLM Item CUDA IPC Transport SGLANG_USE_CUDA_IPC_TRANSPORT = EnvBool(False) + SGLANG_USE_IPC_POOL_HANDLE_CACHE = EnvBool(False) SGLANG_MM_FEATURE_CACHE_MB = EnvInt(4 * 1024) SGLANG_MM_ITEM_MEM_POOL_RECYCLE_INTERVAL_SEC = EnvFloat(0.05) diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 0121c97553cd..b3f7bffb4ffa 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -40,6 +40,7 @@ _is_xpu = is_xpu() SGL_USE_CUDA_IPC = envs.SGLANG_USE_CUDA_IPC_TRANSPORT.get() +_IPC_POOL_HANDLE_CACHE = envs.SGLANG_USE_IPC_POOL_HANDLE_CACHE.get() @dataclasses.dataclass @@ -1134,7 +1135,7 @@ def process_and_combine_mm_data( # post-process for item in all_collected_items: if isinstance(item.feature, torch.Tensor) and item.feature.is_cuda: - sync_flag, available_slice = ( + sync_flag, available_slice, byte_offset = ( self.cudaipc_mmfeature_pool.return_a_slice_tensor_with_flag( item.feature ) @@ -1147,6 +1148,13 @@ def process_and_combine_mm_data( data=available_slice, info_data=item.feature, sync_buffer_meta=sync_flag, + pool_ipc_handle=( + self.cudaipc_mmfeature_pool._pool_ipc_handle + if _IPC_POOL_HANDLE_CACHE + else None + ), + pool_byte_offset=byte_offset, + pool_device_index=self.cudaipc_mmfeature_pool._pool_device_index, ) elif not self.server_args.keep_mm_feature_on_device: item.feature = item.feature.cpu() @@ -1155,7 +1163,7 @@ def process_and_combine_mm_data( and item.precomputed_embeddings.is_cuda ): - sync_flag, available_slice = ( + sync_flag, available_slice, byte_offset = ( self.cudaipc_mmfeature_pool.return_a_slice_tensor_with_flag( item.precomputed_embeddings ) @@ -1169,6 +1177,13 @@ def process_and_combine_mm_data( data=available_slice, info_data=item.precomputed_embeddings, sync_buffer_meta=sync_flag, + pool_ipc_handle=( + self.cudaipc_mmfeature_pool._pool_ipc_handle + if _IPC_POOL_HANDLE_CACHE + else None + ), + pool_byte_offset=byte_offset, + pool_device_index=self.cudaipc_mmfeature_pool._pool_device_index, ) elif not self.server_args.keep_mm_feature_on_device: item.precomputed_embeddings = item.precomputed_embeddings.cpu() diff --git a/python/sglang/srt/utils/cuda_ipc_transport_utils.py b/python/sglang/srt/utils/cuda_ipc_transport_utils.py index 6d76242aef32..9c4abf014c5e 100644 --- a/python/sglang/srt/utils/cuda_ipc_transport_utils.py +++ b/python/sglang/srt/utils/cuda_ipc_transport_utils.py @@ -3,7 +3,7 @@ import threading import time from multiprocessing import shared_memory -from typing import Tuple +from typing import Any, Tuple import numpy as np import torch @@ -22,6 +22,49 @@ SHM_LOCK_FILE = "/tmp/shm_wr_lock.lock" +# Cache for pool-level IPC handles on the consumer side. +# Key: the pool CUDA IPC handle tuple. Value: opened UntypedStorage. +_pool_storage_cache: dict = {} +_pool_cache_lock = threading.Lock() + + +def _normalize_pool_cache_key(pool_handle, pool_device_index: int) -> tuple[Any, ...]: + normalized_handle = ( + pool_handle if isinstance(pool_handle, tuple) else tuple(pool_handle) + ) + return (pool_device_index, normalized_handle) + + +def _open_pooled_storage_uncached(pool_handle): + return torch.UntypedStorage._new_shared_cuda(*pool_handle) + + +def _pool_handle_cache_get_or_open(cache_key, pool_handle): + storage = _pool_storage_cache.get(cache_key) + if storage is None: + with _pool_cache_lock: + storage = _pool_storage_cache.get(cache_key) + if storage is None: + storage = _open_pooled_storage_uncached(pool_handle) + _pool_storage_cache[cache_key] = storage + return storage + + +def _pool_handle_cache_set(cache_key, storage): + with _pool_cache_lock: + _pool_storage_cache[cache_key] = storage + + +def _pool_handle_cache_invalidate(cache_key): + with _pool_cache_lock: + _pool_storage_cache.pop(cache_key, None) + + +def _pool_handle_cache_clear(): + with _pool_cache_lock: + _pool_storage_cache.clear() + + class ShmSyncBuffer: def __init__(self, byte_size: int = 4): self.buffer = shared_memory.SharedMemory(create=True, size=byte_size) @@ -80,6 +123,9 @@ def __init__(self, memory_size, recycle_interval): self.memory_pool = torch.empty( memory_size, dtype=torch.int8, device="cuda" ).contiguous() + storage = self.memory_pool.untyped_storage() + self._pool_ipc_handle = storage._share_cuda_() + self._pool_device_index = self.memory_pool.device.index self.sync_flag_list = [] @@ -181,8 +227,9 @@ def return_a_slice_tensor_with_flag(self, src_tensor: torch.Tensor): return ( available_chunk.sync_flag.meta_data, self.memory_pool[available_chunk.start : available_chunk.end], + available_chunk.start, ) - return None, None + return None, None, None def recycle_chunks(self): @@ -229,6 +276,9 @@ def __init__( data: torch.Tensor, info_data: torch.Tensor, sync_buffer_meta, + pool_ipc_handle=None, + pool_byte_offset: int = 0, + pool_device_index: int = 0, ): if (not isinstance(data, torch.Tensor)) or ( @@ -238,7 +288,24 @@ def __init__( f"Input 'data' must be a torch.Tensor, but got {type(data)}" ) - self.proxy_state = self.get_proxy_state(data, info_data) + if pool_ipc_handle is not None: + self.proxy_state = { + "ipc_extra": { + "pool_handle": pool_ipc_handle, + "pool_byte_offset": pool_byte_offset, + "pool_device_index": pool_device_index, + "shape": data.shape, + "dtype": data.dtype, + "stride": data.stride(), + "storage_offset": 0, + "nbytes": data.numel() * data.element_size(), + "recons_shape": info_data.shape, + "recons_dtype": info_data.dtype, + }, + "tensor_data": None, + } + else: + self.proxy_state = self.get_proxy_state(data, info_data) self.reconstruct_tensor = None self.sync_data_meta = sync_buffer_meta self.sync_buffer = None @@ -283,6 +350,62 @@ def get_proxy_state(self, data, info_data): return state + def _reconstruct_from_ipc_extra(self, ipc_extra, *, use_cache: bool): + shape = ipc_extra["shape"] + dtype = ipc_extra["dtype"] + stride = ipc_extra["stride"] + target_device = torch.device(f"cuda:{ipc_extra['pool_device_index']}") + cache_key = _normalize_pool_cache_key( + ipc_extra["pool_handle"], ipc_extra["pool_device_index"] + ) + + with torch.cuda.device(target_device): + if use_cache: + storage = _pool_handle_cache_get_or_open( + cache_key, ipc_extra["pool_handle"] + ) + storage_to_cache = None + else: + storage = _open_pooled_storage_uncached(ipc_extra["pool_handle"]) + storage_to_cache = storage + slice_storage = storage[ + ipc_extra["pool_byte_offset"] : ipc_extra["pool_byte_offset"] + + ipc_extra["nbytes"] + ] + slice_tensor = torch.empty(0, dtype=dtype, device=target_device).set_( + slice_storage, + storage_offset=ipc_extra["storage_offset"], + size=shape, + stride=stride, + ) + + return slice_tensor, target_device, cache_key, storage_to_cache + + def _copy_slice_tensor_to_target( + self, + slice_tensor: torch.Tensor, + rebuild_device: torch.device, + recons_shape, + recons_dtype, + ): + with torch.cuda.device(rebuild_device): + reconstructed_tensor = torch.empty( + recons_shape, dtype=recons_dtype, device=rebuild_device + ).contiguous() + reconstructed_tensor.view(torch.int8).view(-1).copy_(slice_tensor) + + open(SHM_LOCK_FILE, "a").close() + # write the shm_sync_buffer with a file lock + with open(SHM_LOCK_FILE, "w+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + sync_flag = self.get_sync_flag + sync_flag += 1 + fcntl.flock(f, fcntl.LOCK_UN) + + self.close_shm() + + return reconstructed_tensor + def reconstruct_on_target_device(self, rebuild_device_idx): rebuild_device = torch.device(f"cuda:{rebuild_device_idx}") if ( @@ -293,52 +416,58 @@ def reconstruct_on_target_device(self, rebuild_device_idx): if self.proxy_state["ipc_extra"]: ipc_extra = self.proxy_state["ipc_extra"] - ( - handle, - shape, - dtype, - stride, - source_device_index, - s_offset, - recons_shape, - recons_dtype, - ) = ( - ipc_extra["handle"], - ipc_extra["shape"], - ipc_extra["dtype"], - ipc_extra["stride"], - ipc_extra["device_index"], - ipc_extra["storage_offset"], - ipc_extra["recons_shape"], - ipc_extra["recons_dtype"], + recons_shape = ipc_extra["recons_shape"] + recons_dtype = ipc_extra["recons_dtype"] + + if "pool_handle" in ipc_extra: + try: + ( + slice_tensor, + _target_device, + cache_key, + storage_to_cache, + ) = self._reconstruct_from_ipc_extra(ipc_extra, use_cache=True) + except Exception as e: + cache_key = _normalize_pool_cache_key( + ipc_extra["pool_handle"], ipc_extra["pool_device_index"] + ) + logger.info( + "Failed to deserialize from cached pooled CUDA IPC handle (%s). " + "Invalidating cache entry and retrying uncached.", + e, + ) + _pool_handle_cache_invalidate(cache_key) + ( + slice_tensor, + _target_device, + _cache_key, + storage_to_cache, + ) = self._reconstruct_from_ipc_extra(ipc_extra, use_cache=False) + if storage_to_cache is not None: + _pool_handle_cache_set(cache_key, storage_to_cache) + else: + # Non-pooled path: open handle directly (original behavior) + try: + storage = torch.UntypedStorage._new_shared_cuda( + *ipc_extra["handle"] + ) + target_device = torch.device(f"cuda:{ipc_extra['device_index']}") + with torch.cuda.device(target_device): + slice_tensor = torch.empty( + 0, dtype=ipc_extra["dtype"], device=target_device + ).set_( + storage, + storage_offset=ipc_extra["storage_offset"], + size=ipc_extra["shape"], + stride=ipc_extra["stride"], + ) + except Exception as e: + logger.info("Failed to deserialize from CUDA IPC handle (%s).", e) + raise + + reconstructed_tensor = self._copy_slice_tensor_to_target( + slice_tensor, rebuild_device, recons_shape, recons_dtype ) - - try: - target_device = torch.device(f"cuda:{source_device_index}") - with torch.cuda.device(target_device): - storage = torch.UntypedStorage._new_shared_cuda(*handle) - slice_tensor = torch.empty( - 0, dtype=dtype, device=target_device - ).set_(storage, storage_offset=s_offset, size=shape, stride=stride) - - reconstructed_tensor = torch.empty( - recons_shape, dtype=recons_dtype, device=rebuild_device - ).contiguous() - reconstructed_tensor.view(torch.int8).view(-1).copy_(slice_tensor) - - open(SHM_LOCK_FILE, "a").close() - # write the shm_sync_buffer with a file lock - with open(SHM_LOCK_FILE, "w+") as f: - fcntl.flock(f, fcntl.LOCK_EX) - sync_flag = self.get_sync_flag - sync_flag += 1 - fcntl.flock(f, fcntl.LOCK_UN) - - self.close_shm() - - except Exception as e: - logger.info(f"Error: Failed to deserialize from CUDA IPC handle ({e}).") - raise e elif isinstance(self.proxy_state["tensor_data"], torch.Tensor): reconstructed_tensor = self.proxy_state["tensor_data"].to( rebuild_device, non_blocking=True