Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import pickle
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -43,6 +44,33 @@
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")


# Memory fence for cross-process shared memory visibility.
# Required for correct producer-consumer synchronization when using
# shared memory without locks.
_memory_fence_lock = threading.Lock()


def memory_fence():
"""
Full memory barrier for shared memory synchronization.

Ensures all prior memory writes are visible to other processes before
any subsequent reads. This is critical for lock-free producer-consumer
patterns using shared memory.

Implementation acquires and immediately releases a lock. Python's
threading.Lock provides sequentially consistent memory barrier semantics
across all major platforms (POSIX, Windows). This is a lightweight
operation (~20ns) that guarantees:
- All stores before the barrier are visible to other threads/processes
- All loads after the barrier see the latest values
"""
# Lock acquire/release provides full memory barrier semantics.
# Using context manager ensures lock release even on exceptions.
with _memory_fence_lock:
pass


def to_bytes_big(value: int, size: int) -> bytes:
return value.to_bytes(size, byteorder="big")

Expand Down Expand Up @@ -414,6 +442,10 @@ def acquire_write(self, timeout: float | None = None):
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest read flags from readers.
# Without this, we may read stale flags from our CPU cache and
# spin indefinitely even though readers have completed.
memory_fence()
read_count = sum(metadata_buffer[1:])
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
Expand Down Expand Up @@ -458,6 +490,10 @@ def acquire_write(self, timeout: float | None = None):
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
# Memory fence ensures the write is visible to readers on other cores
# before we proceed. Without this, readers may spin indefinitely
# waiting for a write that's stuck in our CPU's store buffer.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
break

Expand All @@ -473,6 +509,10 @@ def acquire_read(
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
# Memory fence ensures we see the latest writes from the writer.
# Without this, we may read stale flags from our CPU cache
# and spin indefinitely even though writer has updated them.
memory_fence()
read_flag = metadata_buffer[self.local_reader_rank + 1]
written_flag = metadata_buffer[0]
if not written_flag or read_flag:
Expand Down Expand Up @@ -513,6 +553,10 @@ def acquire_read(
# caller has read from the buffer
# set the read flag
metadata_buffer[self.local_reader_rank + 1] = 1
# Memory fence ensures the read flag is visible to the writer.
# Without this, writer may not see our read completion and
# could wait indefinitely for all readers to finish.
memory_fence()
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks

self._read_spin_timer.record_activity()
Expand Down