Skip to content

fix(shm): Add memory barriers for cross-process shared memory visibility#29819

Closed
kitaekatt wants to merge 2 commits into
vllm-project:mainfrom
kitaekatt:fix-shm-memory-barriers
Closed

fix(shm): Add memory barriers for cross-process shared memory visibility#29819
kitaekatt wants to merge 2 commits into
vllm-project:mainfrom
kitaekatt:fix-shm-memory-barriers

Conversation

@kitaekatt
Copy link
Copy Markdown
Contributor

Summary

Fixes freeze/hang during sustained concurrent batch inference caused by missing memory barriers in the shared memory ring buffer protocol.

Root Cause

The shm_broadcast.py shared memory IPC uses plain byte writes (metadata_buffer[0] = 1) to signal between writer and reader processes. On multi-core systems, these writes stay in CPU store buffers and may not be visible to other processes running on different cores. This causes indefinite spinning where:

  • Writer waits for readers to set read flags (that are already set but not visible)
  • Readers wait for writer to set written flag (that is already set but not visible)

The Fix

Add explicit memory barriers using threading.Lock acquire/release pattern (which provides full memory barrier semantics per POSIX.1-2008) at four critical points:

  1. acquire_write() - before reading flags: Ensures writer sees latest read flags
  2. acquire_write() - after setting written flag: Ensures write is globally visible
  3. acquire_read() - before reading flags: Ensures reader sees latest written flag
  4. acquire_read() - after setting read flag: Ensures read completion is visible to writer

Why threading.Lock?

On POSIX systems, pthread_mutex_lock/unlock provides sequentially consistent memory barrier semantics. The lock acquire/release pattern (~20ns overhead) is the most portable and well-defined way to get memory barriers in Python without requiring platform-specific code.

Test Results

Before fix: Freeze at batch ~41 (~492 concurrent requests)
After fix: Successfully completed 120 batches (1440 requests) without freeze

Test configuration:

  • Model: Qwen2.5-32B-Instruct-AWQ
  • max_num_seqs: 62
  • Concurrent requests per batch: 12
  • Test repeated multiple times with consistent success

Test Plan

  • Stress test with sustained concurrent load (120 batches, 1440 total requests)
  • Verified fix eliminates the freeze that occurred reliably at ~500 requests
  • Multiple test runs confirm reliability

@chatgpt-codex-connector
Copy link
Copy Markdown

Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly identifies and fixes a critical race condition in the shared memory IPC mechanism by introducing memory barriers. The use of threading.Lock for this purpose is a standard and portable approach. The analysis in the pull request description is excellent. I have a couple of suggestions to improve the robustness and documentation of the new memory_fence function.

Comment thread vllm/distributed/device_communicators/shm_broadcast.py Outdated
Comment thread vllm/distributed/device_communicators/shm_broadcast.py Outdated
@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 2, 2025

Thanks @kitaekatt! I'm just a bit surprised that this hasn't been encountered more often if it's so easy to reproduce. max_num_seqs=62 isn't very high concurrency.

Are you pinning cpus / setting some particular numa config?

@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 4, 2025

@kitaekatt ping :)

(I think the changes look very reasonable, just keen to understand the practical cases where this may arise)

@kitaekatt
Copy link
Copy Markdown
Contributor Author

@kitaekatt ping :)

(I think the changes look very reasonable, just keen to understand the practical cases where this may arise)

Hi Nick! Excited to contribute some value here back to vllm.

Usage pattern that triggered this: I'm trying to maximize parallel inference bandwidth for batch workloads, using benchmarking (IFEval, GSM8K, HumanEval, MMLU) as a way to simulate sustained load. Running batch_size=12, concurrency=12 across multiple models - continuous high-throughput requests without pauses between batches. The freeze would occur at a consistent threshold for each configuration, suggesting a deterministic trigger once enough state accumulated. Hardware/config: RTX 5090 (Blackwell sm_120, 32GB), Ubuntu with kernel 6.14. No CPU pinning, no NUMA configuration - completely vanilla setup.

Why others likely haven't hit this: When I added diagnostic instrumentation (just time.monotonic() calls and periodic logger.info() in core.py and multiproc_executor.py), the freeze completely disappeared - processed 3x the normal freeze threshold without any issue. The micro-delays broke the precise timing the race condition requires.

