Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 48 additions & 46 deletions python/sglang/srt/managers/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,45 +1627,28 @@ class ShmPointerMMData:
"""

def __init__(self, tensor: torch.Tensor):
self.cpu_tensor = tensor.cpu().contiguous()
self.shape = self.cpu_tensor.shape
self.dtype = self.cpu_tensor.dtype

nbytes = self.cpu_tensor.numel() * self.cpu_tensor.element_size()

self.shm = shared_memory.SharedMemory(create=True, size=nbytes)

if not tensor.is_cpu:
tensor = tensor.cpu()
if not tensor.is_contiguous():
tensor = tensor.contiguous()
self.shape = tensor.shape
self.dtype = tensor.dtype
nbytes = tensor.numel() * tensor.element_size()
shm = shared_memory.SharedMemory(create=True, size=nbytes)
try:
shm_view = np.ndarray((nbytes,), dtype=np.uint8, buffer=self.shm.buf)

shm_view[:] = self.cpu_tensor.view(torch.uint8).numpy().flatten()
finally:
self.shm.close()
dst = torch.frombuffer(shm.buf, dtype=torch.uint8)
dst.copy_(tensor.view(torch.uint8).reshape(-1))
except BaseException:
shm.close()
shm.unlink()
raise
self.shm_name = shm.name
shm.close()
self._shm_handle = None

def __getstate__(self):
if not hasattr(self, "shm") or self.shm is None:
tensor = getattr(self, "cpu_tensor", None)
if tensor is None:
tensor = getattr(self, "tensor", None)
if tensor is None:
raise RuntimeError(
"ShmPointerMMData cannot recreate shared memory without tensor"
)

cpu_tensor = tensor.cpu().contiguous()
self.shape = cpu_tensor.shape
self.dtype = cpu_tensor.dtype

nbytes = cpu_tensor.numel() * cpu_tensor.element_size()
self.shm = shared_memory.SharedMemory(create=True, size=nbytes)
try:
shm_view = np.ndarray((nbytes,), dtype=np.uint8, buffer=self.shm.buf)
shm_view[:] = cpu_tensor.view(torch.uint8).numpy().flatten()
finally:
self.shm.close()

return {
"shm_name": self.shm.name,
"shm_name": self.shm_name,
"shape": self.shape,
"dtype": self.dtype,
}
Expand All @@ -1675,17 +1658,29 @@ def __setstate__(self, state):
self.shape = state["shape"]
self.dtype = state["dtype"]
self.shm = None
self._shm_handle = shared_memory.SharedMemory(name=self.shm_name)
# Zero-copy view into shared memory (no clone, no unlink)
self.tensor = torch.frombuffer(self._shm_handle.buf, dtype=self.dtype).reshape(
self.shape
)

shm_handle = shared_memory.SharedMemory(name=self.shm_name)
try:
self.tensor = (
torch.frombuffer(shm_handle.buf, dtype=self.dtype)
.reshape(self.shape)
.clone()
)
finally:
shm_handle.close()
shm_handle.unlink()
def materialize(self) -> torch.Tensor:
"""Clone tensor from shm to owned memory, then release shm handle."""
tensor = self.tensor.clone()
if self._shm_handle is not None:
self._shm_handle.close()
try:
self._shm_handle.unlink()
except FileNotFoundError:
pass # Another rank already unlinked
self._shm_handle = None
return tensor

def __del__(self):
# Only close; never unlink. Unlinking is materialize()'s job.
if getattr(self, "_shm_handle", None) is not None:
self._shm_handle.close()
self._shm_handle = None


def _get_is_default_transport():
Expand Down Expand Up @@ -1723,12 +1718,19 @@ def wrap_shm_features(obj):
def unwrap_shm_features(obj):
"""
Restore ShmPointerMMData wrappers back into standard torch.Tensors.
Handles both single requests and batch requests.
"""
if _get_is_default_transport() or get_global_server_args().skip_tokenizer_init:
return obj
# Handle batch requests
if hasattr(obj, "batch"):
for sub_obj in obj.batch:
unwrap_shm_features(sub_obj)
return obj
# Handle single requests
if hasattr(obj, "mm_inputs") and obj.mm_inputs:
mm_items = obj.mm_inputs.get("mm_items", [])
for item in mm_items:
if isinstance(item.feature, ShmPointerMMData):
item.feature = item.feature.tensor
item.feature = item.feature.materialize()
return obj
8 changes: 7 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,6 @@ def recv_requests(
if self.recv_limit_reached(len(recv_reqs)):
break
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
recv_req = unwrap_shm_features(recv_req)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
Expand Down Expand Up @@ -1511,6 +1510,13 @@ def recv_requests(
prepare_abort(req, error_msg, status_code=status_code)
self.stream_output([req], req.return_logprob)

# Unwrap shared memory features AFTER all broadcasts complete,
# so that ShmPointerMMData metadata (not full tensor data) is what
# gets serialized during broadcast_pyobj.
if recv_reqs:
for req in recv_reqs:
unwrap_shm_features(req)

return recv_reqs

def _split_work_and_control_reqs(self, recv_reqs: List):
Expand Down
Loading