From 62a03a5d65cd2401dc8f986a9033e77864905a78 Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Thu, 26 Mar 2026 08:49:55 +0000 Subject: [PATCH 1/5] upd --- python/sglang/srt/managers/mm_utils.py | 98 +++++++++++++------------ python/sglang/srt/managers/scheduler.py | 8 +- 2 files changed, 59 insertions(+), 47 deletions(-) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 7cc37176a40c..7100153976bc 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -1624,48 +1624,33 @@ class ShmPointerMMData: """ Wraps a tensor to be sent via a shared memory handle. This acts as a "pointer" to the tensor data across process boundaries. + + Designed to be pickle-serialized multiple times (e.g., ZMQ send then + dist.broadcast) without re-creating the shared memory segment. + Only the shm name/shape/dtype metadata is serialized; actual tensor + data lives in POSIX shared memory. """ def __init__(self, tensor: torch.Tensor): - self.cpu_tensor = tensor.cpu().contiguous() - self.shape = self.cpu_tensor.shape - self.dtype = self.cpu_tensor.dtype - - nbytes = self.cpu_tensor.numel() * self.cpu_tensor.element_size() - + cpu_tensor = ( + tensor + if (tensor.is_cpu and tensor.is_contiguous()) + else tensor.cpu().contiguous() + ) + self.shape = cpu_tensor.shape + self.dtype = cpu_tensor.dtype + nbytes = cpu_tensor.numel() * cpu_tensor.element_size() self.shm = shared_memory.SharedMemory(create=True, size=nbytes) - - try: - shm_view = np.ndarray((nbytes,), dtype=np.uint8, buffer=self.shm.buf) - - shm_view[:] = self.cpu_tensor.view(torch.uint8).numpy().flatten() - finally: - self.shm.close() + self.shm_name = self.shm.name + # Direct copy: torch → shm buffer (no numpy intermediate) + dst = torch.frombuffer(self.shm.buf, dtype=torch.uint8) + dst.copy_(cpu_tensor.view(torch.uint8).reshape(-1)) + self.shm.close() + self._shm_handle = None def __getstate__(self): - if not hasattr(self, "shm") or self.shm is None: - tensor = getattr(self, "cpu_tensor", None) - if tensor is None: - tensor = getattr(self, "tensor", None) - if tensor is None: - raise RuntimeError( - "ShmPointerMMData cannot recreate shared memory without tensor" - ) - - cpu_tensor = tensor.cpu().contiguous() - self.shape = cpu_tensor.shape - self.dtype = cpu_tensor.dtype - - nbytes = cpu_tensor.numel() * cpu_tensor.element_size() - self.shm = shared_memory.SharedMemory(create=True, size=nbytes) - try: - shm_view = np.ndarray((nbytes,), dtype=np.uint8, buffer=self.shm.buf) - shm_view[:] = cpu_tensor.view(torch.uint8).numpy().flatten() - finally: - self.shm.close() - return { - "shm_name": self.shm.name, + "shm_name": self.shm_name, "shape": self.shape, "dtype": self.dtype, } @@ -1675,17 +1660,31 @@ def __setstate__(self, state): self.shape = state["shape"] self.dtype = state["dtype"] self.shm = None + self._shm_handle = shared_memory.SharedMemory(name=self.shm_name) + # Zero-copy view into shared memory (no clone, no unlink) + self.tensor = torch.frombuffer( + self._shm_handle.buf, dtype=self.dtype + ).reshape(self.shape) + + def materialize(self) -> torch.Tensor: + """Clone tensor from shm to owned memory, then release shm handle.""" + tensor = self.tensor.clone() + if self._shm_handle is not None: + self._shm_handle.close() + try: + self._shm_handle.unlink() + except FileNotFoundError: + pass # Another rank already unlinked + self._shm_handle = None + return tensor - shm_handle = shared_memory.SharedMemory(name=self.shm_name) - try: - self.tensor = ( - torch.frombuffer(shm_handle.buf, dtype=self.dtype) - .reshape(self.shape) - .clone() - ) - finally: - shm_handle.close() - shm_handle.unlink() + def __del__(self): + if getattr(self, "_shm_handle", None) is not None: + self._shm_handle.close() + try: + self._shm_handle.unlink() + except FileNotFoundError: + pass def _get_is_default_transport(): @@ -1723,12 +1722,19 @@ def wrap_shm_features(obj): def unwrap_shm_features(obj): """ Restore ShmPointerMMData wrappers back into standard torch.Tensors. + Handles both single requests and batch requests. """ if _get_is_default_transport() or get_global_server_args().skip_tokenizer_init: return obj + # Handle batch requests + if hasattr(obj, "batch"): + for sub_obj in obj.batch: + unwrap_shm_features(sub_obj) + return obj + # Handle single requests if hasattr(obj, "mm_inputs") and obj.mm_inputs: mm_items = obj.mm_inputs.get("mm_items", []) for item in mm_items: if isinstance(item.feature, ShmPointerMMData): - item.feature = item.feature.tensor + item.feature = item.feature.materialize() return obj diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 3e6924807ce1..6fe0c6a2aa71 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1424,7 +1424,6 @@ def recv_requests( if self.recv_limit_reached(len(recv_reqs)): break recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) - recv_req = unwrap_shm_features(recv_req) except zmq.ZMQError: break recv_reqs.append(recv_req) @@ -1511,6 +1510,13 @@ def recv_requests( prepare_abort(req, error_msg, status_code=status_code) self.stream_output([req], req.return_logprob) + # Unwrap shared memory features AFTER all broadcasts complete, + # so that ShmPointerMMData metadata (not full tensor data) is what + # gets serialized during broadcast_pyobj. + if recv_reqs: + for req in recv_reqs: + unwrap_shm_features(req) + return recv_reqs def _split_work_and_control_reqs(self, recv_reqs: List): From 0f3a2cf9aa0d832ae03c0db88daea934571cd169 Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Thu, 26 Mar 2026 08:51:52 +0000 Subject: [PATCH 2/5] upd --- python/sglang/srt/managers/mm_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 7100153976bc..3cfc7c419db4 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -1662,9 +1662,9 @@ def __setstate__(self, state): self.shm = None self._shm_handle = shared_memory.SharedMemory(name=self.shm_name) # Zero-copy view into shared memory (no clone, no unlink) - self.tensor = torch.frombuffer( - self._shm_handle.buf, dtype=self.dtype - ).reshape(self.shape) + self.tensor = torch.frombuffer(self._shm_handle.buf, dtype=self.dtype).reshape( + self.shape + ) def materialize(self) -> torch.Tensor: """Clone tensor from shm to owned memory, then release shm handle.""" From 2635b03c260c24df43e678855b5c4805fee72800 Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Thu, 26 Mar 2026 08:54:04 +0000 Subject: [PATCH 3/5] upd --- python/sglang/srt/managers/mm_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 3cfc7c419db4..f1616f632ffe 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -1624,11 +1624,6 @@ class ShmPointerMMData: """ Wraps a tensor to be sent via a shared memory handle. This acts as a "pointer" to the tensor data across process boundaries. - - Designed to be pickle-serialized multiple times (e.g., ZMQ send then - dist.broadcast) without re-creating the shared memory segment. - Only the shm name/shape/dtype metadata is serialized; actual tensor - data lives in POSIX shared memory. """ def __init__(self, tensor: torch.Tensor): From 60a8007fea294806d41b6e4eb9d03742ceedca7f Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Thu, 26 Mar 2026 09:04:04 +0000 Subject: [PATCH 4/5] fix --- python/sglang/srt/managers/mm_utils.py | 31 ++++++++++++++------------ 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index f1616f632ffe..25a0528fbf1b 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -1627,20 +1627,23 @@ class ShmPointerMMData: """ def __init__(self, tensor: torch.Tensor): - cpu_tensor = ( - tensor - if (tensor.is_cpu and tensor.is_contiguous()) - else tensor.cpu().contiguous() - ) - self.shape = cpu_tensor.shape - self.dtype = cpu_tensor.dtype - nbytes = cpu_tensor.numel() * cpu_tensor.element_size() - self.shm = shared_memory.SharedMemory(create=True, size=nbytes) - self.shm_name = self.shm.name - # Direct copy: torch → shm buffer (no numpy intermediate) - dst = torch.frombuffer(self.shm.buf, dtype=torch.uint8) - dst.copy_(cpu_tensor.view(torch.uint8).reshape(-1)) - self.shm.close() + if not tensor.is_cpu: + tensor = tensor.cpu() + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + self.shape = tensor.shape + self.dtype = tensor.dtype + nbytes = tensor.numel() * tensor.element_size() + shm = shared_memory.SharedMemory(create=True, size=nbytes) + try: + dst = torch.frombuffer(shm.buf, dtype=torch.uint8) + dst.copy_(tensor.view(torch.uint8).reshape(-1)) + except BaseException: + shm.close() + shm.unlink() + raise + self.shm_name = shm.name + shm.close() self._shm_handle = None def __getstate__(self): From 51544946eecb0d4a438392554f53454863de9a2e Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Fri, 27 Mar 2026 05:59:16 +0000 Subject: [PATCH 5/5] FIX --- python/sglang/srt/managers/mm_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index 25a0528fbf1b..1549b6112504 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -1677,12 +1677,10 @@ def materialize(self) -> torch.Tensor: return tensor def __del__(self): + # Only close; never unlink. Unlinking is materialize()'s job. if getattr(self, "_shm_handle", None) is not None: self._shm_handle.close() - try: - self._shm_handle.unlink() - except FileNotFoundError: - pass + self._shm_handle = None def _get_is_default_transport():