So anyone trying to debug this with profiling tools, extra logging, or py-spy would accidentally mask it. Interactive users also don't sustain high-frequency message passing long enough to hit the threshold. You basically need: (1) sustained batch load without pauses, (2) no observation overhead, and (3) hardware
timing that happens to align poorly.

The memory barrier fix directly addresses the visibility issue rather than relying on accidental timing changes.

@kitaekatt
Copy link
Copy Markdown
Contributor Author

Thank you for the review feedback. I've addressed both suggestions:

  1. Platform-agnostic documentation: Updated the docstring to emphasize that threading.Lock provides cross-platform barrier semantics (POSIX, Windows), rather than focusing on POSIX-specific implementation details.

  2. Context manager pattern: Replaced explicit acquire()/release() with with _memory_fence_lock: pass to ensure lock release even on unexpected exceptions.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Dec 9, 2025

Hi @kitaekatt, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@kitaekatt
Copy link
Copy Markdown
Contributor Author

Thanks for the review! Both suggestions have been addressed:

  1. Platform-agnostic documentation: Updated docstring to explicitly state that threading.Lock provides memory barrier semantics "across all major platforms (POSIX, Windows)"

  2. Context manager pattern: Changed from explicit acquire()/release() to with _memory_fence_lock: pass for exception safety

@kitaekatt
Copy link
Copy Markdown
Contributor Author

Pushed additional commit addressing the pre-commit failures:

  • Fixed line length issue (comment was 91 chars, max is 88)
  • Added DCO sign-off to all commits
  • All pre-commit checks now pass locally

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Dec 9, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kitaekatt.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Dec 9, 2025
@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 10, 2025

@kitaekatt it looks like the commits in the branch got a bit messed up. Did you intend to introduce e3ba546 (the previous livelock fix)? And it looks like there's an unrelated commit showing up too.

I'd like to do a quick benchmark of his but otherwise it looks good to me once cleaned up.

The shared memory ring buffer protocol in shm_broadcast.py uses plain byte
writes to signal between writer and reader processes. On multi-core systems,
these writes may stay in CPU store buffers and not be visible to other
processes running on different cores, causing indefinite spinning/freeze
under sustained concurrent load.

This patch adds explicit memory barriers using threading.Lock acquire/release
(which provides full barrier semantics per POSIX.1-2008) at four critical
points:
- In acquire_write(): before reading flags and after setting written flag
- In acquire_read(): before reading flags and after setting read flag

The memory barrier ensures that:
1. All stores before the barrier are globally visible
2. All loads after the barrier see the latest values

Fixes freeze observed during sustained concurrent batch inference (~500+
requests) where both writer and readers would spin indefinitely waiting
for flags that were updated but not visible across CPU cores.

Signed-off-by: Christina Holland <hey@christinaholland.com>
Signed-off-by: Christina <truffle@gmail.com>
…ager

Signed-off-by: Christina <truffle@gmail.com>
@kitaekatt kitaekatt force-pushed the fix-shm-memory-barriers branch from bbe218d to 0f27680 Compare December 10, 2025 00:42
@kitaekatt
Copy link
Copy Markdown
Contributor Author

@njhill Thanks for catching that! I've cleaned up the branch - it was accidentally polluted with commits from a separate PR (#29813, which I closed in favor of this approach).

The branch now contains only:

  1. The memory barriers fix
  2. Review feedback addressing your earlier comments (platform-agnostic docs, context manager)

The SpinBackoffTimer/livelock fix from PR #29813 has been removed - that was a workaround, not the proper solution.

@mergify mergify Bot removed the needs-rebase label Dec 10, 2025
@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 10, 2025

Thanks @kitaekatt

  1. Review feedback addressing your earlier comments (platform-agnostic docs, context manager)

FWIW those comments were from gemini, not me :)

@kitaekatt kitaekatt closed this Dec 10, 2025
@kitaekatt kitaekatt deleted the fix-shm-memory-barriers branch December 10, 2025 17:42
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Dec 10, 2025

⚠️ The sha of the head commit of this PR conflicts with #30407. Mergify cannot evaluate rules on this PR. ⚠️

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kitaekatt did you close this on purpose? I had just benchmarked it and was about to approve/merge :)

Performance actually looks slightly better with this!

@njhill
Copy link
Copy Markdown
Member

njhill commented Dec 10, 2025

Oh I see you opened another PR, will move to that one.

@kitaekatt
Copy link
Copy Markdown
Contributor Author

I appologize @njhill I accidentally deleted the branch it has been resubmitted as PR #30407

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants