diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 9c8bf3ad165c..803ecb023bfd 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -49,7 +49,7 @@ # Memory fence for cross-process shared memory visibility. # Required for correct producer-consumer synchronization when using # shared memory without locks. -_memory_fence_lock = threading.Lock() +_memory_fence_lock = threading.Lock() # kept for backward compat, no longer used def memory_fence(): @@ -60,17 +60,15 @@ def memory_fence(): any subsequent reads. This is critical for lock-free producer-consumer patterns using shared memory. - Implementation acquires and immediately releases a lock. Python's - threading.Lock provides sequentially consistent memory barrier semantics - across all major platforms (POSIX, Windows). This is a lightweight - operation (~20ns) that guarantees: - - All stores before the barrier are visible to other threads/processes - - All loads after the barrier see the latest values + Implementation calls sched_yield (a kernel scheduling boundary acts + as a sequentially consistent memory barrier on all major + architectures). Roughly the same overhead as the previous + threading.Lock acquire/release (~20ns), but avoids the + `_thread.lock.__enter__` C method descriptor that can be corrupted + when other vLLM components JIT-load C extensions at runtime + (FlashInfer / Triton autotune, torch.compile). See vllm-project/vllm#35104. """ - # Lock acquire/release provides full memory barrier semantics. - # Using context manager ensures lock release even on exceptions. - with _memory_fence_lock: - pass + sched_yield() def to_bytes_big(value: int, size: int) -> bytes: